-
Hi! I am trying to add a new primitive to JAX and find myself stuck in the following issue.
Q1) What is the easiest way to do this? Using custom_jvp or custom_vjp seem to require using existing JAX primitives. Q2) I also wrote the JVP in C++, it works, but I do not understand how to write the transpose rule. Does the transpose rule even makes sense if JVP is a C++ function? Q3) Also, it is okay if I cannot JIT it for the moment? Thank you, |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Thanks for the questions! Q1) Right, custom_jvp/vjp aren't the right tool for this job. You want to define your own Primitive. This tutorial is the best documentation on how to do that. A few more comments about this below. Q2) An opaque-to-JAX C++ JVP rule for your primitive can make sense; it's just another primitive, where some inputs and outputs are linear. See the example below. Q3) Yes. If you want to JIT it, you can set up a CustomCall, as described in the tutorial linked above. Here's an example which may make sense. In it, I'll use ordinary NumPy imported as import jax
import numpy as onp
# Make a Primitive
foo_p = jax.core.Primitive('foo')
def foo(x):
return foo_p.bind(x)
# Give it an impl rule, which just accepts as input an raw array-like value (no
# tracers or anything) and applies opaque foreign functions to produce a raw
# array-like value output.
@foo_p.def_impl
def foo_impl(x_arr):
return onp.sin(onp.sin(x_arr))
# Give it a JVP rule, which trivially just calls another primitive we'll define.
from jax.interpreters import ad
def foo_jvp(primals, tangents):
(x,), (xdot,) = primals, tangents
y = foo(x)
y_dot = foo_jvp_p.bind(x, xdot)
return y, y_dot
ad.primitive_jvps[foo_p] = foo_jvp
foo_jvp_p = jax.core.Primitive('foo_jvp')
# We could define an impl rule for foo_jvp_p, and thus get first-order
# forward-mode AD working. But if we only care about reverse-mode, we actually
# don't need one; instead we need an abstract eval rule and a transpose rule.
# The transpose rule can do the full VJP calculation, and can itself call an
# opaque primitive.
@foo_jvp_p.def_abstract_eval
def foo_jvp_abstract_eval(x_aval, x_dot_aval):
y_dot_aval = jax.core.ShapedArray(x_dot_aval.shape, x_dot_aval.dtype)
return y_dot_aval
def foo_jvp_transpose(y_bar, x, x_dot_dummy):
assert ad.is_undefined_primal(x_dot_dummy) # just a dummy input
x_bar = foo_vjp_p.bind(x, y_bar) # y_bar aka y_grad
return None, x_bar # None for nonlinear primal input x
ad.primitive_transposes[foo_jvp_p] = foo_jvp_transpose
# Finally, let's write the vjp rule as a primitive.
foo_vjp_p = jax.core.Primitive('foo_vjp')
@foo_vjp_p.def_impl
def foo_vjp_impl(x, y_bar):
return y_bar * onp.cos(onp.sin(x)) * onp.cos(x)
###
# Let's test it!
print(jax.grad(foo)(3.)) We're working on making the FFI story simpler, to reduce boilerplate. As part of that, we should probably make it easier to handle this common case of wanting to write a primal computation in C++ as well as a VJP rule. cc @sharadmv @zhangqiaorjc |
Beta Was this translation helpful? Give feedback.
Thanks for the questions!
Q1) Right, custom_jvp/vjp aren't the right tool for this job. You want to define your own Primitive. This tutorial is the best documentation on how to do that. A few more comments about this below.
Q2) An opaque-to-JAX C++ JVP rule for your primitive can make sense; it's just another primitive, where some inputs and outputs are linear. See the example below.
Q3) Yes. If you want to JIT it, you can set up a CustomCall, as described in the tutorial linked above.
Here's an example which may make sense. In it, I'll use ordinary NumPy imported as
onp
to perform some operations which are totally opaque to JAX; as far as JAX concerned, those are just calls into opaque f…