Issue
So, I'm trying to load this dataset in pytorch, I'm facing a problem while loading it.
As you can make out my checking the dataset that the directory looks somethings like this:
root
monet_jpg
monet_tfrec
photo_jpg
photo_tfrec
So, I want to load the photo and monet images in separate dataloader variables. But this method doesn't seem to work.
EDIT: By that I mean the monet_ds and photo_ds return only monet images (while photo_ds should return images from photo_jpg)
I'm trying to load the data through this code:
import torchvision.datasets as dset
import torchvision.utils as vutils
from torch.utils.data import Subset
def load_data(dataroot , image_size, batch_size, workers,ngpu,shuffle=True):
#DataLoading
# Create the dataset
dataset = dset.ImageFolder(root=dataroot,
transform=transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
print(dataset.class_to_idx)
#print(dataset.imgs)
monet_ds = Subset(dataset, range(0,299))
photo_ds = Subset(dataset, range(300,))
# Create the dataloader
monet_ds = torch.utils.data.DataLoader(monet_ds, batch_size=batch_size,
num_workers=workers)
photo_ds = torch.utils.data.DataLoader(photo_ds, batch_size=batch_size,
num_workers=workers)
# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
print("Data loaded...")
root = "../input/gan-getting-started"
monet_ds, photo_ds, device = load_data(root, image_size, batch_size, workers, ngpu)
Any help for loading this data perfectly in pytorch would be of good help. Thank you.
Solution
It seems that they are completely independent, so the following should work just fine:
import os
from torchvision.datasets.folder import default_loader
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
class MonetPhotoDataset(Dataset):
def __init__(self, root, transform=None):
self.transform = transform
self.img_paths = sorted(os.path.join(root, x) for x in os.listdir(root) if x.endswith('.jpg'))
def __len__(self):
return len(self.img_paths)
def __getitem__(self, idx):
img_path = self.img_paths[idx]
sample = default_loader(img_path)
if self.transform is not None:
sample = self.transform(sample)
return sample
def load_data(dataroot, image_size, batch_size, workers, ngpu, shuffle=True):
# set up transform
transform = transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
# create datasets
monet_ds = MonetPhotoDataset(root=os.path.join(dataroot, 'monet_jpg'), transform=transform)
photo_ds = MonetPhotoDataset(root=os.path.join(dataroot, 'photo_jpg'), transform=transform)
# create dataloaders
monet_dl = DataLoader(monet_ds, batch_size=batch_size, num_workers=workers)
photo_dl = DataLoader(photo_ds, batch_size=batch_size, num_workers=workers)
# decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
print("Data loaded...")
return monet_dl, photo_dl, device
root = "../input/gan-getting-started"
monet_dl, photo_dl, device = load_data(root, image_size, batch_size, workers, ngpu)
P.S.: I kept the load_data
because I assumed you rely on its signature in your code, but I wouldn't use it otherwise. Also, I didn't test the code above, so expect some typo but the logic is correct.
Note that this dataset returns only the images.
Answered By - Berriel
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.