Skip to content

Commit 90401d5

Browse files
yashk2810jax authors
authored andcommitted
Accept layout on ShapeDtypeStruct on the sharding argument. DeviceLocalLayout.AUTO is not allowed on SDS.
PiperOrigin-RevId: 624982814
1 parent b729300 commit 90401d5

File tree

3 files changed

+45
-10
lines changed

3 files changed

+45
-10
lines changed

jax/_src/api.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070
from jax._src.sharding import Sharding
7171
from jax._src.sharding_impls import (PmapSharding, TransferToMemoryKind,
7272
XLACompatibleSharding)
73-
from jax._src.layout import Layout
73+
from jax._src.layout import Layout, AutoLayout
7474
from jax._src.traceback_util import api_boundary
7575
from jax._src import tree_util
7676
from jax._src.util import unzip2, safe_map, safe_zip, wrap_name, wraps
@@ -2710,22 +2710,34 @@ class ShapeDtypeStruct:
27102710
named_shape: (optional) a dictionary representing a named shape
27112711
sharding: (optional) a :class:`jax.Sharding` object
27122712
"""
2713-
__slots__ = ["shape", "dtype", "named_shape", "sharding"]
2713+
__slots__ = ["shape", "dtype", "named_shape", "sharding", "_dll"]
2714+
27142715
def __init__(self, shape, dtype, named_shape=None, sharding=None):
27152716
self.shape = tuple(shape)
27162717
if dtype is None:
27172718
raise ValueError("ShapeDtypeStruct: dtype must be specified.")
27182719
self.dtype = dtype if dtypes.issubdtype(dtype, dtypes.extended) else np.dtype(dtype)
2719-
if sharding is not None and not isinstance(sharding, Sharding):
2720+
if sharding is not None and not isinstance(sharding, (Sharding, Layout)):
27202721
raise ValueError(
2721-
"sharding should be an instance of `jax.sharding.Sharding`. "
2722-
f"Got {sharding} of type {type(sharding)}.")
2723-
self.sharding = sharding
2722+
"sharding should be an instance of `jax.sharding.Sharding` or"
2723+
f" `jax.experimental.layout.Layout`. Got {sharding} of type"
2724+
f" {type(sharding)}.")
2725+
if (isinstance(sharding, Layout) and
2726+
isinstance(sharding.device_local_layout, AutoLayout)):
2727+
raise TypeError(
2728+
"`DeviceLocalLayout.AUTO` cannot be used in place of a device-local"
2729+
f" layout in a `ShapeDtypeStruct`. Got {sharding}")
2730+
self.sharding = sharding.sharding if isinstance(sharding, Layout) else sharding
2731+
self._dll = sharding.device_local_layout if isinstance(sharding, Layout) else None
27242732
self.named_shape = {} if named_shape is None else dict(named_shape)
27252733

27262734
size = property(lambda self: math.prod(self.shape))
27272735
ndim = property(lambda self: len(self.shape))
27282736

2737+
@property
2738+
def layout(self):
2739+
return Layout(self._dll, self.sharding)
2740+
27292741
def __len__(self):
27302742
try:
27312743
return self.shape[0]
@@ -2735,28 +2747,31 @@ def __len__(self):
27352747
def __repr__(self):
27362748
ns = f", named_shape={self.named_shape}" if self.named_shape else ""
27372749
sh = f", sharding={self.sharding}" if self.sharding is not None else ""
2750+
l = f", layout={self.layout}" if self._dll is not None else ""
27382751
return (f"{type(self).__name__}(shape={self.shape}, "
2739-
f"dtype={self.dtype.name}{ns}{sh})")
2752+
f"dtype={self.dtype.name}{ns}{sh}{l})")
27402753

27412754
__str__ = __repr__
27422755

27432756
def __eq__(self, other):
27442757
if not isinstance(other, ShapeDtypeStruct):
27452758
return False
27462759
else:
2747-
return ((other.shape, other.dtype, other.named_shape, other.sharding) ==
2748-
(self.shape, self.dtype, self.named_shape, self.sharding))
2760+
return ((other.shape, other.dtype, other.named_shape, other.sharding, other.layout) ==
2761+
(self.shape, self.dtype, self.named_shape, self.sharding, self.layout))
27492762

27502763
def __hash__(self):
27512764
# TODO(frostig): avoid the conversion from dict by addressing
27522765
# https://github.com/google/jax/issues/8182
27532766
named = frozenset(self.named_shape.items())
2754-
return hash((self.shape, self.dtype, named, self.sharding))
2767+
return hash((self.shape, self.dtype, named, self.sharding, self.layout))
2768+
27552769

27562770
core.pytype_aval_mappings[ShapeDtypeStruct] = (
27572771
lambda x: ShapedArray(x.shape, dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True),
27582772
weak_type=False, named_shape=x.named_shape))
27592773

2774+
27602775
@api_boundary
27612776
def eval_shape(fun: Callable, *args, **kwargs):
27622777
"""Compute the shape/dtype of ``fun`` without any FLOPs.

jax/_src/pjit.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,7 @@ def lower(*args, **kwargs):
485485
def eval_shape(*args, **kwargs):
486486
_, _, params, _, out_tree, _, _, _ = _infer_params(jit_info, args, kwargs)
487487
out_s = [None if is_unspecified(s) else s for s in params['out_shardings']]
488+
# TODO(yashkatariya): Add `Layout` to SDS.
488489
out = [api.ShapeDtypeStruct(x.shape, x.dtype, x.named_shape, sharding=s)
489490
for x, s in zip(params['jaxpr'].out_avals, out_s)]
490491
return tree_unflatten(out_tree, out)

tests/layout_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,25 @@ def invalid_layout_spec(self):
325325
ValueError, 'Sharding has to be concrete when layout.*'):
326326
Layout(compiled.output_layouts()[0], None)
327327

328+
def test_layout_on_sds(self):
329+
mesh = jtu.create_global_mesh((2, 1), ('x', 'y'))
330+
s = NamedSharding(mesh, P('x', 'y'))
331+
np_inp = np.arange(16).reshape(8, 2)
332+
arr = jax.device_put(np_inp, s)
333+
334+
out_layout = jax.jit(jnp.sin, out_shardings=Layout(DLL.AUTO)).lower(
335+
arr).compile().output_layouts()
336+
337+
sds = jax.ShapeDtypeStruct(arr.shape, arr.dtype, sharding=out_layout)
338+
arg_layout, _ = jax.jit(lambda x: x * 2).lower(sds).compile().input_layouts()
339+
self.assertEqual(arg_layout[0], out_layout)
340+
341+
with self.assertRaisesRegex(
342+
TypeError,
343+
'DeviceLocalLayout.AUTO` cannot be used in place of a device-local'
344+
' layout in a `ShapeDtypeStruct`'):
345+
jax.ShapeDtypeStruct(arr.shape, arr.dtype, sharding=Layout(DLL.AUTO))
346+
328347

329348
if __name__ == '__main__':
330349
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)