'jax: sample many observations from random.choice with replacement between them
I'd like to pick two indices out of an array. These indices must not be the same. One such sample can be obtained with:
random.choice(next(key), num_items, (2,), replace=False)
For performance reasons, I'd like to batch the sampling:
num_samples = 100
samples = random.choice(next(key), num_items, (num_samples, 2), replace=False)
This doesn't work because of replace=False. It raises the error:
ValueError: Cannot take a larger sample than population when 'replace=False'
For each new sample, I'd like to have replace=True. Within one sample, I'd like to have replace=False.
Is there a way to do this?
The next(key) in my random sampling is syntactic sugar. I'm using this snippet for convenience:
def reset_key(seed=42):
key = random.PRNGKey(seed)
while True:
key, subkey = random.split(key)
yield subkey
key = reset_key()
Solution 1:[1]
The best way to do this is using jax.vmap to map across individual samples. For example:
from jax import random, vmap
def sample_two(key, num_items):
return random.choice(key , num_items, (2,), replace=False)
key = random.PRNGKey(0)
num_samples = 10
num_items = 5
key_array = random.split(key, num_samples)
print(vmap(sample_two, in_axes=(0, None))(key_array, num_items))
# [[2 0]
# [1 4]
# [2 1]
# [3 4]
# [4 2]
# [2 0]
# [1 3]
# [2 1]
# [1 0]
# [2 4]]
For more information on jax.vmap, see Automatic Vectorization in JAX.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | jakevdp |
