-
I would like to implement a custom JVP and VJP including higher order derivatives and batching support for a linear function for which JAX's automatically generated backwards pass performs very poorly. Due to the (currently) limited scope of import jax
from jax import core
from jax import numpy as jnp
from jax._src import api
from jax.interpreters import mlir
los_p = core.Primitive("los") # Create the primitive
los = los_p.bind
def los_impl(x, y, z):
"""Concrete implementation of the primitive."""
return jnp.add(jnp.multiply(x, y), z)
def los_abstract_eval(xs, ys, zs):
"""Abstract evaluation of the primitive invoked with abstractions."""
swd = jax.eval_shape(los_impl, xs, ys, zs)
return core.ShapedArray(swd.shape, swd.dtype)
def los_lowering(ctx, xc, yc, zc):
"""Implement computation based on mlir.ir.Value and return mlir.ir.Values."""
args = [jax.ShapeDtypeStruct(i.shape, i.dtype) for i in ctx.avals_in]
# print("START\n", hlo.AddOp(hlo.MulOp(xc, yc), zc).result, "\nSTOP")
l = jax.jit(los_impl).lower(*args)
# TODO: How do I apply `l` (or a version thereof) to `xc`, `yc`, `zc`?
# Register the primal implementation with JAX
los_p.def_impl(los_impl)
los_p.def_abstract_eval(los_abstract_eval)
# Register the lowering rule with JAX
mlir.register_lowering(los_p, los_lowering)
los(1, 2, 3)
assert api.jit(lambda x, y, z: los(x, y, z))(2., 10., 3.) == 23. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
The easiest way to create a lowering rule that calls a python impl is to use the los_lowering = mlir.lower_fun(los_impl, multiple_results=False) |
Beta Was this translation helpful? Give feedback.
The easiest way to create a lowering rule that calls a python impl is to use the
mlir.lower_fun
utility: