Issue
For a given 2D tensor I want to retrieve all indices where the value is 1
. I expected to be able to simply use torch.nonzero(a == 1).squeeze()
, which would return tensor([1, 3, 2])
. However, instead, torch.nonzero(a == 1)
returns a 2D tensor (that's okay), with two values per row (that's not what I expected). The returned indices should then be used to index the second dimension (index 1) of a 3D tensor, again returning a 2D tensor.
import torch
a = torch.Tensor([[12, 1, 0, 0],
[4, 9, 21, 1],
[10, 2, 1, 0]])
b = torch.rand(3, 4, 8)
print('a_size', a.size())
# a_size torch.Size([3, 4])
print('b_size', b.size())
# b_size torch.Size([3, 4, 8])
idxs = torch.nonzero(a == 1)
print('idxs_size', idxs.size())
# idxs_size torch.Size([3, 2])
print(b.gather(1, idxs))
Evidently, this does not work, leading to aRunTimeError:
RuntimeError: invalid argument 4: Index tensor must have same dimensions as input tensor at C:\w\1\s\windows\pytorch\aten\src\TH/generic/THTensorEvenMoreMath.cpp:453
It seems that idxs
is not what I expect it to be, nor can I use it the way I thought. idxs
is
tensor([[0, 1],
[1, 3],
[2, 2]])
but reading through the documentation I don't understand why I also get back the row indices in the resulting tensor. Now, I know I can get the correct idxs by slicing idxs[:, 1]
but then still, I cannot use those values as indices for the 3D tensor because the same error as before is raised. Is it possible to use the 1D tensor of indices to select items across a given dimension?
Solution
You could simply slice them and pass it as the indices as in:
In [193]: idxs = torch.nonzero(a == 1)
In [194]: c = b[idxs[:, 0], idxs[:, 1]]
In [195]: c
Out[195]:
tensor([[0.3411, 0.3944, 0.8108, 0.3986, 0.3917, 0.1176, 0.6252, 0.4885],
[0.5698, 0.3140, 0.6525, 0.7724, 0.3751, 0.3376, 0.5425, 0.1062],
[0.7780, 0.4572, 0.5645, 0.5759, 0.5957, 0.2750, 0.6429, 0.1029]])
Alternatively, an even simpler & my preferred approach would be to just use torch.where()
and then directly index into the tensor b
as in:
In [196]: b[torch.where(a == 1)]
Out[196]:
tensor([[0.3411, 0.3944, 0.8108, 0.3986, 0.3917, 0.1176, 0.6252, 0.4885],
[0.5698, 0.3140, 0.6525, 0.7724, 0.3751, 0.3376, 0.5425, 0.1062],
[0.7780, 0.4572, 0.5645, 0.5759, 0.5957, 0.2750, 0.6429, 0.1029]])
A bit more explanation about the above approach of using torch.where()
: It works based on the concept of advanced indexing. That is, when we index into the tensor using a tuple of sequence objects such as tuple of tensors, tuple of lists, tuple of tuples etc.
# some input tensor
In [207]: a
Out[207]:
tensor([[12., 1., 0., 0.],
[ 4., 9., 21., 1.],
[10., 2., 1., 0.]])
For basic slicing, we would need a tuple of integer indices:
In [212]: a[(1, 2)]
Out[212]: tensor(21.)
To achieve the same using advanced indexing, we would need a tuple of sequence objects:
# adv. indexing using a tuple of lists
In [213]: a[([1,], [2,])]
Out[213]: tensor([21.])
# adv. indexing using a tuple of tuples
In [215]: a[((1,), (2,))]
Out[215]: tensor([21.])
# adv. indexing using a tuple of tensors
In [214]: a[(torch.tensor([1,]), torch.tensor([2,]))]
Out[214]: tensor([21.])
And the dimension of the returned tensor would always be one dimension less than the dimension of the input tensor.
Answered By - kmario23
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.