Any way to diagnose the computational efficiency of JAX codes other than profiling? #18003
Unanswered
ToshiyukiBandai
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi all,
I am trying to improve my JAX codes for scientific computing. I compared JAX with Fortran, and Fortran was about 6 times faster on CPU than JAX to solve a PDE. JAX was faster on GPU, but Fortran was still faster.
I suspect my JAX codes could be improved. How can I diagnose the computational efficiency of JAX codes? I tried profiling JAX programs, but I could not figure out what was wrong with my JAX codes. Are there any ways to check efficiency, such as checking unnecessary copies of arrays or something?
Beta Was this translation helpful? Give feedback.
All reactions