Skip to content

Why is jnp.max() JITable? #17525

Answered by jakevdp
throwaway0123456 asked this question in Q&A
Discussion options

You must be logged in to vote

Hi - thanks for the question!

First of all, you should make sure you're not mixing up jnp.max and jnp.maximum

jnp.max takes the maximum of an array along the specified axis:

>>> x = jnp.array([-2, -1, 0, 1, 2])
>>> jnp.max(x, 0)
Array(2, dtype=int32)

On the other hand, jnp.maximum computes the element-wise maximum between two input arrays (possibly after broadcasting, if the shapes don't match). For example:

>>> jnp.maximum(x, 0)
Array([0, 0, 0, 1, 2], dtype=int32)

The reason these are JIT-compatible is that they can lower to native XLA operations. You can see this by looking at the jaxpr for each. The compiler can condition outputs on operations like x > y at runtime, even if JAX's traci…

Replies: 2 comments

Comment options

You must be logged in to vote
0 replies
Answer selected by throwaway0123456
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants