Issue
I am trying to get gradient from sum of some indexes of an array using bincount. However, pytorch does not implement the gradient. This can be implemented by a loop and torch.sum but it is too slow. Is it possible to do this efficiently in pytorch (maybe einsum or index_add)? Of course, we can loop over indexes and add one by one, however that would increase the computational graph size significantly and is very low performance.
import torch
from torch import autograd
import numpy as np
tt = lambda x, grad=True: torch.tensor(x, requires_grad=grad)
inds = tt([1, 5, 7, 1], False).long()
y = tt(np.arange(4) + 0.1).float()
sum_y_section = torch.bincount(inds, y * y, minlength=8)
#sum_y_section = torch.sum(y * y)
grad = autograd.grad(sum_y_section, y, create_graph=True, allow_unused=False)
print("sum_y_section", sum_y_section)
print("grad", grad)
Solution
We can use a new feature in Pytorch V1.11 called scatter_reduce.
Answered By - Roy
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.