Skip to content

[jaxprs] Hoist large constants as arguments during lowering #30180

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions jax/_src/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -953,6 +953,7 @@ pytype_strict_library(
":sharding_impls",
":source_info_util",
":state_types",
":typing",
":util",
":xla",
":xla_bridge",
Expand Down
5 changes: 3 additions & 2 deletions jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,7 @@ def _remat_translation_using_opt_barrier(*args, jaxpr: core.Jaxpr):


def _remat_lowering(
ctx,
ctx: mlir.LoweringRuleContext,
*args,
jaxpr: core.Jaxpr,
prevent_cse: bool,
Expand All @@ -801,7 +801,8 @@ def _remat_lowering(
jaxpr_args = args
outs, tokens_out = mlir.jaxpr_subcomp(
ctx.module_context, jaxpr, ctx.name_stack.extend('checkpoint'),
ctx.tokens_in, (), *jaxpr_args, dim_var_values=ctx.dim_var_values)
ctx.tokens_in, (), *jaxpr_args, dim_var_values=ctx.dim_var_values,
const_lowering=ctx.const_lowering)
ctx.set_tokens_out(tokens_out)
return outs

Expand Down
4 changes: 2 additions & 2 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1690,8 +1690,8 @@ def cache_miss(*args, **kwargs):
with core.take_current_trace() as trace:
try:
if isinstance(trace, core.EvalTrace):
execute = pxla.xla_pmap_impl_lazy(p.flat_fun, *p.flat_args, **params)
out = execute(*p.flat_args)
execute, hoisted_consts = pxla.xla_pmap_impl_lazy(p.flat_fun, *p.flat_args, **params)
out = execute(*hoisted_consts, *p.flat_args)
else:
out = pxla.xla_pmap_p.bind_with_trace(trace, (p.flat_fun, *p.flat_args), params)
except api_util.InternalFloatingPointError as e:
Expand Down
29 changes: 28 additions & 1 deletion jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
from jax._src.lib import jax_jit
from jax._src.lib import xla_client
from jax._src import traceback_util
from jax._src.typing import Array, DimSize, Shape
from jax._src.typing import Array, ArrayLike, DimSize, Shape
from jax._src import typing
from jax._src import xla_metadata_lib

Expand Down Expand Up @@ -548,6 +548,33 @@ def is_literalable(x: Any) -> bool:
return (not np.shape(x) or config.use_simplified_jaxpr_constants.value)
return False

@partial(weakref_lru_cache, trace_context_in_key=False)
def jaxpr_hoisted_consts(jaxpr: Jaxpr) -> list[ArrayLike]:
# The non-scalar constants in core.Literal, in the entire Jaxpr,
# uniquified by id. These will be hoisted as arguments to the functions
# in which they appear.
if not config.use_simplified_jaxpr_constants.value:
return []
consts_by_id: dict[int, Any] = {}
for v in jaxpr.outvars:
if type(v) is Literal and np.shape(v.val): # type: ignore
consts_by_id[id(v)] = v.val # type: ignore

for eqn in jaxpr.eqns:
for v in eqn.invars:
if type(v) is Literal and np.shape(v.val): # type: ignore
consts_by_id[id(v)] = v.val # type: ignore
consts_by_id.update({id(v): v
for v in eqn_params_hoisted_consts(eqn.params)})
return list(consts_by_id.values())

def eqn_params_hoisted_consts(params) -> list[ArrayLike]:
consts_by_id: dict[int, Any] = {}
for j in jaxprs_in_params(params):
consts_by_id.update({id(v): v for v in jaxpr_hoisted_consts(j)})

return list(consts_by_id.values())

Atom = Union[Var, Literal]

class Primitive:
Expand Down
6 changes: 4 additions & 2 deletions jax/_src/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,11 +437,13 @@ def _custom_jvp_call_typecheck(_, *in_avals, call_jaxpr, jvp_jaxpr_fun,
return call_jaxpr.out_avals, call_jaxpr.effects
core.custom_typechecks[custom_jvp_call_p] = _custom_jvp_call_typecheck

def _custom_jvp_vjp_call_lowering(ctx, *args, call_jaxpr, **_):
def _custom_jvp_vjp_call_lowering(ctx: mlir.LoweringRuleContext, *args,
call_jaxpr: core.ClosedJaxpr, **_):
consts = mlir._ir_consts(call_jaxpr.consts)
out, tokens = mlir.jaxpr_subcomp(ctx.module_context, call_jaxpr.jaxpr,
ctx.name_stack, ctx.tokens_in, consts,
*args, dim_var_values=ctx.dim_var_values)
*args, dim_var_values=ctx.dim_var_values,
const_lowering=ctx.const_lowering)
ctx.set_tokens_out(tokens)
return out
mlir.register_lowering(custom_jvp_call_p, _custom_jvp_vjp_call_lowering)
Expand Down
6 changes: 4 additions & 2 deletions jax/_src/export/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,7 @@ def do_export(*args_specs, **kwargs_specs) -> Exported:
_private_parameters=mlir.LoweringParameters(
override_lowering_rules=override_lowering_rules,
for_export=True,
hoist_constants_as_args=False,
export_ignore_forward_compatibility=config.export_ignore_forward_compatibility.value))
return _export_lowered(
lowered, traced.jaxpr, traced.fun_name,
Expand Down Expand Up @@ -954,15 +955,16 @@ def is_token(typ, attrs):
host_callbacks=[], module=wrapped_module, context=context,
lowering_parameters=mlir.LoweringParameters(
global_constant_computation=True,
for_export=True,
for_export=True, hoist_constants_as_args=False,
export_ignore_forward_compatibility=config.export_ignore_forward_compatibility.value,
))
ctx = mlir.LoweringRuleContext(
module_context=module_context,
name_stack=source_info_util.new_name_stack(), traceback=None,
primitive=None,
avals_in=args_avals_flat, avals_out=None,
tokens_in=mlir.TokenSet(), tokens_out=None)
tokens_in=mlir.TokenSet(), tokens_out=None,
const_lowering={})
# We compute dim_values from the array arguments.
new_main_op_array_args = new_main_op.arguments[-nr_array_args:]
if shape_poly.all_dim_vars(args_avals_flat):
Expand Down
3 changes: 3 additions & 0 deletions jax/_src/frozen_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,6 @@ def __iter__(self) -> Iterator[K]:

def __len__(self) -> int:
return len(self._d)

def get(self, key: K) -> V | None: # type: ignore
return self._d.get(key, None)
Loading
Loading