Skip to content

Commit 70dca30

Browse files
yashk2810jax authors
authored andcommitted
Remove the dead code now that jax.Array is the only array we have
PiperOrigin-RevId: 624390245
1 parent ee8ce0f commit 70dca30

File tree

1 file changed

+3
-7
lines changed

1 file changed

+3
-7
lines changed

jax/_src/array.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -886,15 +886,11 @@ def _hashable_index(idx):
886886
return tree_util.tree_map(
887887
lambda x: (x.start, x.stop) if type(x) == slice else x, idx)
888888

889-
# The fast path is handled directly in shard_args().
889+
890890
def shard_sharded_device_array_slow_path(x, devices, indices, sharding):
891891
candidates = defaultdict(list)
892-
if isinstance(x, ArrayImpl):
893-
bufs = [buf.data for buf in x.addressable_shards]
894-
arr_indices = tuple(x.sharding.devices_indices_map(x.shape).values())
895-
else:
896-
bufs = x.device_buffers
897-
arr_indices = x.indices
892+
bufs = [buf.data for buf in x.addressable_shards]
893+
arr_indices = tuple(x.sharding.devices_indices_map(x.shape).values())
898894
for buf, idx in safe_zip(bufs, arr_indices):
899895
candidates[_hashable_index(idx)].append(buf)
900896

0 commit comments

Comments
 (0)