Issue
I want to use a pretrained model as the encoder part in my model. You can find a version of my model:
class MyClass(nn.Module):
def __init__(self, pretrained=False):
super(MyClass, self).__init__()
self.encoder=S3D_featureExtractor_multi_output()
if pretrained:
weight_dict=torch.load(os.path.join('models','weights.pt'))
model_dict=self.encoder.state_dict()
list_weight_dict=list(weight_dict.items())
list_model_dict=list(model_dict.items())
for i in range(len(list_model_dict)):
assert list_model_dict[i][1].shape==list_weight_dict[i][1].shape
model_dict[list_model_dict[i][0]].copy_(weight_dict[list_weight_dict[i][0]])
for i in range(len(list_model_dict)):
assert torch.all(torch.eq(model_dict[list_model_dict[i][0]],weight_dict[list_weight_dict[i][0]].to('cpu')))
print('Loading finished!')
def forward(self, x):
a, b = self.encoder(x)
return a, b
Because I modified some parts of the code of this pretrained model, based on this post I need to apply strict=False
to avoid facing error, but based on the scenario that I load the pretrained weights, I cannot find a place in the code to apply strict=False. How can I apply that or how can I change the scenario of loading the pretrained model taht makes it possible to apply strict=False
?
Solution
strict = False
is to specify when you use load_state_dict()
method. state_dict
are just Python dictionaries that helps you save and load model weights.
(for more details, see https://pytorch.org/tutorials/recipes/recipes/what_is_state_dict.html)
If you use strict=False
in load_state_dict
, you inform PyTorch that the target model and the original model are not identical, so it just initialises the weights of layers which are present in both and ignores the rest.
(see https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict)
So, you will need to specify the strict argument when you load the pretrained model weights. load_state_dict
can be called at this step.
If the model for which weights must be loaded is self.encoder
and if state_dict
can be retrieved from the model you just loaded, you can just do this
loaded_weights = torch.load(os.path.join('models','weights.pt'))
self.encoder.load_state_dict(loaded_weights, strict=False)
for more details and a tutorial, see https://pytorch.org/tutorials/beginner/saving_loading_models.html .
Answered By - inarighas
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.