Issue
Given tensor
A = torch.tensor([0.0316, 0.2338, 0.2338, 0.2338, 0.0316, 0.0316, 0.0860, 0.0316, 0.0860])
containing probabilities which sum to 1 (I removed some decimals but it's safe to assume it'll always sum to 1), I want to sample a value from A
where the value itself is the likelihood of getting sampled. For instance, the likelihood of sampling 0.0316
from A
is 0.0316
. The output of the value sampled should still be a tensor.
I tried using WeightedRandomSampler
but it doesn't allow the value selected to be a tensor anymore, instead it detaches.
One caveat that makes this tricky is that I want to also know the index of the sampled value as it appears in the tensor. That is, say I sample 0.2338
, I want to know if it's index 1
, 2
or 3
of tensor A
.
Solution
Selecting with the expected probabilities can be achieved by accumulating the weights and selecting the insertion index of a random float [0,1). The example array A is slightly adjusted to sum up to 1.
import torch
A = torch.tensor([0.0316, 0.2338, 0.2338, 0.2338, 0.0316, 0.0316, 0.0860, 0.0316, 0.0862], requires_grad=True)
p = A.cumsum(0)
#tensor([0.0316, 0.2654, 0.4992, 0.7330, 0.7646, 0.7962, 0.8822, 0.9138, 1.0000], grad_fn=<CumsumBackward0>))
idx = torch.searchsorted(p, torch.rand(1))
A[idx], idx
Output
(tensor([0.2338], grad_fn=<IndexBackward0>), tensor([3]))
This is faster than the more common approach with A.multinomial(1)
.
Sampling 10000 times one element to check that the distribution conforms to the probabilities
from collections import Counter
Counter(int(A.multinomial(1)) for _ in range(10000))
#1 loop, best of 5: 233 ms per loop
# vs @HatemAli's solution
dist=torch.distributions.categorical.Categorical(probs=A)
Counter(int(dist.sample()) for _ in range(10000))
# 10 loops, best of 5: 107 ms per loop
Counter(int(torch.searchsorted(p, torch.rand(1))) for _ in range(10000))
# 10 loops, best of 5: 53.2 ms per loop
Output
Counter({0: 319,
1: 2360,
2: 2321,
3: 2319,
4: 330,
5: 299,
6: 903,
7: 298,
8: 851})
Answered By - Michael Szczesny
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.