We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent d6f074b commit d0819aeCopy full SHA for d0819ae
jax/experimental/array_serialization/serialization.py
@@ -260,10 +260,7 @@ async def _write_array(shard):
260
else:
261
await write_future.commit
262
263
- if isinstance(arr_inp, array.ArrayImpl):
264
- local_shards = arr_inp.addressable_shards
265
- else:
266
+ local_shards = arr_inp.addressable_shards
267
future_write_state = jax.tree_util.tree_map(_write_array, local_shards)
268
return await asyncio.gather(*future_write_state)
269
0 commit comments