A generic way / tracer to estimate flops of JAX operations? #30569
-
I was recently looking into estimating the model flops utilization (MFU) for some model I coded in JAX. This gave me the idea that the tracing mechanism in JAX should in principle allow for (at least approximately) estimating the flops for a given model / sequence of operations. What I would be looking for is something like Does something like this already exist? Do you think it would it be possible and useful to interpret a trace this way? Or are the actual flops (after compilation) to different from the naive "execution"? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
Try using this roofline library: https://cs.opensource.google/jax/jax/+/main:jax/experimental/roofline/roofline.py;l=1?q=roofline&sq=&ss=jax%2Fjax tests for usage here: https://cs.opensource.google/jax/jax/+/main:tests/roofline_test.py;l=1?q=roofline&ss=jax%2Fjax |
Beta Was this translation helpful? Give feedback.
-
Beta Was this translation helpful? Give feedback.
Like
cost_analysis
?https://docs.jax.dev/en/latest/aot.html#debug-information-and-analyses-when-available
https://docs.jax.dev/en/latest/jax.stages.html#jax.stages.Lowered.cost_analysis
https://docs.jax.dev/en/latest/jax.stages.html#jax.stages.Compiled.cost_analysis