Issue
During the training of my neural network model, I used a Pytorch's data loader to accelerate the training of the model. But instead of using a fixed batch size before updating the model's parameter, I have a list of different batch sizes that I want the data loader to use.
Example
train_dataset = TensorDataset(x_train, y_train) # x_train.shape (8400, 4)
dataloader_train = DataLoader(train_dataset, batch_size=64) # with fixed batch size of 64
What I want is a data loader that can use a list of batch size that is dynamic (not fixe)?
list_batch_size = [30, 60, 110, ..., 231] # with this list's sum being equal to x_train.shape[0] (8400)
Solution
You can use a custom sampler (or batch sampler) for this.
Here's a quick proof-of-concept for a sampler that takes custom batch sizes as an argument to return batch indices as such:
class VariableBatchSampler(Sampler):
def __init__(self, dataset_len: int, batch_sizes: list):
self.dataset_len = dataset_len
self.batch_sizes = batch_sizes
self.batch_idx = 0
self.start_idx = 0
self.end_idx = self.batch_sizes[self.batch_idx]
def __iter__(self):
return self
def __next__(self):
if self.start_idx >= self.dataset_len:
raise StopIteration()
batch_indices = torch.arange(self.start_idx, self.end_idx, dtype=torch.int)
self.start_idx += (self.end_idx - self.start_idx)
self.batch_idx += 1
try:
self.end_idx += self.batch_sizes[self.batch_idx]
except IndexError:
self.end_idx = self.dataset_len
return batch_indices
You can instantiate the sampler and use it as the sampler
argument while instantiating the DataLoader
e.g.:
sampler = VariableBatchSampler(dataset_len=len(train_dataset), batch_sizes=[10, 20, 30, 40])
data_loader = DataLoader(train_dataset, sampler=sampler)
Note that, each element in the data_loader
iterable would contain one extra dimension for the batch (as the default value for batch_size
is 1 in DataLoader
); you can either use unsqueeze(dim=0)
to get rid of the extra dim. Or better pass the sampler as the batch_sampler
argument:
data_loader = DataLoader(train_dataset, batch_sampler=sampler)
Answered By - heemayl
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.