33
44# Standard Library
55import os
6+ import time
67
78# Third Party
89import pytest
910import tensorflow .compat .v2 as tf
1011import tensorflow_datasets as tfds
1112from tensorflow .python .client import device_lib
13+ from tests .core .utils import verify_files
1214from tests .tensorflow2 .utils import is_tf_2_2
1315from tests .tensorflow .utils import create_trial_fast_refresh
1416
@@ -119,6 +121,10 @@ def scale(image, label):
119121 )
120122
121123 hooks .append (hook )
124+ scalars_to_be_saved = dict ()
125+ ts = time .time ()
126+ scalars_to_be_saved ["scalar/foobar" ] = (ts , steps )
127+ hook .save_scalar ("foobar" , 1 , sm_metric = True , timestamp = ts )
122128
123129 if steps is None :
124130 steps = ["train" ]
@@ -131,7 +137,7 @@ def scale(image, label):
131137 model .predict (train_dataset , steps = 4 , callbacks = hooks , verbose = 0 )
132138
133139 smd .get_hook ().close ()
134- return strategy
140+ return strategy , scalars_to_be_saved
135141
136142
137143def exhaustive_check (trial_dir , include_workers = "one" , eager = True ):
@@ -144,7 +150,7 @@ def exhaustive_check(trial_dir, include_workers="one", eager=True):
144150 CollectionKeys .METRICS ,
145151 CollectionKeys .OPTIMIZER_VARIABLES ,
146152 ]
147- strategy = train_model (
153+ strategy , _ = train_model (
148154 trial_dir ,
149155 include_collections = include_collections ,
150156 steps = ["train" , "eval" , "predict" , "train" ],
@@ -158,9 +164,11 @@ def exhaustive_check(trial_dir, include_workers="one", eager=True):
158164 if include_workers == "all" :
159165 assert len (tr .workers ()) == strategy .num_replicas_in_sync
160166 if eager :
161- assert len (tr .tensor_names ()) == (6 + 1 + 2 + 5 if is_tf_2_2 () else 6 + 1 + 3 + 5 )
162- # 6 weights, 1 loss, 3 metrics, 5 optimizer variables for Tf 2.1
163- # 6 weights, 1 loss, 2 metrics, 5 optimizer variables for Tf 2.2
167+ assert len (tr .tensor_names ()) == (
168+ 6 + 1 + 2 + 5 + 1 if is_tf_2_2 () else 6 + 1 + 3 + 5 + 1
169+ )
170+ # 6 weights, 1 loss, 3 metrics, 5 optimizer variables for Tf 2.1, 1 scalar
171+ # 6 weights, 1 loss, 2 metrics, 5 optimizer variables for Tf 2.2, 1 scalar
164172 else :
165173 assert len (tr .tensor_names ()) == (6 + 6 + 1 + 3 + strategy .num_replicas_in_sync * 3 + 5 )
166174 else :
@@ -235,20 +243,21 @@ def test_tf_keras(out_dir, tf_eager_mode, include_workers="all"):
235243@pytest .mark .slow
236244@pytest .mark .parametrize ("workers" , ["one" , "all" ])
237245def test_save_all (out_dir , tf_eager_mode , workers ):
238- strategy = train_model (
246+ save_config = SaveConfig (save_steps = [5 ])
247+ strategy , saved_scalars = train_model (
239248 out_dir ,
240249 include_collections = None ,
241250 save_all = True ,
242- save_config = SaveConfig ( save_steps = [ 5 ]) ,
251+ save_config = save_config ,
243252 steps = ["train" ],
244253 eager = tf_eager_mode ,
245254 include_workers = workers ,
246255 )
247256 tr = create_trial_fast_refresh (out_dir )
248257 print (tr .tensor_names ())
249258 if tf_eager_mode :
250- assert len (tr .tensor_names ()) == (6 + 2 + 1 + 5 if is_tf_2_2 () else 6 + 3 + 1 + 5 )
251- # weights, metrics, losses, optimizer variables
259+ assert len (tr .tensor_names ()) == (6 + 2 + 1 + 5 + 1 if is_tf_2_2 () else 6 + 3 + 1 + 5 + 1 )
260+ # weights, metrics, losses, optimizer variables, scalar
252261 else :
253262 assert (
254263 len (tr .tensor_names ())
@@ -266,6 +275,7 @@ def test_save_all(out_dir, tf_eager_mode, workers):
266275 assert len (tr .tensor (tname ).workers (0 )) == (
267276 1 if workers == "one" else strategy .num_replicas_in_sync
268277 )
278+ verify_files (out_dir , save_config , saved_scalars )
269279
270280
271281@pytest .mark .slow
@@ -350,7 +360,7 @@ def test_include_regex(out_dir, tf_eager_mode, workers):
350360 include_workers = workers ,
351361 )
352362 hook .get_collection ("custom_coll" ).include ("dense" )
353- strategy = train_model (out_dir , hook = hook , steps = ["train" ], eager = tf_eager_mode )
363+ strategy , _ = train_model (out_dir , hook = hook , steps = ["train" ], eager = tf_eager_mode )
354364
355365 tr = create_trial_fast_refresh (out_dir )
356366 tnames = tr .tensor_names (collection = "custom_coll" )
@@ -378,7 +388,7 @@ def test_include_regex_opt_var(out_dir, tf_eager_mode, workers):
378388 include_workers = workers ,
379389 )
380390 hook .get_collection ("custom_optimizer_variables" ).include ("Adam" )
381- strategy = train_model (out_dir , hook = hook , steps = ["train" ], eager = tf_eager_mode )
391+ strategy , _ = train_model (out_dir , hook = hook , steps = ["train" ], eager = tf_eager_mode )
382392
383393 tr = create_trial_fast_refresh (out_dir )
384394 tnames = tr .tensor_names (collection = "custom_optimizer_variables" )
@@ -411,11 +421,11 @@ def test_clash_with_tb_callback(out_dir):
411421 add_callbacks = ["tensorboard" ],
412422 )
413423 tr = create_trial_fast_refresh (out_dir )
414- assert len (tr .tensor_names ()) == (9 if is_tf_2_2 () else 10 )
424+ assert len (tr .tensor_names ()) == (10 if is_tf_2_2 () else 11 )
415425
416426
417427def test_one_device (out_dir , tf_eager_mode ):
418- strategy = train_model (
428+ strategy , _ = train_model (
419429 out_dir ,
420430 include_collections = [
421431 CollectionKeys .WEIGHTS ,
0 commit comments