Issue
I am trying to perform an image classification task using mini-imagenet dataset. The data that I want to use, contains a few bad data points(I am not sure why). I would like to load this data and train my model on it. In the process, I want to skip the bad data points completely. How do I do this? The data loader I am using is as follows:
class MiniImageNet(Dataset):
def __init__(self, root, train=True,
transform=None,
index_path=None, index=None, base_sess=None):
if train:
setname = 'train'
else:
setname = 'test'
self.root = os.path.expanduser(root)
self.transform = transform
self.train = train # training set or test set
self.IMAGE_PATH = os.path.join(root, 'miniimagenet/images')
self.SPLIT_PATH = os.path.join(root, 'miniimagenet/split')
csv_path = osp.join(self.SPLIT_PATH, setname + '.csv')
lines = [x.strip() for x in open(csv_path, 'r').readlines()][1:]
self.data = []
self.targets = []
self.data2label = {}
lb = -1
self.wnids = []
for l in lines:
name, wnid = l.split(',')
path = osp.join(self.IMAGE_PATH, name)
if wnid not in self.wnids:
self.wnids.append(wnid)
lb += 1
self.data.append(path)
self.targets.append(lb)
self.data2label[path] = lb
self.y = self.targets
if train:
image_size = 84
self.transform = transforms.Compose([
transforms.RandomResizedCrop(image_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
else:
image_size = 84
self.transform = transforms.Compose([
transforms.Resize([image_size, image_size]),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
def __len__(self):
return len(self.data)
def __getitem__(self, i):
path, targets = self.data[i], self.targets[i]
image = self.transform(Image.open(path).convert('RGB'))
return image, targets
I tried to use a try-except sequence, but in that case, instead of skipping, the dataloader is returning None, causing an error. How do I completely skip a datapoint in a dataloader?
Solution
Try removing the bad data at the end of the __init__
function.
for i in range(len(self.data) - 1, -1, -1):
if is_bad_data(self.data[i], self.targets[i]):
del self.data[i]
del self.targets[i]
Answered By - keanu
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.