Skip to content

Commit b1b3bac

Browse files
authored
Use executing_eagerly_outside_functions for global execution context (#178)
Fixing a problem with TF 2.1 training where gradients are not emitted.
1 parent 01c96b2 commit b1b3bac

File tree

5 files changed

+184
-14
lines changed

5 files changed

+184
-14
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def build_package(version):
4747
name="smdebug",
4848
version=version,
4949
long_description="\n".join(DOCLINES[1:]),
50-
long_description_content_type="text/x-rst",
50+
long_description_content_type="text/markdown",
5151
author="AWS DeepLearning Team",
5252
description=DOCLINES[0],
5353
url="https://github.com/awslabs/sagemaker-debugger",

smdebug/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.7.0"
1+
__version__ = "0.7.1"

smdebug/tensorflow/base_hook.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# Third Party
77
import tensorflow.compat.v1 as tf
88
from tensorflow.python.distribute.distribute_lib import _DefaultDistributionStrategy
9+
from tensorflow.python.framework import ops
910

1011
# First Party
1112
from smdebug.core.collection import DEFAULT_TF_COLLECTIONS
@@ -420,7 +421,11 @@ def set_gradients(self, gradients=None, gradients_and_variables=None):
420421
# TF 2.x doesn't provide gradient/optimizer variable names and values by default.
421422
# Skipping set_gradients and set_optimizer_variables for Tf 2.x until there is
422423
# support to pass names and values from TF side.
423-
if is_tf_version_2x() and tf.executing_eagerly():
424+
425+
# From TF 2.2, executing_eagerly_outside_functions() can be used as
426+
# ops.executing_eagerly_outside_functions() or tf.compat.v1.executing_eagerly_outside_functions().
427+
# But in TF 2.1, only ops.executing_eagerly_outside_functions() is valid
428+
if is_tf_version_2x() and ops.executing_eagerly_outside_functions():
424429
return
425430
if self._gradients_set is False:
426431
if gradients is not None:
@@ -441,7 +446,11 @@ def set_optimizer_variables(self, optimizer_variables):
441446
# TF 2.x doesn't provide gradient/optimizer variable names and values by default.
442447
# Skipping set_gradients and set_optimizer_variables for Tf 2.x until there is
443448
# support to pass names and values from TF side.
444-
if is_tf_version_2x() and tf.executing_eagerly():
449+
450+
# From TF 2.2, executing_eagerly_outside_functions() can be used as
451+
# ops.executing_eagerly_outside_functions() or tf.compat.v1.executing_eagerly_outside_functions().
452+
# But in TF 2.1, only ops.executing_eagerly_outside_functions() is valid
453+
if is_tf_version_2x() and ops.executing_eagerly_outside_functions():
445454
return
446455
# since this is done for each variable at a time for keras, not checking if set already
447456
self.collection_manager.get(CollectionKeys.OPTIMIZER_VARIABLES).add_for_mode(

tests/tensorflow2/test_keras.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import smdebug.tensorflow as smd
1717
from smdebug.core.access_layer import has_training_ended
1818
from smdebug.core.collection import CollectionKeys
19+
from smdebug.core.json_config import CONFIG_FILE_PATH_ENV_STR
1920
from smdebug.core.reduction_config import ALLOWED_NORMS, ALLOWED_REDUCTIONS
2021
from smdebug.exceptions import TensorUnavailableForStep
2122
from smdebug.tensorflow import ReductionConfig, SaveConfig
@@ -238,3 +239,55 @@ def test_weights_collections(out_dir, tf_eager_mode):
238239
assert len(trial.tensor_names(collection=CollectionKeys.WEIGHTS)) == 2
239240
assert len(trial.tensor_names(collection=CollectionKeys.LOSSES)) == 1
240241
assert len(trial.tensor_names(collection=CollectionKeys.METRICS)) == 3
242+
243+
244+
@pytest.mark.slow
245+
def test_include_collections(out_dir, tf_eager_mode):
246+
include_collections = [
247+
CollectionKeys.WEIGHTS,
248+
CollectionKeys.BIASES,
249+
CollectionKeys.GRADIENTS,
250+
CollectionKeys.LOSSES,
251+
CollectionKeys.OUTPUTS,
252+
CollectionKeys.METRICS,
253+
CollectionKeys.OPTIMIZER_VARIABLES,
254+
]
255+
save_config = SaveConfig(save_interval=3)
256+
hook = smd.KerasHook(
257+
out_dir,
258+
save_config=save_config,
259+
include_collections=include_collections,
260+
reduction_config=ReductionConfig(norms=ALLOWED_NORMS, reductions=ALLOWED_REDUCTIONS),
261+
)
262+
helper_keras_fit(out_dir, hook=hook, steps=["train", "eval", "predict"], eager=tf_eager_mode)
263+
264+
trial = smd.create_trial(path=out_dir)
265+
# can't save gradients in TF 2.x
266+
if tf_eager_mode:
267+
assert len(trial.tensor_names()) == 8
268+
else:
269+
assert len(trial.tensor_names()) == 18
270+
assert len(trial.tensor_names(collection=CollectionKeys.GRADIENTS)) == 4
271+
assert len(trial.tensor_names(collection=CollectionKeys.OPTIMIZER_VARIABLES)) == 5
272+
assert len(trial.tensor_names(collection=CollectionKeys.BIASES)) == 2
273+
assert len(trial.tensor_names(collection=CollectionKeys.WEIGHTS)) == 2
274+
assert len(trial.tensor_names(collection=CollectionKeys.LOSSES)) == 1
275+
assert len(trial.tensor_names(collection=CollectionKeys.METRICS)) == 3
276+
277+
278+
@pytest.mark.slow
279+
def test_hook_from_json(out_dir, tf_eager_mode, monkeypatch):
280+
monkeypatch.setenv(
281+
CONFIG_FILE_PATH_ENV_STR,
282+
"tests/tensorflow/hooks/test_json_configs/test_collection_defaults.json",
283+
)
284+
hook = smd.KerasHook.create_from_json_file()
285+
helper_keras_fit(out_dir, hook=hook, steps=["train"], eager=tf_eager_mode)
286+
287+
trial = smd.create_trial(path=out_dir)
288+
# can't save gradients in TF 2.x
289+
assert len(trial.tensor_names()) == 6
290+
assert len(trial.tensor_names(collection=CollectionKeys.BIASES)) == 0
291+
assert len(trial.tensor_names(collection=CollectionKeys.WEIGHTS)) == 2
292+
assert len(trial.tensor_names(collection=CollectionKeys.LOSSES)) == 1
293+
assert len(trial.tensor_names(collection=CollectionKeys.METRICS)) == 3

tests/zero_code_change/tensorflow2_integration_tests.py

Lines changed: 118 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,28 +45,31 @@ def get_keras_data():
4545
return (x_train, y_train), (x_test, y_test)
4646

4747

48-
def test_keras_v2(script_mode: bool = False, eager_mode: bool = True):
49-
""" Works as intended. """
48+
def helper_test_keras_v2(script_mode: bool = False, eager_mode: bool = True):
49+
""" Test the default ZCC behavior of saving losses and metrics in eager and non-eager modes."""
5050
smd.del_hook()
51-
51+
tf.keras.backend.clear_session()
5252
if not eager_mode:
5353
tf.compat.v1.disable_eager_execution()
5454
with SagemakerSimulator() as sim:
5555
model = get_keras_model_v2()
5656
(x_train, y_train), (x_test, y_test) = get_keras_data()
5757

58-
model.compile(
59-
loss="sparse_categorical_crossentropy",
60-
optimizer=tf.keras.optimizers.RMSprop(),
61-
metrics=["accuracy"],
62-
)
58+
opt = tf.keras.optimizers.RMSprop()
6359
if script_mode:
6460
hook = smd.KerasHook(out_dir=sim.out_dir, export_tensorboard=True)
61+
opt = hook.wrap_optimizer(opt)
62+
model.compile(
63+
loss="sparse_categorical_crossentropy", optimizer=opt, metrics=["accuracy"]
64+
)
6565
history = model.fit(
6666
x_train, y_train, batch_size=64, epochs=5, validation_split=0.2, callbacks=[hook]
6767
)
6868
test_scores = model.evaluate(x_test, y_test, verbose=2, callbacks=[hook])
6969
else:
70+
model.compile(
71+
loss="sparse_categorical_crossentropy", optimizer=opt, metrics=["accuracy"]
72+
)
7073
history = model.fit(x_train, y_train, batch_size=64, epochs=5, validation_split=0.2)
7174
test_scores = model.evaluate(x_test, y_test, verbose=2)
7275

@@ -77,6 +80,106 @@ def test_keras_v2(script_mode: bool = False, eager_mode: bool = True):
7780
trial = smd.create_trial(path=sim.out_dir)
7881
assert len(trial.steps()) > 0, "Nothing saved at any step."
7982
assert len(trial.tensor_names()) > 0, "Tensors were not saved."
83+
assert len(trial.tensor_names(collection="losses")) > 0
84+
85+
86+
def helper_test_keras_v2_json_config(
87+
json_file_contents, script_mode: bool = False, eager_mode: bool = True
88+
):
89+
""" Tests ZCC with custom hook configs """
90+
smd.del_hook()
91+
tf.keras.backend.clear_session()
92+
if not eager_mode:
93+
tf.compat.v1.disable_eager_execution()
94+
with SagemakerSimulator(json_file_contents=json_file_contents) as sim:
95+
model = get_keras_model_v2()
96+
(x_train, y_train), (x_test, y_test) = get_keras_data()
97+
98+
opt = tf.keras.optimizers.RMSprop()
99+
if script_mode:
100+
hook = smd.KerasHook.create_from_json_file()
101+
opt = hook.wrap_optimizer(opt)
102+
model.compile(
103+
loss="sparse_categorical_crossentropy", optimizer=opt, metrics=["accuracy"]
104+
)
105+
history = model.fit(
106+
x_train, y_train, batch_size=64, epochs=5, validation_split=0.2, callbacks=[hook]
107+
)
108+
test_scores = model.evaluate(x_test, y_test, verbose=2, callbacks=[hook])
109+
else:
110+
model.compile(
111+
loss="sparse_categorical_crossentropy", optimizer=opt, metrics=["accuracy"]
112+
)
113+
history = model.fit(x_train, y_train, epochs=5, batch_size=64, validation_split=0.2)
114+
test_scores = model.evaluate(x_test, y_test, verbose=2)
115+
116+
hook = smd.get_hook()
117+
assert hook
118+
hook.close()
119+
# Check that hook created and tensors saved
120+
trial = smd.create_trial(path=sim.out_dir)
121+
assert len(trial.steps()) > 0, "Nothing saved at any step."
122+
assert len(trial.tensor_names()) > 0, "Tensors were not saved."
123+
if not eager_mode:
124+
assert len(trial.tensor_names(collection="gradients")) > 0
125+
assert len(trial.tensor_names(collection="weights")) > 0
126+
assert len(trial.tensor_names(collection="losses")) > 0
127+
128+
129+
def test_keras_v2_default(script_mode: bool = False, eager_mode: bool = True):
130+
# Test default ZCC behavior
131+
helper_test_keras_v2(script_mode=script_mode, eager_mode=eager_mode)
132+
133+
134+
def test_keras_v2_multi_collections(script_mode: bool = False, eager_mode: bool = True):
135+
# Test multiple collections included in hook json
136+
json_file_contents = """
137+
{
138+
"S3OutputPath": "s3://sagemaker-test",
139+
"LocalPath": "/opt/ml/output/tensors",
140+
"HookParameters" : {
141+
"save_interval": "2",
142+
"include_workers": "all"
143+
},
144+
"CollectionConfigurations": [
145+
{
146+
"CollectionName": "gradients"
147+
},
148+
{
149+
"CollectionName": "weights"
150+
},
151+
{
152+
"CollectionName": "losses"
153+
},
154+
{
155+
"CollectionName": "biases"
156+
},
157+
{
158+
"CollectionName": "optimizer_variables"
159+
}
160+
]
161+
}
162+
"""
163+
helper_test_keras_v2_json_config(
164+
script_mode=script_mode, eager_mode=eager_mode, json_file_contents=json_file_contents
165+
)
166+
167+
168+
def test_keras_v2_save_all(script_mode: bool = False, eager_mode: bool = True):
169+
# Test save all through hook config
170+
json_file_contents = """
171+
{
172+
"S3OutputPath": "s3://sagemaker-test",
173+
"LocalPath": "/opt/ml/output/tensors",
174+
"HookParameters" : {
175+
"save_steps": "0,1,2,3",
176+
"save_all": true
177+
}
178+
}
179+
"""
180+
helper_test_keras_v2_json_config(
181+
script_mode=script_mode, eager_mode=eager_mode, json_file_contents=json_file_contents
182+
)
80183

81184

82185
if __name__ == "__main__":
@@ -88,6 +191,11 @@ def test_keras_v2(script_mode: bool = False, eager_mode: bool = True):
88191
script_mode = args.script_mode
89192

90193
# eager mode
91-
test_keras_v2(script_mode=script_mode)
194+
test_keras_v2_default(script_mode)
195+
test_keras_v2_multi_collections(script_mode)
196+
test_keras_v2_save_all(script_mode)
197+
92198
# non-eager mode
93-
test_keras_v2(script_mode=script_mode, eager_mode=False)
199+
test_keras_v2_default(script_mode, eager_mode=False)
200+
test_keras_v2_multi_collections(script_mode, eager_mode=False)
201+
test_keras_v2_save_all(script_mode, eager_mode=False)

0 commit comments

Comments
 (0)