Issue
I am a fresh starter with PyTorch. Strangely I cannot find anything related to this, although it seems rather simple.
I want to structure my batch with specific examples, like all examples per batch having the same label, or just fill the batch with examples of just 2 classes.
How would I do that? For me, it seems the right place within the data loader and not in the dataset? As the data loader is responsible for the batches and not the dataset?
Is there a simple minimal example?
Solution
TLDR;
Default
DataLoader
only uses a sampler, not a batch sampler.You can define a sampler, plus a batch sampler, a batch sampler will override the sampler.
The sampler only yields the sequence of dataset element, not the actual batches (this is handled by the data loader, depending on
batch_size
).
To answer your initial question: Working with a sampler on an iterable dataset doesn't seem to be possible cf. Github issue (still open). Also, read the following note on pytorch/dataloader.py
.
Samplers (for map-style datasets):
That aside, if you are switching to a map-style dataset, here are some details on samplers and batch samplers work. You have access to a dataset's underlying data using indices, just like you would with a list (since torch.utils.data.Dataset
implements __getitem__
). In another word, your dataset elements are all dataset[i]
, for i
in [0, len(dataset) - 1]
.
Here is a toy dataset:
class DS(Dataset):
def __getitem__(self, index):
return index
def __len__(self):
return 10
In a general use case you would just give torch.utils.data.DataLoader
the arguments batch_size
and shuffle
. By default, shuffle
is set to false
, which means it will use torch.utils.data.SequentialSampler
. Else (if shuffle
is true
) torch.utils.data.RandomSampler
will be used. The sampler defines how the data loader accesses the dataset (in which order it accesses it).
The above dataset (DS
) has 10 elements. The indices are 0
, 1
, 2
, 3
, 4
, 5
, 6
, 7
, 8
, and 9
. They map to elements 0
, 10
, 20
, 30
, 40
, 50
, 60
, 70
, 80
, and 90
. So with a batch size of 2:
SequentialSampler
:DataLoader(ds, batch_size=2)
(implictlyshuffle=False
), identical toDataLoader(ds, batch_size=2, sampler=SequentialSampler(ds))
. The dataloader will delivertensor([0, 10])
,tensor([20, 30])
,tensor([40, 50])
,tensor([60, 70])
, andtensor([80, 90])
.RandomSampler
:DataLoader(ds, batch_size=2, shuffle=True)
, identical toDataLoader(ds, batch_size=2, sampler=RandomSampler(ds))
. The dataloader will sample randomly each time you iterate through it. For instance:tensor([50, 40])
,tensor([90, 80])
,tensor([0, 60])
,tensor([10, 20])
, andtensor([30, 70])
. But the sequence will be different if you iterate through the dataloader a second time!
Batch sampler
Providing batch_sampler
will override batch_size
, shuffle
, sampler
, and drop_last
altogether. It is meant to define exactly the batch elements and their content. For instance:
>>> DataLoader(ds, batch_sampler=[[1,2,3], [6,5,4], [7,8], [0,9]])`
Will yield tensor([10, 20, 30])
, tensor([60, 50, 40])
, tensor([70, 80])
, and tensor([ 0, 90])
.
Batch sampling on the class
Let's say I just want to have 2 elements (different or not) of each class in my batch and have to exclude more examples of each class. So ensuring that not 3 examples are inside of the batch.
Let's say you have a dataset with four classes. Here is how I would do it. First, keep track of dataset indices for each class.
class DS(Dataset):
def __init__(self, data):
super(DS, self).__init__()
self.data = data
self.indices = [[] for _ in range(4)]
for i, x in enumerate(data):
if x > 0 and x % 2: self.indices[0].append(i)
if x > 0 and not x % 2: self.indices[1].append(i)
if x < 0 and x % 2: self.indices[2].append(i)
if x < 0 and not x % 2: self.indices[3].append(i)
def classes(self):
return self.indices
def __getitem__(self, index):
return self.data[index]
For example:
>>> ds = DS([1, 6, 7, -5, 10, -6, 8, 6, 1, -3, 9, -21, -13, 11, -2, -4, -21, 4])
Will give:
>>> ds.classes()
[[0, 2, 8, 10, 13], [1, 4, 6, 7, 17], [3, 9, 11, 12, 16], [5, 14, 15]]
Then for the batch sampler, the easiest way is to create a list of class indices that are available, and have as many class index as there are dataset element.
In the dataset defined above, we have 5 items from class 0
, 5 from class 1
, 5 from class 2
, and 3 from class 3
. Therefore we want to construct [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3]
. We will shuffle it. Then, from this list and the dataset classes content (ds.classes()
) we will be able to construct the batches.
class Sampler():
def __init__(self, classes):
self.classes = classes
def __iter__(self):
classes = copy.deepcopy(self.classes)
indices = flatten([[i for _ in range(len(klass))] for i, klass in enumerate(classes)])
random.shuffle(indices)
grouped = zip(*[iter(indices)]*2)
res = []
for a, b in grouped:
res.append((classes[a].pop(), classes[b].pop()))
return iter(res)
Note - deep copying the list is required since we're popping elements from it.
A possible output of this sampler would be:
[(15, 14), (16, 17), (7, 12), (11, 6), (13, 10), (5, 4), (9, 8), (2, 0), (3, 1)]
At this point we can simply use torch.data.utils.DataLoader
:
>>> dl = DataLoader(ds, batch_sampler=sampler(ds.classes()))
Which could yield something like:
[tensor([ 4, -4]), tensor([-21, 11]), tensor([-13, 6]), tensor([9, 1]), tensor([ 8, -21]), tensor([-3, 10]), tensor([ 6, -2]), tensor([-5, 7]), tensor([-6, 1])]
An easier approach
Here is another - easier - approach that will not guarantee to return all elements from the dataset, on average it will...
For each batch, first sample class_per_batch
classes, then sample batch_size
elements from these selected classes (by first sampling a class from that class subset, then sampling from a data point from that class).
class Sampler():
def __init__(self, classes, class_per_batch, batch_size):
self.classes = classes
self.n_batches = sum([len(x) for x in classes]) // batch_size
self.class_per_batch = class_per_batch
self.batch_size = batch_size
def __iter__(self):
classes = random.sample(range(len(self.classes)), self.class_per_batch)
batches = []
for _ in range(self.n_batches):
batch = []
for i in range(self.batch_size):
klass = random.choice(classes)
batch.append(random.choice(self.classes[klass]))
batches.append(batch)
return iter(batches)
You can try it this way:
>>> s = Sampler(ds.classes(), class_per_batch=2, batch_size=4)
>>> list(s)
[[16, 0, 0, 9], [10, 8, 11, 2], [16, 9, 16, 8], [2, 9, 2, 3]]
>>> dl = DataLoader(ds, batch_sampler=s)
>>> list(iter(dl))
[tensor([ -5, -6, -21, -13]), tensor([ -4, -4, -13, -13]), tensor([ -3, -21, -2, -5]), tensor([-3, -5, -4, -6])]
Answered By - Ivan
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.