Issue
I have a python script that trains and then tests a CNN model. The model weights/parameters are saved after testing through the use of:
checkpoint = {'state_dict': model.state_dict(),'optimizer' :optimizer.state_dict()}
torch.save(checkpoint, path + filename)
After saving I immediately load the model through the use of a function:
model_load = create_model(cnn_type="vgg", numberofclasses=len(cases))
And then, I load the model weights/parameters through:
model_load.load_state_dict(torch.load(filePath+filename), strict = False)
model_load.eval()
Finally, I feed this model the same testing data I used before the model was saved.
The problem is that the testing results are not the same when I compare the testing results of the model before saving and after loading. My hunch is that due to strict = False, some of the parameters are not being passed through to the model. However, when I make strict = True. I receive errors. Is there a work around this?
The error message is:
RuntimeError: Error(s) in loading state_dict for CNN:
Missing key(s) in state_dict: "linear.weight", "linear.bias", "linear 2.weight", "linea r2.bias", "linear 3.weight", "linear3.bias". Unexpected key(s) in state_dict: "state_dict", "optimizer".
Solution
You are loading a dictionary containing the state of your model as well as the optimizer's state. According to your error stack trace, the following should solve the issue:
>>> model_state = torch.load(filePath+filename)['state_dict']
>>> model_load.load_state_dict(model_state, strict=True)
Answered By - Ivan
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.