Issue
trying to write focal loss for multi-label classification
class FocalLoss(nn.Module):
def __init__(self, gamma=2, alpha=0.25):
self._gamma = gamma
self._alpha = alpha
def forward(self, y_true, y_pred):
cross_entropy_loss = torch.nn.BCELoss(y_true, y_pred)
p_t = ((y_true * y_pred) +
((1 - y_true) * (1 - y_pred)))
modulating_factor = 1.0
if self._gamma:
modulating_factor = torch.pow(1.0 - p_t, self._gamma)
alpha_weight_factor = 1.0
if self._alpha is not None:
alpha_weight_factor = (y_true * self._alpha +
(1 - y_true) * (1 - self._alpha))
focal_cross_entropy_loss = (modulating_factor * alpha_weight_factor *
cross_entropy_loss)
return focal_cross_entropy_loss.mean()
But when i run this i get
File "train.py", line 82, in <module>
loss = loss_fn(output, target)
File "/home/bubbles/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 538, in __call__
for hook in self._forward_pre_hooks.values():
File "/home/bubbles/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 591, in __getattr__
type(self).__name__, name))
AttributeError: 'FocalLoss' object has no attribute '_forward_pre_hooks'
Any suggestions would be really helpful, Thanks in advance.
Solution
You shouldn't inherit from torch.nn.Module
as it's designed for modules with learnable parameters (e.g. neural networks).
Just create normal functor or function and you should be fine.
BTW. If you inherit from it, you should call super().__init__()
somewhere in your __init__()
.
EDIT
Actually inheriting from nn.Module
might be a good idea, it allows you to use the loss as part of neural network and is common in PyTorch implementations/PyTorch Lightning.
Answered By - Szymon Maszke
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.