Replies: 2 comments
-
Excuse me, is the jax installed by this method the CPU version or the cuda 12 version? |
Beta Was this translation helpful? Give feedback.
0 replies
-
Please see the docs: https://docs.jax.dev/en/latest/developer.html#building-jaxlib-from-source The |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Background
I have some applications running on
jax
installed viapip install jax[cuda12]==0.4.34 jaxlib==0.4.34
. Recently, I encountered this issue when enabling bothcuDNN
and persistent caching simultaneously:String field 'xla.gpu.CompilationResultProto.DnnCompiledGraphsEntry.value' contains invalid UTF-8 data
.I noticed that this bug was fixed in the commit to the xla repository on January 9, 2025, and after that, the latest version of
jax
was0.5.0
. So, I reinstalled the latestjax
usingpip install jax[cuda12]==0.5.0 jaxlib==0.5.0
after removing the oldjax
. The aforementioned issue was resolved!However, my application is not compatible with
jax 0.5.0
, as it runs slower and encounters someNaN
errors. I decided to buildjaxlib
andjax[cuda12]
using a localxla
repository.What did I do
I pulled the
jax
andxla
repositories and switched to specific commits using the following commands:After modifying the XLA source code to fix the bug, I read developer.md and used the following command to build
jaxlib
:After a long wait, I received the following output:
Question
I successfully built
jaxlib
. However, the question arises: how do I buildjax[cuda12]
? I did not find any instructions in the jax documentation for building a specific wheel with the[cuda12]
tag. Thank you for your response!Beta Was this translation helpful? Give feedback.
All reactions