Issue
I am trying to port some code from PyTorch to LibTorch.
Supposing in a struct inheriting from torch::nn::Module I have a registered sequential module like
branch1 = register_module("branch1", torch::nn::Sequential(torch::nn::Conv2d(torch::nn::Conv2dOptions(in_channels, branch_channels, kernel_size).padding(0)),
torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(branch_channels)),
torch::nn::ReLU());
I am interested in applying a weight initialization function to each component separately (ideally with a different initialization algorithm per module type), say a function that takes in a torch::nn::Module or a pointer to a torch::nn::Module, what is the simplest way to achieve this?
Edit: My current attempt.
#include <torch/torch.h>
using namespace std;
void init_conv(torch::nn::Conv2d& conv) {
torch::NoGradGuard noGrad;
torch::nn::init::kaiming_normal_(conv->weight, 0.0, torch::kFanOut, torch::kReLU);
torch::nn::init::constant_(conv->bias, 0);
}
void init_bn_2d(torch::nn::BatchNorm2d& bn_2d) {
torch::NoGradGuard noGrad;
torch::nn::init::constant_(bn_2d->weight, 1);
torch::nn::init::constant_(bn_2d->bias, 0);
}
void initialize_sequential(torch::nn::Sequential& seq) {
torch::NoGradGuard noGrad;
vector<shared_ptr<torch::nn::Module>> mods = seq->modules();
for (auto mod = std::begin(mods); mod != end(mods); ++mod) {
shared_ptr<torch::nn::Module> m = *mod;
torch::nn::Module* m_ = m.get();
if (typeid(*m_) == typeid(torch::nn::Conv2dImpl*)) {
torch::nn::Conv2d* c = dynamic_cast<torch::nn::Conv2d*>(m_);
init_conv(*c);
}
if (typeid(*m_) == typeid(torch::nn::BatchNorm2dImpl*)) {
torch::nn::BatchNorm2d* bn = dynamic_cast<torch::nn::BatchNorm2d*>(m_);
init_bn_2d(*bn);
}
}
}
Solution
I can use the apply() function on the sequential object like this:
#include <torch/torch.h>
void sequential_init_weights(torch::nn::Module& m){
if ((typeid(m) == typeid(torch::nn::Conv2dImpl))) {
auto p = m.named_parameters(false);
auto w = p.find("weight");
auto b = p.find("bias");
if (w != nullptr) torch::nn::init::kaiming_normal_(*w, 0.0,
torch::kFanOut, torch::kReLU);
if (b != nullptr) torch::nn::init::constant_(*b, 0.0);
}
if ((typeid(m) == typeid(torch::nn::BatchNorm2dImpl))) {
auto p = m.named_parameters(false);
auto w = p.find("weight");
auto b = p.find("bias");
if (w != nullptr) torch::nn::init::constant_(*w, 1.0);
if (b != nullptr) torch::nn::init::constant_(*b, 0.0);
}
}
struct example_mod : torch::nn::Module {
example_mod(int64_t in_channels, int64_t out_channels) {
m = register_module("m", torch::nn::Sequential(torch::nn::Conv2d(torch::nn::Conv2dOptions(in_channels, out_channels, 1)),
torch::nn::BatchNorm2d(torch::nn::BatchNorm2dOptions(out_channels),
torch::nn::ReLU()));
m->apply(sequential_init_weights);
}
torch::nn::Sequential m = nullptr;
};
Basically just write a function that parses the modules by typeid then used the named parameters to get what you need and pass those to an init function, seems to work pretty well.
Answered By - IntegrateThis
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.