Issue
I see pytorch provides support to write custom loss functions. Consider following hinge loss.
class MarginRankingLossExp(nn.Module):
def __init__(self) -> None:
super(MarginRankingLossExp, self).__init__( )
def forward(self,input1,input2,target):
# loss_without_reduction = max(0, −target * (input1 − input2) + margin)
neg_target = -target
input_diff = input2-input1
mul_target_input = neg_target*input_diff
add_margin = mul_target_input
zeros=torch.zeros_like(add_margin)
loss = torch.max(add_margin, zeros)
return loss.mean()
This has only forward and constructor function defined. How does pytorch calculate gradient for custom functions? Does it differentiate it somehow? Also, This function is non differentiable at y=margin but it didn't throw any error.
Solution
Your function will be differentiable by PyTorch's autograd as long as all the operators used in your function's logic are differentiable. That is, as long as you use torch.Tensor
and built-in torch
operators that implement a backward function, your custom function will be differentiable out of the box.
In a few words, on inference, a computational graph will be constructed on the fly. That is, for every operation you make, the tensors necessary to compute the gradients will be matched for a later backward pass. Assuming that you use only differentiable operators (i.e. most operators are mathematically differentiable and as such PyTorch provides the backward functionality for them). You will be able to perform backpropagation on the graph: from the end of it from the loss term, up to its leaves on parameters and inputs.
A very easy way to tell if your function is differentiable by Autograd is to infer its output with inputs which require gradient computation. Then check for a grad_fn
callback on the output:
>>> x1 = torch.rand(1,10,2,2, requires_grad=True)
>>> x2 = torch.rand(1,10,2,2, requires_grad=True)
>>> y = torch.rand(1,10,2,2)
Here we can check with:
>>> MarginRankingLossExp()(x1, x2, y)
tensor(0.1045, grad_fn=<MeanBackward0>)
Where you notice MeanBackward0
which refers to torch.Tensor.mean
, being the very last operator applied by MarginRankingLossExp.forward
.
Answered By - Ivan
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.