Skip to content

How to warm-up jax.lax.while_loop? #26780

Answered by jakevdp
ZedongPeng asked this question in Q&A
Discussion options

You must be logged in to vote

You can use ahead-of-time compilation to separate out tracing/compilation and execution for any function. Here's a simple example:

In [1]: import jax

In [2]: def f(x):
   ...:     ...
   ...:     return x
   ...: 

In [3]: x = 1.0

In [4]: %timeit jax.jit(f).lower(x).compile()
158 μs ± 873 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

In [5]: f_lowered = jax.jit(f).lower(x).compile()

In [6]: %timeit f_lowered(x)
3.67 μs ± 25.3 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by ZedongPeng
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants