@@ -695,13 +695,8 @@ def make_array_from_callback(
695
695
>>> arr.addressable_data(0).shape
696
696
(4, 2)
697
697
"""
698
- has_device_assignment = False
699
698
if sharding .is_fully_replicated :
700
- if isinstance (sharding , XLACompatibleSharding ):
701
- devices = list (sharding ._addressable_device_assignment )
702
- has_device_assignment = True
703
- else :
704
- devices = list (sharding .addressable_devices )
699
+ devices = list (sharding ._internal_device_list .addressable_device_list ) # type: ignore
705
700
per_device_values = [data_callback ((slice (None ),) * len (shape ))] * len (devices )
706
701
else :
707
702
device_to_index_map = sharding .addressable_devices_indices_map (shape )
@@ -716,13 +711,11 @@ def make_array_from_callback(
716
711
first_value = xla .canonicalize_dtype (per_device_values [0 ])
717
712
aval = core .ShapedArray (shape , first_value .dtype , weak_type = False )
718
713
719
- # TODO(yashkatariya): Look into taking this path for non-fully replicated
720
- # shardings too.
721
- if (sharding .is_fully_replicated and has_device_assignment and
722
- not dtypes .issubdtype (aval .dtype , dtypes .extended )):
714
+ # first value can be numpy array, python scalar, etc.
715
+ if (sharding .is_fully_replicated and not isinstance (first_value , ArrayImpl )
716
+ and not dtypes .issubdtype (aval .dtype , dtypes .extended )):
723
717
# Do this check outside because `batched_device_put` won't do these checks
724
- # like ArrayImpl. This is a fast path for fully replicated arrays with
725
- # xla compatible sharding.
718
+ # like ArrayImpl.
726
719
if shape != first_value .shape :
727
720
raise ValueError (
728
721
f"Expected shard shape { shape } doesn't match the single device "
@@ -731,6 +724,11 @@ def make_array_from_callback(
731
724
return pxla .batched_device_put (
732
725
aval , sharding , per_device_values , devices , committed = True )
733
726
727
+ if (sharding .is_fully_replicated and isinstance (first_value , ArrayImpl ) and
728
+ first_value .is_fully_replicated and
729
+ first_value .sharding ._device_assignment == devices ):
730
+ return first_value
731
+
734
732
arrays = api .device_put (per_device_values , devices )
735
733
if dtypes .issubdtype (aval .dtype , dtypes .extended ):
736
734
return aval .dtype ._rules .make_sharded_array (aval , sharding , arrays ,
0 commit comments