Skip to content

Commit 2c85ca6

Browse files
yashk2810jax authors
authored andcommitted
If callback returns a fully replicated global array, return it as is.
Also take the batched_device_put fast path for non-jax.Array's since slicing can return arrays on multiple devices which batched_device_put doesn't support. PiperOrigin-RevId: 624763603
1 parent 4a6ee78 commit 2c85ca6

File tree

3 files changed

+33
-12
lines changed

3 files changed

+33
-12
lines changed

jax/_src/array.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -695,13 +695,8 @@ def make_array_from_callback(
695695
>>> arr.addressable_data(0).shape
696696
(4, 2)
697697
"""
698-
has_device_assignment = False
699698
if sharding.is_fully_replicated:
700-
if isinstance(sharding, XLACompatibleSharding):
701-
devices = list(sharding._addressable_device_assignment)
702-
has_device_assignment = True
703-
else:
704-
devices = list(sharding.addressable_devices)
699+
devices = list(sharding._internal_device_list.addressable_device_list) # type: ignore
705700
per_device_values = [data_callback((slice(None),) * len(shape))] * len(devices)
706701
else:
707702
device_to_index_map = sharding.addressable_devices_indices_map(shape)
@@ -716,13 +711,11 @@ def make_array_from_callback(
716711
first_value = xla.canonicalize_dtype(per_device_values[0])
717712
aval = core.ShapedArray(shape, first_value.dtype, weak_type=False)
718713

719-
# TODO(yashkatariya): Look into taking this path for non-fully replicated
720-
# shardings too.
721-
if (sharding.is_fully_replicated and has_device_assignment and
722-
not dtypes.issubdtype(aval.dtype, dtypes.extended)):
714+
# first value can be numpy array, python scalar, etc.
715+
if (sharding.is_fully_replicated and not isinstance(first_value, ArrayImpl)
716+
and not dtypes.issubdtype(aval.dtype, dtypes.extended)):
723717
# Do this check outside because `batched_device_put` won't do these checks
724-
# like ArrayImpl. This is a fast path for fully replicated arrays with
725-
# xla compatible sharding.
718+
# like ArrayImpl.
726719
if shape != first_value.shape:
727720
raise ValueError(
728721
f"Expected shard shape {shape} doesn't match the single device "
@@ -731,6 +724,11 @@ def make_array_from_callback(
731724
return pxla.batched_device_put(
732725
aval, sharding, per_device_values, devices, committed=True)
733726

727+
if (sharding.is_fully_replicated and isinstance(first_value, ArrayImpl) and
728+
first_value.is_fully_replicated and
729+
first_value.sharding._device_assignment == devices):
730+
return first_value
731+
734732
arrays = api.device_put(per_device_values, devices)
735733
if dtypes.issubdtype(aval.dtype, dtypes.extended):
736734
return aval.dtype._rules.make_sharded_array(aval, sharding, arrays,

tests/array_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -794,6 +794,24 @@ def test_shards_have_correct_dtype(self, dtype):
794794
for shard in x.addressable_shards:
795795
self.assertEqual(shard.data.dtype, dtype)
796796

797+
def test_make_array_from_callback_global_array(self):
798+
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
799+
sharding = jax.sharding.NamedSharding(mesh, P())
800+
np_inp = np.arange(16).reshape(8, 2)
801+
arr = jax.device_put(np_inp, sharding)
802+
803+
out = jax.make_array_from_callback(np_inp.shape, sharding,
804+
lambda idx: arr[idx])
805+
self.assertArraysEqual(out, arr)
806+
self.assertEqual(out.sharding, sharding)
807+
808+
sharding2 = NamedSharding(mesh, P('x', 'y'))
809+
arr2 = jax.device_put(np_inp, sharding2)
810+
out2 = jax.make_array_from_callback(np_inp.shape, sharding2,
811+
lambda idx: arr2[idx])
812+
self.assertArraysEqual(out2, arr2)
813+
self.assertEqual(out2.sharding, sharding2)
814+
797815

798816
class ShardingTest(jtu.JaxTestCase):
799817

tests/pjit_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4062,6 +4062,7 @@ def __init__(self, devices):
40624062
if xla_extension_version >= 235:
40634063
super().__init__()
40644064
self._devices = devices
4065+
self._internal_device_list = xc.DeviceList(tuple(self._devices))
40654066

40664067
@property
40674068
def device_set(self):
@@ -4073,6 +4074,10 @@ def devices_indices_map(self, global_shape):
40734074
def shard_shape(self, global_shape):
40744075
return global_shape
40754076

4077+
@property
4078+
def memory_kind(self):
4079+
return None
4080+
40764081
@property
40774082
def is_fully_replicated(self):
40784083
return True

0 commit comments

Comments
 (0)