[jaxprs] Hoist large constants as arguments during lowering #30180
+597
−173
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
After #29882 non-scalar constants appear as
core.Literal
inJaxpr.eqns.invars
, andClosedJaxpr
are mostly empty ofconsts
. While this simplifies the Jaxprs, there are several problem with the current lowering where we serialize the constants as HLO constants:With the change in this PR, during lowering we hoist the constants out of the Jaxpr (the new function
jaxpr_hoisted_consts
), we pass them as arguments to the top-level function and we thread them to other functions that need them. We do yet hoist constants for AOT because we don't want to change the semantics oflower
to produce a lowering and a list of constants.I tried several places in the call stack where to hoist the constants as arguments, and update the avals, sharding, layouts for the arguments to reflect the newly added arguments. I found that it is easiest to do this fairly high up the call stack, in
pxla.lower_sharding_computation
because the new avals/sharding/layouts are then passed down to the lowering functions from there, but are also put inMeshComputation
. I had to add several arguments to lowering functions to passavals
(they were before picked from the Jaxpr) andnum_hoisted_consts
. The latter makes some sense to have as an argument because it specifies the calling convention desired from the lowered functions.This change has no effect, unless
JAX_USE_SIMPLIFIED_JAXPR_CONSTANTS=True
. If the flag isFalse
(still the default) thenjaxpr_hoisted_consts
returns the empty list, and throughout this changenum_hoisted_consts
is 0. The code behavior should not change.There are a few more tests to fix before we can flip the default for
JAX_USE_SIMPLIFIED_JAXPR_CONSTANTS
, to be done in several smaller follow-up PRs.