Issue
I have two tensors a and b. And I want to retrive the values of b according to the positions of max values in a. That is,
max_values, indices = torch.max(a, dim=0, keepdim=True)
However, I do not know how to use the indices to retrive the values of b. Can anybody helps to solve it? Thanks a lot!!
Edit:
Sorry for not describing my problem concretely. To give a minimal example, the value of tensors a and b are:
a = torch.tensor([[1,2,4],[2,1,3]])
b = torch.tensor([[10,24,2],[23,4,5]])
If I use torch.max(a, dim=0, keepdim=True)
, it will return:
max: tensor([[2, 2, 4]])
indices: tensor([[1, 0, 0]])
What I want to obtain is the selected value of tensor b according to the indices of max values of a in dim=0
, that is,
tensor([[23, 24, 2]])
I have tried b[indices]
, whereas the result is not what I want:
tensor([[[ 2, 3, 5],
[10, 30, 40],
[10, 30, 40]]])
Solution
You can use torch.gather
:
torch.gather(b, dim=0, index=indices)
Answered By - GoodDeeds
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.