Skip to content

Commit 5cbb26f

Browse files
yashk2810jax authors
authored andcommitted
Make device_local_layout and sharding optional in Layout. Also only accept Layout class to _in_layouts and _out_layouts.
This is in preparation to get `jax.jit` to accept `Layout`. PiperOrigin-RevId: 621697750
1 parent d790c88 commit 5cbb26f

File tree

5 files changed

+38
-23
lines changed

5 files changed

+38
-23
lines changed

jax/_src/dispatch.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -441,19 +441,21 @@ def _device_put_impl(
441441
l = device
442442
dll = l.device_local_layout
443443
x_dll = x.layout.device_local_layout if hasattr(x, 'layout') else None
444+
if dll is None and l.sharding is None:
445+
return _device_put_sharding_impl(x, aval, l.sharding)
444446
if (not isinstance(l.sharding, Sharding) or
445447
not isinstance(dll, (DeviceLocalLayout, type(None)))):
446448
raise ValueError(
447449
"sharding and device_local_layout in `Layout` instance should be"
448-
f" concrete. Got layout: {l}")
450+
f" concrete. Got layout: {l} for input {aval.str_short()}")
449451
if getattr(x, 'layout', None) == l and getattr(x, '_committed', False):
450452
return x
451453
if x_dll is None and dll is None:
452454
return _device_put_sharding_impl(x, aval, l.sharding)
453455
# TODO(yashkatariya): Pass layout to out_shardings directly and remove
454456
# out_layouts from lower.
455457
return api.jit(_identity_fn, out_shardings=l.sharding).lower(
456-
x, _out_layouts=dll).compile()(x)
458+
x, _out_layouts=l).compile()(x)
457459

458460
return _device_put_sharding_impl(x, aval, device)
459461

jax/_src/layout.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ def _to_xla_layout(self) -> str:
5858
class Layout:
5959
__slots__ = ['device_local_layout', 'sharding']
6060

61-
def __init__(self, device_local_layout: LayoutOptions,
62-
sharding: ShardingOptions):
61+
def __init__(self, device_local_layout: LayoutOptions = None,
62+
sharding: ShardingOptions = None):
6363
# If layout is concrete and sharding is not, error.
6464
if (isinstance(device_local_layout, DeviceLocalLayout) and
6565
(sharding is None or is_auto(sharding))):
@@ -70,6 +70,19 @@ def __init__(self, device_local_layout: LayoutOptions,
7070
' `jax.sharding.SingleDeviceSharding` to the sharding argument. Got'
7171
f' sharding {sharding}'
7272
)
73+
if not isinstance(
74+
device_local_layout, (DeviceLocalLayout, type(None), AutoLayout)):
75+
raise ValueError(
76+
'Invalid value received for the device_local_layout argument.'
77+
' Expected values are `None`, `DeviceLocalLayout.AUTO` or an instance'
78+
f' of `DeviceLocalLayout`. Got {device_local_layout}')
79+
if not isinstance(
80+
sharding, (Sharding, type(None), AutoSharding)):
81+
raise ValueError(
82+
'Invalid value received for the sharding argument. Expected values'
83+
' are `None`, `pjit.AUTO` or an instance of `jax.Sharding`. Got'
84+
f' {sharding}')
85+
7386
self.device_local_layout = device_local_layout
7487
self.sharding = sharding
7588

jax/_src/pjit.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -425,8 +425,8 @@ def lower(*args, **kwargs):
425425
lowering_parameters = kwargs.pop(
426426
'_experimental_lowering_parameters', mlir.LoweringParameters())
427427
# TODO(yashkatariya): Remove this when it's added on jit.
428-
in_layouts = kwargs.pop('_in_layouts', None)
429-
out_layouts = kwargs.pop('_out_layouts', None)
428+
in_layouts = kwargs.pop('_in_layouts', Layout())
429+
out_layouts = kwargs.pop('_out_layouts', Layout())
430430
(args_flat, flat_global_in_avals, params, in_tree, out_tree,
431431
donated_invars, in_layouts_flat, out_layouts_flat,
432432
arg_names, ()) = _infer_params(
@@ -1272,8 +1272,7 @@ def _resolve_in_layouts(args, jit_in_layouts, jit_in_shardings):
12721272
arg_layout, committed = (
12731273
(arg.layout.device_local_layout, getattr(arg, '_committed', True))
12741274
if getattr(arg, 'layout', None) is not None else (None, False))
1275-
jit_in_l = (jit_in_l.device_local_layout
1276-
if isinstance(jit_in_l, Layout) else jit_in_l)
1275+
jit_in_l = None if jit_in_l is None else jit_in_l.device_local_layout
12771276
if jit_in_l is None:
12781277
if committed:
12791278
resolved_in_layouts.append(arg_layout)
@@ -1293,9 +1292,8 @@ def _resolve_in_layouts(args, jit_in_layouts, jit_in_shardings):
12931292
def _resolve_out_layouts(out_layouts: Sequence[Layout]
12941293
) -> Sequence[LayoutOptions]:
12951294
# TODO(yashkatariya): Remove the if condition when all layouts come via the
1296-
# `layout.Layout` API.
1297-
return tuple(o.device_local_layout if isinstance(o, Layout) else o
1298-
for o in out_layouts)
1295+
# `layout.Layout` API or handle this properly when layout is on jit.
1296+
return tuple(None if o is None else o.device_local_layout for o in out_layouts)
12991297

13001298

13011299
def _resolve_in_shardings(

jax/_src/stages.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,7 @@ def _input_layouts(self):
518518
if self.in_tree.num_leaves > len(layouts_flat):
519519
iter_layouts_flat = iter(layouts_flat)
520520
layouts_flat = [next(iter_layouts_flat) if i in self._executable._kept_var_idx
521-
else None for i in range(self.in_tree.num_leaves)]
521+
else Layout() for i in range(self.in_tree.num_leaves)]
522522
return tree_util.tree_unflatten(self.in_tree, layouts_flat) # pytype: disable=attribute-error
523523

524524
def _output_layouts(self):

tests/layout_test.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def init(x, y):
8989
sds2 = jax.ShapeDtypeStruct(np_inp2.shape, np_inp2.dtype, sharding=s2)
9090

9191
lowered_apply = jax.jit(apply).lower(
92-
sds1, sds2, _in_layouts=DLL.AUTO, _out_layouts=DLL.AUTO)
92+
sds1, sds2, _in_layouts=Layout(DLL.AUTO), _out_layouts=Layout(DLL.AUTO))
9393
compiled_apply = lowered_apply.compile()
9494

9595
arg_layouts, kw_layouts = compiled_apply._input_layouts()
@@ -158,8 +158,8 @@ def f(x):
158158
self.assertArraysEqual(out, np_inp.T)
159159
self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'y', 'x')))
160160

161-
compiled_auto = jax.jit(f).lower(sds, _in_layouts=DLL.AUTO,
162-
_out_layouts=DLL.AUTO).compile()
161+
compiled_auto = jax.jit(f).lower(sds, _in_layouts=Layout(DLL.AUTO),
162+
_out_layouts=Layout(DLL.AUTO)).compile()
163163
self.assertTupleEqual(
164164
extract_minor_to_major(compiled_auto._input_layouts()[0][0]), (2, 1, 0))
165165
self.assertTupleEqual(
@@ -176,7 +176,7 @@ def f(x):
176176
return x.T
177177

178178
compiled = jax.jit(f).lower(
179-
arr, _in_layouts=None, _out_layouts=DLL.AUTO).compile()
179+
arr, _in_layouts=Layout(), _out_layouts=Layout(DLL.AUTO)).compile()
180180
self.assertTupleEqual(
181181
extract_minor_to_major(compiled._input_layouts()[0][0]), (1, 0))
182182
self.assertTupleEqual(
@@ -194,7 +194,8 @@ def test_sharding_and_layouts(self):
194194
s = NamedSharding(mesh, P('x', 'y'))
195195

196196
compiled = jax.jit(lambda x: x.T, in_shardings=s, out_shardings=s).lower(
197-
np_inp, _in_layouts=DLL.AUTO, _out_layouts=DLL.AUTO).compile()
197+
np_inp, _in_layouts=Layout(DLL.AUTO),
198+
_out_layouts=Layout(DLL.AUTO)).compile()
198199
out = compiled(np_inp)
199200
self.assertTupleEqual(
200201
extract_minor_to_major(compiled._input_layouts()[0][0]), (1, 0))
@@ -209,8 +210,8 @@ def f(x, y, z, a, b, c):
209210

210211
shape = (8, 2)
211212
inps = [np.arange(math.prod(shape)).reshape(shape)] * 6
212-
compiled = jax.jit(f).lower(*inps, _in_layouts=DLL.AUTO,
213-
_out_layouts=DLL.AUTO).compile()
213+
compiled = jax.jit(f).lower(*inps, _in_layouts=Layout(DLL.AUTO),
214+
_out_layouts=Layout(DLL.AUTO)).compile()
214215
arg_layouts, _ = compiled._input_layouts()
215216
out1, out2 = compiled(*inps)
216217

@@ -243,10 +244,11 @@ def f(x):
243244
with self.assertRaisesRegex(
244245
ValueError,
245246
'Layout passed to jit does not match the layout on the respective arg'):
246-
jax.jit(f).lower(arr, _in_layouts=DLL.AUTO)
247+
jax.jit(f).lower(arr, _in_layouts=Layout(DLL.AUTO))
247248

248249
compiled = jax.jit(f).lower(
249-
sds, _in_layouts=DLL.AUTO, _out_layouts=DLL.AUTO).compile()
250+
sds, _in_layouts=Layout(DLL.AUTO),
251+
_out_layouts=Layout(DLL.AUTO)).compile()
250252

251253
with self.assertRaisesRegex(
252254
ValueError,
@@ -269,7 +271,7 @@ def test_device_put_concrete_layout(self):
269271
arr = jax.device_put(np_inp, s)
270272

271273
compiled = jax.jit(
272-
lambda x: x * 2).lower(arr, _out_layouts=DLL.AUTO).compile()
274+
lambda x: x * 2).lower(arr, _out_layouts=Layout(DLL.AUTO)).compile()
273275
col = compiled._output_layouts()
274276

275277
out = jax.device_put(np_inp, col)
@@ -287,7 +289,7 @@ def test_device_put_non_concrete_layout_error(self):
287289
ValueError, 'sharding and device_local_layout.*should be concrete'):
288290
jax.device_put(np_inp, l1)
289291

290-
l2 = Layout(DLL.AUTO, None)
292+
l2 = Layout(DLL.AUTO)
291293
with self.assertRaisesRegex(
292294
ValueError, 'sharding and device_local_layout.*should be concrete'):
293295
jax.device_put(np_inp, l2)

0 commit comments

Comments
 (0)