Issue
How can I copy the parameters of one model to another in LibTorch? I know how to do it in Torch (Python).
net2.load_state_dict(net.state_dict())
I have tried with the code below in C++ with quite a bit of work. It didn't copy one to another.
I don't see an option to copy the parameters of one like model into another like model.
#include <torch/torch.h>
using namespace torch::indexing;
torch::Device device(torch::kCUDA);
void loadstatedict(torch::nn::Module& model, torch::nn::Module& target_model) {
torch::autograd::GradMode::set_enabled(false); // make parameters copying possible
auto new_params = target_model.named_parameters(); // implement this
auto params = model.named_parameters(true /*recurse*/);
auto buffers = model.named_buffers(true /*recurse*/);
for (auto& val : new_params) {
auto name = val.key();
auto* t = params.find(name);
if (t != nullptr) {
t->copy_(val.value());
} else {
t = buffers.find(name);
if (t != nullptr) {
t->copy_(val.value());
}
}
}
}
struct Critic_Net : torch::nn::Module {
torch::Tensor next_state_batch__sampled_action;
public:
Critic_Net() {
lin1 = torch::nn::Linear(3, 3);
lin2 = torch::nn::Linear(3, 1);
lin1->to(device);
lin2->to(device);
}
torch::Tensor forward(torch::Tensor next_state_batch__sampled_action) {
auto h = next_state_batch__sampled_action;
h = torch::relu(lin1->forward(h));
h = lin2->forward(h);
return h;
}
torch::nn::Linear lin1{nullptr}, lin2{nullptr};
};
auto net = Critic_Net();
auto net2 = Critic_Net();
auto the_ones = torch::ones({3, 3}).to(device);
int main() {
std::cout << net.forward(the_ones);
std::cout << net2.forward(the_ones);
loadstatedict(net, net2);
std::cout << net.forward(the_ones);
std::cout << net2.forward(the_ones);
}
Solution
Your solution with load_state_dict
should work if I understand correctly. The problem here is the same as in your previous question : nothing is registered as either parameters or buffers or submodules. Add the register_module
calls and it should work fine.
How this class should look like :
struct Critic_Net : torch::nn::Module {
public:
Critic_Net() {
lin1 = register_module("lin1", torch::nn::Linear(427, 42));
lin2 = register_module("lin1", torch::nn::Linear(42, 286));
lin3 = register_module("lin1", torch::nn::Linear(286, 1));
}
torch::Tensor forward(torch::Tensor next_state_batch__sampled_action) {
// unchanged
}
torch::nn::Linear lin1{nullptr}, lin2{nullptr}, lin3{nullptr};
};
Answered By - trialNerror
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.