How to write a custom Jaxpr interpreter that returns a Pytree? #28800
-
I have been following this tutorial on writing custom Jaxpr interpreters and I don't really understand how to how to make a custom function transformation return a Pytree instead of a plain list. The way I understood this is roughly how one could implement a custom transformation: import jax
from functools import wraps
def transformation(fun):
@wraps(fun)
def wrapped(*args, **kwargs):
closed_jaxpr = jax.make_jaxpr(fun)(*args, **kwargs)
out = jax.core.eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args)
return out
return wrapped
def f(x):
y = x + 1
return {"x": x, "y": y}
print(f(1.0)) # {'x': 1.0, 'y': 2.0}
print(transformation(f)(1.0)) # [1.0, 2.0] The simplest way would be to just unflatten everything at the end but I for that I would need the structure which I think I cant get from the Jaxpr. Does anyone know this could be done? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
You want If you want a complete reference point for the above then (There are probably some similar examples inside JAX's own codebase as well but I don't know them as well.) |
Beta Was this translation helpful? Give feedback.
You want
jax.make_jaxpr(..., return_shape=True)
.If you want a complete reference point for the above then
eqx.filter_closure_convert
gives an example of forming a jaxpr and then evaluating it: https://github.com/patrick-kidger/equinox/blob/972ff6a02c5251c6953f5116f38f9842bd585180/equinox/_ad.py#L589-L712 It's a bit wordy as it handles the general case of static metadata etc as well, but you can see the overall structure of evaluating the jaxpr followed by unflattening the tree.(There are probably some similar examples inside JAX's own codebase as well but I don't know them as well.)