Skip to content

Differentiating with respect to closure introduced variables. #17318

Answered by jakevdp
VIVelev asked this question in Q&A
Discussion options

You must be logged in to vote

If you don't want to restructure your closure definitions, you can compute the derivative like this:

result = jax.jacfwd(lambda theta: Rx(theta)(0.0, 1.0, 0.0))(np.pi / 2)
print(result)
# [ 0.000000e+00 -1.000000e+00 -4.371139e-08]

There's no built-in JAX transform to do autodiff of a function that returns a function, but you could use a short wrapper to define this behavior if you wish:

def fun_jacfwd(fun, *args, **kwargs):
  def df(*df_args, **df_kwargs):
    return jax.jacfwd(lambda *a: fun(*a)(*df_args, **df_kwargs), **kwargs)(*args)
  return df

Rx_prime = fun_jacfwd(Rx, np.pi / 2)
print(Rx_prime(0.0, 1.0, 0.0))
# [-0.000000e+00 -1.000000e+00 -4.371139e-08]

Replies: 2 comments 2 replies

Comment options

You must be logged in to vote
1 reply
@VIVelev
Comment options

Comment options

You must be logged in to vote
1 reply
@VIVelev
Comment options

Answer selected by VIVelev
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants