Issue
For any 2D tensor X
, how to get the mask for top K
elements for each row where K
is a tensor (not restricted to an int
)?
Input:
tensor([[0.6607, 0.1165, 0.0278, 0.1950],
[0.0529, 0.4607, 0.2729, 0.2135],
[0.3267, 0.0902, 0.4578, 0.1253]])
Desired output: for K = torch.tensor([2,3,1])
tensor([[ True, False, False, True],
[ False, True, True, True],
[ False, False, True, False]])
I have tried these [1], [2], but can not succeed.
Solution
You can use the torch.topk
and torch.tensor.scatter_
methods for this:
K = torch.tensor([2,3,1])
for idx, k in enumerate(K):
top_k = torch.topk(x[idx], k)
x[idx].scatter_(0, top_k.indices, 1)
mask = x.eq(1)
Answered By - Maximilian Gangloff
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.