Issue
Im doing a research project where I want to create a custom loss function depending on the targets. I.e. I want to penalize with BCEWithLogitsLoss
plus adding a hyperparameter lambda
. I only want to add this hyperparameter if the model is not correctly detecting a class.
With more detail, I have a pretrained model that I want to retrain freezing some of the layers. This model detects faces in images with some probability. I want to penalize certain kind of images if they are incorrectly classified with a factor lambda (suppose that the images that need that penalization have a special character in the name or so)
From the source code of pytorch
:
import torch.nn.modules.loss as l
class CustomBCEWithLogitsLoss(l._Loss):
def __init__(self, weight: Optional[Tensor] = None, size_average=None, reduce=None, reduction: str = 'mean',
pos_weight: Optional[Tensor] = None) -> None:
super(BCEWithLogitsLoss, self).__init__(size_average, reduce, reduction)
self.register_buffer('weight', weight)
self.register_buffer('pos_weight', pos_weight)
self.weight: Optional[Tensor]
self.pos_weight: Optional[Tensor]
def forward(self, input: Tensor, target: Tensor) -> Tensor:
return F.binary_cross_entropy_with_logits(input, target,
self.weight,
pos_weight=self.pos_weight,
reduction=self.reduction)
Here, forward has two tensors as inputs, so I dont know how to add here the class of the images that I want to penalize with lambda. Adding lambda to the constructor is ok, but how to do the forward pass if it only allows tensors?
Edit:
To clarify the question, Suppose that I have a training/testing folder with the images. The files with the character @
in the filename are the ones that I want to classify correctly way more than the files without the character, with a factor lambda
.
How can I tell in the regular fashion of training a model in pytorch, that those files have to use a lambda
penalization (let's say that the loss function is lambda * BCEWithLogitLoss) but the other ones not? I'm using DataLoader
.
Solution
You can create a custom class for your dataset or instead build on top of an existing built-in dataset. For instance, you can use datasets.ImageFolder
as a base class. The logic added on top is to identify if the filename contains the special token, for example @
and provide this information in the element returned by __getitem__
. Looking at the parent __getitem__
function from datasets.DatasetFolder
, a minimal working implementation could be:
class Dataset(datasets.ImageFolder):
def __init__(self, token, *args, **kwargs):
super().__init__(*args, **kwargs)
self.token = token
def __getitem__(self, index):
path, _ = self.samples[index] # retrieve path of instance
match = self.token in path # determine if there is a match
x, y = super().__getitem__(index) # call parent to get input and label
return x, int(match), y
Each dataset element consists of the input image tensor and a label 0
, or 1
designating whether that input has the token
in its filename.
>>> ds = Dataset(token='@', root='root_to_dataset')
Which you would use like any other dataset: with a DataLoader
wrapper
>>> dl = DataLoader(ds, batch_size=2)
Now when iterating over this daloader, you will have:
>>> for x, m, y in dl:
... # x is the batch of images (b, c, h, w)
... # m is the batch of {0,1} whether inputs have the pattern in their path (b,)
... # y is the batch of labels (b,)
Now that we have this, we need to apply the lambda factor on the loss terms. However, we can't assume here that all elements in a given minibatch will follow the criteria (i.e. have the pattern in their filename), therefore we need to handle this element-wise and not in reduced form.
If you take a look at the source file for built-in loss functions nn.modules.loss
you will notice all loss functions are based on a class named _Loss
which expects a reduction
parameter. This will be useful for us.
First, consider switching off reduction on your loss function:
>>> bce = nn.BCEWithLogitLoss(reduction='none') # provide additional args if necessary
Considering we have the mask of "matches" m
which contains a 1
when the input has the token in their filename, and 0
otherwise. And given a lambda factor lamb
with which we want to weigh the elements where m=1
, we can provide the following coefficient to our loss term to perform the desired operation:
>>> coeff = lamb*m + 1-m
# if m=0 => coeff=1;
# if m=1 => coeff=lamb;
To apply the loss strategy properly we simple point-wise multiply coeff
with the unreduced loss term (which is shaped (b,)
).
>>> weighted = coeff*bce_loss
All in all, this would look like this:
>>> for x, m, y in dl:
... y_pred = model(x)
... bce_loss = bce(y_pred, y)
... coeff = lamb*m + 1-m
... bce_weighted = torch.mean(coeff*bce_loss)
... bce_weighted.backward()
Answered By - Ivan
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.