Issue
In PyTorch I have an RGB tensor imgA
of batch size 256. I want to retain the green channel for first 128 batches and red channel for remaining 128 batches, something like below:
imgA[:128,2,:,:] = imgA[:128,1,:,:]
imgA[128:,2,:,:] = imgA[128:,0,:,:]
imgA = imgA[:,2,:,:].unsqueeze(1)
or same can be achieved like
imgA = torch.cat((imgA[:128,1,:,:].unsqueeze(1),imgA[128:,0,:,:].unsqueeze(1)),dim=0)
but as I have multiple such images like imgA, imgB, imgC, etc what is the fastest way of achieving the above goal?
Solution
A slicing-based solution can be achieved using torch.gather
and repeat_interleave
:
select = torch.tensor([1, 0], device=imgA.device)
imgA = = imgA.gather(dim=1, index=select.repeat_interleave(128, dim=0).view(256, 1, 1, 1).expand(-1, -1, *imgA.shape[-2:]))
You can also do that using matrix multiplication and repeat_interleave
:
# select c=1 for first half and c=0 for second
select = torch.tensor([[0, 1],[1, 0],[0, 0]], dtype=imgA.dtype, device=imgA.device)
imgA = torch.einsum('cb,bchw->bhw',select.repeat_interleave(128, dim=1), imgA).unsqueeze(dim=1)
Answered By - Shai
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.