'Jacobian diagonal computation in JAX
I was wondering how we could use jax (https://github.com/google/jax) to compute a mapping of the derivative.
That is to say :
we have a vector and we want to apply (with the jax framework) a function to it, we call it
and it's a function
My question is : how can we easily retrieve the vector :
For exemple :
from jax import random
from jax import jacfwd, jacrev
import jax.numpy as jnp
key = random.PRNGKey(0)
key, W_key, b_key, input_key = random.split(key, 4)
W = random.normal(W_key, (10, 10))
b = random.normal(b_key, (10, ))
input = random.normal(input_key, (10, ))
One easy way to do that will be to take diagonal of the jacobian, but this method is very slow for high dimensional vector (> 10000). I am only interested in the diagonal of the jacobian ...
def f(input):
return jnp.dot(W, input) + b
J = jacfwd(f, argnums=0)(input)
result = jnp.diagonal(J)
For recall the jabobian matrix is :
Solution 1:[1]
There's not really a natural way to do this with JAX's transforms: you cannot simply map the input, because in general each diagonal entry of the jacobian depends on all inputs.
But given your particular function, you could compute the diagonal of the jacobian directly by rewriting the function like this:
from jax import vmap, grad
def f_single(val, i, W=W, b=b, input=input):
return jnp.dot(W[i], input.at[i].set(val)) + b[i]
idx = jnp.arange(len(input))
# equivalent to f(input)
print(vmap(f_single)(input, idx))
# [-1.5965443 -1.4081277 1.866176 -0.9789318 2.6717818 -1.0995009
# -2.3647223 3.6962256 3.3946664 2.589026 ]
# equivalent to jnp.diagonal(jacrev(f)(input))
print(vmap(grad(f_single))(input, idx))
# [-0.87553114 0.543098 2.265052 0.1403018 -1.4744948 1.4401387
# 0.4466088 0.72063404 -0.9135868 0.34965768]
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | jakevdp |
