Is the iteration of the jax version affecting the CUDA code? #17539
Unanswered
jing-alice
asked this question in
Q&A
Replies: 1 comment 1 reply
-
I think the issue is that you are updating the If you want the input and output to be the same array so you can do updates in-place, use the (I should note that |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I used my own cuda code to implement function similar to jax.at[].add(), it is worked for jax==0.4.8, jaxlib==0.4.7+cuda11.cudnn86.
This is my code:
and the result is:
However, when jax==0.4.13,jaxlib==0.4.13+cuda12.cudnn8, the result is incorrect and random, The following two figures show the result of running the code twice:
And this is my CUDA code:
This is the code how to use CUDA in JAX:
Beta Was this translation helpful? Give feedback.
All reactions