Skip to content

Commit 497ffb4

Browse files
committed
[jaxprs] The next step in ClosedJaxpr simplification
Now that after #29882 non-scalar constants appear as `core.Literal` in `Jaxpr.eqns.invars`, and `ClosedJaxpr` are mostly empty of `consts`. Before this change the lowering would serialize the constant as a HLO constant. There are several problems with this: * if the constant is on a device, it gets copied to the host and is embedded in the lowered HLO. This increases memory usage, and if the constant was sharded on multiple devices it will now become replicated. * if a constant appears in multiple places in the Jaxprs, today we will make copies. * if the constants are inlined in the HLO the XLA constant folding starts to constant-fold them, and we see warnings of very slow constant folding. * the constant folding changes the numerical behavior compared to current lowering strategy. With this change we hoist the constants out of the Jaxpr (the new function `jaxpr_hoisted_consts`, we pass them as argument to the top-level function and we thread them to other functions that need them. We add new arguments to functions to represent the hoisted constants. This change has no effect, unless JAX_USE_SIMPLIFIED_JAXPR_CONSTANTS=True. This is still an experiment.
1 parent 8020564 commit 497ffb4

24 files changed

+567
-165
lines changed

jax/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1214,6 +1214,7 @@ pytype_strict_library(
12141214
":sharding_impls",
12151215
":source_info_util",
12161216
":state_types",
1217+
":typing",
12171218
":util",
12181219
":xla",
12191220
":xla_bridge",

jax/_src/ad_checkpoint.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -779,7 +779,7 @@ def _remat_translation_using_opt_barrier(*args, jaxpr: core.Jaxpr):
779779

780780

781781
def _remat_lowering(
782-
ctx,
782+
ctx: mlir.LoweringRuleContext,
783783
*args,
784784
jaxpr: core.Jaxpr,
785785
prevent_cse: bool,
@@ -797,7 +797,8 @@ def _remat_lowering(
797797
jaxpr_args = args
798798
outs, tokens_out = mlir.jaxpr_subcomp(
799799
ctx.module_context, jaxpr, ctx.name_stack.extend('checkpoint'),
800-
ctx.tokens_in, (), *jaxpr_args, dim_var_values=ctx.dim_var_values)
800+
ctx.tokens_in, (), *jaxpr_args, dim_var_values=ctx.dim_var_values,
801+
const_lowering=ctx.const_lowering)
801802
ctx.set_tokens_out(tokens_out)
802803
return outs
803804

jax/_src/api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1689,8 +1689,8 @@ def cache_miss(*args, **kwargs):
16891689
with core.take_current_trace() as trace:
16901690
try:
16911691
if isinstance(trace, core.EvalTrace):
1692-
execute = pxla.xla_pmap_impl_lazy(p.flat_fun, *p.flat_args, **params)
1693-
out = execute(*p.flat_args)
1692+
execute, hoisted_consts = pxla.xla_pmap_impl_lazy(p.flat_fun, *p.flat_args, **params)
1693+
out = execute(*hoisted_consts, *p.flat_args)
16941694
else:
16951695
out = pxla.xla_pmap_p.bind_with_trace(trace, (p.flat_fun, *p.flat_args), params)
16961696
except api_util.InternalFloatingPointError as e:

jax/_src/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1109,7 +1109,7 @@ def _validate_jax_pjrt_client_create_options(new_val):
11091109

11101110
use_simplified_jaxpr_constants = bool_state(
11111111
name='jax_use_simplified_jaxpr_constants',
1112-
default=False,
1112+
default=True,
11131113
help=('Enable a simplification of the handling of closed-over constants '
11141114
'in Jaxpr. The value `True` enables the new behavior. '
11151115
'This flag will exist only briefly, while we transition '

jax/_src/core.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
from jax._src.lib import jax_jit
5858
from jax._src.lib import xla_client
5959
from jax._src import traceback_util
60-
from jax._src.typing import Array, DimSize, Shape
60+
from jax._src.typing import Array, ArrayLike, DimSize, Shape
6161
from jax._src import typing
6262
from jax._src import xla_metadata as xla_metadata_lib
6363

@@ -546,6 +546,33 @@ def is_literalable(x: Any) -> bool:
546546
return (not np.shape(x) or config.use_simplified_jaxpr_constants.value)
547547
return False
548548

549+
@partial(weakref_lru_cache, trace_context_in_key=False)
550+
def jaxpr_hoisted_consts(jaxpr: Jaxpr) -> list[ArrayLike]:
551+
# The non-scalar constants in core.Literal, in the entire Jaxpr,
552+
# uniquified by id. These will be hoisted as arguments to the functions
553+
# in which they appear.
554+
if not config.use_simplified_jaxpr_constants.value:
555+
return []
556+
consts_by_id: dict[int, Any] = {}
557+
for v in jaxpr.outvars:
558+
if type(v) is Literal and np.shape(v.val): # type: ignore
559+
consts_by_id[id(v)] = v.val # type: ignore
560+
561+
for eqn in jaxpr.eqns:
562+
for v in eqn.invars:
563+
if type(v) is Literal and np.shape(v.val): # type: ignore
564+
consts_by_id[id(v)] = v.val # type: ignore
565+
consts_by_id.update({id(v): v
566+
for v in eqn_params_hoisted_consts(eqn.params)})
567+
return list(consts_by_id.values())
568+
569+
def eqn_params_hoisted_consts(params) -> list[ArrayLike]:
570+
consts_by_id: dict[int, Any] = {}
571+
for j in jaxprs_in_params(params):
572+
consts_by_id.update({id(v): v for v in jaxpr_hoisted_consts(j)})
573+
574+
return list(consts_by_id.values())
575+
549576
Atom = Union[Var, Literal]
550577

551578
class Primitive:

jax/_src/custom_derivatives.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -437,11 +437,13 @@ def _custom_jvp_call_typecheck(_, *in_avals, call_jaxpr, jvp_jaxpr_fun,
437437
return call_jaxpr.out_avals, call_jaxpr.effects
438438
core.custom_typechecks[custom_jvp_call_p] = _custom_jvp_call_typecheck
439439

440-
def _custom_jvp_vjp_call_lowering(ctx, *args, call_jaxpr, **_):
440+
def _custom_jvp_vjp_call_lowering(ctx: mlir.LoweringRuleContext, *args,
441+
call_jaxpr: core.ClosedJaxpr, **_):
441442
consts = mlir._ir_consts(call_jaxpr.consts)
442443
out, tokens = mlir.jaxpr_subcomp(ctx.module_context, call_jaxpr.jaxpr,
443444
ctx.name_stack, ctx.tokens_in, consts,
444-
*args, dim_var_values=ctx.dim_var_values)
445+
*args, dim_var_values=ctx.dim_var_values,
446+
const_lowering=ctx.const_lowering)
445447
ctx.set_tokens_out(tokens)
446448
return out
447449
mlir.register_lowering(custom_jvp_call_p, _custom_jvp_vjp_call_lowering)

jax/_src/export/_export.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,7 @@ def do_export(*args_specs, **kwargs_specs) -> Exported:
619619
_private_parameters=mlir.LoweringParameters(
620620
override_lowering_rules=override_lowering_rules,
621621
for_export=True,
622+
hoist_constants_as_args=False,
622623
export_ignore_forward_compatibility=config.export_ignore_forward_compatibility.value))
623624
return _export_lowered(
624625
lowered, traced.jaxpr, traced.fun_name,
@@ -963,15 +964,16 @@ def is_token(typ, attrs):
963964
host_callbacks=[], module=wrapped_module, context=context,
964965
lowering_parameters=mlir.LoweringParameters(
965966
global_constant_computation=True,
966-
for_export=True,
967+
for_export=True, hoist_constants_as_args=False,
967968
export_ignore_forward_compatibility=config.export_ignore_forward_compatibility.value,
968969
))
969970
ctx = mlir.LoweringRuleContext(
970971
module_context=module_context,
971972
name_stack=source_info_util.new_name_stack(), traceback=None,
972973
primitive=None,
973974
avals_in=args_avals_flat, avals_out=None,
974-
tokens_in=mlir.TokenSet(), tokens_out=None)
975+
tokens_in=mlir.TokenSet(), tokens_out=None,
976+
const_lowering={})
975977
# We compute dim_values from the array arguments.
976978
new_main_op_array_args = new_main_op.arguments[-nr_array_args:]
977979
if shape_poly.all_dim_vars(args_avals_flat):

jax/_src/frozen_dict.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,6 @@ def __iter__(self) -> Iterator[K]:
4747

4848
def __len__(self) -> int:
4949
return len(self._d)
50+
51+
def get(self, key: K) -> V | None: # type: ignore
52+
return self._d.get(key, None)

0 commit comments

Comments
 (0)