Issue
Say I have a tensor
tensor([[0, 1, 2, 2],
[2, 4, 2, 4],
[3, 4, 3, 1],
[4, 4, 4, 3]])
and a tensor of indices
tensor([[1],
[2],
[1],
[3]])
I want to compute the mean where the indices values match. In this case I want the mean of row 1 and 3 so the final output would be
tensor([[1.5, 2.5, 2.5, 1.5],
[2, 4, 2, 4],
[4, 4, 4, 3]])
Solution
You can use torch.scatter_reduce
to compute sums. To compute averages we have to use it twice, one for computing sums, and one for counting the summands, such that we can divide by the number of counts. One detail though is that since pytorch uses 0-based indexing we need to subtract 1 from those values:
import torch
a = torch.tensor([[0, 1, 2, 2], [2, 4, 2, 4], [3, 4, 3, 1], [4, 4, 4, 3]])
b = torch.tensor([[1], [2], [1], [3]])
cc = torch.tensor([[1.5, 5.2, 5.2, 1.5], [2, 4, 2, 4], [4, 4, 4, 3]]) # goal
c = torch.scatter_reduce(
a.to(float),
0,
torch.broadcast_to(b, a.shape) - 1,
reduce='mean'
)
print(c)
Answered By - flawr
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.