Issue
Imagine the following scenario:
data = torch.Tensor([0.5,0.4,1.2,1.1,0.4,0.4])
indices = torch.Tensor([0,1,1,2,2,2])
What I would like to achieve is the following:
Compute the mean over the subset of samples within data
as indexed by indices
subset_means == torch.Tensor([0.5, 0.8, 0.8, 0.63, 0.63, 0.63])
I have not been able to come up with a satisfactory solution so far.
Solution
You can use Tensor.index_put
to accumulate values of an array according to some index array. This way you can sum up all values belonging to the same index. In the following snippet I use a separate call with an array of just ones to count the number of occurences of each index, to be able to compute the means from the sums:
import torch
data = torch.tensor([0.5,0.4,1.2,1.1,0.4,0.4])
indices = torch.tensor([0,1,1,2,2,2]).to(torch.long)
# sum groups according to indices
accum = torch.zeros((indices.max()+1, )).index_put((indices,), data, accumulate=True)
# count groups according to indices
cnt = torch.zeros((indices.max()+1,)).index_put((indices,), torch.ones((1,)), accumulate=True)
# compute means and expand according to indices
subset_means = (accum / cnt)[indices]
print(subset_means)
#subset_means == torch.Tensor([0.5, 0.8, 0.8, 0.63, 0.63, 0.63])
Answered By - flawr
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.