Issue
I'm attempting to save and load best model through torch, where I've defined my training function as follows:
def train_model(model, train_loader, test_loader, device, learning_rate=1e-1, num_epochs=200):
# The training configurations were not carefully selected.
criterion = nn.CrossEntropyLoss()
model.to(device)
# It seems that SGD optimizer is better than Adam optimizer for ResNet18 training on CIFAR10.
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-4)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=500)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[65, 75], gamma=0.75, last_epoch=-1)
# optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
# Evaluation
model.eval()
eval_loss, eval_accuracy = evaluate_model(model=model, test_loader=test_loader, device=device, criterion=criterion)
print("Epoch: {:02d} Eval Loss: {:.3f} Eval Acc: {:.3f}".format(-1, eval_loss, eval_accuracy))
load_model = input('Load a model?')
for epoch in range(num_epochs):
if epoch//2 == 0:
write_checkpoint(model=model, epoch=epoch, scheduler=scheduler, optimizer=optimizer)
model, optimizer, epoch, scheduler = load_checkpoint(model=model, scheduler=scheduler, optimizer=optimizer)
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(device)
# Training
model.train()
running_loss = 0
running_corrects = 0
for inputs, labels in train_loader:
inputs = torch.FloatTensor(inputs)
inputs = inputs.to(device)
labels = labels.to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# statistics
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
train_loss = running_loss / len(train_loader.dataset)
train_accuracy = running_corrects / len(train_loader.dataset)
# Evaluation
model.eval()
eval_loss, eval_accuracy = evaluate_model(model=model, test_loader=test_loader, device=device, criterion=criterion)
# Set learning rate scheduler
scheduler.step()
print("Epoch: {:03d} Train Loss: {:.3f} Train Acc: {:.3f} Eval Loss: {:.3f} Eval Acc: {:.3f}".format(epoch, train_loss, train_accuracy, eval_loss, eval_accuracy))
return model
Where I'd like to be able to load a model, and start training from the epoch where model was saved.
So far I have methods to save model, optimizer,scheduler states and the epoch via
def write_checkpoint(model, optimizer, epoch, scheduler):
state = {'epoch': epoch + 1, 'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), }
filename = '/content/model_'
torch.save(state, filename + f'CP_epoch{epoch + 1}.pth')
def load_checkpoint(model, optimizer, scheduler, filename='/content/checkpoint.pth'):
# Note: Input model & optimizer should be pre-defined. This routine only updates their states.
start_epoch = 0
if os.path.isfile(filename):
print("=> loading checkpoint '{}'".format(filename))
checkpoint = torch.load(filename)
start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
scheduler = checkpoint['scheduler']
print("=> loaded checkpoint '{}' (epoch {})"
.format(filename, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(filename))
return model, optimizer, start_epoch, scheduler
But I can't seem to come up with the logic of how I'd update the epoch to start at the correct one. Looking for hints or ideas on how to implement just that.
Solution
If I understand correctly you trying to resume training from last progress with correct epoch number.
Before calling train_model
load the checkpoint values including start_epoch
. Then use start_epoch
as loop starting point,
for epoch in range(start_epoch, num_epochs):
Answered By - B200011011
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.