Issue
in a simple test in pytorch, I want to see grad in a non-leaf tensor, so I use retain_grad():
import torch
a = torch.tensor([1.], requires_grad=True)
y = torch.zeros((10))
gt = torch.zeros((10))
y[0] = a
y[1] = y[0] * 2
y.retain_grad()
loss = torch.sum((y-gt) ** 2)
loss.backward()
print(y.grad)
it gives me a normal output:
tensor([2., 4., 0., 0., 0., 0., 0., 0., 0., 0.])
but when I use retain grad() before y[1] and after y[0] is assigned:
import torch
a = torch.tensor([1.], requires_grad=True)
y = torch.zeros((10))
gt = torch.zeros((10))
y[0] = a
y.retain_grad()
y[1] = y[0] * 2
loss = torch.sum((y-gt) ** 2)
loss.backward()
print(y.grad)
now the output changes to:
tensor([10., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
I can't understand the result at all.
Solution
Okay so what's going on is really weird.
What .retain_grad()
essentially does is convert any non-leaf tensor into a leaf tensor, such that it contains a .grad
attribute (since by default, pytorch computes gradients to leaf tensors only).
Hence, in your first example, after calling y.retain_grad()
, it basically converted y
into a leaf tensor with an accessible .grad
attribute.
However, in your second example, you initially converted the entire y
tensor into a leaf tensor; then, you created a non-leaf tensor (y[1])
within your leaf tensor (y)
, which is what caused the confusion.
y = torch.zeros((10)) # y is a non-leaf tensor
y[0] = a # y[0] is a non-leaf tensor
y.retain_grad() # y is a leaf tensor (including y[1])
y[1] = y[0] * 2 # y[1] is a non-leaf tensor, BUT y[0], y[2], y[3], ..., y[9] are all leaf tensors!
The confusing part is:
y[1]
after calling y.retain_grad()
is now a leaf tensor with a .grad
attribute. However, y[1]
after the computation (y[1] = y[0] * 2)
is now not a leaf tensor with a .grad
attribute; it is now treated as a new non-leaf variable/tensor.
Therefore, when calling loss.backward()
, the Chain rule of the loss
w.r.t y
, and particularly looking at the Chain rule of the loss
w.r.t leaf y[1]
now looks something like this:
Answered By - Omar AlSuwaidi
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.