Skip to content

Using jax.pmap/vmap and jax.lax.switch to execute a simulation in parallel. Is it an anti-pattern? #20916

Discussion options

You must be logged in to vote

No, i never found a better/faster solution to this. I ended up just optimizing other parts to avoid having to call this logic too often and accepted that it takes a couple of hours to generate the data.

If i would do it again, i would probably not implement this part in JAX. If you need lots of branching (via e.g. jax.lax.switch), then JAX might not be the best option.

Replies: 2 comments 2 replies

Comment options

You must be logged in to vote
1 reply
@simon-bachhuber
Comment options

Comment options

You must be logged in to vote
1 reply
@simon-bachhuber
Comment options

Answer selected by simon-bachhuber
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants