Issue
I am trying to train SRGAN from scratch. I have read solutions for this type of problem, but it would be great if someone could help me debug my code. The exact error is: "RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad()" Here is the snippet I am trying to train:
gen_model = Generator().to(device, non_blocking=True)
disc_model = Discriminator().to(device, non_blocking=True)
opt_gen = optim.Adam(gen_model.parameters(), lr=0.01)
opt_disc = optim.Adam(disc_model.parameters(), lr=0.01)
from torch.nn.modules.loss import BCELoss
def train_model(gen, disc):
for epoch in range(20):
run_loss_disc = 0
run_loss_gen = 0
for data in train:
low_res, high_res = data[0].to(device, non_blocking=True, dtype=torch.float).permute(0, 3, 1, 2),data[1].to(device, non_blocking=True, dtype=torch.float).permute(0, 3, 1, 2)
#--------Discriminator-----------------
gen_image = gen(low_res)
gen_image = gen_image.detach()
disc_gen = disc(gen_image)
disc_real = disc(high_res)
p=nn.BCEWithLogitsLoss()
loss_gen = p(disc_real, torch.ones_like(disc_real))
loss_real = p(disc_gen, torch.zeros_like(disc_gen))
loss_disc = loss_gen + loss_real
opt_disc.zero_grad()
loss_disc.backward()
run_loss_disc+=loss_disc
#---------Generator--------------------
cont_loss = vgg_loss(high_res, gen_image)
adv_loss = 1e-3*p(disc_gen, torch.ones_like(disc_gen))
gen_loss = cont_loss+(10^-3)*adv_loss
opt_gen.zero_grad()
gen_loss.backward()
opt_disc.step()
opt_gen.step()
run_loss_gen+=gen_loss
print("Run Loss Discriminator: %d", run_loss_disc)
print("Run Loss Generator: %d", run_loss_gen)
train_model(gen_model, disc_model)
Solution
Apparently your disc_gen
value was discarded by the first backward()
call, as it says.
It should work if you change the discriminator part a bit:
gen_image = gen(low_res)
disc_gen = disc(gen_image.detach())
and add this at the start of the generator part:
disc_gen = disc(gen_image)
Answered By - dx2-66
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.