Question on installing Jax using CUDA #20678
-
Dear all, After installing JAX, I found there were some problems to run it with GPU. My environment is: I tried to install it with command:
and found the output of the installation:
which indicates it has been installed successfully.
Based on this |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 12 replies
-
The key line in your output is this:
You should try For more information on GPU-compatible JAX installation, see Installation: NVIDIA GPU. |
Beta Was this translation helpful? Give feedback.
-
Thank you so much for your quick reply! If I need to install with Besides, I tried to run:
which is also based on this advice. However, I got the error of:
|
Beta Was this translation helpful? Give feedback.
-
Hi, import jax.numpy as jp
jp.zeros((1,)) raised the following error 2025-01-28 11:34:00.709230: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:439] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
2025-01-28 11:34:00.709269: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:443] Memory usage: 734789632 bytes free, 8361738240 bytes total. However jp.array([1.]).device() returns My environment: Ubuntu 20.04, python 3.8, GPU Quadro RTX 4000 and no local cuda toolkit installed, LD_LIBRARY_PATH is not empty but no cuda related path included. I have pytorch with cuda 11.8 installed I installed jax with pip install "jax[cuda11_pip]==0.4.13" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html cuda related wheels in my env: pip list | grep nvidia
nvidia-cublas-cu11 11.11.3.6
nvidia-cuda-cupti-cu11 11.8.87
nvidia-cuda-nvcc-cu11 2022.5.4
nvidia-cuda-nvcc-cu117 11.7.64
nvidia-cuda-nvrtc-cu11 11.8.89
nvidia-cuda-runtime-cu11 11.8.89
nvidia-cudnn-cu11 9.1.0.70
nvidia-cufft-cu11 10.9.0.58
nvidia-curand-cu11 10.3.0.86
nvidia-cusolver-cu11 11.4.1.48
nvidia-cusparse-cu11 11.7.5.86
nvidia-nccl-cu11 2.20.5
nvidia-nvtx-cu11 11.8.86
nvidia-pyindex 1.0.9 |
Beta Was this translation helpful? Give feedback.
OK, I think the hardware is compatible. Next guess would be that your CUDNN version is incompatible with your CUDA version. Given that you have CUDA 10 installed locally, and CUDA 11 or 12 installed via pip, it's likely that there's cross-talk between the two installations. I'd make sure that your system paths are setup such that you're only loading one set of CUDA packages at runtime.
Alternatively, you could try wiping the CUDA 10 installation from your system, and installing CUDA 11 or CUDA 12 locally, so you can run jax against these and avoid having multiple CUDA versions on your system.