-
Hi everybody, this is probably a stupid question, but I just can't seem to get it right. I want to use random sampling within a function that I will Here's a MWE of what I mean: key = jax.random.PRNGKey(42)
def add_gaussian(inp, key):
return inp + jax.random.normal(key, inp.shape)
add_gaussian(jnp.zeros([]), key)
# Returns -0.18471177
vectorized_add_gaussian = jax.vmap(add_gaussian, [0, None], 0)
vectorized_add_gaussian(jnp.zeros(5), key)
# Returns [ -0.18471177, -0.18471177, -0.18471177, -0.18471177, -0.18471177 ] I would like the last function call to output 5 different values. One solution would obviously be to split the rng-key and vectorized_add_gaussian = jax.vmap(add_gaussian, [0, 0], 0)
subkeys = key.split(5)
vectorized_add_gaussian(jnp.zeros(5), subkeys)
# Returns [ 0.48309308, -1.1981592 , -0.7775587 , 1.2407155 , -2.313012 ] Is this the recommended way to handle the situation? Or is there a more elegant way of doing it? Manual vectorization would be possible without splitting the PRNG key, so I'm wondering whether |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
It's a good question! The key thing to understand about JAX PRNG keys is that they are just values, and they act like all other values with respect to transformations. So, if you want to have each element of a This does require you to be explicit about threading PRNG state, but that's a good thing! If you did want to use the same PRNG key for each element of a Aside: We've long wanted to make it harder to accidentally reuse PRNG keys, so in the future we might require that you do something explicit to say "yes, I really want to use this PRNG key multiple times" in this sort of situation. Hope that helps! |
Beta Was this translation helpful? Give feedback.
It's a good question!
The key thing to understand about JAX PRNG keys is that they are just values, and they act like all other values with respect to transformations. So, if you want to have each element of a
vmap
have a different PRNG key, you should split the key andvmap
over the split key. Indeed, that's exactly what you should expect from the semantics ofvmap
: anything else would be surprising.This does require you to be explicit about threading PRNG state, but that's a good thing! If you did want to use the same PRNG key for each element of a
vmap
, it's clear how to do that (don't split), and if you want to use different PRNG keys for eachvmap
element, it's clear how to do that …