Why is jnp.max() JITable? #17525
-
I have some code inside a jitted function where there is a jnp.max( X, Y) operation. Considering that a max operation can be essentially broken down into an if statement about whether X > Y or vice versa, and this conditions on the exact values of X and Y, how this this jitable? Is it guaranteed that jitting such a function will always calculate this jnp.max() operation correctly every time, or could it simply take the same index that the first compilation involved? Interested to learn how this works. Any help appreciated :) |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
Hi - thanks for the question! First of all, you should make sure you're not mixing up
>>> x = jnp.array([-2, -1, 0, 1, 2])
>>> jnp.max(x, 0)
Array(2, dtype=int32) On the other hand, >>> jnp.maximum(x, 0)
Array([0, 0, 0, 1, 2], dtype=int32) The reason these are JIT-compatible is that they can lower to native XLA operations. You can see this by looking at the jaxpr for each. The compiler can condition outputs on operations like >>> jax.make_jaxpr(lambda x: jnp.max(x, 0))(x)
{ lambda ; a:i32[5]. let b:i32[] = reduce_max[axes=(0,)] a in (b,) }
>>> jax.make_jaxpr(jnp.maximum)(x, 2)
{ lambda ; a:i32[5] b:i32[]. let
c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
d:i32[5] = max a c
in (d,) } In general, you can depend on such operations to work as expected at runtime. |
Beta Was this translation helpful? Give feedback.
-
Thank you jakevdp for your very comprehensive answer. Super interesting to see, I guess jit is more powerful than I thought since it can do some operations that involve kind of conditioning on some values. Thanks again for your time! |
Beta Was this translation helpful? Give feedback.
Hi - thanks for the question!
First of all, you should make sure you're not mixing up
jnp.max
andjnp.maximum
jnp.max
takes the maximum of an array along the specified axis:On the other hand,
jnp.maximum
computes the element-wise maximum between two input arrays (possibly after broadcasting, if the shapes don't match). For example:The reason these are JIT-compatible is that they can lower to native XLA operations. You can see this by looking at the jaxpr for each. The compiler can condition outputs on operations like
x > y
at runtime, even if JAX's traci…