Replies: 1 comment
-
I can't run this, because I don't know what In general, though, JAX has a fixed-cost overhead for every computation that is on the order tens of milliseconds, while numpy does not. I suspect that if you ran your benchmarks on arrays that are large enough so that this overhead is not important, you'd see a truer comparison of JAX vs numpy runtime scaling. Also, keep in mind for this sort of micro-benchmark that JAX executes Asynchronously, and so you should use the |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
At Colab, I am doing the following experiment to see the speed of some basic addition computation, the code is very simple
sum and a conditional sum in JAX
same way, in Numpy
For CPU only
For GPU accelerator
It seems that a plain numpy has orders of magnitude win, why is that, I already used @jit and the code is straightforward to run, do I make anything wrong? Where can we get the gain using JAX
Beta Was this translation helpful? Give feedback.
All reactions