Issue
I want to train a classifier on ImageNet dataset (1000 classes) and I need each batch to contain 64 images from the same class and consecutive batches from different classes. So far based on @shai
's suggestion and this post I have
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import numpy as np
import random
import argparse
import torch
import os
class DS(Dataset):
def __init__(self, data, num_classes):
super(DS, self).__init__()
self.data = data
self.indices = [[] for _ in range(num_classes)]
for i, (data, class_label) in enumerate(data):
# create a list of lists, where every sublist containts the indices of
# the samples that belong to the class_label
self.indices[class_label].append(i)
def classes(self):
return self.indices
def __getitem__(self, index):
return self.data[index]
class BatchSampler:
def __init__(self, classes, batch_size):
# classes is a list of lists where each sublist refers to a class and contains
# the sample ids that belond to this class
self.classes = classes
self.n_batches = sum([len(x) for x in classes]) // batch_size
self.min_class_size = min([len(x) for x in classes])
self.batch_size = batch_size
self.class_range = list(range(len(self.classes)))
random.shuffle(self.class_range)
assert batch_size < self.min_class_size, 'batch_size should be at least {}'.format(self.min_class_size)
def __iter__(self):
batches = []
for j in range(self.n_batches):
if j < len(self.class_range):
batch_class = self.class_range[j]
else:
batch_class = random.choice(self.class_range)
batches.append(np.random.choice(self.classes[batch_class], self.batch_size))
return iter(batches)
def main():
# Code about
_train_dataset = DS(train_dataset, train_dataset.num_classes)
_batch_sampler = BatchSampler(_train_dataset.classes(), batch_size=args.batch_size)
_train_loader = DataLoader(dataset=_train_dataset, batch_sampler=_batch_sampler)
labels = []
for i, (inputs, _labels) in enumerate(_train_loader):
labels.append(torch.unique(_labels).item())
print("Unique labels: {}".format(torch.unique(_labels).item()))
labels = set(labels)
print('Length of traversed unique labels: {}'.format(len(labels)))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('--data', metavar='DIR', nargs='?', default='imagenet',
help='path to dataset (default: imagenet)')
parser.add_argument('--dummy', action='store_true', help="use fake data to benchmark")
parser.add_argument('-b', '--batch-size', default=64, type=int,
metavar='N',
help='mini-batch size (default: 256), this is the total '
'batch size of all GPUs on the current node when '
'using Data Parallel or Distributed Data Parallel')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
args = parser.parse_args()
if args.dummy:
print("=> Dummy data is used!")
num_classes = 100
train_dataset = datasets.FakeData(size=12811, image_size=(3, 224, 224),
num_classes=num_classes, transform=transforms.ToTensor())
val_dataset = datasets.FakeData(5000, (3, 224, 224), num_classes, transforms.ToTensor())
else:
traindir = os.path.join(args.data, 'train')
valdir = os.path.join(args.data, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
train_dataset = datasets.ImageFolder(
traindir,
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
val_dataset = datasets.ImageFolder(
valdir,
transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]))
# Samplers are initialized to None and train_sampler will be replaced
train_sampler, val_sampler = None, None
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler)
val_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True, sampler=val_sampler)
main()
which prints: Length of traversed unique labels: 100
.
However, creating self.indices
in the for loop takes a lot of time. Is there a more efficient way to construct this sampler?
EDIT: yield implementation
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import numpy as np
import random
import argparse
import torch
import os
from tqdm import tqdm
import os.path
class DS(Dataset):
def __init__(self, data, num_classes):
super(DS, self).__init__()
self.data = data
self.data_len = len(data)
indices = [[] for _ in range(num_classes)]
for i, (_, class_label) in tqdm(enumerate(data), total=len(data), miniters=1,
desc='Building class indices dataset..'):
indices[class_label].append(i)
self.indices = indices
def per_class_sample_indices(self):
return self.indices
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.data_len
class BatchSampler:
def __init__(self, per_class_sample_indices, batch_size):
# classes is a list of lists where each sublist refers to a class and contains
# the sample ids that belond to this class
self.per_class_sample_indices = per_class_sample_indices
self.n_batches = sum([len(x) for x in per_class_sample_indices]) // batch_size
self.min_class_size = min([len(x) for x in per_class_sample_indices])
self.batch_size = batch_size
self.class_range = list(range(len(self.per_class_sample_indices)))
random.shuffle(self.class_range)
def __iter__(self):
for j in range(self.n_batches):
if j < len(self.class_range):
batch_class = self.class_range[j]
else:
batch_class = random.choice(self.class_range)
if self.batch_size <= len(self.per_class_sample_indices[batch_class]):
batch = np.random.choice(self.per_class_sample_indices[batch_class], self.batch_size)
# batches.append(np.random.choice(self.per_class_sample_indices[batch_class], self.batch_size))
else:
batch = self.per_class_sample_indices[batch_class]
yield batch
def n_batches(self):
return self.n_batches
def main():
file_path = 'a_file_path'
file_name = 'per_class_sample_indices.pt'
if not os.path.exists(os.path.join(file_path, file_name)):
print('File: {} does not exists. Create it.'.format(file_name))
per_class_sample_indices = DS(train_dataset, num_classes).per_class_sample_indices()
torch.save(per_class_sample_indices, os.path.join(file_path, file_name))
else:
per_class_sample_indices = torch.load(os.path.join(file_path, file_name))
print('File: {} exists. Do not create it.'.format(file_name))
batch_sampler = BatchSampler(per_class_sample_indices,
batch_size=args.batch_size)
train_loader = torch.utils.data.DataLoader(
train_dataset,
# batch_size=args.batch_size,
# shuffle=(train_sampler is None),
num_workers=args.workers,
pin_memory=True,
# sampler=train_sampler,
batch_sampler=batch_sampler
)
# We do not use sampler for the validation
# val_loader = torch.utils.data.DataLoader(
# val_dataset, batch_size=args.batch_size, shuffle=False,
# num_workers=args.workers, pin_memory=True, sampler=None)
labels = []
for i, (inputs, _labels) in enumerate(train_loader):
labels.append(torch.unique(_labels).item())
print("Unique labels: {}".format(torch.unique(_labels).item()))
labels = set(labels)
print('Length of traversed unique labels: {}'.format(len(labels)))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('--data', metavar='DIR', nargs='?', default='imagenet',
help='path to dataset (default: imagenet)')
parser.add_argument('--dummy', action='store_true', help="use fake data to benchmark")
parser.add_argument('-b', '--batch-size', default=64, type=int,
metavar='N',
help='mini-batch size (default: 256), this is the total '
'batch size of all GPUs on the current node when '
'using Data Parallel or Distributed Data Parallel')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
args = parser.parse_args()
if args.dummy:
print("=> Dummy data is used!")
num_classes = 100
train_dataset = datasets.FakeData(size=12811, image_size=(3, 224, 224),
num_classes=num_classes, transform=transforms.ToTensor())
val_dataset = datasets.FakeData(5000, (3, 224, 224), num_classes, transforms.ToTensor())
else:
traindir = os.path.join(args.data, 'train')
valdir = os.path.join(args.data, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
train_dataset = datasets.ImageFolder(
traindir,
transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
val_dataset = datasets.ImageFolder(
valdir,
transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]))
num_classes = len(train_dataset.classes)
main()
A similar post but in TensorFlow can be found here
Solution
You should write your own batch_sampler
class for the DataLoader
.
Answered By - Shai
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.