Skip to content

About referencing the shape of the traced array and using it in an if statement. (ndim, size) #18167

Answered by jakevdp
helpingstar asked this question in Q&A
Discussion options

You must be logged in to vote

Hi - thanks for the question!

You're always free to use Python control flow on static values, and JAX array shape, dtype, and related attributes are static. So jax_if_2 works because you're using Python control flow on static attributes of arrays, and jax_if_1 fails because you're using Python control flow on dynamic values in arrays.

For more background on how to think about tracing in JAX, a good place to start is How to think in JAX.

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by helpingstar
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants