Skip to content

[jaxprs] Hoist large constants as arguments during lowering #30180

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

gnecula
Copy link
Collaborator

@gnecula gnecula commented Jul 14, 2025

After #29882 non-scalar constants appear as core.Literal in Jaxpr.eqns.invars, and ClosedJaxpr are mostly empty of consts. While this simplifies the Jaxprs, there are several problem with the current lowering where we serialize the constants as HLO constants:

  • if the constant is on a device, it gets copied to the host and is embedded in the lowered HLO. This increases memory usage, and if the constant was sharded on multiple devices it will now become replicated. Several large Google-internal tests are OOMing.
  • if a constant appears in multiple places in the Jaxpr, we will make copies.
  • if the constants are inlined in the HLO the XLA constant folding starts to constant-fold them, and we see warnings due to very slow constant folding.
  • the constant folding changes the numerical behavior compared to current lowering strategy.

With the change in this PR, during lowering we hoist the constants out of the Jaxpr (the new function jaxpr_hoisted_consts), we pass them as arguments to the top-level function and we thread them to other functions that need them. We do yet hoist constants for AOT because we don't want to change the semantics of lower to produce a lowering and a list of constants.

I tried several places in the call stack where to hoist the constants as arguments, and update the avals, sharding, layouts for the arguments to reflect the newly added arguments. I found that it is easiest to do this fairly high up the call stack, in pxla.lower_sharding_computation because the new avals/sharding/layouts are then passed down to the lowering functions from there, but are also put in MeshComputation. I had to add several arguments to lowering functions to pass avals (they were before picked from the Jaxpr) and num_hoisted_consts. The latter makes some sense to have as an argument because it specifies the calling convention desired from the lowered functions.

This change has no effect, unless JAX_USE_SIMPLIFIED_JAXPR_CONSTANTS=True. If the flag is False (still the default) then jaxpr_hoisted_consts returns the empty list, and throughout this change num_hoisted_consts is 0. The code behavior should not change.

There are a few more tests to fix before we can flip the default for JAX_USE_SIMPLIFIED_JAXPR_CONSTANTS, to be done in several smaller follow-up PRs.

@gnecula gnecula self-assigned this Jul 14, 2025
@gnecula gnecula added the pull ready Ready for copybara import and testing label Jul 14, 2025
@gnecula gnecula changed the title [jaxprs] The next step in ClosedJaxpr simplification [jaxprs] Hoist large constants as arguments during lowering Jul 14, 2025
@gnecula gnecula force-pushed the jaxpr_consts_pass branch 21 times, most recently from eeb5384 to 497ffb4 Compare July 18, 2025 02:38
@gnecula gnecula force-pushed the jaxpr_consts_pass branch from 497ffb4 to d06561c Compare July 23, 2025 01:52
@gnecula gnecula requested a review from yashk2810 July 23, 2025 02:03
@gnecula gnecula force-pushed the jaxpr_consts_pass branch 4 times, most recently from e3c9e2b to 773d519 Compare July 23, 2025 19:01
@gnecula gnecula force-pushed the jaxpr_consts_pass branch 6 times, most recently from 90cb44d to 1200679 Compare July 24, 2025 18:20
Now after jax-ml#29882 non-scalar constants appear as `core.Literal`
in `Jaxpr.eqns.invars`, and `ClosedJaxpr` are mostly empty of `consts`.
Before this change the lowering would serialize the constant as
a HLO constant. There are several problems with this:

  * if the constant is on a device, it gets copied to the host and
  is embedded in the lowered HLO. This increases memory usage,
  and if the constant was sharded on multiple devices it will
  now become replicated.
  * if a constant appears in multiple places in the Jaxprs, today
  we will make copies.
  * if the constants are inlined in the HLO the XLA constant folding
  starts to constant-fold them, and we see warnings of very slow
  constant folding.
  * the constant folding changes the numerical behavior compared
  to current lowering strategy.

With this change we hoist the constants out of the Jaxpr (the
new function `jaxpr_hoisted_consts`, we pass them as argument to
the top-level function and we thread them to other functions
that need them. We add new arguments to functions to represent
the hoisted constants.

This change has no effect, unless JAX_USE_SIMPLIFIED_JAXPR_CONSTANTS=True.

This is still an experiment.
@gnecula gnecula force-pushed the jaxpr_consts_pass branch from 1200679 to 4ce6830 Compare July 24, 2025 18:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant