'jax: sample many observations from random.choice with replacement between them

I'd like to pick two indices out of an array. These indices must not be the same. One such sample can be obtained with:

random.choice(next(key), num_items, (2,), replace=False)

For performance reasons, I'd like to batch the sampling:

num_samples = 100
samples = random.choice(next(key), num_items, (num_samples, 2), replace=False)

This doesn't work because of replace=False. It raises the error:

ValueError: Cannot take a larger sample than population when 'replace=False'

For each new sample, I'd like to have replace=True. Within one sample, I'd like to have replace=False. Is there a way to do this?

The next(key) in my random sampling is syntactic sugar. I'm using this snippet for convenience:

def reset_key(seed=42):
    key = random.PRNGKey(seed)
    while True:
        key, subkey = random.split(key)
        yield subkey
        
key = reset_key()
jax


Solution 1:[1]

The best way to do this is using jax.vmap to map across individual samples. For example:

from jax import random, vmap

def sample_two(key, num_items):
  return random.choice(key , num_items, (2,), replace=False)

key = random.PRNGKey(0)
num_samples = 10
num_items = 5

key_array = random.split(key, num_samples)
print(vmap(sample_two, in_axes=(0, None))(key_array, num_items))
# [[2 0]
#  [1 4]
#  [2 1]
#  [3 4]
#  [4 2]
#  [2 0]
#  [1 3]
#  [2 1]
#  [1 0]
#  [2 4]]

For more information on jax.vmap, see Automatic Vectorization in JAX.

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 jakevdp