Issue
I have to admit, I'm a bit confused by the scatter* and index* operations - I'm not sure any of them do exactly what I'm looking for, which is very simple:
Given some 2-D tensor
z = tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]])
And a list (or tensor?) of 2-d indexes:
inds = tensor([[0, 0],
[1, 1],
[1, 2]])
I want to add a scalar to z at those indexes (and do it efficiently):
znew = z.something_add(inds, 3)
->
znew = tensor([[4., 1., 1., 1.],
[1., 4., 4., 1.],
[1., 1., 1., 1.]])
If I have to I can make that scalar a tensor of whatever shape (where all elements = 3), but I'd rather not...
Solution
You must provide two lists to your indexing. The first having the row positions and the second the column positions. In your example, it would be:
z[[0, 1, 1], [0, 1, 2]] += 3
torch.Tensor indexing follows Numpy. See https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#integer-array-indexing for more details.
Answered By - Fábio Perez
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.