-
Seemingly overnight, my JAX project (in v. 0.4.1) became suddenly ~2x slower in terms operations per second when running on V100 and Quadro GP100s. The project git commit hash is identical and there were no changes to my conda environment. Does anyone have any ideas for what might cause this performance degradation? |
Beta Was this translation helpful? Give feedback.
Replies: 4 comments 1 reply
-
I am experiencing also increasing slowness of JAX since a few releases, expecially when working with very large array (on the order of a Gb or more) |
Beta Was this translation helpful? Give feedback.
-
Hrm, this is weird...
@minqi do you think your Any chance at sharing a repro? |
Beta Was this translation helpful? Give feedback.
-
Also, what change did you make when you noticed this slowness? Was it a jax/jaxlib upgrade or was it just shifting from 1 GPU to another GPU or something else? If you upgraded JAX, don't forget to bump your jaxlib to 0.4.1 too! |
Beta Was this translation helpful? Give feedback.
-
Thanks for the responses! I ended up reinstalling both jax + jaxlib and now see the prior training speeds. |
Beta Was this translation helpful? Give feedback.
Thanks for the responses! I ended up reinstalling both jax + jaxlib and now see the prior training speeds.