Issue
Consider the following code:
x = torch.tensor(2.0, requires_grad=True)
y = torch.square(x)
grad = autograd.grad(y, x)
x = x + grad[0]
y = torch.square(x)
grad2 = autograd.grad(y, x)
First, we have that ∇(x^2)=2x
. In my understanding, grad2=∇((x + ∇(x^2))^2)=∇((x+2x)^2)=∇((3x)^2)=9∇x^2=18x
. As expected, grad=4.0=2x
, but grad2=12.0=6x
, which I don't understand where it comes from. It feels as though the 3
comes from the expression I had, but it is not squared, and the 2
comes from the traditional derivative. Could somebody help me understand why this is happening? Furthermore, how far back does the computational graph that stores the gradients go?
Specifically, I am coming from a meta learning perspective, where one is interested in computing a quantity of the following form ∇ L(theta - alpha * ∇ L(theta))=(1 + ∇^2 L(theta)) ∇L(theta - alpha * ∇ L(theta)
(here the derivative is with respect to theta
). Therefore, the computation, let's call it A
, includes a second derivative. Computation is quite different than the following ∇_{theta - alpha ∇ L(theta)}L(\theta - alpha * ∇ L(theta))=∇_beta L(beta)
, which I will call B
.
Hopefully, it is clear how the snippet I had is related to what I described in the second paragraph. My overall question is: under what circumstances does pytorch realize computation A
vs computation B
when using autograd.grad
? I'd appreciate any explanation that goes into technical details about how this particular case is handled by autograd
.
PD. The original code I was following made me wonder this is here; in particular, lines 69 through 106, and subsequently line 193, which is when they use autograd.grad
. For the code is even more unclear because they do a lot of model.clone()
and so on.
If the question is unclear in any way, please let me know.
Solution
I made a few changes:
- I am not sure what
torch.rand(2.0)
is supposed to do. According to the text I simply set it to 2. - An intermediate variable
z
is added so that we can compute gradient w.r.t. to the original variable. Yours is overwritten. - Set
create_graph=True
to compute higher order gradients. See https://pytorch.org/docs/stable/generated/torch.autograd.grad.html
import torch
from torch import autograd
x = torch.ones(1, requires_grad=True)*2
y = torch.square(x)
grad = autograd.grad(y, x, create_graph=True)
z = x + grad[0]
y = torch.square(z)
grad2 = autograd.grad(y, x)
# yours is more like autograd.grad(y, z)
print(x)
print(grad)
print(grad2)
Answered By - hkchengrex
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.