Issue
My PyTorch model
outputs two images: op
and psuedo-op
and I wish to backpropagate at only those pixels where loss(op_i,gt_i)<loss(psuedo-op_i,gt_i)
, where i
is used to index pixels. Clearly, loss(op,gt).backward()
cannot achieve this. So how to do?
Have some rough solution in mind as follows:
def loss(pred, target):
""" shape of pred/target is BatchxChannelxHeightxWidth, where Channel=1 """
abs_diff = torch.abs(target - pred)
l1_loss = abs_diff.mean(1, True)
return l1_loss
o1 = loss(op,gt)
o2 = loss(psuedo-op,gt)
o = torch.cat((o1,o2),dim=1)
value, idx = torch.min(o,dim=1)
NOW SOMEHOW USE IDX TO GENERATE MASK AND SELECTIVE BACKPROPAGATION
Any other solution would also work if it allows me to backpropagate on o1
but for only those pixels where o1<o2
.
Solution
You can use the relu
function for this purpose I think. Since you need to backprop only on o1
, you first need to detach
the loss o2
. And there is also a minus sign to correct the sign of the gradient.
# This diff is 0 when o1 > o2, equal to o2-o1 otherwise
o_diff = nn.functional.relu(o2.detach()-o1)
# gradient of (-relu(b-x)) is 0 if b-x < 0, 1 otherwise
(-o_diff).sum().backward()
Here, using the relu
as a kind of conditional on the sign of o2-o1
makes it very easy to to void gradients for coefficients with minus sign
I need to emphasize that since o2
is detached from the graph it is a constant with respect to your network, so it does not affect the gradient and thus this operation achieves what you need : it is basically backpropagating d/dx(-relu(b-o1(x))
which is 0 if b < o1(x)
and d/dx(o1(x))
otherwise (where b = o2
is constant).
Answered By - trialNerror
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.