-
Hi — this is a bit of a meandering “open discussion” question, but it includes some technical questions about how I should think about the compilation model of JAX/XLA:
The reason why I ask this is because I’ve been curious about exposing transformations as library code for new programming languages (for quite some time now -- starting with a fascination with the way that Julia performs type inference for optimization). Obviously, JAX is a practical tool — it’s focused on accelerating programs which fit a specific picture (the one compatible with XLA) and focused on as wide an audience as possible (numerical Python writers). Ignoring this obvious focus on practicality, one thing I've often asked myself: what if you took a principled IR (some dialect of MLIR) -- potentially supporting higher-order features, etc -- and built the interpreter-as-transformer idiom expressed in In I think there are some interesting language design benefits if the compiler middleware above is constructed from first principles. For one, JAX is obviously focused on XLA -- but pulling out the interpreter idiom means your new toolchain doesn't necessarily need to be -- e.g. I could imagine lowering analyses which identify when code is XLA dialect compat, vs. must fall back on LLVM (or other runtime specific) toolchains. I'm also curious how much this solution "recovers Julia" modulo the incredible engineering required to support dynamic dispatch and specialization at runtime. Also: how much of a benefit would there be to expressing the interpreter/transformer stack in a low-level language (vs. falling back on Python's interpreter to run the stack). Is tracing a bad bottleneck? Is XLA a significantly worse one? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
Jaxpr has been repurposed for use-cases beyond AD and XLA. Pallas is a lowering from Jaxpr -> Triton IR, for example. I think although the set of JAX primitives was designed without non-XLA lowerings in mind, it's easy enough to extend the set of JAX primitives to enable alternative lowerings. |
Beta Was this translation helpful? Give feedback.
Jaxpr has been repurposed for use-cases beyond AD and XLA. Pallas is a lowering from Jaxpr -> Triton IR, for example.
I think although the set of JAX primitives was designed without non-XLA lowerings in mind, it's easy enough to extend the set of JAX primitives to enable alternative lowerings.