Issue
Suppose I have the following matrix, and its top 3 results across each row:
p = torch.randn(5, 7)
val, idx = p.topk(3, dim=-1)
I wish to assign x
to the top 3 results of each row where x
is:
x = torch.randn(5, 3)
Now I know that doing torch.gather(p, -1, idx)
will get me the correct elements that I want to replace, but I cannot replace against the function gather
. What is the best way of getting the effect of:
torch.gather(p, -1, idx) = x
Solution
One solution is to use list-style indexing to p:
# create dummy indices to index the correct row (we need one value per value in idx)
row_idx = torch.arange(len(p)).unsqueeze(1).repeat(1,3)
# use flattened views
p[row_idx.view(-1),idx.view(-1)] = x.view(-1)
List-based indexing does require contiguous memory tensors, so you may pay a small computational penalty if p is non-contiguous, but I suspect any non-looping solution to this indexing task would have the same requirement.
Answered By - DerekG
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.