Skip to content

Commit 5ce7dca

Browse files
yashk2810jax authors
authored andcommitted
Add support for loading checkpoints with a given layout to the array serialization library
PiperOrigin-RevId: 624596358
1 parent 70dca30 commit 5ce7dca

File tree

2 files changed

+56
-4
lines changed

2 files changed

+56
-4
lines changed

jax/experimental/array_serialization/serialization.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from jax._src import distributed
3434
from jax._src import sharding
3535
from jax._src import sharding_impls
36+
from jax._src.layout import Layout, DeviceLocalLayout as DLL
3637
from jax._src import typing
3738
from jax._src import util
3839
from jax._src.lib import xla_extension as xe
@@ -306,14 +307,22 @@ def estimate_read_memory_footprint(t: ts.TensorStore,
306307

307308

308309
async def async_deserialize(
309-
in_sharding: sharding_impls.XLACompatibleSharding,
310+
in_sharding: sharding_impls.XLACompatibleSharding | Layout,
310311
tensorstore_spec: ts.Spec | dict[str, Any],
311312
global_shape: Sequence[int] | None = None,
312313
dtype=None,
313314
byte_limiter: _LimitInFlightBytes | None = None,
314315
context=TS_CONTEXT,
315316
assume_metadata: bool = False,
316317
):
318+
in_sharding = (in_sharding.sharding if isinstance(in_sharding, Layout) else # type: ignore
319+
in_sharding)
320+
if not isinstance(in_sharding, sharding_impls.XLACompatibleSharding):
321+
raise ValueError(
322+
'sharding passed to deserialization should be specified, concrete and'
323+
f' an instance of `jax.XLACompatibleSharding`. Got {in_sharding}')
324+
dll = (in_sharding.device_local_layout if isinstance(in_sharding, Layout)
325+
else None)
317326
t = await ts.open(
318327
tensorstore_spec,
319328
open=True,
@@ -340,7 +349,8 @@ async def cb(index: array.Index, device: jax.Device):
340349
# Cast while reloading on process to avoid 2 copies on device if the
341350
# casting is done on device.
342351
out = out.astype(dtype)
343-
result = jax.device_put(out, device)
352+
result = jax.device_put(
353+
out, Layout(dll, jax.sharding.SingleDeviceSharding(device)))
344354
if byte_limiter is not None:
345355
# NB: `out` actually might not be ready for garbage collection by the
346356
# time we call release_bytes . Thus peak memory usage still might grow
@@ -358,7 +368,7 @@ async def cb(index: array.Index, device: jax.Device):
358368
return await create_async_array_from_callback(tuple(shape), in_sharding, cb)
359369

360370

361-
def run_deserialization(shardings: Sequence[sharding.Sharding],
371+
def run_deserialization(shardings: Sequence[sharding.Sharding | Layout],
362372
tensorstore_specs: Sequence[dict[str, Any]],
363373
global_shapes: Sequence[array.Shape] | None = None,
364374
dtypes: Sequence[typing.DTypeLike] | None = None,
@@ -596,7 +606,7 @@ def serialize_with_paths(self, arrays: Sequence[jax.Array],
596606
tspecs = jax.tree.map(get_tensorstore_spec, paths)
597607
self.serialize(arrays, tspecs, on_commit_callback=on_commit_callback)
598608

599-
def deserialize(self, shardings: Sequence[sharding.Sharding],
609+
def deserialize(self, shardings: Sequence[sharding.Sharding | Layout],
600610
tensorstore_specs: Sequence[dict[str, Any]],
601611
global_shapes: Sequence[array.Shape] | None = None,
602612
dtypes: Sequence[typing.DTypeLike] | None = None,

jax/experimental/array_serialization/serialization_test.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import asyncio
1717
import math
1818
from functools import partial
19+
import re
1920
import os
2021
import pathlib
2122
import tracemalloc as tm
@@ -28,6 +29,7 @@
2829
from jax.sharding import NamedSharding, GSPMDSharding
2930
from jax.sharding import PartitionSpec as P
3031
from jax.experimental.array_serialization import serialization
32+
from jax.experimental.layout import Layout, DeviceLocalLayout as DLL
3133
import numpy as np
3234
import tensorstore as ts
3335

@@ -45,6 +47,13 @@ def tearDownModule():
4547
prev_xla_flags()
4648

4749

50+
pattern = re.compile(r"\{(.*?):")
51+
52+
def extract_minor_to_major(l):
53+
match = re.search(pattern, str(l))
54+
return tuple(int(i) for i in match.groups()[0].split(','))
55+
56+
4857
class CheckpointTest(jtu.JaxTestCase):
4958

5059
def _on_commit_callback(self, temp_ckpt_dir, final_ckpt_dir):
@@ -411,5 +420,38 @@ def test_maybe_cloud_storage(self):
411420
}
412421
self.assertTrue(serialization.is_remote_storage(nested_tspec))
413422

423+
def test_load_with_layout(self):
424+
if not jtu.test_device_matches(['tpu']):
425+
self.skipTest('Layouts are only supported on TPUs')
426+
427+
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
428+
np_inp = np.arange(32).reshape(8, 4)
429+
s = NamedSharding(mesh, P('x', 'y'))
430+
arr = jax.device_put(np_inp, s)
431+
432+
out_layout = jax.jit(lambda x: x.T, out_shardings=Layout(DLL.AUTO)).lower(
433+
arr).compile().output_layouts()
434+
self.assertEqual(extract_minor_to_major(arr.layout),
435+
extract_minor_to_major(out_layout)[::-1])
436+
437+
ckpt_dir = pathlib.Path(self.create_tempdir('ckpt').full_path)
438+
ckpt_path = pathlib.Path(self.create_tempdir(f'{ckpt_dir}/first').full_path)
439+
tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, [ckpt_path])
440+
441+
manager = serialization.GlobalAsyncCheckpointManager()
442+
manager.serialize(
443+
[arr], tspecs,
444+
on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir))
445+
manager.wait_until_finished()
446+
447+
out, = serialization.run_deserialization([out_layout], tspecs)
448+
449+
self.assertEqual(out.layout, out_layout)
450+
self.assertIsInstance(out, array.ArrayImpl)
451+
self.assertArraysEqual(out, np_inp)
452+
for s in out.addressable_shards:
453+
self.assertArraysEqual(s.data, np_inp[s.index])
454+
455+
414456
if __name__ == '__main__':
415457
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)