You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi all, I'm new to profiling, especially within the JAX framework.
My code contains one big loop over time, implemented as a jax.lax.scan(f, ...). Now I'd like to time computation that is performed at each timestep to identify the bottlenecks. For this I'm looking at jax.profiler.trace with Perfetto (https://docs.jax.dev/en/latest/profiling.html).
When putting the profiler around the scan function, I do not see any information on the subfunctions being executed within the function that is being scanned over. However, when bringing the profiler inside of the scan function, the profiling is being performed on the first run, during compilation and caching of the jax scan function (even when running jax.lax.scan(f, ...) on a precompiled f function, a second run of the scan performs way faster, indicating additional caching of the jax scan).
My question is now as follows, when the jax.profile.trace is being performed during the first call of jax.lax.scan, am I only seing raw computation times in the results or also caching/compilation timings? The resulting times displayed in Perfetto seem like raw computation times are being displayed (even when f is not precompiled), however I'd like to be sure.
An additional question, is it normal that a second call to the same jax.lax.scan function executes way faster, even though the function that is being scanned over was already compiled beforehand?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
Hi all, I'm new to profiling, especially within the JAX framework.
My code contains one big loop over time, implemented as a jax.lax.scan(f, ...). Now I'd like to time computation that is performed at each timestep to identify the bottlenecks. For this I'm looking at jax.profiler.trace with Perfetto (https://docs.jax.dev/en/latest/profiling.html).
When putting the profiler around the scan function, I do not see any information on the subfunctions being executed within the function that is being scanned over. However, when bringing the profiler inside of the scan function, the profiling is being performed on the first run, during compilation and caching of the jax scan function (even when running jax.lax.scan(f, ...) on a precompiled f function, a second run of the scan performs way faster, indicating additional caching of the jax scan).
My question is now as follows, when the jax.profile.trace is being performed during the first call of jax.lax.scan, am I only seing raw computation times in the results or also caching/compilation timings? The resulting times displayed in Perfetto seem like raw computation times are being displayed (even when f is not precompiled), however I'd like to be sure.
An additional question, is it normal that a second call to the same jax.lax.scan function executes way faster, even though the function that is being scanned over was already compiled beforehand?
Thanks in advance
Cedric
Beta Was this translation helpful? Give feedback.
All reactions