-
Hi Folks, Anybody knows the underlying logics when differentiating a function with multiple return values? I am trying to differentiate the function following code lines,
The print of execution as follows,
Supposedly, the print should appear somethings like (648,, In the case I like to compute the partial derivative of (48, 540) for f(x, y) = 4 * x ** 3 + 5 * y ** 4 instead, how to designate the argnums = (0, 1), or other changes? Thank you in advance. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Thanks for the question! There are two issues here with how you're using
With these two pieces in mind, I'd probably do something like the following to compute the results you're were looking for: df1_dx, df2_dx = jax.jacobian(function, argnums=0)(x, y)
df1_dy, df2_dy = jax.jacobian(function, argnums=1)(x, y)
result = ((df1_dx, df1_dy), (df2_dx, df2_dy)) |
Beta Was this translation helpful? Give feedback.
Thanks for the question! There are two issues here with how you're using
jax.grad
.jax.grad
only works with single-output functions. Passinghas_aux
allows you to use a function which returns a second auxillary argument which will not be differentiated (see docs). If you want to compute the derivative of a function with multiple outputs in one pass, one way is to use a more general gradient transform such asjax.jacobian
.argnums=(0, 1)
computesdf/d{x, y}
, which is useful in some cases but does not appear to be what you expected. It sounds like you want to compute(df/dx, df/dy)
, which requires two gradient evaluations.With these two pieces in mind, I'd probably do something l…