Issue
I am a beginner in PyTorch. I want to train a network using NYU dataset, but I am getting an error.
The error happens while I use the Dataloader to load my local dataset, and I want to print the data to demonstrate the code is right:
test=Mydataset(data_root,transforms,'image_train')
test2=DataLoader(test,batch_size=4,num_workers=0,shuffle=False)
for idx,data in enumerate(test2):
print(idx)
Here's the rest of the code with the Mydataset
definition:
from __future__ import division,absolute_import,print_function
from PIL import Image
from torch.utils.data import DataLoader,Dataset
from torchvision.transforms import transforms
data_root='D:/AuxiliaryDocuments/NYU/'
transforms=transforms.Compose([transforms.ToPILImage(),
transforms.Resize(224,101),
transforms.ToTensor()])
filename_txt={'image_train':'image_train.txt','image_test':'image_test.txt',
'depth_train':'depth_train.txt','depth_test':'depth_test.txt'}
class Mydataset(Dataset):
def __init__(self,data_root,transformation,data_type):
self.transform=transformation
self.image_path_txt=filename_txt[data_type]
self.sample_list=list()
f=open(data_root+'/'+data_type+'/'+self.image_path_txt)
lines=f.readlines()
for line in lines:
line=line.strip()
line=line.replace(';','')
self.sample_list.append(line)
f.close()
def __getitem__(self, index):
item=self.sample_list[index]
img=Image.open(item)
if self.transform is not None:
img=self.transform(img)
idx=index
return idx,img
def __len__(self):
return len(self.sample_list)
Solution
The error in the title is different from the one in the image (which you should have posted as text, by the way). Assuming the one from the image is correct, your problem is the following:
Your transforms
begins with a transforms.ToPILImage()
, but the image is already read as a PIL image by the dataloader. If you remove that transformation, the code should run just fine.
# [...]
transforms = transforms.Compose([
transforms.ToPILImage(), # <<< remove this
transforms.Resize(224, 101),
transforms.ToTensor()
])
# [...]
class Mydataset(Dataset):
# [...]
def __getitem__(self, index):
item = self.sample_list[index]
img = Image.open(item) # <<< this image is already a PIL image
if self.transform is not None:
img = self.transform(img)
idx = index
return idx, img
# [...]
Answered By - Berriel
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.