Issue
I'm trying to clone a tensor in pytorch and would like to also clone the tensor attributes. Here is an example:
import torch
from torch import nn
a = nn.Parameter(torch.rand(1))
a.adapt = True # define tensor attribute
b = a.clone() # clone
In the example above, I would like print(b.adapt)
to return True
; however, I get the following error:
Traceback (most recent call last): File "scratch.py", line 13, in <module> print(b.adapt) AttributeError: 'Tensor' object has no attribute 'adapt'
I'm wondering why tensor object attributes are removed by cloning and how to fix that.
Solution
The function torch.Tensor.clone
performs a copy of the tensor's data, not a copy of the Python object. This is the reason why the adapt attribute of a
is not available on b
. Additionally, it will keep the same grad_fn
on the newly created tensor:
Answered By - Ivan
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.