Issue
I want to apply a custom non-torch function on the final calculated loss before computing the gradients (calling backward()). An example would be to replace the torch.mean() on the loss vector with a custom pythonic, non-torch mean function. But doing so will break the computation graph. I can not rewrite the custom mean function using torch operators and I am at a loss as how to do this. Any suggestions?
Solution
In pytorch you can easily do this by inheriting from torch.autograd.Function
: All you need to do is implement your custom forward()
and the corresponding backward()
methods. Because I don't know the function you intend to write, I'll demonstrate it by implementing the sine function in a way that works with the automatic differentiation. Note that you need to have a method to compute the derivative of your function with respect to its input to implement the backward pass.
import torch
class MySin(torch.autograd.Function):
@staticmethod
def forward(ctx, inp):
""" compute forward pass of custom function """
ctx.save_for_backward(inp) # save activation for backward pass
return inp.sin() # compute forward pass, can also be computed by any other library
@staticmethod
def backward(ctx, grad_out):
""" compute product of output gradient with the
jacobian of your function evaluated at input """
inp, = ctx.saved_tensors
grad_inp = grad_out * torch.cos(inp) # propagate gradient, can also be computed by any other library
return grad_inp
To use it you can use the function sin = MySin.apply
on your input.
There is also another example worked out in the documentation.
Answered By - flawr
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.