Issue
I'm trying to define a pytorch Dataset
/Dataloader
for an image style transfer network. I have a dataset of images grouped by styles, and I want each sample from this dataset to consist of two images, one for style, the other for content. My first idea was to implement a Dataset
with something like this in __init__
:
n = len(images)
itr_style = random.shuffle([i for i in range(n)]))
itr_content = random.shuffle([i for i in range(n)]))
and this in __getitem__
:
return (images[itr_style[index]], images[itr_content[index]])
Which is probably not the most efficient implementation, and I also need to make sure that:
- The two images don't come from the same style
- The dataset re-shuffles every epoch
So what is the best way to implement this Dataset
?
Solution
I understood you want to make combination of two images, which are from different groups.
Assuming you have group of images, you can preload every combination of image index from each group, and load image from __getitem__
.
from typing import List
from torch.utils.data import Dataset
class Image():
"""Placeholder class - you may change Image class into some tensor objects"""
pass
class PreloadedDataset(Dataset):
def __init__(self, img_groups: List[List[Image]]):
super(PreloadedDataset, self).__init__()
self.groups = img_groups
self.combinations = []
for group_idx1, group1 in enumerate(img_groups):
for group_idx2, group2 in enumerate(img_groups[group_idx1:]):
for img1 in range(len(group1)):
for img2 in range(len(group2)):
self.combinations.append((group_idx1, img1, group_idx2, img2))
def __len__(self):
return len(self.combinations)
def __getitem__(self, item):
group1, img1, group2, img2 = self.combinations[item]
return self.groups[group1][img1], self.groups[group2][img2]
Answered By - minolee
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.