Does JAX automatically adjust the tangent output in custom_jvp? #16871
-
Hello everyone, I am trying to write a custom Jacobian-vector product rule for the Newton method. At first, I used custom_vjp, but I switched to custom_jvp to use both forward and reverse automatic differentiation. When implementing custom_jvp, I encountered an interesting behavior of custom_jvp. In the example below, I want to compute the derivative
Then, the output was
This result was unexpected to me. I expected that there was an error because |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
Thanks for the question! When you define a Does that make sense? |
Beta Was this translation helpful? Give feedback.
-
This might be coming in too late to be useful -- but heads-up that we now have a Newton implementation available in Optimistix, which might save you some implementation time. (This library is the latest entry in the Equinox scientific ecosystem, offering nonlinear optimisation.) |
Beta Was this translation helpful? Give feedback.
Thanks for the question!
When you define a
custom_jvp
, you are implicitly defining the vjp as well via automatic transposition of the custom jvp function. Because the output ofnewton_solver_jvp
has no dependence onx_guess
, the automatic transposition makes this effectively equivalent to the VJP rule returningNone
, and so the autodiff machinery returns a zero gradient in both cases.Does that make sense?