Issue
I'm trying to understand why I cannot directly overwrite the weights of a torch layer. Consider the following example:
import torch
from torch import nn
net = nn.Linear(3, 1)
weights = torch.zeros(1,3)
# Overwriting does not work
net.state_dict()["weight"] = weights # nothing happens
print(f"{net.state_dict()['weight']=}")
# But mutating does work
net.state_dict()["weight"][0] = weights # indexing works
print(f"{net.state_dict()['weight']=}")
#########
# output
: net.state_dict()['weight']=tensor([[ 0.5464, -0.4110, -0.1063]])
: net.state_dict()['weight']=tensor([[0., 0., 0.]])
I'm confused since state_dict()["weight"]
is just a torch tensor, so I feel I'm missing something really obvious here.
Solution
This is because net.state_dict()
first creates a collections.OrderedDict
object, then stores the weight tensor(s) of this module to it, and returns the dict:
state_dict = net.state_dict()
print(type(state_dict)) # <class 'collections.OrderedDict'>
When you "overwrite" (it's in fact not an overwrite; it's assignment in python) this ordered dict, you reassign an int 0 to the key 'weights'
of this ordered dict. The data in that tensor is not modified, it's just not referred to by the ordered dict.
When you check whether the tensor is modified by:
print(f"{net.state_dict()['weight']}")
a new ordered dict different from the one you have modified is created, so you see the unchanged tensor.
However, when you use indexing like this:
net.state_dict()["weight"][0] = weights # indexing works
then it's not assignment to the ordered dict anymore. Instead, the __setitem__
method of the tensor is called, which allows you to access and modify the underlying memory inplace. Other tensor APIs such as copy_
can also achieve desired results.
A clear explanation on the difference of a = b
and a[:] = b
when a
is a tensor/array can be found here: https://stackoverflow.com/a/68978622/11790637
Answered By - ihdv
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.