Skip to content

Commit d0819ae

Browse files
author
jax authors
committed
remove unnecessary if statement
PiperOrigin-RevId: 617653292
1 parent d6f074b commit d0819ae

File tree

1 file changed

+1
-4
lines changed

1 file changed

+1
-4
lines changed

jax/experimental/array_serialization/serialization.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -260,10 +260,7 @@ async def _write_array(shard):
260260
else:
261261
await write_future.commit
262262

263-
if isinstance(arr_inp, array.ArrayImpl):
264-
local_shards = arr_inp.addressable_shards
265-
else:
266-
local_shards = arr_inp.addressable_shards
263+
local_shards = arr_inp.addressable_shards
267264
future_write_state = jax.tree_util.tree_map(_write_array, local_shards)
268265
return await asyncio.gather(*future_write_state)
269266

0 commit comments

Comments
 (0)