Batching rule for 'celerite2_factor' not implemented #28769
Replies: 1 comment 1 reply
-
Thanks for the question! This isn't really a JAX issue, and the issue that you linked is the one to comment on. celerite2 isn't really very maintained these days. As the developer of both, I'd recommend tinygp which can use the same algorithms, and already supports batching. |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I am using the celerite2 package, with Numpyro. When trying to run AIES MCMC I get the following error:
lib/python3.11/site-packages/celerite2/jax/ops.py", line 39, in factor d, W, S = factor_p.bind(t, c, a, U, V) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ NotImplementedError: Batching rule for 'celerite2_factor' not implemented
I know this is something that needs to be added to celerite2 package (and is maybe referenced in exoplanet-dev/celerite2#11)?
Any advice on how to go about writing the rule would be welcome!
Beta Was this translation helpful? Give feedback.
All reactions