Issue
I trained my own model but decided to continue training. When I use the code below, my model shows high BCELoss as it is a non-trained model. Where is the problem? Thank you
model_1 = SimpleCnn(n_classes) # model class
model_1.load_state_dict(torch.load('./model.pth', map_location='cuda:0'))
model_1.to(DEVICE) # torch cuda device
history = train(train_dataset, val_dataset, model=model_1, epochs=8, batch_size=16) # train function
torch.save(model_1.state_dict(), 'model_1.pth')
Solution
In order to continue training, you need to save not only the state_dict
of the model, but of the optimizer's as well.
That is, during training, you need to save not only the trained weights of the model, but some other parameters. For example:
def train_function(...):
for e in range(num_epochs):
...
# when done training
torch.save({'model', model.state_dict(),
'opt', optimizer.state_dict(),
'lr', lr_sched.state_dict(), # you also need to save the state of the learning rate scheduler
... # there might be other things that define the "state" of your training
}, PATH)
Then, if you resume training:
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['opt'])
lr_sched.load_state_dict(checkpoint['lr'])
... # might need to restore other things
Answered By - Shai
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.