Issue
I want to use BoolTensor indices to slice a multidimensional tensor in Pytorch. I expect for the indexed tensor, the parts where the indices are true are kept, while the parts where the indices are false are sliced out.
My code is like
import torch
a = torch.zeros((5, 50, 5, 50))
tr_indices = torch.zeros((50), dtype=torch.bool)
tr_indices[1:50:2] = 1
val_indices = ~tr_indices
print(a[:, tr_indices].shape)
print(a[:, tr_indices, :, val_indices].shape)
I expect a[:, tr_indices, :, val_indices]
to be of shape [5, 25, 5, 25]
, however it returns [25, 5, 5]
. The result is
torch.Size([5, 25, 5, 50])
torch.Size([25, 5, 5])
I'm very confused. Can anyone explain why?
Solution
PyTorch inherits its advanced indexing behaviour from Numpy. Slicing twice like so should achieve your desired output:
a[:, tr_indices][..., val_indices]
Answered By - iacob
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.