-
In theory, EDIT: I went with sprinkling |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 3 replies
-
Right now, nested |
Beta Was this translation helpful? Give feedback.
-
For anyone coming back to this, For example, >>> f = lambda x: sum(i * x for i in range(256)) # common function
>>> jf = jax.jit(f)
>>> %timeit jax.jit(lambda x: sum(f(x) for _ in range(16)))(3.14)
2.79 s ± 133 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
>>> %timeit jax.jit(lambda x: sum(jf(x) for _ in range(16)))(3.14)
36.5 ms ± 9.6 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) |
Beta Was this translation helpful? Give feedback.
-
Revisiting this again in 2025, and I don't think this is the case anymore with the recent pulling of pjit.cc across into jaxlib. Notably, if I def f(in_vars):
@jax.jit
def sub_fun(vars):
return jnp.some_op(vars)
x = jnp.some_op(in_vars)
y = sub_fun(x)
return y And turn on verbose logging ( You end up with many (function name e.g. annotation mine):
Notably, you get these errors on JAX's built-in functions as well - so even if you Which, looking at Lines 641 to 657 in 8ff5d33 You can produce this behaviour with this: import jax
import jax.numpy as jnp
import numpy as np
jax.config.update("jax_platform_name", "gpu")
@jax.jit
def simple_stuff(x, y):
z = jnp.matmul(x, y)
z = jnp.abs(z)
z = z**2
return jnp.sum(z)
many_x = jnp.array(np.random.random((batch_dim, 1000, 1000)))
many_y = jnp.array(np.random.random((batch_dim, 1000, 1000)))
def do_many(x, y):
res = simple_stuff(x,y) - 1
return res
v_matmul = jax.jit(jax.vmap(do_many, in_axes=(0,0)))
res = v_matmul(many_x, many_y) Which seems to indicate to me that true nested jit should be avoided - unless of course you use some of the nested functions independently elsewhere where they won't be passed Tracers as input. From https://docs.jax.dev/en/latest/faq.html#benchmarking-jax-code :
If my understanding of all of this is correct, I'd love if someone with authority can confirm/correct! |
Beta Was this translation helpful? Give feedback.
Right now, nested
jit
calls will be preserved as function calls in the IR that JAX generates, but will be flattened by XLA. In the future, this may not be true any more! At some point, XLA (or another compiler) may inline less aggressively.