-
So, I've got the following example: import jax
import jax.numpy as jnp
import numpy as np
def Rx(theta):
"""Rotation around the x-axis by angle theta"""
ct = jnp.cos(theta)
st = jnp.sin(theta)
def f(x, y, z):
return jnp.array([x, ct * y - st * z, st * y + ct * z])
return f
Rx(np.pi / 2)(
0.0, 1.0, 0.0
) # => Array([ 0.000000e+00, -4.371139e-08, 1.000000e+00], dtype=float32)
# Take the derivative of rotation around the x-axis w.r.t. theta at [0, 1, 0]:
jax.jacfwd(Rx)(np.pi / 2)(
jnp.array([0.0, 1.0, 0.0])
) # => TypeError: Value <function Rx.<locals>.f at 0x127a16520> with type <class 'function'> is not a valid JAX type Can I make this work? I would like to differentiate the rotation w.r.t. theta, or in general, apply the |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
The problem is that I slightly modified the code to make it work import jax
import jax.numpy as jnp
def Rx(theta, x, y, z):
"""Rotation around the x-axis by angle theta"""
ct = jnp.cos(theta)
st = jnp.sin(theta)
out = jnp.array([x, ct * y - st * z, st * y + ct * z])
return out
print(Rx(jnp.pi / 2, 0.0, 1.0, 0.0))
# => Array([ 0.000000e+00, -4.371139e-08, 1.000000e+00], dtype=float32)
# Take the derivative of rotation around the x-axis w.r.t. theta (argnums=0) at [0, 1, 0]:
jax.jacfwd(Rx, argnums=0)(jnp.pi / 2, 0.0, 1.0, 0) Let me know if this helps |
Beta Was this translation helpful? Give feedback.
-
If you don't want to restructure your closure definitions, you can compute the derivative like this: result = jax.jacfwd(lambda theta: Rx(theta)(0.0, 1.0, 0.0))(np.pi / 2)
print(result)
# [ 0.000000e+00 -1.000000e+00 -4.371139e-08] There's no built-in JAX transform to do autodiff of a function that returns a function, but you could use a short wrapper to define this behavior if you wish: def fun_jacfwd(fun, *args, **kwargs):
def df(*df_args, **df_kwargs):
return jax.jacfwd(lambda *a: fun(*a)(*df_args, **df_kwargs), **kwargs)(*args)
return df
Rx_prime = fun_jacfwd(Rx, np.pi / 2)
print(Rx_prime(0.0, 1.0, 0.0))
# [-0.000000e+00 -1.000000e+00 -4.371139e-08] |
Beta Was this translation helpful? Give feedback.
If you don't want to restructure your closure definitions, you can compute the derivative like this:
There's no built-in JAX transform to do autodiff of a function that returns a function, but you could use a short wrapper to define this behavior if you wish: