JAX extremely slow under some circumstances (MWE included) #26429
Unanswered
JoaoAparicio
asked this question in
General
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
JAX takes about 30x longer than numpy in converting lists to its own arrays.
I've attached a file here. The code would be:
I guess that in here this would qualify as a "micro-benchmark". However I think this minimizes the problem. If I'm calling this in a loop because that's how I get the data and it becomes significant compared to the rest of the problem, it's no longer micro.
I understand that JAX does a lot of difficulty things that numpy can't (JAX is amazing, thank you), but at the same time if it's going to be mentioned as "can often be used as drop-in replacement" to numpy (e.g. here), it shouldn't be 30x slower at any point...
Alternatively it's also possible I'm doing something wrong, in which case help? :-)
Beta Was this translation helpful? Give feedback.
All reactions