'Handle varying shapes in jax numpy arrays (jit compatible)
Important note: I need everything to be jit compatible here, otherwise my problem is trivial :)
I have a jax numpy array such as:
a = jnp.array([1,5,3,4,5,6,7,2,9])
First I filter it considering a value, let's assume that I only keep values that are < 5
a = jnp.where((a < 5), x=a, y=jnp.nan)
# a is now [ 1. nan 3. 4. nan nan nan 2. nan]
I want to keep only non-nan values: [ 1. 3. 4. 2.] and I will then use this array for other operations.
But more importantly, during execution of my program, this code will be executed multiple times with a threshold value that will change (i.e. it won't always be 5).
Hence, the shape of the final array will change too. Here is my problem with jit compilation, I don't know how to make it jit compatible since the shape depends on how many elements comply to the threshold condition.
Solution 1:[1]
JAX's JIT is not currently compatible with arrays of dynamic (data-dependent) shape, so there is no way to do what your question asks.
There is some experimental work in progress on handling dynamic shapes within JAX transforms like JIT (see https://github.com/google/jax/pull/9335) but I'm not certain when it will be available to use.
The usual workaround for this is to re-express your computations in terms of statically-shaped arrays with a fill value; for example, you could use something like this:
a = jnp.where((a < 5), size=len(a), fill_value=np.nan)
This will create an array of the same length as a, with non-nan values at the front, and filled with nan values at the end.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | jakevdp |
