'vmap ops.index_update in Jax

I have the following code below and it's using a simple for loop. I was just wondering if there was a way to vmap it? Here is the original code:

import numpy as np 
import jax.numpy as jnp
import jax.scipy.signal as jscp
from scipy import signal
import jax

data = np.random.rand(192,334)

a = [1,-1.086740193996892,0.649914553946275,-0.124948974636730]
b = [0.054778173164082,0.164334519492245,0.164334519492245,0.054778173164082]
impulse = signal.lfilter(b, a, [1] + [0]*99) 
impulse_20 = impulse[:20]
impulse_20 = jnp.asarray(impulse_20)

@jax.jit
def filter_jax(y):
    for ind in range(0, len(y)):
      y = jax.ops.index_update(y, jax.ops.index[:, ind], jscp.convolve(impulse_20, y[:,ind])[:-19])
    return y

jnpData = jnp.asarray(data)

%timeit filter_jax(jnpData).block_until_ready()

And here is my attempt at using vmap:

def paraUpdate(y, ind):
    return jax.ops.index_update(y, jax.ops.index[:, ind], jscp.convolve(impulse_20, y[:,ind])[:-19])

@jax.jit
def filter_jax2(y):
  ranger = range(0, len(y))
  return jax.vmap(paraUpdate, y)(ranger)

But I receive the following error:

TypeError: vmap in_axes must be an int, None, or (nested) container with those types as leaves, but got Traced<ShapedArray(float32[192,334])>with<DynamicJaxprTrace(level=0/1)>.

I'm a little confused since the range is of type int so I'm not too sure what's going on.

In the end, I'm trying to get this little piece optimized as best as possible to get the lowest time.



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source