Skip to content

Commit b6e985f

Browse files
jianlijianlijax authors
authored andcommitted
Add int4 test to ArrayImpl.
PiperOrigin-RevId: 614778550
1 parent 778933d commit b6e985f

File tree

1 file changed

+47
-0
lines changed

1 file changed

+47
-0
lines changed

jax/experimental/array_serialization/serialization_test.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,53 @@ def cb1(index):
243243
for l in m2.addressable_shards:
244244
self.assertArraysEqual(l.data, global_input_data1.astype('float32'))
245245

246+
def test_checkpointing_with_int4(self):
247+
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
248+
global_input_shape = (8, 2)
249+
num = math.prod(global_input_shape)
250+
251+
global_input_data = np.arange(num, dtype=jax.numpy.int8).reshape(global_input_shape)
252+
def cb(index):
253+
return global_input_data[index]
254+
arr = array.make_array_from_callback(
255+
global_input_shape, NamedSharding(global_mesh, P('x', 'y')), cb)
256+
ckpt_dir = pathlib.Path(self.create_tempdir('first').full_path)
257+
258+
ckpt_paths = [str(ckpt_dir)]
259+
tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths)
260+
261+
manager = serialization.GlobalAsyncCheckpointManager()
262+
manager.serialize(
263+
[arr], tspecs,
264+
on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir))
265+
manager.wait_until_finished()
266+
267+
ds = NamedSharding(jtu.create_global_mesh((4, 2), ('x', 'y')), P('x', 'y'))
268+
269+
target_dtype = jax.numpy.dtype('int4')
270+
m1, = serialization.run_deserialization([ds], tspecs, [(12, 2)],
271+
[target_dtype])
272+
273+
# values bigger than 7 are converted properly.
274+
expected_data = {
275+
0: jax.numpy.array([[0], [2], [4]], dtype=target_dtype),
276+
1: jax.numpy.array([[1], [3], [5]], dtype=target_dtype),
277+
2: jax.numpy.array([[6], [8], [10]], dtype=target_dtype),
278+
3: jax.numpy.array([[7], [9], [11]], dtype=target_dtype),
279+
4: jax.numpy.array([[12], [14], [0]], dtype=target_dtype),
280+
5: jax.numpy.array([[13], [15], [0]], dtype=target_dtype),
281+
6: jax.numpy.array([[0], [0], [0]], dtype=target_dtype),
282+
7: jax.numpy.array([[0], [0], [0]], dtype=target_dtype),
283+
}
284+
285+
for l in m1.addressable_shards:
286+
self.assertArraysEqual(np.asarray(l.data), expected_data[l.device.id])
287+
288+
new_ds = GSPMDSharding.get_replicated(list(global_mesh.devices.flat))
289+
m2, = serialization.run_deserialization([new_ds], tspecs, [(8, 2)], [target_dtype])
290+
for l in m2.addressable_shards:
291+
self.assertArraysEqual(l.data, global_input_data.astype(target_dtype))
292+
246293
def test_checkpointing_scalar_jax_array(self):
247294
global_mesh = jtu.create_global_mesh((2,), ('x'))
248295
global_input_shape = ()

0 commit comments

Comments
 (0)