Issue
I am trying to build a model with pytorch, and I want to use a customized dataset. So, I have a dataset.py
which defines a class, MyData
, which is a subclass of torch.utils.data.Dataset
. Here's the file.
# dataset.py
import torch
from tqdm import tqdm
import numpy as np
import re
from torch.utils.data import Dataset
from pathlib import Path
class MyDataset(Dataset):
def __init__(self, path, size=10000):
if not Path(path).exists():
raise FileNotFoundError
self.data = []
self.load_data(path, size)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
def load_data(self, path, size):
# Loading data from csv files and some preparation
# Each sample is in the format of (int_tag1, int_tag2, feature_dictionary),
# then the sample is appended to self.data
pass
Then I tried to test this dataset using a DataLoader
in the test file dataset_test.py
from torch.utils.data import DataLoader
from dataset import MyDataset
path = 'dataset/sample_train.csv'
size = 1000
dataset = MyDataset(path, size)
dataloader = DataLoader(dataset, batch_size=1000)
for v in dataloader:
print(v)
I got the following output
730600it [11:08, 1093.11it/s]
1000it [00:00, 20325.47it/s]
Traceback (most recent call last):
File "dataset_test.py", line 12, in <module>
for v in dataloader:
File "/home/usr/.local/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 521, in __next__
data = self._next_data()
File "/home/usr/.local/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 561, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "/home/usr/.local/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
return self.collate_fn(data)
File "/home/usr/.local/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 84, in default_collate
return [default_collate(samples) for samples in transposed]
File "/home/usr/.local/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 84, in <listcomp>
return [default_collate(samples) for samples in transposed]
File "/home/usr/.local/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 74, in default_collate
return {key: default_collate([d[key] for d in batch]) for key in elem}
File "/home/usr/.local/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 74, in <dictcomp>
return {key: default_collate([d[key] for d in batch]) for key in elem}
File "/home/usr/.local/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 74, in <listcomp>
return {key: default_collate([d[key] for d in batch]) for key in elem}
KeyError: '210'
The first two lines might be the output when loading data. (I'm not sure because I didn't write any output. But I am using tqdm to load data, so I assume it's tqdm's output?)
Then, I got this key error. I'm wondering which part should be modified? I think the dataset class is well-written, since there's no error when reading the data from file. Is it because the format of samples is not right, so the dataloader cannot load data from dataset properly? Is there any requirement for the format? I've read other people's code, but I didn't find any info mentioning that there's any requirement of the format of samples in Dataset class.
EDIT: A single sample looks like this
('0', '0', {'210': '9093445', '216': '9154780', '301': '9351665', '205': '4186222', '206': '8316799', '207': '8416205', '508': '9355039', '121': '3438658', '122': '3438762', '101': '31390', '124': '3438769', '125': '3438774', '127': '3438782', '128': '3864885', '129': '3864887', '150_14': '3941161', '127_14': '3812616', '109_14': '449068', '110_14': '569621'})
The first two '0'
s are labels, and the following dictionary contains features.
Solution
As @Shai mentioned, if they keys in feature_dictionary
are not the same in a batch, then you get this error from the default collate_fn
of DataLoader
. As a solution, you can write a custom collate_fn
as follows and it works
class MyDataset(Dataset):
# ... your code ...
def collate_fn(self, batch):
tag1_batch = []
tag2_batch = []
feat_dict_batch = []
for tag1, tag2, feat_dict in batch:
tag1_batch.append(tag1)
tag2_batch.append(tag2)
feat_dict_batch.append(feat_dict)
return tag1_batch, tag2_batch, feat_dict_batch
path = 'dataset/sample_train.csv'
size = 1000
dataset = MyDataset(path, size)
dataloader = DataLoader(dataset, batch_size=3, collate_fn=dataset.collate_fn)
Answered By - The Exile
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.