Differentiate through an ordinary differential equation with jax.jet? #26332
Unanswered
DavidLanders95
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I have been trying to obtain higher order partials of the output solution y_out to an ode, with respect to the input parameters y_in, to obtain the taylor expansion coefficients, and it seems like jax.jet is a reasonable tool to perform this. Since I would rather avoid repeated calls to jax.jacobian, as I would like to get higher order partials, and the compilation time for a jitted function with higher order jacobians is quite slow it almost seems neccessary to use jet. However, I find that I cannot use jet through jax's odeint solver, or a solver from diffrax. It seems like the while loops or other conditions that typically adaptive ode solvers use are not compatible with jet. Has anyone tried to perform such an operation before and found a work around for this?
Beta Was this translation helpful? Give feedback.
All reactions