33
33
from jax ._src import distributed
34
34
from jax ._src import sharding
35
35
from jax ._src import sharding_impls
36
+ from jax ._src .layout import Layout , DeviceLocalLayout as DLL
36
37
from jax ._src import typing
37
38
from jax ._src import util
38
39
from jax ._src .lib import xla_extension as xe
@@ -306,14 +307,22 @@ def estimate_read_memory_footprint(t: ts.TensorStore,
306
307
307
308
308
309
async def async_deserialize (
309
- in_sharding : sharding_impls .XLACompatibleSharding ,
310
+ in_sharding : sharding_impls .XLACompatibleSharding | Layout ,
310
311
tensorstore_spec : ts .Spec | dict [str , Any ],
311
312
global_shape : Sequence [int ] | None = None ,
312
313
dtype = None ,
313
314
byte_limiter : _LimitInFlightBytes | None = None ,
314
315
context = TS_CONTEXT ,
315
316
assume_metadata : bool = False ,
316
317
):
318
+ in_sharding = (in_sharding .sharding if isinstance (in_sharding , Layout ) else # type: ignore
319
+ in_sharding )
320
+ if not isinstance (in_sharding , sharding_impls .XLACompatibleSharding ):
321
+ raise ValueError (
322
+ 'sharding passed to deserialization should be specified, concrete and'
323
+ f' an instance of `jax.XLACompatibleSharding`. Got { in_sharding } ' )
324
+ dll = (in_sharding .device_local_layout if isinstance (in_sharding , Layout )
325
+ else None )
317
326
t = await ts .open (
318
327
tensorstore_spec ,
319
328
open = True ,
@@ -340,7 +349,8 @@ async def cb(index: array.Index, device: jax.Device):
340
349
# Cast while reloading on process to avoid 2 copies on device if the
341
350
# casting is done on device.
342
351
out = out .astype (dtype )
343
- result = jax .device_put (out , device )
352
+ result = jax .device_put (
353
+ out , Layout (dll , jax .sharding .SingleDeviceSharding (device )))
344
354
if byte_limiter is not None :
345
355
# NB: `out` actually might not be ready for garbage collection by the
346
356
# time we call release_bytes . Thus peak memory usage still might grow
@@ -358,7 +368,7 @@ async def cb(index: array.Index, device: jax.Device):
358
368
return await create_async_array_from_callback (tuple (shape ), in_sharding , cb )
359
369
360
370
361
- def run_deserialization (shardings : Sequence [sharding .Sharding ],
371
+ def run_deserialization (shardings : Sequence [sharding .Sharding | Layout ],
362
372
tensorstore_specs : Sequence [dict [str , Any ]],
363
373
global_shapes : Sequence [array .Shape ] | None = None ,
364
374
dtypes : Sequence [typing .DTypeLike ] | None = None ,
@@ -596,7 +606,7 @@ def serialize_with_paths(self, arrays: Sequence[jax.Array],
596
606
tspecs = jax .tree .map (get_tensorstore_spec , paths )
597
607
self .serialize (arrays , tspecs , on_commit_callback = on_commit_callback )
598
608
599
- def deserialize (self , shardings : Sequence [sharding .Sharding ],
609
+ def deserialize (self , shardings : Sequence [sharding .Sharding | Layout ],
600
610
tensorstore_specs : Sequence [dict [str , Any ]],
601
611
global_shapes : Sequence [array .Shape ] | None = None ,
602
612
dtypes : Sequence [typing .DTypeLike ] | None = None ,
0 commit comments