Possible improvement for translation from JAXPR to MLIR #26667
-
Hello, I was wondering if anyone has profiled the time it takes to convert JAXPR to MLIR. I recently noticed a commit that would hash JAXPR literals and reuse them. Cherry-picking this commit on an old JAX version showed a decent improvement in conversion time. With JAXPR being functional, it seems like this can also be extended to JAXPR equations where if the operands are the same then one can reuse the same JAXPR equation in multiple places (i.e., common subexpression elimination). I was wondering if instead of applying CSE at the MLIR stage, whether just generating less JAXPR could be advantageous. I.e., instead of spending compile time doing CSE in MLIR stage, one could generate better code during tracing and save time in MLIR translation and in during CSE in MLIR (less work needed to be done, although since this is C++ it may be negligible). Similarly, during the JAXPR to MLIR translation layer, before generating a new operation one could check if an equivalent one already exists under the current context. E.g., instead of:
tracing could generate:
Of course, there are trade-offs. Like, keeping around all hashed equations around and checking their equivalences before generating them. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
I'm marking this as resolved via 618460e |
Beta Was this translation helpful? Give feedback.
I'm marking this as resolved via 618460e