Targeting JAX with a domain-specific compiler #30141
-
Hello, I would like to use JAX as the target for the compiler of a domain-specific language and I would like to know what the best way is to do that. To give an example, an expression in the DSL of interest (called UFL) could be And in code this would look something like: f = Coefficient(...)
v = TestFunction(...)
expr = f * v * dx We would like to be able to generate JAX code from this which would look something like: def compiled_expr(indata, outdata, gather_map, scatter_map):
indata_packed = indata[gather_map]
# 'f' represents a collection of JAX operations of things like einsum, pointwise additions
# and transcendental functions like sin and cos
outdata_packed = f(indata_packed)
return outdata.at[scatter_map].add(outdata_packed) Currently the only approach I can think of for doing this is to construct the above function as a string, at which point one can do: func = eval("""def compiled_expr(...):
...""")
callable = jax.jit(func) but it would be much nicer if we could construct the expression directly in the IR itself. For example instead of producing |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
This looks really interesting! We don't have any way to programmatically build an IR in a step-by-step fashion like this, but I wonder if it would be possible to do so via normal JAX operations in your context? The supported way to build an expression tree (i.e. a jaxpr) is by wrapping a function with an abstract evaluator like |
Beta Was this translation helpful? Give feedback.
This looks really interesting! We don't have any way to programmatically build an IR in a step-by-step fashion like this, but I wonder if it would be possible to do so via normal JAX operations in your context? The supported way to build an expression tree (i.e. a jaxpr) is by wrapping a function with an abstract evaluator like
jax.make_jaxpr
. Could you perhaps use this approach with the routine that is meant to generate the jax expression?