TypeError: JAX encountered invalid PRNG key data: expected key_data.ndim >= 1; got ndim=0 #18150
Unanswered
SaschaFroelich
asked this question in
Q&A
Replies: 1 comment 3 replies
-
Thanks for the question - can you share the full error traceback? There's not enough information in your question to know what might be going on. Also, what version of numpyro are you using? |
Beta Was this translation helpful? Give feedback.
3 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I reinstalled JAX on a new machine and now code that ran before throws an error:
Now it throws the error
TypeError: JAX encountered invalid PRNG key data: expected key_data.ndim >= 1; got ndim=0
I don't exactly understand what's the problem, since ndim of the rng_key is 1, and it worked with an earlier version of jax. The only difference now is that I switched from jax 0.4.14 with jaxlib 0.4.14. to jax 0.4.18 with jaxlib 0.4.18+cuda12.cudnn89. Does it maybe have something to do with the cuda support?
Package versions
jax 0.4.18
jaxlib 0.4.18+cuda12.cudnn89
numpyro 0.13.2
python 3.10
Error Trace
Beta Was this translation helpful? Give feedback.
All reactions