Issue
Dear community I have a challenge with regard to tensor indexing in PyTorch. The problem is very simple. Given a tensor create an index tensor to index its maximum values per column.
x = T.tensor([[0, 3, 0, 5, 9, 8, 2, 0],
[0, 4, 9, 6, 7, 9, 1, 0]])
Given this tensor I would like to build a boolean mask for indexing its maximum values per colum. To be specific I do not need its maximum values, torch.max(x, dim=0)
, nor its indices, torch.argmax(x, dim=0)
, but a boolean mask for indexing other tensor based on this tensor max values. My ideal output would be:
# Input tensor
x
tensor([[0, 3, 0, 5, 9, 8, 2, 0],
[0, 4, 9, 6, 7, 9, 1, 0]])
# Ideal output bool mask tensor
idx
tensor([[1, 0, 0, 0, 1, 0, 1, 1],
[0, 1, 1, 1, 0, 1, 0, 0]])
I know that values_max = x[idx]
and values_max = x.max(dim=0)
are equivalent but I am not looking for values_max
but for idx
.
I have built a solution around it but it just seem to complex and I am sure torch
have an optimized way to do this. I have tried to use torch.index_select
with the output of x.argmax(dim=0)
but failed so I built a custom solution that seems to cumbersome to me so I am asking for help to do this in a vectorized / tensorial / torch way.
Solution
You can perform this operation by first extracting the index of the maximum value column-wise of your tensor with torch.argmax
, setting keepdim
to True
>>> x.argmax(0, keepdim=True)
tensor([[0, 1, 1, 1, 0, 1, 0, 0]])
Then you can use torch.scatter
to place 1
s in a zero tensor at the designated indices:
>>> torch.zeros_like(x).scatter(0, x.argmax(0,True), value=1)
tensor([[1, 0, 0, 0, 1, 0, 1, 1],
[0, 1, 1, 1, 0, 1, 0, 0]])
Answered By - Ivan
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.