Issue
I have 2 networks that I'm trying to update:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal
import matplotlib.pyplot as plt
from tqdm import tqdm
softplus = torch.nn.Softplus()
class Model_RL(nn.Module):
def __init__(self):
super(Model_RL, self).__init__()
self.fc1 = nn.Linear(3, 20)
self.fc2 = nn.Linear(20, 30)
self.fc3 = nn.Linear(30, 2)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = softplus(self.fc3(x))
return x
class Model_FA(nn.Module):
def __init__(self):
super(Model_FA, self).__init__()
self.fc1 = nn.Linear(1, 20)
self.fc2 = nn.Linear(20, 30)
self.fc3 = nn.Linear(30, 1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = softplus(self.fc3(x))
return x
net_RL = Model_RL()
net_FA = Model_FA()
The training loop is
inps = torch.tensor([[1.0]])
y = torch.tensor(10.0)
opt_RL = optim.Adam(net_RL.parameters())
opt_FA = optim.Adam(net_FA.parameters())
baseline = 0
baseline_lr = 0.1
epochs = 100
for _ in tqdm(range(epochs)):
for inp in inps:
with torch.no_grad():
net_FA(inp)
for layer in range(3):
out_RL = net_RL(torch.tensor([1.0,2.0,3.0]))
mu, std = out_RL
dist = Normal(mu, std)
update_values = dist.sample()
log_p = dist.log_prob(update_values).mean()
out = net_FA(inp)
reward = -torch.square((y - out))
baseline = (1 - baseline_lr) * baseline + baseline_lr * reward
loss_RL = - (reward - baseline) * log_p
opt_RL.zero_grad()
opt_FA.zero_grad()
loss_RL.backward()
opt_RL.step()
out = net_FA(inp)
loss_FA = torch.mean(torch.square(y - out))
opt_RL.zero_grad()
opt_FA.zero_grad()
loss_FA.backward()
opt_FA.step()
print("Mean: " + str(mu.detach().numpy()) + ", Goal: " + str(y))
print("Standard deviation: " + str(softplus(std).detach().numpy()) + ", Goal: 0ish")
I'm getting 2 main errors:
RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling .backward()...
And when I add retain_graph=True
to both backward
calls I get the following
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [30, 1]], which is output 0 of TBackward, is at version 5; expected version 4 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True)
My main question is how can I make this training work?
But intermediate questions are:
why does retain_graph=True
is needed here if I'm using a loop? From here: "there is no need to use retain_graph=True. In each loop, a new graph is created"
Why does it seem as if the retain_graph=True
makes training significantly slower (if I remove the other backward
call)? This doesn't really makes sense to me as in each epoch a new computational graph should be created (and not just one that is being extended).
Solution
I think the line baseline = (1 - baseline_lr) * baseline + baseline_lr * reward
causing the error. Because:
- previous state of
baseline
is used to get new state ofbaseline
. - PyTorch will track all these states inside a graph.
backward
will flush the graph.- variable
baseline
of time - t + 1 will try to backpropagate throughbaseline
of time - t. - But at time - t + 1 graph behind
baseline
of time - t doesn't exist. - This leads to error
Solution:
As you are not optimizing variable baseline
or anything behind baseline
- Initialize
baseline
as torch tensor. - detach it from graph before updating state.
Try this:
# intialize baseline as torch tensor
baseline = torch.tensor(0.)
baseline_lr = 0.1
epochs = 100
for _ in tqdm(range(epochs)):
for inp in inps:
with torch.no_grad():
net_FA(inp)
for layer in range(3):
out_RL = net_RL(torch.tensor([1.0,2.0,3.0]))
mu, std = out_RL
dist = Normal(mu, std)
update_values = dist.sample()
log_p = dist.log_prob(update_values).mean()
out = net_FA(inp)
reward = -torch.square((y - out))
# detach baseline from graph
baseline = (1 - baseline_lr) * baseline.detach() + baseline_lr * reward
loss_RL = - (reward - baseline) * log_p
opt_RL.zero_grad()
opt_FA.zero_grad()
loss_RL.backward()
opt_RL.step()
out = net_FA(inp)
loss_FA = torch.mean(torch.square(y - out))
opt_RL.zero_grad()
opt_FA.zero_grad()
loss_FA.backward()
opt_FA.step()
But actually I don't know why you are updating the networks, 3 times for the same input?
Answered By - Girish Dattatray Hegde
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.