1
+ # Copyright 2024 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import unittest
17
+ import numpy as np
18
+ import jax
19
+ import jax .numpy as jnp
20
+ from flax import nnx
21
+ import orbax .checkpoint as ocp
22
+
23
+
24
+ class SimpleModel (nnx .Module ):
25
+
26
+ def __init__ (self , rngs ):
27
+ self .layer1 = nnx .Linear (2 , 5 , rngs = rngs )
28
+ self .layer2 = nnx .Linear (5 , 3 , rngs = rngs )
29
+
30
+ def __call__ (self , x ):
31
+ for layer in [self .layer1 , self .layer2 ]:
32
+ x = layer (x )
33
+ return x
34
+
35
+
36
+ class NNXOrbaxTest (unittest .TestCase ):
37
+
38
+ def setUp (self ):
39
+ options = ocp .CheckpointManagerOptions (
40
+ create = True , max_to_keep = 3 , keep_period = 2 , step_prefix = 'test' )
41
+ self .manager = ocp .CheckpointManager (
42
+ ocp .test_utils .erase_and_create_empty ('/tmp/test-checkpoint/' ),
43
+ ocp .Checkpointer (ocp .PyTreeCheckpointHandler ()),
44
+ options = options
45
+ )
46
+
47
+ def test_nnx_orbax_checkpoint (self ):
48
+ key = jax .random .key (1701 )
49
+ x = jax .random .normal (key , (1 , 2 ))
50
+ y = jnp .ones ((1 , 3 ))
51
+
52
+ model = SimpleModel (nnx .Rngs (0 ))
53
+
54
+ # Create the checkpoint
55
+ graphdef , state = nnx .split (model )
56
+ abstract_state = jax .tree .map (ocp .utils .to_shape_dtype_struct , state )
57
+
58
+ restore_args = ocp .checkpoint_utils .construct_restore_args (abstract_state )
59
+ self .manager .save (0 , state )
60
+ self .manager .wait_until_finished ()
61
+
62
+ restored_state = self .manager .restore (
63
+ 0 , items = abstract_state , restore_kwargs = {'restore_args' : restore_args })
64
+ restored_model = nnx .merge (graphdef , restored_state )
65
+
66
+ self .assertEqual (type (model ), type (restored_model ))
67
+ jax .tree .map (np .testing .assert_array_equal , state , restored_state )
68
+
69
+ if __name__ == '__main__' :
70
+ unittest .main ()
0 commit comments