Why is my JAX profiler not displaying the full function stack? #26094
Unanswered
MelodicDrumstep
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello,
I’ve recently been working on profiling and tuning the inference procedure for the AlphaFold3 model. As someone relatively new to JAX, I’ve been experimenting with the JAX profiler as outlined in the official documentation:
However, the results I’ve obtained have been somewhat disappointing:
The host running time is predominantly dominated by "pjit helper," which obscures the actual function calls. Additionally, the device running time is not fully displayed. I’ve reviewed the troubleshooting section in the official documentation, but my issue doesn’t seem to align with the cases mentioned there.
For context, the process doesn’t output any error messages related to this issue. The full output is as follows:
In an attempt to better understand the issue, I conducted an experiment where I removed a JIT statement from a core function:
The results after this change were as follows:
The host and device profiles now display more information and are more readable. However, I’m puzzled by the host profile output, which suggests that
model.py:_call__
doesn’t call further functions for an extended period (which contradicts the actual code). Additionally, in the device profile output, the kernel function elapsed time appears unusually short (though this might be normal).Thank you in advance! Here are my profile output files:
perfetto_trace in the first case
perfetto_trace in the second case
Beta Was this translation helpful? Give feedback.
All reactions