Issue
I have a dataloader that returns a batch of shape torch.Size([bs, c, h, w])
where bs=4, c=1,and (h, w=128)
. Now I want to apply some custom transformations to the returned batch. Note that I can not apply transformations in the Dataloader as I need to feed the returned batch as is to one network and a transformed one to another network.
More specifically, I want to apply the following transformations to the returned batch:
1. CenterCrop(100)
2. FiveCrop(16)
3. Resize(128)
4. ToTensor()
5. Normalize([0.5], [0.5])
I have created a function to achieve the following task as follows:
# DataLoader code
#
#
orig_img = next(iter(DataLoader))
patches = get_patches(orig_img)
def get_patches(orig_img):
# orig_img.shape = torch.Size([4, 1, 128, 128])
images = [TF.to_pil_image(x) for x in orig_img.cpu()]
resized_imgs = []
for img in images:
img = transforms.CenterCrop(100)(img)
five_crop = transforms.FiveCrop(64)(img)
f_crops = transforms.Lambda(lambda crops: torch.stack([transforms.Normalize([0.5], [0.5])(transforms.ToTensor()(transforms.Resize(128)(crop))) for crop in crops]))(five_crop)
resized_imgs.append(f_crops)
return resized_imgs
The problem right now is that when I get the resized_imgs list, every tensor inside it looses the batch size dimension i.e. resized_imgs[0].shape = torch.Size([ncrops, c, h, w]) (4d)
whereas, I expect the shape to be torch.Size([bs, ncrops, c, h, w]) (5d)
.
Solution
Your data loader will return a tensor of shape (bs, c, h, w)
. Therefore orig_img
is shaped the same way and iterating through it will provide you with a tensor img
shaped as (c, h, w)
. Applying FiveCrop
will create an additional dimension such that five_crop
is shaped (5, c, h, w)
. Then f_crops
will be shaped (5, c, 128, 128)
. Finally, the tensor is appended with the others in resized_imgs
(the list containing the different patched images). All in all resized_imgs contains bs
elements since orig_img.size(0) = bs
, and each element is a tensor shaped (5, c, 128, 128)
(five patches per image) as we've described above.
Another way of writing this function would be:
def get_patches(orig_img):
# orig_img.shape = (4, 1, 128, 128)
img_t = T.Compose([T.ToPILImage(),
T.CenterCrop(100),
T.FiveCrop(64)])
patch_t = T.Compose([T.Resize(128),
T.ToTensor(),
T.Normalize([0.5], [0.5])])
resized_imgs = []
for img in orig_img:
five_crop = img_t(img)
f_crops = torch.stack(list(map(patch_t, five_crop)))
resized_imgs.append(f_crops)
return torch.stack(resized_imgs)
The last line will stack all image patches into a single tensor of shape (bs, 5, c, 128, 128)
.
Answered By - Ivan
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.