Skip to content

Commit 0d8eb45

Browse files
yashk2810jax authors
authored andcommitted
Remove the sharding and layout checks for non-DCE'd arguments during AOT safe call.
This is because the tracing, lowering and compilation caches do not register a miss if sharding/layout of a DCE'd arg changes when it's passed again to a jitted function. This is not true for avals so that check still exists. PiperOrigin-RevId: 623375760
1 parent 987e4b4 commit 0d8eb45

File tree

3 files changed

+22
-52
lines changed

3 files changed

+22
-52
lines changed

jax/_src/interpreters/pxla.py

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1989,10 +1989,8 @@ def are_all_shardings_default_mem_kind(da_object, shardings):
19891989

19901990

19911991
class AllArgsInfo(NamedTuple):
1992-
"""Avals, shardings, layouts and debug_info for all arguments prior to DCE."""
1992+
"""Avals and debug_info for all arguments prior to DCE."""
19931993
in_avals: Sequence[core.ShapedArray]
1994-
in_shardings: Any
1995-
in_layouts: Any
19961994
debug_info: core.JaxprDebugInfo | None
19971995

19981996

@@ -2038,8 +2036,7 @@ def lower_sharding_computation(
20382036
auto_spmd_lowering = check_if_any_auto(
20392037
it.chain.from_iterable([in_shardings, out_shardings])) # type: ignore
20402038

2041-
all_args_info = AllArgsInfo(global_in_avals, in_shardings, in_layouts,
2042-
closed_jaxpr.jaxpr.debug_info)
2039+
all_args_info = AllArgsInfo(global_in_avals, closed_jaxpr.jaxpr.debug_info)
20432040

20442041
(closed_jaxpr, global_in_avals, global_out_avals, donated_invars,
20452042
kept_var_idx, name_stack) = _dce_jaxpr(
@@ -3013,28 +3010,22 @@ def xla_extension_executable(self):
30133010
return self.xla_executable
30143011

30153012
def call(self, *args):
3013+
args_after_dce = [a for i, a in enumerate(args) if i in self._kept_var_idx]
30163014
if self._all_args_info is None:
3017-
kept_args = [a for i, a in enumerate(args) if i in self._kept_var_idx]
3015+
kept_args = args_after_dce
30183016
ref_avals = self.in_avals
3019-
in_shardings = self._in_shardings
3020-
in_layouts = self._in_layouts
30213017
debug_info = None
30223018
else:
30233019
kept_args = args
30243020
ref_avals = self._all_args_info.in_avals
3025-
iter_in_shardings = iter(self._in_shardings)
3026-
in_shardings = [next(iter_in_shardings) if i in self._kept_var_idx else s
3027-
for i, s in enumerate(self._all_args_info.in_shardings)]
3028-
iter_in_layouts = iter(self._in_layouts)
3029-
in_layouts = [next(iter_in_layouts) if i in self._kept_var_idx else s
3030-
for i, s in enumerate(self._all_args_info.in_layouts)]
30313021
debug_info = self._all_args_info.debug_info
30323022

3033-
arg_avals = map(xla.abstractify, kept_args)
3034-
check_arg_avals_for_call(ref_avals, arg_avals, debug_info)
3023+
all_arg_avals = map(xla.abstractify, kept_args)
3024+
check_arg_avals_for_call(ref_avals, all_arg_avals, debug_info)
30353025
# Check the GDA sharding and the input sharding.
3036-
check_array_xla_sharding_layout_match(kept_args, in_shardings,
3037-
in_layouts, debug_info)
3026+
check_array_xla_sharding_layout_match(
3027+
args_after_dce, self._in_shardings, self._in_layouts, debug_info,
3028+
self._kept_var_idx)
30383029
return self.unsafe_call(*args) # pylint: disable=not-callable
30393030

30403031
def input_shardings(self) -> Sequence[sharding_impls.XLACompatibleSharding]:
@@ -3163,16 +3154,22 @@ def check_device_backend_on_shardings(shardings) -> bool:
31633154

31643155

31653156
def check_array_xla_sharding_layout_match(
3166-
args, in_xla_shardings: Sequence[sharding_impls.XLACompatibleSharding],
3157+
args_after_dce,
3158+
in_xla_shardings: Sequence[sharding_impls.XLACompatibleSharding],
31673159
in_xla_layouts: Sequence[DeviceLocalLayout],
3168-
jaxpr_debug_info: core.JaxprDebugInfo | None) -> None:
3160+
jaxpr_debug_info: core.JaxprDebugInfo | None,
3161+
kept_var_idx: set[int]) -> None:
31693162
from jax._src.array import ArrayImpl
3170-
arg_names = ([''] * len(args) if jaxpr_debug_info is None else
3171-
jaxpr_debug_info.arg_names)
3163+
# jaxpr_debug_info.arg_names are before DCE, so need to DCE them.
3164+
arg_names = (
3165+
[""] * len(args_after_dce) if jaxpr_debug_info is None
3166+
else [a for i, a in enumerate(jaxpr_debug_info.arg_names) # type: ignore
3167+
if i in kept_var_idx]
3168+
)
31723169
errors = []
31733170
num_errors = 5
3174-
for arg, xs, xl, name in safe_zip(args, in_xla_shardings, in_xla_layouts,
3175-
arg_names):
3171+
for arg, xs, xl, name in safe_zip(
3172+
args_after_dce, in_xla_shardings, in_xla_layouts, arg_names):
31763173
if not isinstance(arg, ArrayImpl):
31773174
continue
31783175
if is_unspecified_or_auto(xs):
@@ -3200,7 +3197,6 @@ def check_array_xla_sharding_layout_match(
32003197

32013198
if (xla_extension_version >= 249 and not db_xs and arg._committed and
32023199
arg.layout.device_local_layout is not None and xl is not None and
3203-
not isinstance(xl, AutoLayout) and
32043200
arg.layout.device_local_layout != xl):
32053201
errors.append(
32063202
("Got input layout(s) that compiled object was called with: "

jax/_src/pjit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1486,7 +1486,7 @@ def _pjit_call_impl_python(
14861486
if compiled._auto_spmd_lowering and config.enable_checks.value:
14871487
pxla.check_array_xla_sharding_layout_match(
14881488
args, compiled._in_shardings, compiled._in_layouts,
1489-
jaxpr.jaxpr.debug_info)
1489+
jaxpr.jaxpr.debug_info, compiled._kept_var_idx)
14901490
if config.distributed_debug.value:
14911491
# Defensively only perform fingerprint logic if debug logging is enabled
14921492
# NOTE(skyewm): I didn't benchmark this

tests/pjit_test.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4368,32 +4368,6 @@ def f(x, y):
43684368
' compiled'):
43694369
g(x, y2)
43704370

4371-
def test_aot_error_on_dced_shardings_mismatch(self):
4372-
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
4373-
shape = (8, 2)
4374-
np_inp = np.arange(math.prod(shape)).reshape(shape)
4375-
4376-
x = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))
4377-
y1 = jax.device_put(np_inp, NamedSharding(mesh, P('x')))
4378-
y2 = jax.device_put(np_inp, NamedSharding(mesh, P('y')))
4379-
4380-
@jax.jit
4381-
def f(x, y):
4382-
return x + 1
4383-
4384-
f_out1 = f(x, y1)
4385-
f(x, y2)
4386-
4387-
g = f.lower(x, y1).compile()
4388-
g_out1 = g(x, y1)
4389-
self.assertArraysEqual(f_out1, g_out1)
4390-
4391-
with self.assertRaisesRegex(
4392-
ValueError,
4393-
r"Compiled object called with input sharding.*does not match the "
4394-
r"sharding.*the computation was compiled with"):
4395-
g(x, y2)
4396-
43974371
def test_dce_no_array(self):
43984372
mesh = jtu.create_global_mesh((2,), ('x',))
43994373
arr = jax.device_put(np.arange(8.), NamedSharding(mesh, P('x')))

0 commit comments

Comments
 (0)