Question about JAX working with CUDA 11.2 #14352
Unanswered
dhruvsreenivas
asked this question in
General
Replies: 1 comment 1 reply
-
Alright I think I got it to work... found the wheel online somewhere and now everything seems to work as expected. Terribly sorry for the inconvenience! |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
Hope you are doing well! I'm working on a large-scale Jax project, but the cluster I'm using only uses CUDA 11.2, not CUDA ≥ 11.4. I had to install Jax 0.4.2 and Jaxlib 0.4.2 for CUDA recently due to having to use PyTorch dataloading (TensorFlow dataloading ended up being pretty memory heavy so now I currently have a PyTorch setup, which made me upgrade Jax). Previously, the following worked:
jax version == 0.3.16
jaxlib version == 0.3.15 (cuda version, trying to find that now)
dm-haiku version 0.0.7
optax version 0.1.3 I believe
rlax version 0.1.2 I believe
I don't know how I made it work before, but I want to bring that back. How can I install the dependencies for CUDA appropriately? I really like JAX and so want to use it with my available NVIDIA drivers well, but I know on the website for installation it only talks about JAX installation with cuda versions 11.4 or above. I've downgraded PyTorch in the hope that the JAX library will not auto-update upon installation. Thanks!
Beta Was this translation helpful? Give feedback.
All reactions