Install and/or package jax matching cuda+cudnn versions #24042
-
How should one install jax + jaxlib with the most recent cuda versions on ubuntu? As of today, this is
Currently, the jax_cuda_releases page only holds versions up to 0.4.29, while I'm currently trying A follow-up question that I've been already reading a lot of internet about, is how to register jax/jaxlib as a flexible enough dependency within a platform-agnostic library (not just an application repo)... I'm happy to drop poetry for uv or plain old setup-tools, but I doubt there is an easy solution for this right now. Any advice welcome! Cheers 🙏 EDIT: I've been through e.g. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
Update: running
seemed to provide satsifactory versions. Is this robust or recommended? |
Beta Was this translation helpful? Give feedback.
-
JAX is compatible with CUDA 12.4, see https://jax.readthedocs.io/en/latest/installation.html#pip-installation-nvidia-gpu-cuda-installed-locally-harder. However, I'm trying to understand what went wrong when you used CUDA 12.6 and just let JAX install it. |
Beta Was this translation helpful? Give feedback.
I'll also note that
jax_cuda_releases
will never be updated again: we now ship wheels on pypi.