Skip to content

How to write a custom Jaxpr interpreter that returns a Pytree? #28800

Answered by patrick-kidger
huterguier asked this question in Q&A
Discussion options

You must be logged in to vote

You want jax.make_jaxpr(..., return_shape=True).

If you want a complete reference point for the above then eqx.filter_closure_convert gives an example of forming a jaxpr and then evaluating it: https://github.com/patrick-kidger/equinox/blob/972ff6a02c5251c6953f5116f38f9842bd585180/equinox/_ad.py#L589-L712 It's a bit wordy as it handles the general case of static metadata etc as well, but you can see the overall structure of evaluating the jaxpr followed by unflattening the tree.

(There are probably some similar examples inside JAX's own codebase as well but I don't know them as well.)

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@huterguier
Comment options

Answer selected by huterguier
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants