Issue
I am using Jax to do the grad of a matrix. For example I have a function f(A) where A is a matrix like A = \[\[a,b\], \[c,d\]\]
. I want to just do the grad of f(A) for a,c and d (more specific for the lower-triangular part). How can I do that? also for a general NxN matrix not just the 2x2.
I tried to convert the regular grad in a lower-triangular, but I am not sure if that is the same of if the output is correct.
Solution
JAX does not offer any way to take the gradient with respect to individual matrix elements. There are two ways you could proceed; first, you could take the gradient with respect to the entire array and extract the elements you're interested in; for example:
import jax
import jax.numpy as jnp
def f(A):
return (A ** 2).sum()
A = jnp.array([[1.0, 2.0], [3.0, 4.0]])
df_dA = jax.grad(f)(A)
print(df_dA[0, 0], df_dA[0, 1], df_dA[1, 2])
2.0 4.0 8.0
Alternatively, you could split the entries of the array into individual function arguments, and then use argnums
to take the gradient with respect to just the ones you're interested in:
def f(a, b, c, d):
A = jnp.array([[a, b], [c, d]])
return (A ** 2).sum()
df_da, df_db, df_dc = jax.grad(f, argnums=(0, 1, 2))(1.0, 2.0, 3.0, 4.0)
print(df_da, df_db, df_dc)
2.0 4.0 8.0
In general you'll probably find the first approach to be both easier to use in practice, and also more efficient. It does have some wasted computation, but sticking with vectorized computations will generally be a net win, especially if you're running on accelerators like GPU or TPU.
Answered By - jakevdp
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.