Skip to content

Commit 5f467b9

Browse files
yashk2810jax authors
authored andcommitted
Propagate sharding of inputs to full_like that are capable of carrying sharding as an attribute.
Fixes #20390 PiperOrigin-RevId: 618202319
1 parent bed4f65 commit 5f467b9

File tree

2 files changed

+17
-8
lines changed

2 files changed

+17
-8
lines changed

jax/_src/lax/lax.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1251,7 +1251,6 @@ def full(shape: Shape, fill_value: ArrayLike, dtype: DTypeLike | None = None, *,
12511251
# if needed?
12521252
if (sharding is not None and not isinstance(sharding, PmapSharding) and
12531253
isinstance(fill_value, array.ArrayImpl)):
1254-
12551254
broadcast_shape = sharding.shard_shape(shape)
12561255
shard = broadcast(fill_value, broadcast_shape)
12571256
return array.make_array_from_callback(shape, sharding, lambda _: shard)
@@ -1415,17 +1414,15 @@ def full_like(x: ArrayLike | DuckTypedArray,
14151414
if dtypes.issubdtype(dtype, dtypes.extended):
14161415
return dtype._rules.full(fill_shape, fill_value, dtype) # type: ignore[union-attr]
14171416

1417+
# If `x` has a sharding but no `_committed` attribute
1418+
# (in case of ShapeDtypeStruct), default it to True.
14181419
use_x_sharding = (
14191420
sharding is None and
1420-
isinstance(x, array.ArrayImpl) and
1421-
not weak_type and x._committed and
1422-
# NB: consider reusng x.sharding for mismatched shapes
1423-
# if x is replicated or single device.
1424-
fill_shape == x.shape)
1421+
hasattr(x, 'sharding') and getattr(x, '_committed', True) and
1422+
not weak_type and fill_shape == x.shape) # type: ignore
14251423
if use_x_sharding:
1426-
assert isinstance(x, array.ArrayImpl) # makes pytype happy.
14271424
# TODO(yashkatariya): Use shard_alike in tracing_mode once it is supported.
1428-
sharding = x.sharding
1425+
sharding = x.sharding # type: ignore
14291426
val = full(fill_shape, _convert_element_type(fill_value, dtype, weak_type),
14301427
sharding=sharding)
14311428
return val

tests/pjit_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1824,6 +1824,18 @@ def f(tree):
18241824
for s in out4.addressable_shards:
18251825
self.assertArraysEqual(s.data, input_data)
18261826

1827+
def test_sds_full_like(self):
1828+
# https://github.com/google/jax/issues/20390
1829+
mesh = jtu.create_global_mesh((2, 1), ('x', 'y'))
1830+
s = NamedSharding(mesh, P('x', 'y'))
1831+
x = jax.ShapeDtypeStruct((4, 4), jnp.float32, sharding=s)
1832+
y = jnp.zeros_like(x)
1833+
z = jnp.zeros_like(x, device=y.sharding)
1834+
1835+
self.assertEqual(x.sharding, s)
1836+
self.assertEqual(y.sharding, s)
1837+
self.assertEqual(z.sharding, s)
1838+
18271839
def test_in_axis_resources_mismatch_error(self):
18281840
global_input_shape = (8, 2)
18291841
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))

0 commit comments

Comments
 (0)