Replies: 1 comment 5 replies
-
Can you put together a reproducible example of what you're seeing? |
Beta Was this translation helpful? Give feedback.
5 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I am trying to create a custom jvp rule for a newton ralphson root finder algorithm.
Since I use linear solvers and host of other external stuff inside the newton_solver_fn(), I do the following:
As per my understanding, JAX should not look inside "_solve()".
But when I try to do jax.grad(), it shows an error:
"ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop with dynamic start/stop values"
which I believe is coming from inside the "_solve()".
Why is this happening? Is there something that I am missing?
P.s. The error only happens when I create the jvp [which is an abstract tracer]and use that as the RHS. If I simply do it with some other RHS vector, the code works [Of course, the gradient is wrong].
Beta Was this translation helpful? Give feedback.
All reactions