@@ -45,28 +45,31 @@ def get_keras_data():
4545 return (x_train , y_train ), (x_test , y_test )
4646
4747
48- def test_keras_v2 (script_mode : bool = False , eager_mode : bool = True ):
49- """ Works as intended. """
48+ def helper_test_keras_v2 (script_mode : bool = False , eager_mode : bool = True ):
49+ """ Test the default ZCC behavior of saving losses and metrics in eager and non-eager modes. """
5050 smd .del_hook ()
51-
51+ tf . keras . backend . clear_session ()
5252 if not eager_mode :
5353 tf .compat .v1 .disable_eager_execution ()
5454 with SagemakerSimulator () as sim :
5555 model = get_keras_model_v2 ()
5656 (x_train , y_train ), (x_test , y_test ) = get_keras_data ()
5757
58- model .compile (
59- loss = "sparse_categorical_crossentropy" ,
60- optimizer = tf .keras .optimizers .RMSprop (),
61- metrics = ["accuracy" ],
62- )
58+ opt = tf .keras .optimizers .RMSprop ()
6359 if script_mode :
6460 hook = smd .KerasHook (out_dir = sim .out_dir , export_tensorboard = True )
61+ opt = hook .wrap_optimizer (opt )
62+ model .compile (
63+ loss = "sparse_categorical_crossentropy" , optimizer = opt , metrics = ["accuracy" ]
64+ )
6565 history = model .fit (
6666 x_train , y_train , batch_size = 64 , epochs = 5 , validation_split = 0.2 , callbacks = [hook ]
6767 )
6868 test_scores = model .evaluate (x_test , y_test , verbose = 2 , callbacks = [hook ])
6969 else :
70+ model .compile (
71+ loss = "sparse_categorical_crossentropy" , optimizer = opt , metrics = ["accuracy" ]
72+ )
7073 history = model .fit (x_train , y_train , batch_size = 64 , epochs = 5 , validation_split = 0.2 )
7174 test_scores = model .evaluate (x_test , y_test , verbose = 2 )
7275
@@ -77,6 +80,106 @@ def test_keras_v2(script_mode: bool = False, eager_mode: bool = True):
7780 trial = smd .create_trial (path = sim .out_dir )
7881 assert len (trial .steps ()) > 0 , "Nothing saved at any step."
7982 assert len (trial .tensor_names ()) > 0 , "Tensors were not saved."
83+ assert len (trial .tensor_names (collection = "losses" )) > 0
84+
85+
86+ def helper_test_keras_v2_json_config (
87+ json_file_contents , script_mode : bool = False , eager_mode : bool = True
88+ ):
89+ """ Tests ZCC with custom hook configs """
90+ smd .del_hook ()
91+ tf .keras .backend .clear_session ()
92+ if not eager_mode :
93+ tf .compat .v1 .disable_eager_execution ()
94+ with SagemakerSimulator (json_file_contents = json_file_contents ) as sim :
95+ model = get_keras_model_v2 ()
96+ (x_train , y_train ), (x_test , y_test ) = get_keras_data ()
97+
98+ opt = tf .keras .optimizers .RMSprop ()
99+ if script_mode :
100+ hook = smd .KerasHook .create_from_json_file ()
101+ opt = hook .wrap_optimizer (opt )
102+ model .compile (
103+ loss = "sparse_categorical_crossentropy" , optimizer = opt , metrics = ["accuracy" ]
104+ )
105+ history = model .fit (
106+ x_train , y_train , batch_size = 64 , epochs = 5 , validation_split = 0.2 , callbacks = [hook ]
107+ )
108+ test_scores = model .evaluate (x_test , y_test , verbose = 2 , callbacks = [hook ])
109+ else :
110+ model .compile (
111+ loss = "sparse_categorical_crossentropy" , optimizer = opt , metrics = ["accuracy" ]
112+ )
113+ history = model .fit (x_train , y_train , epochs = 5 , batch_size = 64 , validation_split = 0.2 )
114+ test_scores = model .evaluate (x_test , y_test , verbose = 2 )
115+
116+ hook = smd .get_hook ()
117+ assert hook
118+ hook .close ()
119+ # Check that hook created and tensors saved
120+ trial = smd .create_trial (path = sim .out_dir )
121+ assert len (trial .steps ()) > 0 , "Nothing saved at any step."
122+ assert len (trial .tensor_names ()) > 0 , "Tensors were not saved."
123+ if not eager_mode :
124+ assert len (trial .tensor_names (collection = "gradients" )) > 0
125+ assert len (trial .tensor_names (collection = "weights" )) > 0
126+ assert len (trial .tensor_names (collection = "losses" )) > 0
127+
128+
129+ def test_keras_v2_default (script_mode : bool = False , eager_mode : bool = True ):
130+ # Test default ZCC behavior
131+ helper_test_keras_v2 (script_mode = script_mode , eager_mode = eager_mode )
132+
133+
134+ def test_keras_v2_multi_collections (script_mode : bool = False , eager_mode : bool = True ):
135+ # Test multiple collections included in hook json
136+ json_file_contents = """
137+ {
138+ "S3OutputPath": "s3://sagemaker-test",
139+ "LocalPath": "/opt/ml/output/tensors",
140+ "HookParameters" : {
141+ "save_interval": "2",
142+ "include_workers": "all"
143+ },
144+ "CollectionConfigurations": [
145+ {
146+ "CollectionName": "gradients"
147+ },
148+ {
149+ "CollectionName": "weights"
150+ },
151+ {
152+ "CollectionName": "losses"
153+ },
154+ {
155+ "CollectionName": "biases"
156+ },
157+ {
158+ "CollectionName": "optimizer_variables"
159+ }
160+ ]
161+ }
162+ """
163+ helper_test_keras_v2_json_config (
164+ script_mode = script_mode , eager_mode = eager_mode , json_file_contents = json_file_contents
165+ )
166+
167+
168+ def test_keras_v2_save_all (script_mode : bool = False , eager_mode : bool = True ):
169+ # Test save all through hook config
170+ json_file_contents = """
171+ {
172+ "S3OutputPath": "s3://sagemaker-test",
173+ "LocalPath": "/opt/ml/output/tensors",
174+ "HookParameters" : {
175+ "save_steps": "0,1,2,3",
176+ "save_all": true
177+ }
178+ }
179+ """
180+ helper_test_keras_v2_json_config (
181+ script_mode = script_mode , eager_mode = eager_mode , json_file_contents = json_file_contents
182+ )
80183
81184
82185if __name__ == "__main__" :
@@ -88,6 +191,11 @@ def test_keras_v2(script_mode: bool = False, eager_mode: bool = True):
88191 script_mode = args .script_mode
89192
90193 # eager mode
91- test_keras_v2 (script_mode = script_mode )
194+ test_keras_v2_default (script_mode )
195+ test_keras_v2_multi_collections (script_mode )
196+ test_keras_v2_save_all (script_mode )
197+
92198 # non-eager mode
93- test_keras_v2 (script_mode = script_mode , eager_mode = False )
199+ test_keras_v2_default (script_mode , eager_mode = False )
200+ test_keras_v2_multi_collections (script_mode , eager_mode = False )
201+ test_keras_v2_save_all (script_mode , eager_mode = False )
0 commit comments