Skip to content

Does JAX automatically adjust the tangent output in custom_jvp? #16871

Answered by jakevdp
ToshiyukiBandai asked this question in Q&A
Discussion options

You must be logged in to vote

Thanks for the question!

When you define a custom_jvp, you are implicitly defining the vjp as well via automatic transposition of the custom jvp function. Because the output of newton_solver_jvp has no dependence on x_guess, the automatic transposition makes this effectively equivalent to the VJP rule returning None, and so the autodiff machinery returns a zero gradient in both cases.

Does that make sense?

Replies: 2 comments 1 reply

Comment options

You must be logged in to vote
1 reply
@ToshiyukiBandai
Comment options

Answer selected by ToshiyukiBandai
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants