Issue
Let's say I wanted to multiply all parameters of a neural network in PyTorch (an instance of a class inheriting from torch.nn.Module
) by 0.9
. How would I do that?
Solution
Let net
be an instance of a neural network nn.Module
.
Then, to multiply all parameters by 0.9
:
state_dict = net.state_dict()
for name, param in state_dict.items():
# Transform the parameter as required.
transformed_param = param * 0.9
# Update the parameter.
param.copy_(transformed_param)
If you want to only update weights instead of every parameter:
state_dict = net.state_dict()
for name, param in state_dict.items():
# Don't update if this is not a weight.
if not "weight" in name:
continue
# Transform the parameter as required.
transformed_param = param * 0.9
# Update the parameter.
param.copy_(transformed_param)
Answered By - the-bass
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.