-
I'm just wondering if we can ignore the computations unrelated to the output of a jitted function?
corresponds to the following jaxpr
The computations about the variable |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Thanks for the question! The short answer is: use If you want to see the compiled HLO with the dead code eliminated, you can do something like this: print(jax.jit(foo).lower().compile().as_text())
|
Beta Was this translation helpful? Give feedback.
Thanks for the question! The short answer is: use
jit
and unnecessary computations will be pruned by XLA as part of the compilation process. Note that because this happens at the XLA level, the dead code elimination (DCE) will not be reflected in the printed jaxpr. There's been some discussion in the past about doing DCE also at the jaxpr lavel; you can read more at #4818.If you want to see the compiled HLO with the dead code eliminated, you can do something like this: