Has anyone successfully used HSSM/Numpyro with GPU on recent JAX versions (Linux/Windows)? #781
Replies: 12 comments
-
Hello @JamesWeiChen, Thanks for letting us know. We'll try to update numpyro dependency for hssm. We'll let you know when we are done. Thanks! |
Beta Was this translation helpful? Give feedback.
-
Hi @JamesWeiChen, I did some digging into the issue. Although we will be updating the numpyro and pymc dependencies in HSSM, it seems that it won't necessarily fix your problem. Both numpyro 0.17.0 and 0.18.0 (last) require Jax 0.4.25, so there is not much we can do on our end to update its JAX compatibility. We can help you find the best compatible version of JAX though. For that purpose, can you let us know how you are installing HSSM, and the cuda and cudnn version in your set up? Thanks! |
Beta Was this translation helpful? Give feedback.
-
Just FYI, |
Beta Was this translation helpful? Give feedback.
-
Looks like it was just released 4 hours ago. Thanks for letting us know. We'll incorporate this update. In the meantime, let us know if it indeed solve the problem |
Beta Was this translation helpful? Give feedback.
-
I still cannot work with my computer. Environment
I tried running JAX on GPU using the CUDA and cuDNN installed on my system, but it failed. I then attempted to use the cudatoolkit package within my conda environment, but that also failed. I am quite new to Linux, which might be contributing to the issue. |
Beta Was this translation helpful? Give feedback.
-
After few hours of work... it seems like the problem is that Jax does not support 5080... |
Beta Was this translation helpful? Give feedback.
-
nVidia does have an NGC for JAX with support for Blackwell GPUs, so 5080 should be supported. Can you provide a bit more details on how you got JAX installed into your conda environment? Was it a pip install or a conda install? |
Beta Was this translation helpful? Give feedback.
-
I tried both methods I think. |
Beta Was this translation helpful? Give feedback.
-
Does the GPU version of PyTorch work? Just to see whether it's a jax-specific thing |
Beta Was this translation helpful? Give feedback.
-
Thanks for the suggestion. I am able to run the code with gpu now. But I did too many things at once... I don't know which step fix the problem... I think it is because the JAX was trying to use many different CUDA toolkits. Anyway... now it works. |
Beta Was this translation helpful? Give feedback.
-
Setting Up JAX + NumPyro + HSSM with CUDA 12.9 and cuDNN 9.11 on Pop!_OSThis is how I successfully configured JAX (GPU), NumPyro, and HSSM on a machine with NVIDIA RTX 5080 using CUDA 12.9 and cuDNN 9.11 on Pop!_OS (Ubuntu 22.04 base). 1. Install CUDA 12.9 (System-wide)Add the CUDA 12.9 local repository and install the toolkit:
2. Configure Environment VariablesAdd CUDA paths to
Check installation:
(Should show: Cuda compilation tools, release 12.9) 3. Install cuDNN 9.11 (System-wide)Add the cuDNN local repository and install:
4. Create a Clean Conda EnvironmentUse a fresh environment to avoid conflicts with system CUDA:
5. Install JAX (CUDA 12 Support)Install JAX with the CUDA 12 wheel (no conda CUDA packages):
6. Install NumPyro and HSSMEnsure NumPyro is version 0.19.0 (compatible with HSSM):
7. Verify GPU SetupRun a quick test:
Expected output:
|
Beta Was this translation helpful? Give feedback.
-
Thank you so much for the info, @JamesWeiChen! I'll turn this issue into a discussion thread so more people can see this |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi all,
I’ve been trying to run HSSM models with GPU acceleration on my local machine (Linux, NVIDIA RTX 5080, CUDA 12.8). However, I ran into compatibility issues between NumPyro and the latest JAX releases:
Questions:
Any guidance or working environment specs would be greatly appreciated!
Beta Was this translation helpful? Give feedback.
All reactions