Issue
Assume that I have two models in PyTorch, how can I load the weights of model 1 by weights of model 2 without saving the weights?
Like this:
model1.weights = model2.weights
In TensorFlow I can do this:
variables1 = model1.trainable_variables
variables2 = model2.trainable_variables
for v1, v2 in zip(variables1, variables2):
v1.assign(v2.numpy())
Solution
Assuming you have two instances of the same model (must subclass nn.Module
), then you can use nn.Module.state_dict()
and nn.Module.load_state_dict()
. You can find a brief introduction to state dictionaries here.
model1.load_state_dict(model2.state_dict())
Answered By - jodag
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.