Skip to content

Commit cd79e71

Browse files
author
jax authors
committed
Reverts 0e092a7
PiperOrigin-RevId: 618127324
1 parent c2d9528 commit cd79e71

File tree

8 files changed

+76
-202
lines changed

8 files changed

+76
-202
lines changed

jax/_src/array.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,10 @@
3434
from jax._src import dispatch
3535
from jax._src import dtypes
3636
from jax._src import errors
37-
from jax._src import layout
3837
from jax._src import profiler
3938
from jax._src import tree_util
4039
from jax._src import xla_bridge
4140
from jax._src.lib import xla_client as xc
42-
from jax._src.lib import xla_extension as xe
4341
from jax._src.interpreters import mlir
4442
from jax._src.interpreters import pxla
4543
from jax._src.interpreters import xla
@@ -529,18 +527,6 @@ def addressable_shards(self) -> Sequence[Shard]:
529527
out.append(Shard(_get_device(a), self.sharding, self.shape, a))
530528
return out
531529

532-
@property
533-
def layout(self):
534-
# TODO(yashkatariya): Remove the try;except when pathways supports layouts.
535-
try:
536-
return layout.SpecifiedLayout(self._pjrt_layout)
537-
except xe.XlaRuntimeError as e:
538-
msg, *_ = e.args
539-
if type(msg) is str and msg.startswith("UNIMPLEMENTED"):
540-
return None
541-
else:
542-
raise
543-
544530
@property
545531
def global_shards(self) -> Sequence[Shard]:
546532
"""Returns list of all `Shard`s of the Array across all devices.
@@ -651,7 +637,7 @@ def _value(self) -> np.ndarray:
651637
ArrayImpl = use_cpp_class(xc.ArrayImpl)(ArrayImpl)
652638

653639

654-
# explicitly set to be unhashable.
640+
# explicitly set to be unhashable. Same as what device_array.py does.
655641
setattr(ArrayImpl, "__hash__", None)
656642
setattr(ArrayImpl, "__array_priority__", 100)
657643

jax/_src/interpreters/mlir.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
from jax._src import xla_bridge as xb
4747
from jax._src.interpreters import partial_eval as pe
4848
from jax._src.interpreters import xla
49-
from jax._src.layout import AutoLayout, SpecifiedLayout
49+
from jax._src.layout import XLACompatibleLayout, LayoutRequest
5050
from jax._src.lib import xla_client as xc
5151
from jax._src.lib import xla_extension
5252
from jax._src.lib import xla_extension_version
@@ -834,10 +834,10 @@ def _to_physical_op_sharding(
834834
return sharding._to_xla_hlo_sharding(aval.ndim).to_proto() # type: ignore
835835

836836

837-
def _to_xla_layout(layout: SpecifiedLayout | None | AutoLayout) -> str | None:
837+
def _to_xla_layout(layout: XLACompatibleLayout | None | LayoutRequest) -> str | None:
838838
if layout is None:
839839
return "default"
840-
if isinstance(layout, AutoLayout):
840+
if isinstance(layout, LayoutRequest):
841841
return "auto"
842842
return layout._to_xla_layout()
843843

@@ -862,8 +862,8 @@ def lower_jaxpr_to_module(
862862
replicated_args: Sequence[bool] | None = None,
863863
arg_shardings: Sequence[XLACompatibleSharding | None] | None = None,
864864
result_shardings: Sequence[XLACompatibleSharding | None] | None = None,
865-
in_layouts: Sequence[SpecifiedLayout | None | AutoLayout] | None = None,
866-
out_layouts: Sequence[SpecifiedLayout | None | AutoLayout] | None = None,
865+
in_layouts: Sequence[XLACompatibleLayout | None | LayoutRequest] | None = None,
866+
out_layouts: Sequence[XLACompatibleLayout | None | LayoutRequest] | None = None,
867867
arg_names: Sequence[str | None] | None = None,
868868
result_names: Sequence[str | None] | None = None,
869869
num_replicas: int = 1,

jax/_src/interpreters/pxla.py

Lines changed: 18 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
from jax._src.interpreters import partial_eval as pe
6161
from jax._src.interpreters import mlir
6262
from jax._src.interpreters import xla
63-
from jax._src.layout import SpecifiedLayout, AutoLayout
63+
from jax._src.layout import XLACompatibleLayout, SpecifiedLayout, LayoutRequest
6464
from jax._src.lib import xla_client as xc
6565
from jax._src.lib import xla_extension_version
6666
from jax._src.lib.mlir import ir
@@ -1985,14 +1985,13 @@ def are_all_shardings_default_mem_kind(da_object, shardings):
19851985
return False
19861986
return True
19871987

1988-
MaybeLayout = Sequence[Union[SpecifiedLayout, AutoLayout, None]]
1988+
MaybeLayout = Sequence[Union[XLACompatibleLayout, LayoutRequest, None]]
19891989

19901990

19911991
class AllArgsInfo(NamedTuple):
19921992
"""Avals, shardings, layouts and debug_info for all arguments prior to DCE."""
19931993
in_avals: Sequence[core.ShapedArray]
19941994
in_shardings: Any
1995-
in_layouts: Any
19961995
debug_info: core.JaxprDebugInfo | None
19971996

19981997

@@ -2024,7 +2023,7 @@ def lower_sharding_computation(
20242023
auto_spmd_lowering = check_if_any_auto(
20252024
it.chain.from_iterable([in_shardings, out_shardings])) # type: ignore
20262025

2027-
all_args_info = AllArgsInfo(global_in_avals, in_shardings, in_layouts,
2026+
all_args_info = AllArgsInfo(global_in_avals, in_shardings,
20282027
closed_jaxpr.jaxpr.debug_info)
20292028

20302029
(closed_jaxpr, global_in_avals, global_out_avals, donated_invars,
@@ -2228,6 +2227,8 @@ def lower_mesh_computation(
22282227
out_jaxpr_avals = fun_or_jaxpr.out_avals
22292228
consts = fun_or_jaxpr.consts
22302229

2230+
all_args_info = AllArgsInfo(global_in_avals, in_shardings, jaxpr.debug_info)
2231+
22312232
assert len(out_shardings) == len(out_jaxpr_avals)
22322233
if spmd_lowering:
22332234
global_out_avals = out_jaxpr_avals
@@ -2318,7 +2319,7 @@ def lower_mesh_computation(
23182319
in_layouts=(None,) * len(global_in_avals),
23192320
out_layouts=(None,) * len(global_out_avals),
23202321
shape_poly_state=lowering_result.shape_poly_state,
2321-
all_args_info=None)
2322+
all_args_info=all_args_info)
23222323

23232324
class MeshComputation(stages.XlaLowering):
23242325
_hlo: ir.Module | None
@@ -2598,7 +2599,7 @@ def _get_layouts_from_executable(
25982599
if isinstance(i, SpecifiedLayout):
25992600
if i != x:
26002601
raise AssertionError(
2601-
f"Unexpected XLA layout override: (XLA) {x} != {i} (User layout)")
2602+
f"Unexpected XLA layout override: (XLA) {x} != {i} (User sharding)")
26022603
new_in_layouts.append(i)
26032604
else:
26042605
new_in_layouts.append(x)
@@ -2609,7 +2610,7 @@ def _get_layouts_from_executable(
26092610
if isinstance(o, SpecifiedLayout):
26102611
if o != x:
26112612
raise AssertionError(
2612-
f"Unexpected XLA layout override: (XLA) {x} != {o} (User layout)")
2613+
f"Unexpected XLA layout override: (XLA) {x} != {o} (User sharding)")
26132614
new_out_layouts.append(o)
26142615
else:
26152616
new_out_layouts.append(x)
@@ -3015,24 +3016,19 @@ def call(self, *args):
30153016
kept_args = [a for i, a in enumerate(args) if i in self._kept_var_idx]
30163017
ref_avals = self.in_avals
30173018
in_shardings = self._in_shardings
3018-
in_layouts = self._in_layouts
30193019
debug_info = None
30203020
else:
30213021
kept_args = args
30223022
ref_avals = self._all_args_info.in_avals
30233023
iter_in_shardings = iter(self._in_shardings)
30243024
in_shardings = [next(iter_in_shardings) if i in self._kept_var_idx else s
30253025
for i, s in enumerate(self._all_args_info.in_shardings)]
3026-
iter_in_layouts = iter(self._in_layouts)
3027-
in_layouts = [next(iter_in_layouts) if i in self._kept_var_idx else s
3028-
for i, s in enumerate(self._all_args_info.in_layouts)]
30293026
debug_info = self._all_args_info.debug_info
30303027

30313028
arg_avals = map(xla.abstractify, kept_args)
30323029
check_arg_avals_for_call(ref_avals, arg_avals, debug_info)
30333030
# Check the GDA sharding and the input sharding.
3034-
check_array_xla_sharding_layout_match(kept_args, in_shardings,
3035-
in_layouts, debug_info)
3031+
check_gda_or_array_xla_sharding_match(kept_args, in_shardings, debug_info)
30363032
return self.unsafe_call(*args) # pylint: disable=not-callable
30373033

30383034
def input_shardings(self) -> Sequence[sharding_impls.XLACompatibleSharding]:
@@ -3188,17 +3184,15 @@ def check_device_backend_on_shardings(shardings) -> bool:
31883184
return False
31893185

31903186

3191-
def check_array_xla_sharding_layout_match(
3187+
def check_gda_or_array_xla_sharding_match(
31923188
args, in_xla_shardings: Sequence[sharding_impls.XLACompatibleSharding],
3193-
in_xla_layouts: Sequence[SpecifiedLayout],
31943189
jaxpr_debug_info: core.JaxprDebugInfo | None) -> None:
31953190
from jax._src.array import ArrayImpl
31963191
arg_names = ([''] * len(args) if jaxpr_debug_info is None else
31973192
jaxpr_debug_info.arg_names)
31983193
errors = []
31993194
num_errors = 5
3200-
for arg, xs, xl, name in safe_zip(args, in_xla_shardings, in_xla_layouts,
3201-
arg_names):
3195+
for arg, xs, name in safe_zip(args, in_xla_shardings, arg_names):
32023196
if not isinstance(arg, ArrayImpl):
32033197
continue
32043198
if is_unspecified_or_auto(xs):
@@ -3211,47 +3205,27 @@ def check_array_xla_sharding_layout_match(
32113205
# Raise memory kind mismatch error even if the arg is uncommitted.
32123206
if arg.sharding.memory_kind != xs.memory_kind:
32133207
errors.append(
3214-
("Got input sharding(s) that compiled object was called with: "
3208+
"Got input sharding(s) that compiled object was called with: "
32153209
f"{arg.sharding} and sharding(s) the computation was compiled "
3216-
f"with: {xs} for arg {name} with shape: {arg.aval.str_short()}",
3217-
'sharding'))
3210+
f"with: {xs} for arg {name} with shape: {arg.aval.str_short()}")
32183211

32193212
if (not db_xs and arg._committed and
32203213
not op_shardings.are_op_shardings_equal(
32213214
arg.sharding._to_xla_hlo_sharding(arg.ndim),
32223215
xs._to_xla_hlo_sharding(arg.ndim))):
32233216
errors.append(
3224-
("Got input sharding(s) that compiled object was called with: "
3217+
"Got input sharding(s) that compiled object was called with: "
32253218
f"{arg.sharding} and sharding(s) the computation was compiled "
3226-
f"with: {xs} for arg {name} with shape: {arg.aval.str_short()}",
3227-
'sharding'))
3228-
3229-
# TODO(yashkatariya): Remove `arg.layout is not None` check after pathways
3230-
# supports layout on Array.
3231-
if (xla_extension_version >= 249 and not db_xs and arg._committed and
3232-
arg.layout is not None and arg.layout != xl):
3233-
errors.append(
3234-
("Got input layout(s) that compiled object was called with: "
3235-
f"{arg.layout} and layout(s) the computation was compiled "
3236-
f"with: {xl} for arg {name} with shape: {arg.aval.str_short()}",
3237-
'layout'))
3219+
f"with: {xs} for arg {name} with shape: {arg.aval.str_short()}")
32383220

32393221
if errors:
3240-
first_errors, error_kinds = unzip2(errors[:num_errors])
3241-
str_errors = '\n'.join(first_errors)
3242-
if all(k == 'sharding' for k in error_kinds):
3243-
kind_str = r'sharding(s)'
3244-
elif all(k == 'layout' for k in error_kinds):
3245-
kind_str = 'layout(s)'
3246-
else:
3247-
kind_str = 'sharding(s) and layout(s)'
3222+
str_errors = '\n'.join(errors[:num_errors])
32483223
num_mismatch_str = (
32493224
f'the {len(errors)} mismatches' if len(errors) < num_errors else
32503225
f"{num_errors} mismatches out of {len(errors)}")
32513226
raise ValueError(
3252-
f"Compiled object called with input {kind_str} does "
3253-
f"not match the {kind_str} the computation was "
3254-
"compiled with. "
3227+
"Compiled object called with input sharding(s) does not match the "
3228+
"sharding(s) the computation was compiled with. "
32553229
f"Here are {num_mismatch_str}:\n{str_errors}")
32563230

32573231

jax/_src/layout.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
from __future__ import annotations
1616

17+
import re
18+
1719
from jax._src.lib import xla_client as xc
1820

1921

@@ -22,10 +24,16 @@ class Layout:
2224
pass
2325

2426

25-
class SpecifiedLayout(Layout):
26-
layout: xc.PjRtLayout
27+
class XLACompatibleLayout(Layout):
28+
29+
def _to_xla_layout(self) -> str:
30+
raise NotImplementedError("Subclasses should implement this method.")
31+
32+
33+
class SpecifiedLayout(XLACompatibleLayout):
34+
layout: xc.Layout
2735

28-
def __init__(self, layout: xc.PjRtLayout):
36+
def __init__(self, layout: xc.Layout):
2937
self._layout = layout
3038
self._layout_str = str(self._layout)
3139

@@ -43,10 +51,19 @@ def __eq__(self, other):
4351
def _to_xla_layout(self) -> str:
4452
return self._layout_str
4553

54+
@property
55+
def _minor_to_major(self):
56+
m = re.search("{([0-9,]*):", str(self))
57+
assert m is not None
58+
m2m_str = m.group(1)
59+
if m2m_str == "":
60+
return ()
61+
return tuple(int(x) for x in m2m_str.split(","))
62+
4663

47-
class AutoLayout:
64+
class LayoutRequest:
4865

4966
def __repr__(self):
50-
return "AUTO"
67+
return "Request a layout from the compiler"
5168

52-
AUTO = AutoLayout()
69+
AUTO = LayoutRequest()

jax/_src/pjit.py

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,6 @@ def lower(*args, **kwargs):
435435
try:
436436
in_shardings = _resolve_in_shardings(
437437
args_flat, params['in_shardings'], params['out_shardings'], mesh)
438-
in_layouts_flat = _resolve_in_layouts(args_flat, in_layouts_flat)
439438
lowering = _pjit_lower(
440439
params['jaxpr'], in_shardings, params['out_shardings'],
441440
params['resource_env'], params['donated_invars'], params['name'],
@@ -1131,6 +1130,7 @@ def unpack(key):
11311130
p("explanation unavailable! please open an issue at https://github.com/google/jax")
11321131
return done()
11331132

1133+
11341134
@partial(lu.cache, explain=explain_tracing_cache_miss)
11351135
def _create_pjit_jaxpr(fun, in_type, debug_info, out_paths, ignored_inline):
11361136
del ignored_inline # just for explain_cache_miss
@@ -1264,28 +1264,6 @@ def pjit_check_aval_sharding(
12641264
pjit_p.multiple_results = True
12651265

12661266

1267-
def _resolve_in_layouts(args, jit_in_layouts):
1268-
resolved_in_layouts = []
1269-
for arg, jit_in_l in safe_zip(args, jit_in_layouts):
1270-
arg_layout, committed = (
1271-
(arg.layout, getattr(arg, '_committed', True))
1272-
if getattr(arg, 'layout', None) is not None else (None, False))
1273-
if jit_in_l is None:
1274-
if committed:
1275-
resolved_in_layouts.append(arg_layout)
1276-
else:
1277-
resolved_in_layouts.append(None)
1278-
else:
1279-
if committed and arg_layout != jit_in_l:
1280-
raise ValueError('Layout passed to jit does not match the layout '
1281-
'on the respective arg. '
1282-
f'Got pjit layout: {jit_in_l},\n'
1283-
f'arg sharding: {arg_layout} for '
1284-
f'arg shape: {shaped_abstractify(arg).str_short()}')
1285-
resolved_in_layouts.append(jit_in_l)
1286-
return tuple(resolved_in_layouts)
1287-
1288-
12891267
def _resolve_in_shardings(
12901268
args, pjit_in_shardings: Sequence[PjitSharding],
12911269
out_shardings: Sequence[PjitSharding],
@@ -1409,9 +1387,8 @@ def _pjit_call_impl_python(
14091387
_most_recent_pjit_call_executable.weak_key_dict[jaxpr] = compiled
14101388
# This check is expensive so only do it if enable_checks is on.
14111389
if compiled._auto_spmd_lowering and config.enable_checks.value:
1412-
pxla.check_array_xla_sharding_layout_match(
1413-
args, compiled._in_shardings, compiled._in_layouts,
1414-
jaxpr.jaxpr.debug_info)
1390+
pxla.check_gda_or_array_xla_sharding_match(args, compiled._in_shardings,
1391+
jaxpr.jaxpr.debug_info)
14151392
if config.distributed_debug.value:
14161393
# Defensively only perform fingerprint logic if debug logging is enabled
14171394
# NOTE(skyewm): I didn't benchmark this

0 commit comments

Comments
 (0)