Issue
I want to apply the function torch.topk
, but only on the non-zero elements of the tensor (i.e, not to count zero elements in the counting process).
Currently I do this:
torch.topk(tensor.view(-1), k)
But this also considers the zero elements in variable tensor and returns the top largest among them. What should I do to get the topk among non-zero elements?
Solution
# get the top k values in a tensor excluding zeros
top_vals = torch.topk(tensor.view(-1), k)[0]
mask = top_vals != 0
values = top_vals[mask]
print(values)
# Get Indices of Top K Values
indices = torch.nonzero(mask)
print(indices)
credits: this question and its answers
Answered By - Mahmood Hussain
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.