Issue
I have created a synthesized dataset and the corresponding data loader for a binary classification problem using Pytorch. The zero class happens almost 20% and the other one 80%. When I train my model it only predicts the 80% ones which makes sense because it has seen the one with label one 80% of the time.
How can I handle this imbalance after getting the data from the data loader?
Is BCELoss capable of understanding this situation?
import torch
from torch.utils.data import DataLoader
class MyDataset(torch.utils.data.Dataset):
def __init__(self, data_size=50000):
self.data_size = data_size
def __len__(self) -> int:
return self.data_size
def __getitem__(self, idx):
data, label (label=0 (20%) or label=1 (80%)) = my_function()
return data, label
dataset = MyDataset()
# Assume a default batch size of 1
batch_size = 1000
dl = DataLoader(dataset, batch_size=batch_size)
# network
loss_fn = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-6)
losses = []
for epoch in range(num_epochs):
print(f"EPOCH: {epoch}")
for data, label in tqdm(dl, total=len(dl)):
data= data.to(DEVICE).float()
label = label.to(DEVICE).float()
optimizer.zero_grad()
preds = model(data)
prob = torch.sigmoid(preds)
loss = loss_fn(prob, label)
loss.backward()
losses.append(loss.detach().cpu().numpy())
optimizer.step()
# break
Question:
I have read about the weight argument for BCELoss but I am not sure if I can use it here or not. To me, it does not handle imbalance in the data. Note that I cannot manipulate my dataset since in practice the percentage is inherit it the data and we cannot change it from source.
Solution
I'm not sure what you mean by "after getting the data from the data loader" but I'll suggest anyway that you could oversample the minority class by using a WeightedRandomSampler
. This will make sure that the dataloader always returns the same amount of samples for each class. So there will be a 50/50 chance that it returns a sample of class 1 and 0. Here is how to do it:
import torch
from torch.utils.data import DataLoader, WeightedRandomSampler
class MyDataset(torch.utils.data.Dataset):
def __init__(self, data_size=50000):
self.data_size = data_size
def __len__(self) -> int:
return self.data_size
def __getitem__(self, idx):
data, label (label=0 (20%) or label=1 (80%)) = my_function()
return data, label
dataset = MyDataset()
# Assume a default batch size of 1
batch_size = 1000
class_weights = [1/20, 1/80] # inverse relative amount of samples per class
sample_weights = [0] * len(dataset)
for idx, (data, label) in enumerate(dataset):
sample_weights[idx] = class_weights[label]
sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)
dl = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
# continue with your code...
And using a weighted loss can also improve the accuracy. You can simply do something like this:
weights = [1/20, 1/80] # or [4, 1] or [1/2, 1/4] you can try whatever as long as the proportions match the class distribution in your dataset
loss_fn = nn.BCELoss(weights=weights)
You can use both of the methods. And of course the most effective method: augmentation. If it's possible for you to synthetically create more samples from your dataset the do that. An example for augmentation with images is to flip, rotate, crop the images to create slighty different ones. You could add that to your dataset class and have a 80% that the minority class is augmented.
Answered By - Theodor Peifer
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.