About referencing the shape of the traced array and using it in an if statement. (ndim
, size
)
#18167
-
In python control flow + JIT, I learned what to be careful about when using the if statement when using jit. So I was being careful about that, and I found that no error occurred for @jax.jit
def jax_if_1(x):
if x > 0:
return 1
@jax.jit
def jax_if_2(x):
if x.ndim > 0:
return 1 test = jnp.zeros((1, 1, 1)) jax_if_2(test)
# Array(1, dtype=int32, weak_type=True) jax_if_1(test)
# TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[1,1,1]..
I thought I would get an error in If there is any tutorial or documentation on this, please let me know. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Hi - thanks for the question! You're always free to use Python control flow on static values, and JAX array shape, dtype, and related attributes are static. So For more background on how to think about tracing in JAX, a good place to start is How to think in JAX. |
Beta Was this translation helpful? Give feedback.
Hi - thanks for the question!
You're always free to use Python control flow on static values, and JAX array shape, dtype, and related attributes are static. So
jax_if_2
works because you're using Python control flow on static attributes of arrays, andjax_if_1
fails because you're using Python control flow on dynamic values in arrays.For more background on how to think about tracing in JAX, a good place to start is How to think in JAX.