File tree Expand file tree Collapse file tree 1 file changed +3
-7
lines changed Expand file tree Collapse file tree 1 file changed +3
-7
lines changed Original file line number Diff line number Diff line change @@ -886,15 +886,11 @@ def _hashable_index(idx):
886
886
return tree_util .tree_map (
887
887
lambda x : (x .start , x .stop ) if type (x ) == slice else x , idx )
888
888
889
- # The fast path is handled directly in shard_args().
889
+
890
890
def shard_sharded_device_array_slow_path (x , devices , indices , sharding ):
891
891
candidates = defaultdict (list )
892
- if isinstance (x , ArrayImpl ):
893
- bufs = [buf .data for buf in x .addressable_shards ]
894
- arr_indices = tuple (x .sharding .devices_indices_map (x .shape ).values ())
895
- else :
896
- bufs = x .device_buffers
897
- arr_indices = x .indices
892
+ bufs = [buf .data for buf in x .addressable_shards ]
893
+ arr_indices = tuple (x .sharding .devices_indices_map (x .shape ).values ())
898
894
for buf , idx in safe_zip (bufs , arr_indices ):
899
895
candidates [_hashable_index (idx )].append (buf )
900
896
You can’t perform that action at this time.
0 commit comments