Issue
I am new to PyTorch and am still wrapping my head around how to form a proper gather
statement. I have a 4D input tensor of size (1,200,61,1632)
, where 1632
is the time dimension. I want to index it with a tensor idx
which is size (4,1632)
where each row of idx
is a value I want to extract from the input
tensor. So the rows of idx
look like:
[0,20,30,0]
[0,150,9,1]
[0,180,100,2]
...
So that the output has size 1632
. In other words I want to do this:
output = []
for i in range(1632):
output.append(input[idx[0,i], idx[1,i], idx[2,i], idx[3,i]])
Is this an appropriate use case for torch.gather? Looking at the documentation for gather, it says the input and index tensors must have the same shape.
Solution
Since PyTorch doesn't offer an implementation of ravel_multi_index
, the ugly way of doing it is this one:
output = input[idx[0, :], idx[1, :], idx[2, :], idx[3, :]]
In NumPy, you could do this way:
output = np.take(input, np.ravel_multi_index(idx, input.shape))
Answered By - Berriel
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.