Issue
I have a custom forward
implementation for a PyTorch loss. The training works well. I've checked the loss.grad_fn
and it is not None
.
I'm trying to understand two things:
How this function can be differentiable since there is an
if
-else
statement on the path from input to output?Does the path from
gt
(ground truth input) to loss (output) need to be differentiable? or only the path frompred
(prediction input)?
Here is the source code:
class FocalLoss(nn.Module):
def __init__(self):
super(FocalLoss, self).__init__()
def forward(self, pred, gt):
pos_inds = gt.eq(1).float()
neg_inds = gt.lt(1).float()
neg_weights = torch.pow(1 - gt, 4)
pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds
neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * neg_weights * neg_inds
num_pos = pos_inds.float().sum()
pos_loss_s = pos_loss.sum()
neg_loss_s = neg_loss.sum()
if num_pos == 0:
loss = - neg_loss_s
else:
loss = - (pos_loss_s + neg_loss_s) / num_pos
return loss
Solution
The if
statement is not part of the computational graph. It is part of the code used to build this graph dynamically (i.e. the forward
function) but it isn't in itself part of it. The principle to follow is to ask yourself whether you backtrack to the leaves of the graph (tensors that do not have parents in the graph, i.e. inputs, and parameters) using grad_fn
callbacks of each node, backpropagating through the graph. The answer is you can only do so if each of the operators is differentiable: in programming terms, they implement a backward function operation (a.k.a. grad_fn
).
In your example, whether
num_pos
is equal to0
or not, the resulting loss tensor will depend onneg_loss_s
alone or onpos_loss_s
andneg_loss_s
. However in either cases, the resultingloss
tensor remains attached to the inputpred
:- via one way: the "
neg_loss_s
" node - or the other: the "
pos_loss_s
" and "neg_loss_s
" nodes.
- via one way: the "
In your setup, either way, the operation is differentiable.
- If
gt
is a ground-truth tensor then it doesn't require gradient and the operation from it to the final loss doesn't need to be differentiable. This is the case in your example where bothpos_inds
, andneg_inds
are non-differientblae because they are boolean operators.
Answered By - Ivan
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.