Issue
I usually use fastai (v2 or v1) for fast prototyping. Now I'd like to deploy one of my models, trained with fastai, to torchserver.
Let's say that we have a simple model like this one:
learn = cnn_learner(data,
models.resnet34,
metrics=[accuracy, error_rate, score])
# after the training
torch.save(learn.model.state_dict(), "./test1.pth")
state = torch.load("./test1.pth")
model_torch_rep = models.resnet34()
model_torch_rep.load_state_dict(state)
I've tried many different things with the same result
RuntimeError Traceback (most recent call last)
<ipython-input-284-e4dbdce23d43> in <module>
----> 1 model_torch_rep.load_state_dict(state);
/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
837 if len(error_msgs) > 0:
838 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
--> 839 self.__class__.__name__, "\n\t".join(error_msgs)))
840 return _IncompatibleKeys(missing_keys, unexpected_keys)
841
RuntimeError: Error(s) in loading state_dict for ResNet:
Missing key(s) in state_dict: "conv1.weight", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "layer1.0.conv1.weight", "layer1.0.bn1.weight"
This is happening with fastai 1.0.6 or fastai 2.3.1 + pytorch 1.8.1 ...
Solution
Just figured this out.
For some reason the way you save the state_dict
adds a string "module." to each key in the loaded state_dict
. (This is because you aren't using Learner
class from FastAI to save the model, I assume).
Simply remove the "module." substring from the state_dict
and you're all good.
learn = cnn_learner(data,
models.resnet34,
metrics=[accuracy, error_rate, score])
# after the training
torch.save(learn.model.state_dict(), "./test1.pth")
state = torch.load("./test1.pth")
# fix dict keys
new_state = OrderedDict([(k.partition('module.')[2], v) for k, v in state.items()])
model_torch_rep = models.resnet34()
model_torch_rep.load_state_dict(new_state)
Answered By - Benjamin Kolber
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.