Issue
I am trying to port a python PyTorch model to LibTorch in C++.
In python the line of code within a subclass of a torch.Module object
self.A = nn.Parameter(A)
where A
is a torch.tensor object with requires_grad=True
.
What would be the equivalent of this for a torch::Tensor in a torch::nn::Module class in C++ ?
The autocomplete in my editor shows the classes ParameterDict, ParameterList, ParameterDictImpl, ParamaterListImpl, but no Parameter. Do I need to wrap it in a list of size 1 or is there something else I'm missing. I wasn't able to find what I needed from a google search or the documentation, but I wasn't sure precisely what to search to be honest.
Solution
To register a parameter (or tensor which requires gradients) to a module, you could use:
m.register_parameter("A", torch::ones({20, 1, 5, 5}), True)
;
in libtorch.
Answered By - IntegrateThis
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.