Issue
I have created a dataloader whose length is 50000. When I calculate its length it prints out 50000.
class MyDataLoader(torch.utils.data.Dataset):
def __init__(self, data_size=50000):
self.data_size = data_size
def __len__(self) -> int:
return self.data_size
def __getitem__(self, idx) -> t.Tuple[torch.Tensor, torch.Tensor]:
image, label = my_function()#(has_star=True)
return image[None], label
dl = MyDataLoader()
print(len(dl))
50000
However, when I iterate over it, it goes forever like the following:
for j, i in enumerate(dl):
if j%10000 == 0:
print(j)
10000
20000
30000
40000
50000
60000
...
How is that possible?
Solution
You have created a Dataset
, not a Dataloader
.
This should work:
import torch
from torch.utils.data import DataLoader
class MyDataset(torch.utils.data.Dataset):
def __init__(self, data_size=50000):
self.data_size = data_size
def __len__(self) -> int:
return self.data_size
def __getitem__(self, idx):
# print(idx)
return idx
dataset = MyDataset()
# Assume a default batch size of 1
dl = DataLoader(dataset)
print(len(dl))
for j, i in enumerate(dl):
if j%10000 == 0:
print(j)
# And with a different batch size:
dl = DataLoader(dataset, batch_size=2)
print(len(dl))
for j, i in enumerate(dl):
if j%10000 == 0:
print(j)
Note how len(dl)
changes when the batch size changes.
Answered By - hkchengrex
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.