GPU support for the latest version 0.9.4 #496
amithm1111
started this conversation in
General
Replies: 1 comment
-
Hi @amithm1111 , we recently removed JAX/jaxlib pins. Perhaps you could try again now? |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I am not able to run the code on gpu. If I install the latest version of jax and jaxlib with gpu support, gpjax is throwing errors. I also tried installing the gpjax supported version of jax==0.4.27 and jaxlib==0.4.27+cuda12.cudnn89. In this case the jupyter kernal is crashing without giving any error messages.
A solution I saw in earlier discussions was to install an older version of jax
pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
This command does not work now.
How to install the correct jax version for gpjax that support gpu ?
Beta Was this translation helpful? Give feedback.
All reactions