Issue
Given any PyTorch 2D tensor, what will be the most efficient way to compute the number of top-K values for each row whose sum is less than a given value?
Input:
tensor([[0.6607, 0.1165, 0.0278, 0.1950],
[0.0529, 0.4607, 0.2729, 0.2135],
[0.3267, 0.0902, 0.4578, 0.1253]])
Required Output for the given value 0.8
:
tensor([[1], #as 0.6607+0.1950 > 0.8
[2], #as 0.4607+0.2729+0.2135 > 0.8
[2]]) #as 0.4578+0.3267+0.1253 > 0.8
Solution
You can manage such operation by using a combination of sorting, cumulative sum, and max pooling.
First sort the values by row with torch.Tensor.sort
>>> v = x.sort(dim=1, descending=True).values
tensor([[0.6607, 0.1950, 0.1165, 0.0278],
[0.4607, 0.2729, 0.2135, 0.0529],
[0.4578, 0.3267, 0.1253, 0.0902]])
Then construct a mask on the cumulative sorted values that you get from applying torch.cumsum
:
>>> torch.cumsum(v, dim=1) > .8
tensor([[False, True, True, True],
[False, False, True, True],
[False, False, True, True]])
Applying a torch.Tensor.max
on that mask will return the index of the first occurring True
value, i.e. the location of the cumulative element which is above the threshold value 0.8
:
>>> mask.max(1, True).indices
tensor([[1],
[2],
[2]])
Answered By - Ivan
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.