'Jacobian diagonal computation in JAX

I was wondering how we could use jax (https://github.com/google/jax) to compute a mapping of the derivative.

That is to say : we have a vector and we want to apply (with the jax framework) a function to it, we call it and it's a function

My question is : how can we easily retrieve the vector :

For exemple :

from jax import random
from jax import jacfwd, jacrev
import jax.numpy as jnp

key = random.PRNGKey(0)

key, W_key, b_key, input_key = random.split(key, 4)
W = random.normal(W_key, (10, 10))
b = random.normal(b_key, (10, ))
input = random.normal(input_key, (10, ))

One easy way to do that will be to take diagonal of the jacobian, but this method is very slow for high dimensional vector (> 10000). I am only interested in the diagonal of the jacobian ...

def f(input):
  return jnp.dot(W, input) + b

J = jacfwd(f, argnums=0)(input)

result = jnp.diagonal(J)

For recall the jabobian matrix is :



Solution 1:[1]

There's not really a natural way to do this with JAX's transforms: you cannot simply map the input, because in general each diagonal entry of the jacobian depends on all inputs.

But given your particular function, you could compute the diagonal of the jacobian directly by rewriting the function like this:

from jax import vmap, grad

def f_single(val, i, W=W, b=b, input=input):
  return jnp.dot(W[i], input.at[i].set(val)) + b[i]

idx = jnp.arange(len(input))

# equivalent to f(input)
print(vmap(f_single)(input, idx))
# [-1.5965443 -1.4081277  1.866176  -0.9789318  2.6717818 -1.0995009
#  -2.3647223  3.6962256  3.3946664  2.589026 ]

# equivalent to jnp.diagonal(jacrev(f)(input))
print(vmap(grad(f_single))(input, idx))
# [-0.87553114  0.543098    2.265052    0.1403018  -1.4744948   1.4401387
#   0.4466088   0.72063404 -0.9135868   0.34965768]

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 jakevdp