How to warm-up jax.lax.while_loop? #26780
-
I’m using a jitted I want to measure the execution time of my algorithm, excluding compilation time. Currently, I’m running the code twice and measuring the execution time of the second run. Is there a smarter way to do this? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
You can use ahead-of-time compilation to separate out tracing/compilation and execution for any function. Here's a simple example: In [1]: import jax
In [2]: def f(x):
...: ...
...: return x
...:
In [3]: x = 1.0
In [4]: %timeit jax.jit(f).lower(x).compile()
158 μs ± 873 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
In [5]: f_lowered = jax.jit(f).lower(x).compile()
In [6]: %timeit f_lowered(x)
3.67 μs ± 25.3 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each) |
Beta Was this translation helpful? Give feedback.
You can use ahead-of-time compilation to separate out tracing/compilation and execution for any function. Here's a simple example: