Issue
I have a subclass of torch.nn.Module, whose initialiser have the following form: (in class A)
def __init__(self, additional_layer=False):
...
if additional_layer:
self.additional = nn.Sequential(nn.Linear(8,3)).to(self.device)
else:
self.additional = None
...
...
I train with additional_layer=True and save the model with torch.save
. The object I save is model.state_dict()
. Then I load the model for inference. But then I get the following error:
model.load_state_dict(best_model["my_model"])
RuntimeError: Error(s) in loading state_dict for A:
Unexpected key(s) in state_dict: "additional.0.weight"
Is using an optional field which can be None disallowed?? How to handle this properly? [Also posted here]
Solution
This is not a problem specifically related to the value being None
; you'd have the same issue if you were to use any other nn.Module
(as the value of additional
attribute) that is not a sequence (the 0
after additional
) and does not have a parameter named weight
in the first nn.Module
in the sequential module (the weight
after additional.0
).
The issue is, in your train mode, when you initialized your model, you have passed True
for additional_layer
argument i.e.:
model = YourModelClass(additional_layer=True)
hence self.additional
is set to a nn.Module
(nn.Sequential
specifically). So, the model
object's state_dict
would have the parameters for the module referred to by the self.additional
attribute.
Now, when you re-initialized the model for inference, you didn't have the additional layers as you initialized the model presumably by one of the following:
model = YourModelClass(additional_layer=False)
model = YourModelClass()
This time there the self.additional
i.e. model.additional
attribute would be None
. As a result, when you call model.load_state_dict
and pass it the state dict that was saved earlier in train mode (when additinal layer was there), it gives you the exception that all the keys for the additional
attribute are missing.
Assuming you have the correct conditional setup in the forward
method when self.additional
is None
, you can ignore the exception and bypass the loading of missing keys/parameters by setting the strict
argument to False
while using load_state_dict
:
model.load_state_dict(best_model["my_model"], strict=False)
Answered By - heemayl
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.