Issue
I want to select specific rows of 2d tensor by giving a source tensor and index tensor.
Input:
a = torch.FloatTensor([[[1,1,1],[2,2,2]],[[9,9,9],[5,5,5]]]) b = torch.IntTensor([1,0])
Any solution?
Expect result: [[2,2,2]],[9,9,9]]
Solution
That is a great solution.
out = a[torch.arange(a.size(0)), b]
Answered By - Yanan Wang
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.