Issue
I have two variables, x
and theta
. I am trying to minimise my loss with respect to theta
only, but as part of my loss function I need the derivative of a different function (f
) with respect to x
. This derivative itself is not relevant to the minimisation, only its output. However, when implementing this in PyTorch I am getting a Runtime error.
A minimal example is as follows:
# minimal example of two different autograds
import torch
from torch.autograd.functional import jacobian
def f(theta, x):
return torch.sum(theta * x ** 2)
def df(theta, x):
J = jacobian(lambda x: f(theta, x), x)
return J
# example evaluations of the autograd gradient
x = torch.tensor([1., 2.])
theta = torch.tensor([1., 1.], requires_grad = True)
# derivative should be 2*theta*x (same as an analytical)
with torch.no_grad():
print(df(theta, x))
print(2*theta*x)
tensor([2., 4.])
tensor([2., 4.])
# define some arbitrary loss as a fn of theta
loss = torch.sum(df(theta, x)**2)
loss.backward()
gives the following error
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
If I provide an analytic derivative (2*theta*x
), it works fine:
loss = torch.sum((2*theta*x)**2)
loss.backward()
Is there a way to do this in PyTorch? Or am I limited in some way?
Let me know if anyone needs any more details.
PS
I am imagining the solution is something similar to the way that JAX does autograd, as that is what I am more familiar with. What I mean here is that in JAX I believe you would just do:
from jax import grad
df = grad(lambda x: f(theta, x))
and then df
would just be a function that can be called at any point. But is PyTorch the same? Or is there some conflict within .backward()
that causes this error?
Solution
PyTorch's jacobian
does not create a computation graph unless you explicitely ask for it
J = jacobian(lambda x: f(theta, x), x, create_graph=True)
.. with create_graph
argument.
The documentation is quite clear about it
create_graph (bool, optional) – If True, the Jacobian will be computed in a differentiable manner
Answered By - ayandas
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.