Skip to content

JAX Jacobian with upstream gradient? #18092

Closed Answered by patrick-kidger
karan-dalal asked this question in Q&A
Discussion options

You must be logged in to vote

I think _, vjp_fn = jax.vjp(f, params, ...) will do what you want. You can then pass the dL/dz cotangents to vjp_fn and it will compute the cotangents dL/dparams.

Replies: 2 comments 1 reply

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
1 reply
@karan-dalal
Comment options

Answer selected by jakevdp
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants