Issue
I want to train a neural network with the help of two other neural networks, which are already trained and tested. The input of the network that I want to train is simultaniously inputted to the first static network. The output of the of the network that I want to train is inputted to the second static network. The loss shall be computed on the outputs of the static networks and propagated back to the train network.
# Initialization
var_model_statemapper = NeuralNetwork(9, [('linear', 9), ('relu', None), ('dropout', 0.2), ('linear', 8)])
var_model_panda = NeuralNetwork(9, [('linear', 9), ('relu', None), ('dropout', 0.2), ('linear', 27)])
var_model_panda.load_state_dict(torch.load("panda.pth"))
var_model_ur5 = NeuralNetwork(8, [('linear', 8), ('relu', None), ('dropout', 0.2), ('linear', 24)])
var_model_ur5.load_state_dict(torch.load("ur5.pth"))
var_loss_function = torch.nn.MSELoss()
var_optimizer = torch.optim.Adam(var_model_statemapper.parameters(), lr=0.001)
# Forward Propagation
var_panda_output = var_model_panda(var_statemapper_input)
var_ur5_output = var_model_ur5(var_statemapper_output)
var_train_loss = var_loss_function(var_panda_output, var_ur5_output)
# Backward Propagation
var_optimizer.zero_grad()
var_train_loss.backward()
var_optimizer.step()
You can see that the "var_model_statemapper" is the network that shall be trained. The networks "var_model_panda" and "var_model_ur5" are initialized and their state_dicts are being read from the according ".pth" files, so these networks need to be static. My main question is, which of the networks is updated in the backward propagation? Just the "var_model_statemapper" or all networks? And if the "var_model_statemapper" isn't updated, how do I achive this? And does PyTorch know which network to update just from the initialization of the optimizer?
Solution
Formalizing your pipeline to get a good idea of the setup:
x --- | state_mapper | --> y --- | ur5 | --> ur5_out
\ |
\ ↓
\--- | panda | --> panda_out ----------- | loss_fn | --> loss
Here is what is happening with lines you provided:
var_optimizer.zero_grad() # 0.
var_train_loss.backward() # 1.
var_optimizer.step() # 2.
Calling
zero_grad
on an optimizer will clear the cache of all parameter gradients contained in that optimizer. In your case, you havevar_optimizer
registered with the parameters fromvar_model_statemapper
(the model that you want to optimize).When you infer loss and backpropagate on it via the
backward
call, the gradients will propagate through the parameters of all three models.Then calling
step
on the optimizer will update the parameters registered in the optimizer you're called it upon. In your case, this meansvar_optimizer.step()
will update all parameters of the modelvar_model_statemapper
alone using the gradients computed in step 1. (namely using thebackward
call onvar_train_loss
).
All in all, your current approach will only update the parameters of var_model_statemapper
. Ideally, you can freeze models var_model_panda
and var_model_ur5
by setting their parameters' requires_grad
flag to False
. This will save speed on inference and training since their gradients won't be computed and stored during backpropagation.
Answered By - Ivan
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.