Skip to content

Commit 773d519

Browse files
committed
[jaxprs] The next step in ClosedJaxpr simplification
Now after #29882 non-scalar constants appear as `core.Literal` in `Jaxpr.eqns.invars`, and `ClosedJaxpr` are mostly empty of `consts`. Before this change the lowering would serialize the constant as a HLO constant. There are several problems with this: * if the constant is on a device, it gets copied to the host and is embedded in the lowered HLO. This increases memory usage, and if the constant was sharded on multiple devices it will now become replicated. * if a constant appears in multiple places in the Jaxprs, today we will make copies. * if the constants are inlined in the HLO the XLA constant folding starts to constant-fold them, and we see warnings of very slow constant folding. * the constant folding changes the numerical behavior compared to current lowering strategy. With this change we hoist the constants out of the Jaxpr (the new function `jaxpr_hoisted_consts`, we pass them as argument to the top-level function and we thread them to other functions that need them. We add new arguments to functions to represent the hoisted constants. This change has no effect, unless JAX_USE_SIMPLIFIED_JAXPR_CONSTANTS=True. This is still an experiment.
1 parent fd8a79f commit 773d519

23 files changed

+580
-169
lines changed

jax/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1235,6 +1235,7 @@ pytype_strict_library(
12351235
":sharding_impls",
12361236
":source_info_util",
12371237
":state_types",
1238+
":typing",
12381239
":util",
12391240
":xla",
12401241
":xla_bridge",

jax/_src/ad_checkpoint.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -783,7 +783,7 @@ def _remat_translation_using_opt_barrier(*args, jaxpr: core.Jaxpr):
783783

784784

785785
def _remat_lowering(
786-
ctx,
786+
ctx: mlir.LoweringRuleContext,
787787
*args,
788788
jaxpr: core.Jaxpr,
789789
prevent_cse: bool,
@@ -801,7 +801,8 @@ def _remat_lowering(
801801
jaxpr_args = args
802802
outs, tokens_out = mlir.jaxpr_subcomp(
803803
ctx.module_context, jaxpr, ctx.name_stack.extend('checkpoint'),
804-
ctx.tokens_in, (), *jaxpr_args, dim_var_values=ctx.dim_var_values)
804+
ctx.tokens_in, (), *jaxpr_args, dim_var_values=ctx.dim_var_values,
805+
const_lowering=ctx.const_lowering)
805806
ctx.set_tokens_out(tokens_out)
806807
return outs
807808

jax/_src/api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1690,8 +1690,8 @@ def cache_miss(*args, **kwargs):
16901690
with core.take_current_trace() as trace:
16911691
try:
16921692
if isinstance(trace, core.EvalTrace):
1693-
execute = pxla.xla_pmap_impl_lazy(p.flat_fun, *p.flat_args, **params)
1694-
out = execute(*p.flat_args)
1693+
execute, hoisted_consts = pxla.xla_pmap_impl_lazy(p.flat_fun, *p.flat_args, **params)
1694+
out = execute(*hoisted_consts, *p.flat_args)
16951695
else:
16961696
out = pxla.xla_pmap_p.bind_with_trace(trace, (p.flat_fun, *p.flat_args), params)
16971697
except api_util.InternalFloatingPointError as e:

jax/_src/core.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
from jax._src.lib import jax_jit
6060
from jax._src.lib import xla_client
6161
from jax._src import traceback_util
62-
from jax._src.typing import Array, DimSize, Shape
62+
from jax._src.typing import Array, ArrayLike, DimSize, Shape
6363
from jax._src import typing
6464
from jax._src import xla_metadata_lib
6565

@@ -548,6 +548,33 @@ def is_literalable(x: Any) -> bool:
548548
return (not np.shape(x) or config.use_simplified_jaxpr_constants.value)
549549
return False
550550

551+
@partial(weakref_lru_cache, trace_context_in_key=False)
552+
def jaxpr_hoisted_consts(jaxpr: Jaxpr) -> list[ArrayLike]:
553+
# The non-scalar constants in core.Literal, in the entire Jaxpr,
554+
# uniquified by id. These will be hoisted as arguments to the functions
555+
# in which they appear.
556+
if not config.use_simplified_jaxpr_constants.value:
557+
return []
558+
consts_by_id: dict[int, Any] = {}
559+
for v in jaxpr.outvars:
560+
if type(v) is Literal and np.shape(v.val): # type: ignore
561+
consts_by_id[id(v)] = v.val # type: ignore
562+
563+
for eqn in jaxpr.eqns:
564+
for v in eqn.invars:
565+
if type(v) is Literal and np.shape(v.val): # type: ignore
566+
consts_by_id[id(v)] = v.val # type: ignore
567+
consts_by_id.update({id(v): v
568+
for v in eqn_params_hoisted_consts(eqn.params)})
569+
return list(consts_by_id.values())
570+
571+
def eqn_params_hoisted_consts(params) -> list[ArrayLike]:
572+
consts_by_id: dict[int, Any] = {}
573+
for j in jaxprs_in_params(params):
574+
consts_by_id.update({id(v): v for v in jaxpr_hoisted_consts(j)})
575+
576+
return list(consts_by_id.values())
577+
551578
Atom = Union[Var, Literal]
552579

553580
class Primitive:

jax/_src/custom_derivatives.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -437,11 +437,13 @@ def _custom_jvp_call_typecheck(_, *in_avals, call_jaxpr, jvp_jaxpr_fun,
437437
return call_jaxpr.out_avals, call_jaxpr.effects
438438
core.custom_typechecks[custom_jvp_call_p] = _custom_jvp_call_typecheck
439439

440-
def _custom_jvp_vjp_call_lowering(ctx, *args, call_jaxpr, **_):
440+
def _custom_jvp_vjp_call_lowering(ctx: mlir.LoweringRuleContext, *args,
441+
call_jaxpr: core.ClosedJaxpr, **_):
441442
consts = mlir._ir_consts(call_jaxpr.consts)
442443
out, tokens = mlir.jaxpr_subcomp(ctx.module_context, call_jaxpr.jaxpr,
443444
ctx.name_stack, ctx.tokens_in, consts,
444-
*args, dim_var_values=ctx.dim_var_values)
445+
*args, dim_var_values=ctx.dim_var_values,
446+
const_lowering=ctx.const_lowering)
445447
ctx.set_tokens_out(tokens)
446448
return out
447449
mlir.register_lowering(custom_jvp_call_p, _custom_jvp_vjp_call_lowering)

jax/_src/export/_export.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,7 @@ def do_export(*args_specs, **kwargs_specs) -> Exported:
617617
_private_parameters=mlir.LoweringParameters(
618618
override_lowering_rules=override_lowering_rules,
619619
for_export=True,
620+
hoist_constants_as_args=False,
620621
export_ignore_forward_compatibility=config.export_ignore_forward_compatibility.value))
621622
return _export_lowered(
622623
lowered, traced.jaxpr, traced.fun_name,
@@ -954,15 +955,16 @@ def is_token(typ, attrs):
954955
host_callbacks=[], module=wrapped_module, context=context,
955956
lowering_parameters=mlir.LoweringParameters(
956957
global_constant_computation=True,
957-
for_export=True,
958+
for_export=True, hoist_constants_as_args=False,
958959
export_ignore_forward_compatibility=config.export_ignore_forward_compatibility.value,
959960
))
960961
ctx = mlir.LoweringRuleContext(
961962
module_context=module_context,
962963
name_stack=source_info_util.new_name_stack(), traceback=None,
963964
primitive=None,
964965
avals_in=args_avals_flat, avals_out=None,
965-
tokens_in=mlir.TokenSet(), tokens_out=None)
966+
tokens_in=mlir.TokenSet(), tokens_out=None,
967+
const_lowering={})
966968
# We compute dim_values from the array arguments.
967969
new_main_op_array_args = new_main_op.arguments[-nr_array_args:]
968970
if shape_poly.all_dim_vars(args_avals_flat):

jax/_src/frozen_dict.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,6 @@ def __iter__(self) -> Iterator[K]:
4747

4848
def __len__(self) -> int:
4949
return len(self._d)
50+
51+
def get(self, key: K) -> V | None: # type: ignore
52+
return self._d.get(key, None)

0 commit comments

Comments
 (0)