Skip to content

How to prune the computations unrelated to the output? #12712

Answered by jakevdp
llan-ml asked this question in General
Discussion options

You must be logged in to vote

Thanks for the question! The short answer is: use jit and unnecessary computations will be pruned by XLA as part of the compilation process. Note that because this happens at the XLA level, the dead code elimination (DCE) will not be reflected in the printed jaxpr. There's been some discussion in the past about doing DCE also at the jaxpr lavel; you can read more at #4818.

If you want to see the compiled HLO with the dead code eliminated, you can do something like this:

print(jax.jit(foo).lower().compile().as_text())
%region_0.7 (Arg_0.8: f32[], Arg_1.9: f32[]) -> f32[] {
  %Arg_0.8 = f32[] parameter(0)
  %Arg_1.9 = f32[] parameter(1)
  ROOT %add.10 = f32[] add(f32[] %Arg_0.8, f32[] %Arg_…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@llan-ml
Comment options

Answer selected by llan-ml
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants