Execution of replica 0 failed: INVALID_ARGUMENT #24727
Unanswered
jaxengodfrey
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hello,
I am using GPU enabled
numpyro
andjax
. I am getting the following error after my chain finishes sampling but beforenumpyro.infer.MCMC.run()
has finished compiling the results:I only get this error with large warmup and sample sizes, ~100k. I think it started occurring when I started saving a large number of deterministic variables during sampling. I have done this in the past with older versions of numpyro/jax without issue, same sample sizes, deterministic variables, GPU, etc.
I'm using the NUTS kernel and an 80GB Nvidia A100 GPU.
Because I don't know much about how
numpyro
usesjax
under the hood, I'm not sure how I could isolate this issue withinjax
itself to make troubleshooting easier. Any ideas/suggestions would be appreciated!Beta Was this translation helpful? Give feedback.
All reactions