Why is my simple Jax program so slow? #18208
Unanswered
NeilGirdhar
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Simple functions that just apply matrices, nonlinear functions, etc. seem to be extremely slow. I'm not really sure if it's because Jax is not able to fuse kernels? What should I do to make this program faster?
Attached the Perfetto trace: perfetto_trace.json.gz
Functions like
to_exp
look simple to me, so I'm surprised that they're so slow.Any help is greatly appreciated!
Beta Was this translation helpful? Give feedback.
All reactions