Issue
Given tensor IN
of shape (A, B, C, D)
and index tensor IDX
of shape [A, B, C]
with torch.long
values in [0, C)
, how can I get a tensor OUT
of shape (A, B, C, D)
such that:
OUT[a, b, c, :] == IN[a, b, IDX[a, b, c], :]
This is trivial without dimensions A
and B
:
# C = 2, D = 3
IN = torch.arange(6).view(2, 3)
IDX = torch.tensor([0,0])
print(IN[IDX])
# tensor([[0, 1, 2],
# [0, 1, 2]])
Obviously, I can write a nested for loop over A and B. But surely there must be a vectorized way to do it?
Solution
This is the perfect use case for torch.gather
. Given two 4d tensors, input
the input tensor and index
the tensor containing the indices for input
, calling torch.gather
on dim=2
will return a tensor out shaped like input
such that:
out[i][j][k][l] = input[i][j][index[i][j][k][l]][l]
In other words, index
indexes dimension n°3 of input
.
Before applying such function though, notice all tensors must have the same number of dimensions. Since index
is only 3d, we need to insert and expand an additional 4th dimension on it. We can do so with the following lines:
>>> idx_ = idx[...,None].expand_as(x)
Then call the torch.gather
function
>>> x.gather(dim=2, index=idx_)
You can try out the solution with this code:
>>> A = 1; B = 2; C=3; D=2
>>> x = torch.rand(A,B,C,D)
tensor([[[[0.6490, 0.7670],
[0.7847, 0.9058],
[0.3606, 0.7843]],
[[0.0666, 0.7306],
[0.1923, 0.3513],
[0.5287, 0.3680]]]])
>>> idx = torch.randint(0, C, (A,B,C))
tensor([[[1, 2, 2],
[0, 0, 1]]])
>>> x.gather(dim=2, index=idx[...,None].expand_as(x))
tensor([[[[0.7847, 0.9058],
[0.3606, 0.7843],
[0.3606, 0.7843]],
[[0.0666, 0.7306],
[0.0666, 0.7306],
[0.1923, 0.3513]]]])
Answered By - Ivan
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.