Skip to content

jax.grad with multiple return values #15025

Answered by jakevdp
halfroad asked this question in General
Discussion options

You must be logged in to vote

Thanks for the question! There are two issues here with how you're using jax.grad.

  1. jax.grad only works with single-output functions. Passing has_aux allows you to use a function which returns a second auxillary argument which will not be differentiated (see docs). If you want to compute the derivative of a function with multiple outputs in one pass, one way is to use a more general gradient transform such as jax.jacobian.
  2. Passing argnums=(0, 1) computes df/d{x, y}, which is useful in some cases but does not appear to be what you expected. It sounds like you want to compute (df/dx, df/dy), which requires two gradient evaluations.

With these two pieces in mind, I'd probably do something l…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@halfroad
Comment options

Answer selected by halfroad
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants