Issue
When we deal with imbalanced training data (there are more negative samples and less positive samples), usually pos_weight
parameter will be used.
The expectation of pos_weight
is that the model will get higher loss when the positive sample
gets the wrong label than the negative sample
.
When I use the binary_cross_entropy_with_logits
function, I found:
bce = torch.nn.functional.binary_cross_entropy_with_logits
pos_weight = torch.FloatTensor([5])
preds_pos_wrong = torch.FloatTensor([0.5, 1.5])
label_pos = torch.FloatTensor([1, 0])
loss_pos_wrong = bce(preds_pos_wrong, label_pos, pos_weight=pos_weight)
preds_neg_wrong = torch.FloatTensor([1.5, 0.5])
label_neg = torch.FloatTensor([0, 1])
loss_neg_wrong = bce(preds_neg_wrong, label_neg, pos_weight=pos_weight)
However:
>>> loss_pos_wrong
tensor(2.0359)
>>> loss_neg_wrong
tensor(2.0359)
The losses derived from wrong positive samples and negative samples are the same, so how does pos_weight
work in the imbalanced data loss calculation?
Solution
TLDR; both losses are identical because you are computing the same quantity: both inputs are identical, the two batch elements and labels are just switched.
Why are you getting the same loss?
I think you got confused in the usage of F.binary_cross_entropy_with_logits
(you can find a more detailed documentation page with nn.BCEWithLogitsLoss
). In your case your input shape (aka the output of your model) is one-dimensional, which means you only have a single logit x
, not two).
In your example you have
preds_pos_wrong = torch.FloatTensor([0.5, 1.5])
label_pos = torch.FloatTensor([1, 0])
This means your batch size is 2
, and since by default the function is averaging the losses of the batch elements, you end up with the same result for BCE(preds_pos_wrong, label_pos)
and BCE(preds_neg_wrong, label_neg)
. The two elements of your batch are just switched.
You can verify this very easily by not averaging the loss over the batch-elements with the reduction='none'
option:
>>> F.binary_cross_entropy_with_logits(preds_pos_wrong, label_pos,
pos_weight=pos_weight, reduction='none')
tensor([2.3704, 1.7014])
>>> F.binary_cross_entropy_with_logits(preds_pos_wrong, label_pos,
pos_weight=pos_weight, reduction='none')
tensor([1.7014, 2.3704])
Looking into F.binary_cross_entropy_with_logits
:
That being said the formula for the binary cross-entropy is:
bce = -[y*log(sigmoid(x)) + (1-y)*log(1- sigmoid(x))]
Where y
(respectively sigmoid(x)
is for the positive class associated with that logit, and 1 - y
(resp. 1 - sigmoid(x)
) is the negative class.
The documentation could be more precise on the weighting scheme for pos_weight
(not to be confused with weight
, which is the weighting of the different logits output). The idea with pos_weight
as you said, is to weigh the positive term, not the whole term.
bce = -[w_p*y*log(sigmoid(x)) + (1-y)*log(1- sigmoid(x))]
Where w_p
is the weight for the positive term, to compensate for the positive to negative sample imbalance. In practice, this should be w_p = #negative/#positive
.
Therefore:
>>> w_p = torch.FloatTensor([5])
>>> preds = torch.FloatTensor([0.5, 1.5])
>>> label = torch.FloatTensor([1, 0])
With the builtin loss function,
>>> F.binary_cross_entropy_with_logits(preds, label, pos_weight=w_p, reduction='none')
tensor([2.3704, 1.7014])
Compared with the manual computation:
>>> z = torch.sigmoid(preds)
>>> -(w_p*label*torch.log(z) + (1-label)*torch.log(1-z))
tensor([2.3704, 1.7014])
Answered By - Ivan
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.