Skip to content

Randomness within a vmap? #7342

Answered by hawkinsp
khdlr asked this question in Q&A
Jul 21, 2021 · 1 comments · 2 replies
Discussion options

You must be logged in to vote

It's a good question!

The key thing to understand about JAX PRNG keys is that they are just values, and they act like all other values with respect to transformations. So, if you want to have each element of a vmap have a different PRNG key, you should split the key and vmap over the split key. Indeed, that's exactly what you should expect from the semantics of vmap: anything else would be surprising.

This does require you to be explicit about threading PRNG state, but that's a good thing! If you did want to use the same PRNG key for each element of a vmap, it's clear how to do that (don't split), and if you want to use different PRNG keys for each vmap element, it's clear how to do that …

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@khdlr
Comment options

@Logon27
Comment options

Answer selected by khdlr
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants