Skip to content

Commit 0e092a7

Browse files
yashk2810jax authors
authored andcommitted
Expose .layout on jax.Array. Also add checks in the AOT path to make sure that the input Array's layout matches the layout given to jax.jit.
PiperOrigin-RevId: 618050680
1 parent 73aadbf commit 0e092a7

File tree

8 files changed

+202
-76
lines changed

8 files changed

+202
-76
lines changed

jax/_src/array.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,12 @@
3434
from jax._src import dispatch
3535
from jax._src import dtypes
3636
from jax._src import errors
37+
from jax._src import layout
3738
from jax._src import profiler
3839
from jax._src import tree_util
3940
from jax._src import xla_bridge
4041
from jax._src.lib import xla_client as xc
42+
from jax._src.lib import xla_extension as xe
4143
from jax._src.interpreters import mlir
4244
from jax._src.interpreters import pxla
4345
from jax._src.interpreters import xla
@@ -527,6 +529,18 @@ def addressable_shards(self) -> Sequence[Shard]:
527529
out.append(Shard(_get_device(a), self.sharding, self.shape, a))
528530
return out
529531

532+
@property
533+
def layout(self):
534+
# TODO(yashkatariya): Remove the try;except when pathways supports layouts.
535+
try:
536+
return layout.SpecifiedLayout(self._pjrt_layout)
537+
except xe.XlaRuntimeError as e:
538+
msg, *_ = e.args
539+
if type(msg) is str and msg.startswith("UNIMPLEMENTED"):
540+
return None
541+
else:
542+
raise
543+
530544
@property
531545
def global_shards(self) -> Sequence[Shard]:
532546
"""Returns list of all `Shard`s of the Array across all devices.
@@ -637,7 +651,7 @@ def _value(self) -> np.ndarray:
637651
ArrayImpl = use_cpp_class(xc.ArrayImpl)(ArrayImpl)
638652

639653

640-
# explicitly set to be unhashable. Same as what device_array.py does.
654+
# explicitly set to be unhashable.
641655
setattr(ArrayImpl, "__hash__", None)
642656
setattr(ArrayImpl, "__array_priority__", 100)
643657

jax/_src/interpreters/mlir.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
from jax._src import xla_bridge as xb
4747
from jax._src.interpreters import partial_eval as pe
4848
from jax._src.interpreters import xla
49-
from jax._src.layout import XLACompatibleLayout, LayoutRequest
49+
from jax._src.layout import AutoLayout, SpecifiedLayout
5050
from jax._src.lib import xla_client as xc
5151
from jax._src.lib import xla_extension
5252
from jax._src.lib import xla_extension_version
@@ -834,10 +834,10 @@ def _to_physical_op_sharding(
834834
return sharding._to_xla_hlo_sharding(aval.ndim).to_proto() # type: ignore
835835

836836

837-
def _to_xla_layout(layout: XLACompatibleLayout | None | LayoutRequest) -> str | None:
837+
def _to_xla_layout(layout: SpecifiedLayout | None | AutoLayout) -> str | None:
838838
if layout is None:
839839
return "default"
840-
if isinstance(layout, LayoutRequest):
840+
if isinstance(layout, AutoLayout):
841841
return "auto"
842842
return layout._to_xla_layout()
843843

@@ -862,8 +862,8 @@ def lower_jaxpr_to_module(
862862
replicated_args: Sequence[bool] | None = None,
863863
arg_shardings: Sequence[XLACompatibleSharding | None] | None = None,
864864
result_shardings: Sequence[XLACompatibleSharding | None] | None = None,
865-
in_layouts: Sequence[XLACompatibleLayout | None | LayoutRequest] | None = None,
866-
out_layouts: Sequence[XLACompatibleLayout | None | LayoutRequest] | None = None,
865+
in_layouts: Sequence[SpecifiedLayout | None | AutoLayout] | None = None,
866+
out_layouts: Sequence[SpecifiedLayout | None | AutoLayout] | None = None,
867867
arg_names: Sequence[str | None] | None = None,
868868
result_names: Sequence[str | None] | None = None,
869869
num_replicas: int = 1,

jax/_src/interpreters/pxla.py

Lines changed: 44 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
from jax._src.interpreters import partial_eval as pe
6161
from jax._src.interpreters import mlir
6262
from jax._src.interpreters import xla
63-
from jax._src.layout import XLACompatibleLayout, SpecifiedLayout, LayoutRequest
63+
from jax._src.layout import SpecifiedLayout, AutoLayout
6464
from jax._src.lib import xla_client as xc
6565
from jax._src.lib import xla_extension_version
6666
from jax._src.lib.mlir import ir
@@ -1985,13 +1985,14 @@ def are_all_shardings_default_mem_kind(da_object, shardings):
19851985
return False
19861986
return True
19871987

1988-
MaybeLayout = Sequence[Union[XLACompatibleLayout, LayoutRequest, None]]
1988+
MaybeLayout = Sequence[Union[SpecifiedLayout, AutoLayout, None]]
19891989

19901990

19911991
class AllArgsInfo(NamedTuple):
19921992
"""Avals, shardings, layouts and debug_info for all arguments prior to DCE."""
19931993
in_avals: Sequence[core.ShapedArray]
19941994
in_shardings: Any
1995+
in_layouts: Any
19951996
debug_info: core.JaxprDebugInfo | None
19961997

19971998

@@ -2023,7 +2024,7 @@ def lower_sharding_computation(
20232024
auto_spmd_lowering = check_if_any_auto(
20242025
it.chain.from_iterable([in_shardings, out_shardings])) # type: ignore
20252026

2026-
all_args_info = AllArgsInfo(global_in_avals, in_shardings,
2027+
all_args_info = AllArgsInfo(global_in_avals, in_shardings, in_layouts,
20272028
closed_jaxpr.jaxpr.debug_info)
20282029

20292030
(closed_jaxpr, global_in_avals, global_out_avals, donated_invars,
@@ -2227,8 +2228,6 @@ def lower_mesh_computation(
22272228
out_jaxpr_avals = fun_or_jaxpr.out_avals
22282229
consts = fun_or_jaxpr.consts
22292230

2230-
all_args_info = AllArgsInfo(global_in_avals, in_shardings, jaxpr.debug_info)
2231-
22322231
assert len(out_shardings) == len(out_jaxpr_avals)
22332232
if spmd_lowering:
22342233
global_out_avals = out_jaxpr_avals
@@ -2319,7 +2318,7 @@ def lower_mesh_computation(
23192318
in_layouts=(None,) * len(global_in_avals),
23202319
out_layouts=(None,) * len(global_out_avals),
23212320
shape_poly_state=lowering_result.shape_poly_state,
2322-
all_args_info=all_args_info)
2321+
all_args_info=None)
23232322

23242323
class MeshComputation(stages.XlaLowering):
23252324
_hlo: ir.Module | None
@@ -2599,7 +2598,7 @@ def _get_layouts_from_executable(
25992598
if isinstance(i, SpecifiedLayout):
26002599
if i != x:
26012600
raise AssertionError(
2602-
f"Unexpected XLA layout override: (XLA) {x} != {i} (User sharding)")
2601+
f"Unexpected XLA layout override: (XLA) {x} != {i} (User layout)")
26032602
new_in_layouts.append(i)
26042603
else:
26052604
new_in_layouts.append(x)
@@ -2610,7 +2609,7 @@ def _get_layouts_from_executable(
26102609
if isinstance(o, SpecifiedLayout):
26112610
if o != x:
26122611
raise AssertionError(
2613-
f"Unexpected XLA layout override: (XLA) {x} != {o} (User sharding)")
2612+
f"Unexpected XLA layout override: (XLA) {x} != {o} (User layout)")
26142613
new_out_layouts.append(o)
26152614
else:
26162615
new_out_layouts.append(x)
@@ -3016,19 +3015,24 @@ def call(self, *args):
30163015
kept_args = [a for i, a in enumerate(args) if i in self._kept_var_idx]
30173016
ref_avals = self.in_avals
30183017
in_shardings = self._in_shardings
3018+
in_layouts = self._in_layouts
30193019
debug_info = None
30203020
else:
30213021
kept_args = args
30223022
ref_avals = self._all_args_info.in_avals
30233023
iter_in_shardings = iter(self._in_shardings)
30243024
in_shardings = [next(iter_in_shardings) if i in self._kept_var_idx else s
30253025
for i, s in enumerate(self._all_args_info.in_shardings)]
3026+
iter_in_layouts = iter(self._in_layouts)
3027+
in_layouts = [next(iter_in_layouts) if i in self._kept_var_idx else s
3028+
for i, s in enumerate(self._all_args_info.in_layouts)]
30263029
debug_info = self._all_args_info.debug_info
30273030

30283031
arg_avals = map(xla.abstractify, kept_args)
30293032
check_arg_avals_for_call(ref_avals, arg_avals, debug_info)
30303033
# Check the GDA sharding and the input sharding.
3031-
check_gda_or_array_xla_sharding_match(kept_args, in_shardings, debug_info)
3034+
check_array_xla_sharding_layout_match(kept_args, in_shardings,
3035+
in_layouts, debug_info)
30323036
return self.unsafe_call(*args) # pylint: disable=not-callable
30333037

30343038
def input_shardings(self) -> Sequence[sharding_impls.XLACompatibleSharding]:
@@ -3184,15 +3188,17 @@ def check_device_backend_on_shardings(shardings) -> bool:
31843188
return False
31853189

31863190

3187-
def check_gda_or_array_xla_sharding_match(
3191+
def check_array_xla_sharding_layout_match(
31883192
args, in_xla_shardings: Sequence[sharding_impls.XLACompatibleSharding],
3193+
in_xla_layouts: Sequence[SpecifiedLayout],
31893194
jaxpr_debug_info: core.JaxprDebugInfo | None) -> None:
31903195
from jax._src.array import ArrayImpl
31913196
arg_names = ([''] * len(args) if jaxpr_debug_info is None else
31923197
jaxpr_debug_info.arg_names)
31933198
errors = []
31943199
num_errors = 5
3195-
for arg, xs, name in safe_zip(args, in_xla_shardings, arg_names):
3200+
for arg, xs, xl, name in safe_zip(args, in_xla_shardings, in_xla_layouts,
3201+
arg_names):
31963202
if not isinstance(arg, ArrayImpl):
31973203
continue
31983204
if is_unspecified_or_auto(xs):
@@ -3205,27 +3211,47 @@ def check_gda_or_array_xla_sharding_match(
32053211
# Raise memory kind mismatch error even if the arg is uncommitted.
32063212
if arg.sharding.memory_kind != xs.memory_kind:
32073213
errors.append(
3208-
"Got input sharding(s) that compiled object was called with: "
3214+
("Got input sharding(s) that compiled object was called with: "
32093215
f"{arg.sharding} and sharding(s) the computation was compiled "
3210-
f"with: {xs} for arg {name} with shape: {arg.aval.str_short()}")
3216+
f"with: {xs} for arg {name} with shape: {arg.aval.str_short()}",
3217+
'sharding'))
32113218

32123219
if (not db_xs and arg._committed and
32133220
not op_shardings.are_op_shardings_equal(
32143221
arg.sharding._to_xla_hlo_sharding(arg.ndim),
32153222
xs._to_xla_hlo_sharding(arg.ndim))):
32163223
errors.append(
3217-
"Got input sharding(s) that compiled object was called with: "
3224+
("Got input sharding(s) that compiled object was called with: "
32183225
f"{arg.sharding} and sharding(s) the computation was compiled "
3219-
f"with: {xs} for arg {name} with shape: {arg.aval.str_short()}")
3226+
f"with: {xs} for arg {name} with shape: {arg.aval.str_short()}",
3227+
'sharding'))
3228+
3229+
# TODO(yashkatariya): Remove `arg.layout is not None` check after pathways
3230+
# supports layout on Array.
3231+
if (xla_extension_version >= 249 and not db_xs and arg._committed and
3232+
arg.layout is not None and arg.layout != xl):
3233+
errors.append(
3234+
("Got input layout(s) that compiled object was called with: "
3235+
f"{arg.layout} and layout(s) the computation was compiled "
3236+
f"with: {xl} for arg {name} with shape: {arg.aval.str_short()}",
3237+
'layout'))
32203238

32213239
if errors:
3222-
str_errors = '\n'.join(errors[:num_errors])
3240+
first_errors, error_kinds = unzip2(errors[:num_errors])
3241+
str_errors = '\n'.join(first_errors)
3242+
if all(k == 'sharding' for k in error_kinds):
3243+
kind_str = r'sharding(s)'
3244+
elif all(k == 'layout' for k in error_kinds):
3245+
kind_str = 'layout(s)'
3246+
else:
3247+
kind_str = 'sharding(s) and layout(s)'
32233248
num_mismatch_str = (
32243249
f'the {len(errors)} mismatches' if len(errors) < num_errors else
32253250
f"{num_errors} mismatches out of {len(errors)}")
32263251
raise ValueError(
3227-
"Compiled object called with input sharding(s) does not match the "
3228-
"sharding(s) the computation was compiled with. "
3252+
f"Compiled object called with input {kind_str} does "
3253+
f"not match the {kind_str} the computation was "
3254+
"compiled with. "
32293255
f"Here are {num_mismatch_str}:\n{str_errors}")
32303256

32313257

jax/_src/layout.py

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414

1515
from __future__ import annotations
1616

17-
import re
18-
1917
from jax._src.lib import xla_client as xc
2018

2119

@@ -24,16 +22,10 @@ class Layout:
2422
pass
2523

2624

27-
class XLACompatibleLayout(Layout):
28-
29-
def _to_xla_layout(self) -> str:
30-
raise NotImplementedError("Subclasses should implement this method.")
31-
32-
33-
class SpecifiedLayout(XLACompatibleLayout):
34-
layout: xc.Layout
25+
class SpecifiedLayout(Layout):
26+
layout: xc.PjRtLayout
3527

36-
def __init__(self, layout: xc.Layout):
28+
def __init__(self, layout: xc.PjRtLayout):
3729
self._layout = layout
3830
self._layout_str = str(self._layout)
3931

@@ -51,19 +43,10 @@ def __eq__(self, other):
5143
def _to_xla_layout(self) -> str:
5244
return self._layout_str
5345

54-
@property
55-
def _minor_to_major(self):
56-
m = re.search("{([0-9,]*):", str(self))
57-
assert m is not None
58-
m2m_str = m.group(1)
59-
if m2m_str == "":
60-
return ()
61-
return tuple(int(x) for x in m2m_str.split(","))
62-
6346

64-
class LayoutRequest:
47+
class AutoLayout:
6548

6649
def __repr__(self):
67-
return "Request a layout from the compiler"
50+
return "AUTO"
6851

69-
AUTO = LayoutRequest()
52+
AUTO = AutoLayout()

jax/_src/pjit.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,7 @@ def lower(*args, **kwargs):
435435
try:
436436
in_shardings = _resolve_in_shardings(
437437
args_flat, params['in_shardings'], params['out_shardings'], mesh)
438+
in_layouts_flat = _resolve_in_layouts(args_flat, in_layouts_flat)
438439
lowering = _pjit_lower(
439440
params['jaxpr'], in_shardings, params['out_shardings'],
440441
params['resource_env'], params['donated_invars'], params['name'],
@@ -1130,7 +1131,6 @@ def unpack(key):
11301131
p("explanation unavailable! please open an issue at https://github.com/google/jax")
11311132
return done()
11321133

1133-
11341134
@partial(lu.cache, explain=explain_tracing_cache_miss)
11351135
def _create_pjit_jaxpr(fun, in_type, debug_info, out_paths, ignored_inline):
11361136
del ignored_inline # just for explain_cache_miss
@@ -1264,6 +1264,28 @@ def pjit_check_aval_sharding(
12641264
pjit_p.multiple_results = True
12651265

12661266

1267+
def _resolve_in_layouts(args, jit_in_layouts):
1268+
resolved_in_layouts = []
1269+
for arg, jit_in_l in safe_zip(args, jit_in_layouts):
1270+
arg_layout, committed = (
1271+
(arg.layout, getattr(arg, '_committed', True))
1272+
if getattr(arg, 'layout', None) is not None else (None, False))
1273+
if jit_in_l is None:
1274+
if committed:
1275+
resolved_in_layouts.append(arg_layout)
1276+
else:
1277+
resolved_in_layouts.append(None)
1278+
else:
1279+
if committed and arg_layout != jit_in_l:
1280+
raise ValueError('Layout passed to jit does not match the layout '
1281+
'on the respective arg. '
1282+
f'Got pjit layout: {jit_in_l},\n'
1283+
f'arg sharding: {arg_layout} for '
1284+
f'arg shape: {shaped_abstractify(arg).str_short()}')
1285+
resolved_in_layouts.append(jit_in_l)
1286+
return tuple(resolved_in_layouts)
1287+
1288+
12671289
def _resolve_in_shardings(
12681290
args, pjit_in_shardings: Sequence[PjitSharding],
12691291
out_shardings: Sequence[PjitSharding],
@@ -1387,8 +1409,9 @@ def _pjit_call_impl_python(
13871409
_most_recent_pjit_call_executable.weak_key_dict[jaxpr] = compiled
13881410
# This check is expensive so only do it if enable_checks is on.
13891411
if compiled._auto_spmd_lowering and config.enable_checks.value:
1390-
pxla.check_gda_or_array_xla_sharding_match(args, compiled._in_shardings,
1391-
jaxpr.jaxpr.debug_info)
1412+
pxla.check_array_xla_sharding_layout_match(
1413+
args, compiled._in_shardings, compiled._in_layouts,
1414+
jaxpr.jaxpr.debug_info)
13921415
if config.distributed_debug.value:
13931416
# Defensively only perform fingerprint logic if debug logging is enabled
13941417
# NOTE(skyewm): I didn't benchmark this

0 commit comments

Comments
 (0)