Issue
I have a 2D pytorch tensor of shape n by m. I want to index the second dimension using a list of indices (which could be done with torch.gather) then then also set new values to the result of the indexing.
Example:
data = torch.tensor([[0,1,2], [3,4,5], [6,7,8]]) # shape (3,3)
indices = torch.tensor([1,2,1], dtype=torch.long).unsqueeze(-1) # shape (3,1)
# data tensor:
# tensor([[0, 1, 2],
# [3, 4, 5],
# [6, 7, 8]])
I want to select the specified indices per row (which would be [1,5,7]
but then also set these values to another number - e.g. 42
I can select the desired columns row wise by doing:
data.gather(1, indices)
tensor([[1],
[5],
[7]])
data.gather(1, indices)[:] = 42 # **This does NOT work**, since the result of gather
# does not use the same storage as the original tensor
which is fine, but I would like to change these values now, and have the change also affect the data
tensor.
I can do what I want to achieve using this, but it seems to be very un-pythonic:
max_index = torch.max(indices)
for i in range(0, max_index + 1):
mask = (indices == i).nonzero(as_tuple=True)[0]
data[mask, i] = 42
print(data)
# tensor([[ 0, 42, 2],
# [ 3, 4, 42],
# [ 6, 42, 8]])
Any hints on how to do that more elegantly?
Solution
What you are looking for is torch.scatter_
with the value
option.
Tensor.scatter_(dim, index, src, reduce=None) → Tensor
Writes all values from the tensorsrc
intoself
at the indices specified in theindex
tensor. For each value insrc
, its outputindex
is specified by its index in src fordimension != dim
and by the corresponding value in index fordimension = dim
.With 2D tensors as input and
dim=1
, the operation is:self[i][index[i][j]] = src[i][j]
No mention of the value parameter though...
With value=42
, and dim=1
, this will have the following effect on data:
data[i][index[i][j]] = 42
Here applied in-place:
>>> data.scatter_(index=indices, dim=1, value=42)
>>> data
tensor([[ 0, 42, 2],
[ 3, 4, 42],
[ 6, 42, 8]])
Answered By - Ivan
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.