Issue
I have Pytorch 2d tensor with normal distribution. Is there a fast way to nullify top 10% max values of this tensor using Python?
I see two possible ways here:
- Flatten tensor to 1d and just sort it
- Non-vectorized way using some native Python operators (for-if)
But neither of these looks fast enough.
So, what is the fastest way to set X max values of a tensor to zero?
Solution
Well, it seems that Pytorch has a useful operator torch.quantile() that helps here a lot.
The solution (for 1d tensor):
import torch
x = torch.randn(100)
y = torch.tensor(0.) #new value to assign
split_val = torch.quantile(x, 0.9)
x = torch.where(x < split_val, x, y)
Answered By - Cepera
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.