Issue
For simplicity say I want to set all params of a torch model to the constant 72114982
with this code
model = Net()
params = model.state_dict()
for k, v in params.items():
params[k] = torch.full(v.shape, 72114982, dtype=torch.long)
model.load_state_dict(params)
print(model.state_dict().values())
Then the print statement shows all values actually get set to 72114984
that is 2 off from the one I initially intended.
For simplicity define Net
as follows
class Net(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(2, 2, 2)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(2, 2)
Solution
This is an issue of data types.
Model parameters are cast to float tensors. 72114984
is large enough that its float representation rounds to 72114984
.
You can verify this with the following:
x = torch.tensor(72114982, dtype=torch.long)
y = x.float() # y will actually be `72114984.0`
# this returns `True` because x is cast to a float before evaluating
x == y
> tensor(True)
# for the same reason, this returns 0.
y - x
> tensor(0.)
# this returns `False` because the tensors have different values and we don't cast to float
x == y.long()
> tensor(False)
# as longs, the difference correctly evaluates to 2
y.long() - x
> tensor(2)
Answered By - Karl
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.