Issue
I'm trying to create dataloaders using only a specific digit from PyTorch Mnist dataset
I already tried to create my own Sampler but it doesn't work and I'm not sure I'm using correctly the mask.
class YourSampler(torch.utils.data.sampler.Sampler):
def __init__(self, mask):
self.mask = mask
def __iter__(self):
return (self.indices[i] for i in torch.nonzero(self.mask))
def __len__(self):
return len(self.mask)
mnist = datasets.MNIST(root=dataroot, train=True, download=True, transform = transform)
mask = [True if mnist[i][1] == 5 else False for i in range(len(mnist))]
mask = torch.tensor(mask)
sampler = YourSampler(mask)
trainloader = torch.utils.data.DataLoader(mnist, batch_size=4, sampler = sampler, shuffle=False, num_workers=2)
So far I had many different types of errors. For this implementation, it's "Stop Iteration". I feel like this is very easy/stupid but I can't find a simple way to do it. Thank you for your help!
Solution
Thank you for your help. After a while I figured out a solution (but might not be the best at all):
class YourSampler(torch.utils.data.sampler.Sampler):
def __init__(self, mask, data_source):
self.mask = mask
self.data_source = data_source
def __iter__(self):
return iter([i.item() for i in torch.nonzero(mask)])
def __len__(self):
return len(self.data_source)
mnist = datasets.MNIST(root=dataroot, train=True, download=True, transform = transform)
mask = [1 if mnist[i][1] == 5 else 0 for i in range(len(mnist))]
mask = torch.tensor(mask)
sampler = YourSampler(mask, mnist)
trainloader = torch.utils.data.DataLoader(mnist, batch_size=batch_size,sampler = sampler, shuffle=False, num_workers=workers)
Answered By - Aymeric .Bass
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.