Issue
A typical custom PyTorch Dataset looks like this,
class TorchCustomDataset(torch.utils.data.Dataset):
def __init__(self, filenames, speech_labels):
pass
def __len__(self):
return 100
def __getitem__(self, idx):
return 1, 0
Here, with __getitem__
I can read any file, and apply any pre-processing for that specific file.
What if I want to apply some tensor-level pre-processing to the whole batch of data? Technically, it's possible to just iterate through the data loader to get the batch sample and apply the pre-processing on it.
But how to do it with a custom data loader? In short, what will be the __getitem__
equivalent for data loader to apply some operation on the whole batch of data?
Solution
You can override the collate_fn
of DataLoader
: This function takes the individual items from the underlying Dataset
and forms the batch. You can add your custom pre-processing at that point by modifying the collate_fn
.
Answered By - Shai
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.