Issue
I have a pytorch tensor A, that's of size (n,m) and a list of indices for size n, such that each entry of 0 <= indices[i] < m. For each row i of A, I want to multiply A[i, indices[i]] *= -1, in a vectorized way. Is there an easy way to do this?
A = torch.tensor([[1,2,3],[4,5,6]])
indices = torch.tensor([1, 2])
#desired result
A = [[1,-2,3],[4,5,-6]]
Solution
Sure there is, fancy indexing is the way to go:
import torch
A = torch.tensor([[1, 2, 3], [4, 5, 6]])
indices = torch.tensor([1, 2]).long()
A[range(A.shape[0]), indices] *= -1
Remember indices must be torch.LongTensor
type. You could cast it if you have float
using .long()
member function.
Answered By - Szymon Maszke
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.