JAX Jacobian with upstream gradient? #18092
-
Beta Was this translation helpful? Give feedback.
Answered by
patrick-kidger
Oct 15, 2023
Replies: 2 comments 1 reply
-
@jakevdp any help would be great! |
Beta Was this translation helpful? Give feedback.
0 replies
-
I think |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
jakevdp
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I think
_, vjp_fn = jax.vjp(f, params, ...)
will do what you want. You can then pass the dL/dz cotangents tovjp_fn
and it will compute the cotangents dL/dparams.