Issue
I'm trying to use two different losses, MSELoss for some of my labels and a custom loss for the other labels. I'm then trying to sum these losses together before backprop. My model prints out the same loss after every epoch so I must be doing something wrong. Any help is appreciated! I suspect my implementation is messing up Pytorch's autograd. See code below:
mse_loss = torch.nn.MSELoss()
...
loss1 = mse_loss(preds[:,(0,1,3)], label[:,(0,1,3)])
print("loss1", loss1)
loss2 = my_custom_loss(preds[:,2], label[:,2])
print("loss2", loss2)
print("summing losses")
loss = sum([loss1, loss2]) # tensor + float = tensor
print("loss sum", loss)
loss = torch.autograd.Variable(loss, requires_grad=True)
print("loss after Variable(loss, requires_grad=True)", loss)
These print statements yield:
loss1 tensor(4946.1221, device='cuda:0', grad_fn=<MseLossBackward0>)
loss2 34.6672
summing losses
loss sum tensor(4980.7891, device='cuda:0', grad_fn=<AddBackward0>)
loss after Variable() tensor(4980.7891, device='cuda:0', requires_grad=True)
My custom loss function is below:
def my_custom_loss(preds, label):
angle_diff = preds - label
# /2 to bring angle diff between -180<theta<180
half_angle_diff = angle_diff.detach().cpu().numpy()/2
sine_diff = np.sin(half_angle_diff)
square_sum = np.nansum(sine_diff**2)
return square_sum
Solution
The reason why you are not backpropagating through your second loss is that you haven't defined it as a differentiable operator. You should stick with PyTorch operators without switching to NumPy.
Something like this will work:
def my_custom_loss(preds, label):
half_angle_diff = (preds - label)/2
sine_diff = torch.sin(half_angle_diff)
square_sum = torch.nansum(sine_diff**2)
return square_sum
You can check that your custom loss is differentiable with dummy inputs:
>>> preds = torch.rand(1,3,10,10, requires_grad=True)
>>> label = torch.rand(1,3,10,10)
>>> my_custom_loss(preds, label)
tensor(11.7584, grad_fn=<NansumBackward0>)
Notice the grad_fn
attribute on it which shows the output tensor is indeed attached to a computational graph, and you can therefore perform back propagation from it.
Additionally, you should not use Variable
as it is now deprecated.
Answered By - Ivan
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.