Skip to content

Commit 92326db

Browse files
yashk2810jax authors
authored andcommitted
Expose Layout(device_local_layout, sharding) class allowing users to specify layouts of Arrays.
Users should be able to load checkpoints with the layout that the `train_step` specifies via device_put. Note: This currently only works on TPU. PiperOrigin-RevId: 621668247
1 parent 24517ca commit 92326db

File tree

10 files changed

+168
-37
lines changed

10 files changed

+168
-37
lines changed

jax/BUILD

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -722,7 +722,11 @@ pytype_strict_library(
722722
pytype_strict_library(
723723
name = "layout",
724724
srcs = ["_src/layout.py"],
725-
deps = ["//jax/_src/lib"],
725+
deps = [
726+
":sharding",
727+
":sharding_impls",
728+
"//jax/_src/lib",
729+
],
726730
)
727731

728732
pytype_strict_library(

jax/_src/api.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
from jax._src.sharding import Sharding
7171
from jax._src.sharding_impls import (PmapSharding, TransferToMemoryKind,
7272
XLACompatibleSharding)
73+
from jax._src.layout import Layout
7374
from jax._src.traceback_util import api_boundary
7475
from jax._src import tree_util
7576
from jax._src.util import unzip2, safe_map, safe_zip, wrap_name, wraps
@@ -2461,8 +2462,8 @@ def _check_sharding(x, s):
24612462

24622463
def device_put(
24632464
x,
2464-
device: None | xc.Device | Sharding | Any | TransferToMemoryKind = None,
2465-
*, src: None | xc.Device | Sharding | Any | TransferToMemoryKind = None):
2465+
device: None | xc.Device | Sharding | Layout | Any | TransferToMemoryKind = None,
2466+
*, src: None | xc.Device | Sharding | Layout | Any | TransferToMemoryKind = None):
24662467
"""Transfers ``x`` to ``device``.
24672468
24682469
Args:

jax/_src/array.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -531,13 +531,13 @@ def addressable_shards(self) -> Sequence[Shard]:
531531

532532
@property
533533
def layout(self):
534-
# TODO(yashkatariya): Remove the try;except when pathways supports layouts.
535534
try:
536-
return layout.DeviceLocalLayout(self._pjrt_layout)
535+
return layout.Layout(layout.DeviceLocalLayout(self._pjrt_layout),
536+
self.sharding)
537537
except xe.XlaRuntimeError as e:
538538
msg, *_ = e.args
539539
if type(msg) is str and msg.startswith("UNIMPLEMENTED"):
540-
return None
540+
return layout.Layout(None, self.sharding)
541541
else:
542542
raise
543543

jax/_src/dispatch.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
from jax._src.sharding_impls import (
5252
PmapSharding, SingleDeviceSharding, NamedSharding, XLACompatibleSharding,
5353
GSPMDSharding, TransferToMemoryKind)
54+
from jax._src.layout import Layout, DeviceLocalLayout
5455

5556

5657
JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration"
@@ -380,25 +381,9 @@ def _mcjax_reshard(x, target_sharding):
380381
pxla._get_and_check_device_assignment.fn = _orig_get_and_check_device_assignment
381382

382383

383-
def _device_put_impl(
384-
x,
385-
device: Device | Sharding | None = None,
386-
src: Device | Sharding | None = None):
384+
def _device_put_sharding_impl(x, aval, device):
387385
from jax._src import array
388386

389-
if (isinstance(device, TransferToMemoryKind) or
390-
isinstance(src, TransferToMemoryKind)):
391-
raise ValueError(
392-
"TransferToMemoryKind argument to jax.device_put can only be used"
393-
" inside jax.jit. If you are using device_put outside jax.jit, then"
394-
" please provide a concrete Sharding with memory_kind.")
395-
396-
try:
397-
aval = xla.abstractify(x)
398-
except TypeError as err:
399-
raise TypeError(
400-
f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err
401-
402387
if isinstance(device, Sharding):
403388
s = device
404389
if getattr(x, 'sharding', None) == s and getattr(x, '_committed', False):
@@ -435,6 +420,43 @@ def _device_put_impl(
435420
if device is None else device)
436421
return _put_x(x, sh, aval, device is not None)
437422

423+
def _device_put_impl(
424+
x,
425+
device: Device | Sharding | Layout | None = None,
426+
src: Device | Sharding | Layout | None = None):
427+
if (isinstance(device, TransferToMemoryKind) or
428+
isinstance(src, TransferToMemoryKind)):
429+
raise ValueError(
430+
"TransferToMemoryKind argument to jax.device_put can only be used"
431+
" inside jax.jit. If you are using device_put outside jax.jit, then"
432+
" please provide a concrete Sharding with memory_kind.")
433+
434+
try:
435+
aval = xla.abstractify(x)
436+
except TypeError as err:
437+
raise TypeError(
438+
f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err
439+
440+
if isinstance(device, Layout):
441+
l = device
442+
dll = l.device_local_layout
443+
x_dll = x.layout.device_local_layout if hasattr(x, 'layout') else None
444+
if (not isinstance(l.sharding, Sharding) or
445+
not isinstance(dll, (DeviceLocalLayout, type(None)))):
446+
raise ValueError(
447+
"sharding and device_local_layout in `Layout` instance should be"
448+
f" concrete. Got layout: {l}")
449+
if getattr(x, 'layout', None) == l and getattr(x, '_committed', False):
450+
return x
451+
if x_dll is None and dll is None:
452+
return _device_put_sharding_impl(x, aval, l.sharding)
453+
# TODO(yashkatariya): Pass layout to out_shardings directly and remove
454+
# out_layouts from lower.
455+
return api.jit(_identity_fn, out_shardings=l.sharding).lower(
456+
x, _out_layouts=dll).compile()(x)
457+
458+
return _device_put_sharding_impl(x, aval, device)
459+
438460

439461
device_put_p = core.Primitive('device_put')
440462
device_put_p.def_impl(_device_put_impl)

jax/_src/interpreters/pxla.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
from jax._src.interpreters import partial_eval as pe
6161
from jax._src.interpreters import mlir
6262
from jax._src.interpreters import xla
63-
from jax._src.layout import DeviceLocalLayout, AutoLayout
63+
from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout
6464
from jax._src.lib import xla_client as xc
6565
from jax._src.lib import xla_extension_version
6666
from jax._src.lib.mlir import ir
@@ -2624,7 +2624,8 @@ def _get_layouts_from_executable(
26242624
if isinstance(i, DeviceLocalLayout):
26252625
if i != x:
26262626
raise AssertionError(
2627-
f"Unexpected XLA layout override: (XLA) {x} != {i} (User layout)")
2627+
f"Unexpected XLA layout override: (XLA) {x} != {i} (User input"
2628+
" layout)")
26282629
new_in_layouts.append(i)
26292630
else:
26302631
new_in_layouts.append(x)
@@ -2635,7 +2636,8 @@ def _get_layouts_from_executable(
26352636
if isinstance(o, DeviceLocalLayout):
26362637
if o != x:
26372638
raise AssertionError(
2638-
f"Unexpected XLA layout override: (XLA) {x} != {o} (User layout)")
2639+
f"Unexpected XLA layout override: (XLA) {x} != {o} (User output"
2640+
" layout)")
26392641
new_out_layouts.append(o)
26402642
else:
26412643
new_out_layouts.append(x)
@@ -3072,10 +3074,12 @@ def output_shardings(self) -> Sequence[sharding_impls.XLACompatibleSharding]:
30723074
return self._out_shardings
30733075

30743076
def input_layouts(self):
3075-
return self._in_layouts
3077+
return [Layout(l, s)
3078+
for l, s in safe_zip(self._in_layouts, self._in_shardings)]
30763079

30773080
def output_layouts(self):
3078-
return self._out_layouts
3081+
return [Layout(l, s)
3082+
for l, s in safe_zip(self._out_layouts, self._out_shardings)]
30793083

30803084
def create_cpp_call(self, no_kwargs, in_tree, out_tree):
30813085
if not (isinstance(self.unsafe_call, ExecuteReplicated) and
@@ -3254,7 +3258,8 @@ def check_array_xla_sharding_layout_match(
32543258
'sharding'))
32553259

32563260
if (xla_extension_version >= 249 and not db_xs and arg._committed and
3257-
arg.layout is not None and xl is not None and arg.layout != xl):
3261+
arg.layout.device_local_layout is not None and xl is not None and
3262+
arg.layout.device_local_layout != xl):
32583263
errors.append(
32593264
("Got input layout(s) that compiled object was called with: "
32603265
f"{arg.layout} and layout(s) the computation was compiled "

jax/_src/layout.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@
1414

1515
from __future__ import annotations
1616

17+
from typing import Union
18+
19+
from jax._src.sharding import Sharding
20+
from jax._src.sharding_impls import AUTO as AutoSharding, is_auto
1721
from jax._src.lib import xla_client as xc
1822

1923

@@ -45,3 +49,39 @@ def __repr__(self):
4549
return "AUTO"
4650

4751
AUTO = AutoLayout()
52+
53+
54+
LayoutOptions = Union[DeviceLocalLayout, None, AutoLayout]
55+
ShardingOptions = Union[Sharding, None, AutoSharding]
56+
57+
58+
class Layout:
59+
__slots__ = ['device_local_layout', 'sharding']
60+
61+
def __init__(self, device_local_layout: LayoutOptions,
62+
sharding: ShardingOptions):
63+
# If layout is concrete and sharding is not, error.
64+
if (isinstance(device_local_layout, DeviceLocalLayout) and
65+
(sharding is None or is_auto(sharding))):
66+
raise ValueError(
67+
'Sharding has to be concrete when layout is of type'
68+
f' {type(device_local_layout)}. Please pass a'
69+
' `jax.sharding.NamedSharding`, `jax.sharding.PositionalSharding` or'
70+
' `jax.sharding.SingleDeviceSharding` to the sharding argument. Got'
71+
f' sharding {sharding}'
72+
)
73+
self.device_local_layout = device_local_layout
74+
self.sharding = sharding
75+
76+
def __repr__(self):
77+
return (f'Layout(device_local_layout={self.device_local_layout},'
78+
f' sharding={self.sharding})')
79+
80+
def __hash__(self):
81+
return hash((self.device_local_layout, self.sharding))
82+
83+
def __eq__(self, other):
84+
if not isinstance(other, Layout):
85+
return False
86+
return (self.device_local_layout == other.device_local_layout and
87+
self.sharding == other.sharding)

jax/_src/pjit.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
SingleDeviceSharding, PmapSharding, AUTO, UNSPECIFIED, UnspecifiedValue,
6868
ParsedPartitionSpec, SpecSync, get_single_pspec, is_auto, is_unspecified,
6969
is_unspecified_or_auto, prepare_axis_resources, parse_flatten_op_sharding)
70+
from jax._src.layout import Layout, LayoutOptions
7071
from jax._src.state import discharge as state_discharge, RefEffect
7172
from jax._src.traceback_util import api_boundary
7273
from jax._src.tree_util import (
@@ -437,6 +438,7 @@ def lower(*args, **kwargs):
437438
args_flat, params['in_shardings'], params['out_shardings'], mesh)
438439
in_layouts_flat = _resolve_in_layouts(
439440
args_flat, in_layouts_flat, in_shardings)
441+
out_layouts_flat = _resolve_out_layouts(out_layouts_flat)
440442
lowering = _pjit_lower(
441443
params['jaxpr'], in_shardings, params['out_shardings'],
442444
params['resource_env'], params['donated_invars'], params['name'],
@@ -1268,8 +1270,10 @@ def _resolve_in_layouts(args, jit_in_layouts, jit_in_shardings):
12681270
resolved_in_layouts = []
12691271
for arg, jit_in_l in safe_zip(args, jit_in_layouts):
12701272
arg_layout, committed = (
1271-
(arg.layout, getattr(arg, '_committed', True))
1273+
(arg.layout.device_local_layout, getattr(arg, '_committed', True))
12721274
if getattr(arg, 'layout', None) is not None else (None, False))
1275+
jit_in_l = (jit_in_l.device_local_layout
1276+
if isinstance(jit_in_l, Layout) else jit_in_l)
12731277
if jit_in_l is None:
12741278
if committed:
12751279
resolved_in_layouts.append(arg_layout)
@@ -1286,6 +1290,14 @@ def _resolve_in_layouts(args, jit_in_layouts, jit_in_shardings):
12861290
return tuple(resolved_in_layouts)
12871291

12881292

1293+
def _resolve_out_layouts(out_layouts: Sequence[Layout]
1294+
) -> Sequence[LayoutOptions]:
1295+
# TODO(yashkatariya): Remove the if condition when all layouts come via the
1296+
# `layout.Layout` API.
1297+
return tuple(o.device_local_layout if isinstance(o, Layout) else o
1298+
for o in out_layouts)
1299+
1300+
12891301
def _resolve_in_shardings(
12901302
args, pjit_in_shardings: Sequence[PjitSharding],
12911303
out_shardings: Sequence[PjitSharding],

jax/_src/stages.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from jax._src import tree_util
4545
from jax._src.tree_util import tree_unflatten, keystr
4646
from jax._src import util
47-
from jax._src.layout import DeviceLocalLayout
47+
from jax._src.layout import Layout
4848
from jax._src.interpreters import mlir
4949
from jax._src.lib.mlir import ir
5050
from jax._src.lib import xla_client as xc
@@ -513,7 +513,7 @@ def output_shardings(self): # PyTree[sharding.XLACompatibleSharding]
513513

514514
def _input_layouts(self):
515515
layouts_flat = self._executable.input_layouts()
516-
assert all(isinstance(l, DeviceLocalLayout) for l in layouts_flat)
516+
assert all(isinstance(l, Layout) for l in layouts_flat)
517517
# Some input layouts got DCE'd
518518
if self.in_tree.num_leaves > len(layouts_flat):
519519
iter_layouts_flat = iter(layouts_flat)
@@ -523,7 +523,7 @@ def _input_layouts(self):
523523

524524
def _output_layouts(self):
525525
layouts_flat = self._executable.output_layouts()
526-
assert all(isinstance(l, DeviceLocalLayout) for l in layouts_flat)
526+
assert all(isinstance(l, Layout) for l in layouts_flat)
527527
return tree_util.tree_unflatten(self.out_tree, layouts_flat) # pytype: disable=attribute-error
528528

529529
@staticmethod

jax/experimental/layout.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@
1515
from jax._src.layout import (
1616
DeviceLocalLayout as DeviceLocalLayout,
1717
AUTO as AUTO,
18+
Layout as Layout
1819
)

tests/layout_test.py

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@
2020

2121
import jax
2222
import jax.numpy as jnp
23-
from jax.sharding import NamedSharding, PartitionSpec as P
23+
from jax.sharding import NamedSharding, PartitionSpec as P, SingleDeviceSharding
2424
from jax._src import config
2525
from jax._src import layout
26+
from jax._src.layout import Layout
2627
from jax._src import test_util as jtu
2728
from jax._src.util import safe_zip
2829
from jax._src import xla_bridge
@@ -115,7 +116,7 @@ def init(x, y):
115116
self.assertEqual(init_count[0], 1)
116117

117118
self.assertEqual(init_out[0].layout, init_compiled._output_layouts()[0])
118-
self.assertEqual(init_out[1].layout, init_compiled._output_layouts()[0])
119+
self.assertEqual(init_out[1].layout, init_compiled._output_layouts()[1])
119120

120121
with jtu.count_aot_jit_cpp_cache_miss() as apply_count:
121122
apply_out = compiled_apply(*init_out)
@@ -223,8 +224,10 @@ def f(x, y, z, a, b, c):
223224
self.assertArraysEqual(out1, out3)
224225
self.assertArraysEqual(out2, out4)
225226

226-
# TODO(yashkatariya, frostig): Also use the arg_layouts to create an Array
227-
# and then pass that back into compiled.
227+
arrs = [jax.device_put(i, l) for i, l in zip(inps, arg_layouts)]
228+
out5, out6 = jax.jit(f)(*arrs)
229+
self.assertArraysEqual(out1, out5)
230+
self.assertArraysEqual(out2, out6)
228231

229232
def test_aot_layout_mismatch(self):
230233
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
@@ -259,6 +262,49 @@ def test_cpu_default_backend_layout(self):
259262
jax.jit(jnp.dot, backend=jax.default_backend()).lower(
260263
out_cpu, out_cpu).compile() # doesn't crash
261264

265+
def test_device_put_concrete_layout(self):
266+
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
267+
shape = (8, 128)
268+
np_inp = np.arange(math.prod(shape)).reshape(shape)
269+
s = NamedSharding(mesh, P('x', 'y'))
270+
arr = jax.device_put(np_inp, s)
271+
272+
compiled = jax.jit(
273+
lambda x: x * 2).lower(arr, _out_layouts=layout.AUTO).compile()
274+
col = compiled._output_layouts()
275+
276+
out = jax.device_put(np_inp, col)
277+
self.assertEqual(out.layout, col)
278+
self.assertArraysEqual(out, np_inp)
279+
for s in out.addressable_shards:
280+
self.assertEqual(out.layout.device_local_layout,
281+
s.data.layout.device_local_layout)
282+
283+
def test_device_put_non_concrete_layout_error(self):
284+
np_inp = np.arange(16).reshape(8, 2)
285+
286+
l1 = Layout(layout.AUTO, SingleDeviceSharding(jax.devices()[0]))
287+
with self.assertRaisesRegex(
288+
ValueError, 'sharding and device_local_layout.*should be concrete'):
289+
jax.device_put(np_inp, l1)
290+
291+
l2 = Layout(layout.AUTO, None)
292+
with self.assertRaisesRegex(
293+
ValueError, 'sharding and device_local_layout.*should be concrete'):
294+
jax.device_put(np_inp, l2)
295+
296+
l3 = Layout(None, SingleDeviceSharding(jax.devices()[0]))
297+
out = jax.device_put(np_inp, l3)
298+
self.assertArraysEqual(out, np_inp)
299+
self.assertTrue(out._committed)
300+
301+
def invalid_layout_spec(self):
302+
x = np.arange(8)
303+
compiled = jax.jit(lambda x: x).lower(x).compile()
304+
with self.assertRaisesRegex(
305+
ValueError, 'Sharding has to be concrete when layout.*'):
306+
Layout(compiled._output_layouts()[0], None)
307+
262308

263309
if __name__ == '__main__':
264310
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)