Issue
I am writing a simple transformation for a dataset which contains many pairs of images. As a data augmentation, I want to apply some random transformation for each pair but the images in that pair should be transformed in the same way.
For example, given a pair of two images A
and B
, if A
is flipped horizontally, B
must be flipped horizontally as A
. Then the next pair C
and D
should be differently transformed from A
and B
but C
and D
are transformed in the same way. I am trying that in the way below
import random
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
img_a = Image.open("sample_ajpg") # note that two images have the same size
img_b = Image.open("sample_b.png")
img_c, img_d = Image.open("sample_c.jpg"), Image.open("sample_d.png")
transform = transforms.RandomChoice(
[transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip()]
)
random.seed(0)
display(transform(img_a))
display(transform(img_b))
random.seed(1)
display(transform(img_c))
display(transform(img_d))
Yet、 the above code does not choose the same transformation and as I tested, it is dependent on the number of times transform
is called.
Is there any way to force transforms.RandomChoice
to use the same transform when specified?
Solution
Usually a workaround is to apply the transform on the first image, retrieve the parameters of that transform, then apply with a deterministic transform with those parameters on the remaining images. However, here RandomChoice
does not provide an API to get the parameters of the applied transform since it involves a variable number of transforms.
In those cases, I usually implement an overwrite to the original function.
Looking at the torchvision implementation, it's as simple as:
class RandomChoice(RandomTransforms):
def __call__(self, img):
t = random.choice(self.transforms)
return t(img)
Here are two possible solutions.
You can either sample from the transform list on
__init__
instead of on__call__
:import random import torchvision.transforms as T class RandomChoice(torch.nn.Module): def __init__(self): super().__init__() self.t = random.choice(self.transforms) def __call__(self, img): return self.t(img)
So you can do:
transform = T.RandomChoice([ T.RandomHorizontalFlip(), T.RandomVerticalFlip() ]) display(transform(img_a)) # both img_a and img_b will display(transform(img_b)) # have the same transform transform = T.RandomChoice([ T.RandomHorizontalFlip(), T.RandomVerticalFlip() ]) display(transform(img_c)) # both img_c and img_d will display(transform(img_d)) # have the same transform
Or better yet, transform the images in batch:
import random import torchvision.transforms as T class RandomChoice(torch.nn.Module): def __init__(self, transforms): super().__init__() self.transforms = transforms def __call__(self, imgs): t = random.choice(self.transforms) return [t(img) for img in imgs]
Which allows to do:
transform = T.RandomChoice([ T.RandomHorizontalFlip(), T.RandomVerticalFlip() ]) img_at, img_bt = transform([img_a, img_b]) display(img_at) # both img_a and img_b will display(img_bt) # have the same transform img_ct, img_dt = transform([img_c, img_d]) display(img_ct) # both img_c and img_d will display(img_dt) # have the same transform
Answered By - Ivan
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.