Issue
I'm writing a toy example performing the MNIST classification. Here is the full code of my example:
import matplotlib
matplotlib.use("Agg")
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import os
from os import system, listdir
from os.path import join, isfile, isdir, dirname
def img_transform(image):
transform=transforms.Compose([
# transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
return transform(image)
def normalize_output(img):
img = img - img.min()
img = img / img.max()
return img
def save_checkpoint(state, filename='checkpoint.pth.tar'):
torch.save(state, filename)
class Net(nn.Module):
"""docstring for Net"""
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.max_pool2d(x, 2)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
data_images, data_labels = torch.load("./PATH/MNIST/processed/training.pt")
model = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-2)
epochs = 5
batch_size = 30
num_batch = int(data_images.shape[0] / batch_size)
for epoch in range(epochs):
for batch_idx in range(num_batch):
data = data_images[ batch_idx*batch_size : (batch_idx+1)*batch_size ].float()
label = data_labels[ batch_idx*batch_size : (batch_idx+1)*batch_size ]
data = img_transform(data)
data = data.unsqueeze_(1)
pred_score = model(data)
loss = criterion(pred_score, label)
loss.backward()
optimizer.step()
if batch_idx % 200 == 0:
print('epoch', epoch, batch_idx, '/', num_batch, 'loss', loss.item())
_, pred = pred_score.topk(1)
pred = pred.t().squeeze()
correct = pred.eq(label)
num_correct = correct.sum(0).item()
print('acc=', num_correct/batch_size)
dict_to_save = {
'epoch': epochs,
'state_dict': model.state_dict(),
'optimizer' : optimizer.state_dict(),
}
ckpt_file = 'a.pth.tar'
save_checkpoint(dict_to_save, ckpt_file)
print('save to ckpt_file', ckpt_file)
exit()
The code is executable with MNIST dataset saved in the path ./PATH/MNIST/processed/training.pt
However, the training process does not converge, with the training accuracy always lower than 0.2. What's wrong with my implementation? I have tried different learning rates and batch size. It doesn't work.
Is there any other problem in my code?
Here are some of the training logs
epoch 0 0 / 2000 loss 27.2023868560791
acc= 0.1
epoch 0 200 / 2000 loss 2.3346288204193115
acc= 0.13333333333333333
epoch 0 400 / 2000 loss 2.691042900085449
acc= 0.13333333333333333
epoch 0 600 / 2000 loss 2.6452369689941406
acc= 0.06666666666666667
epoch 0 800 / 2000 loss 2.7910964488983154
acc= 0.13333333333333333
epoch 0 1000 / 2000 loss 2.966330051422119
acc= 0.1
epoch 0 1200 / 2000 loss 3.111387014389038
acc= 0.06666666666666667
epoch 0 1400 / 2000 loss 3.1988155841827393
acc= 0.03333333333333333
Solution
I see at least four issues that impact on the results you're getting:
- You need to zero the gradient, ex:
optimizer.zero_grad()
loss.backward()
optimizer.step()
- You're feeding
nn.CrossEntropyLoss()
withF.softmax
. It expects logits. Remove this:
output = F.log_softmax(x, dim=1)
- You're computing the loss and acc only for the current batch when you print it. So, it's not the correct result. To solve it you need to store all losses/accs and compute the average before print, for ex:
# During the loop
loss_value += loss.item()
# When printing:
print(loss_value/number_of_batch_losses_stored)
- It's not a huge problem, but I'd say this learning rate should be lesser, ex:
1e-3
.
As a tip to improve your pipeline, it's better to use a DataLoader
to load your data. Have a look at torch.utils.data
to learn how to do that. It's not efficient loading the batches the way you're doing because you're not using generators. Also, MNIST is already available on torchvision.datasets.MNIST
. It'll save you some time if you load data from there.
Answered By - André Pacheco
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.