Issue
I'm trying to train a Convolutional Neural Network for image recognition on a training set 1500 images with 15 categories. I've been told that, with this architecture and initial weights drawn from a Gaussian distribution with a mean of 0 and a standard deviation of 0.01 and the initial bias values to 0, with the proper learning rate it should achieve an accuracy of around 30%.
However, it doesn't learn anything at all: the accuracy is similar to the one of a random classifier and the weights after training still follow a normal distribution. What am I doing wrong?
This is the NN
class simpleCNN(nn.Module):
def __init__(self):
super(simpleCNN,self).__init__() #initialize the model
self.conv1=nn.Conv2d(in_channels=1,out_channels=8,kernel_size=3,stride=1) #Output image size is (size+2*padding-kernel)/stride -->62*62
self.relu1=nn.ReLU()
self.maxpool1=nn.MaxPool2d(kernel_size=2,stride=2) #outtput image 62/2-->31*31
self.conv2=nn.Conv2d(in_channels=8,out_channels=16,kernel_size=3,stride=1) #output image is 29*29
self.relu2=nn.ReLU()
self.maxpool2=nn.MaxPool2d(kernel_size=2,stride=2) #output image is 29/2-->14*14 (MaxPool2d approximates size with floor)
self.conv3=nn.Conv2d(in_channels=16,out_channels=32,kernel_size=3,stride=1) #output image is 12*12
self.relu3=nn.ReLU()
self.fc1=nn.Linear(32*12*12,15) #16 channels * 16*16 image (64*64 with 2 maxpooling of stride 2), 15 output features=15 classes
self.softmax = nn.Softmax(dim=1)
def forward(self,x):
x=self.conv1(x)
x=self.relu1(x)
x=self.maxpool1(x)
x=self.conv2(x)
x=self.relu2(x)
x=self.maxpool2(x)
x=self.conv3(x)
x=self.relu3(x)
x=x.view(-1,32*12*12)
x=self.fc1(x)
x=self.softmax(x)
return x
The inizialization:
def init_weights(m):
if isinstance(m,nn.Conv2d) or isinstance(m,nn.Linear):
nn.init.normal_(m.weight,0,0.01)
nn.init.zeros_(m.bias)
model = simpleCNN()
model.apply(init_weights)
The training function:
loss_function=nn.CrossEntropyLoss()
optimizer=optim.SGD(model.parameters(),lr=0.1,momentum=0.9)
def train_one_epoch(epoch_index,loader):
running_loss=0
for i, data in enumerate(loader):
inputs,labels=data #get the minibatch
outputs=model(inputs) #forward pass
loss=loss_function(outputs,labels) #compute loss
running_loss+=loss.item() #sum up the loss for the minibatches processed so far
optimizer.zero_grad() #reset gradients
loss.backward() #compute gradient
optimizer.step() #update weights
return running_loss/(i+1) # average loss per minibatch
The training:
EPOCHS=20
best_validation_loss=np.inf
for epoch in range(EPOCHS):
print('EPOCH{}:'.format(epoch+1))
model.train(True)
train_loss=train_one_epoch(epoch,train_loader)
running_validation_loss=0.0
model.eval()
with torch.no_grad(): # Disable gradient computation and reduce memory consumption
for i,vdata in enumerate(validation_loader):
vinputs,vlabels=vdata
voutputs=model(vinputs)
vloss=loss_function(voutputs,vlabels)
running_validation_loss+=vloss.item()
validation_loss=running_validation_loss/(i+1)
print('LOSS train: {} validation: {}'.format(train_loss,validation_loss))
if validation_loss<best_validation_loss: #save the model if it's the best so far
timestamp=datetime.now().strftime('%Y%m%d_%H%M%S')
best_validation_loss=validation_loss
model_path='model_{}_{}'.format(timestamp,epoch)
torch.save(model.state_dict(),model_path)
With the default initializion it works a little better, but i'm supposed to reach 30% with the gaussian. Could you spot some issue that might be causing it not to learn? I have already tries different learning rates and momentum.
Solution
The problem is given was given by the fact that I was importing the images, and then converting them to tensors with transforms.ToTensor()
, which rescales the pixel values in the range [0,1]. While the CNN was actually meant to work with [0,255].
Having so small pixel values, a small standard deviation with the normal initialization is almost equivalent to a null initialization.
So in order to fix this kind of problem you have to be sure that the pixel values are in the range [0,255].
Also the softmax
at the end of the network worsens the problem, as already pointed out.
Answered By - Alessandro Cesa
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.