@@ -243,6 +243,53 @@ def cb1(index):
243
243
for l in m2 .addressable_shards :
244
244
self .assertArraysEqual (l .data , global_input_data1 .astype ('float32' ))
245
245
246
+ def test_checkpointing_with_int4 (self ):
247
+ global_mesh = jtu .create_global_mesh ((2 , 2 ), ('x' , 'y' ))
248
+ global_input_shape = (8 , 2 )
249
+ num = math .prod (global_input_shape )
250
+
251
+ global_input_data = np .arange (num , dtype = jax .numpy .int8 ).reshape (global_input_shape )
252
+ def cb (index ):
253
+ return global_input_data [index ]
254
+ arr = array .make_array_from_callback (
255
+ global_input_shape , NamedSharding (global_mesh , P ('x' , 'y' )), cb )
256
+ ckpt_dir = pathlib .Path (self .create_tempdir ('first' ).full_path )
257
+
258
+ ckpt_paths = [str (ckpt_dir )]
259
+ tspecs = jax .tree_util .tree_map (serialization .get_tensorstore_spec , ckpt_paths )
260
+
261
+ manager = serialization .GlobalAsyncCheckpointManager ()
262
+ manager .serialize (
263
+ [arr ], tspecs ,
264
+ on_commit_callback = partial (self ._on_commit_callback , ckpt_dir , ckpt_dir ))
265
+ manager .wait_until_finished ()
266
+
267
+ ds = NamedSharding (jtu .create_global_mesh ((4 , 2 ), ('x' , 'y' )), P ('x' , 'y' ))
268
+
269
+ target_dtype = jax .numpy .dtype ('int4' )
270
+ m1 , = serialization .run_deserialization ([ds ], tspecs , [(12 , 2 )],
271
+ [target_dtype ])
272
+
273
+ # values bigger than 7 are converted properly.
274
+ expected_data = {
275
+ 0 : jax .numpy .array ([[0 ], [2 ], [4 ]], dtype = target_dtype ),
276
+ 1 : jax .numpy .array ([[1 ], [3 ], [5 ]], dtype = target_dtype ),
277
+ 2 : jax .numpy .array ([[6 ], [8 ], [10 ]], dtype = target_dtype ),
278
+ 3 : jax .numpy .array ([[7 ], [9 ], [11 ]], dtype = target_dtype ),
279
+ 4 : jax .numpy .array ([[12 ], [14 ], [0 ]], dtype = target_dtype ),
280
+ 5 : jax .numpy .array ([[13 ], [15 ], [0 ]], dtype = target_dtype ),
281
+ 6 : jax .numpy .array ([[0 ], [0 ], [0 ]], dtype = target_dtype ),
282
+ 7 : jax .numpy .array ([[0 ], [0 ], [0 ]], dtype = target_dtype ),
283
+ }
284
+
285
+ for l in m1 .addressable_shards :
286
+ self .assertArraysEqual (np .asarray (l .data ), expected_data [l .device .id ])
287
+
288
+ new_ds = GSPMDSharding .get_replicated (list (global_mesh .devices .flat ))
289
+ m2 , = serialization .run_deserialization ([new_ds ], tspecs , [(8 , 2 )], [target_dtype ])
290
+ for l in m2 .addressable_shards :
291
+ self .assertArraysEqual (l .data , global_input_data .astype (target_dtype ))
292
+
246
293
def test_checkpointing_scalar_jax_array (self ):
247
294
global_mesh = jtu .create_global_mesh ((2 ,), ('x' ))
248
295
global_input_shape = ()
0 commit comments