Replies: 2 comments 1 reply
-
You might be seeing the fact that JAX uses asynchronous dispatch. Try wrapping |
Beta Was this translation helpful? Give feedback.
1 reply
-
using : |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I'm trying to use both PyCUDA and JAX in the same Python script, but I'm running into a frustrating issue. Everything works fine when I:
-Run PyCUDA code first (works perfectly)
-Then run JAX code (also works fine)
But when I try to go back to PyCUDA after using JAX, it crashes saying the GPU is "busy or unavailable". It's like JAX won't let go of the GPU even after it's done.
Here is a small snippet of code I used to try and understand what was going on :
jax version : 0.6.1 (Installed with pip install -U "jax[cuda12]")
Cuda version : 12.9 (nvidia-smi (version 575.51.03))
Any idea on how I can make jax release the GPU ?
Thank you in advance !
Beta Was this translation helpful? Give feedback.
All reactions