Skip to content

Commit bb8f4b9

Browse files
authored
test concat layers (#367)
1 parent 1c10565 commit bb8f4b9

File tree

2 files changed

+48
-5
lines changed

2 files changed

+48
-5
lines changed

smdebug/tensorflow/keras.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -556,11 +556,13 @@ def _save_layer_input_and_outputs(self):
556556
if self._is_collection_being_saved_for_step(CollectionKeys.LAYERS)
557557
else set()
558558
)
559-
if hasattr(tensor, "numpy"):
560-
self._save_tensor_to_file(export_name, tensor.numpy(), input_collection)
561-
else:
559+
t = tensor[0] if isinstance(tensor, list) and len(tensor) else tensor
560+
if hasattr(t, "numpy") is False:
562561
self.logger.warning("cannot save layer values during forward pass with tf.function")
563562
continue
563+
else:
564+
self._save_tensor_to_file(export_name, tensor, input_collection)
565+
564566
# Save Output
565567
tensor = self.saved_layers[layer_name].layer_output
566568
export_name = get_export_name_for_keras(layer_name, tensor_type="output", tensor=tensor)
@@ -570,8 +572,11 @@ def _save_layer_input_and_outputs(self):
570572
if self._is_collection_being_saved_for_step(CollectionKeys.LAYERS)
571573
else set()
572574
)
573-
if hasattr(tensor, "numpy"):
574-
self._save_tensor_to_file(export_name, tensor.numpy(), output_collection)
575+
t = tensor[0] if isinstance(tensor, list) and len(tensor) else tensor
576+
if hasattr(t, "numpy") is False:
577+
self.logger.warning("cannot save layer values during forward pass with tf.function")
578+
else:
579+
self._save_tensor_to_file(export_name, tensor, output_collection)
575580

576581
def _save_tensors_post_step(self, batch, logs):
577582
# some tensors available as value from within hook are saved here
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Third Party
2+
import numpy as np
3+
from tensorflow.keras.layers import Concatenate, Dense
4+
from tensorflow.python.keras.models import Model
5+
6+
# First Party
7+
import smdebug.tensorflow as smd
8+
from smdebug.trials import create_trial
9+
10+
11+
class MyModel(Model):
12+
def __init__(self):
13+
super(MyModel, self).__init__()
14+
self.con = Concatenate()
15+
self.dense = Dense(10, activation="relu")
16+
17+
def call(self, x):
18+
x = self.con([x, x])
19+
return self.dense(x)
20+
21+
22+
def test_multiple_inputs(out_dir):
23+
my_model = MyModel()
24+
hook = smd.KerasHook(
25+
out_dir, save_all=True, save_config=smd.SaveConfig(save_steps=[0], save_interval=1)
26+
)
27+
28+
hook.register_model(my_model)
29+
x_train = np.random.random((1000, 20))
30+
y_train = np.random.random((1000, 1))
31+
my_model.compile(optimizer="Adam", loss="mse", run_eagerly=True)
32+
my_model.fit(x_train, y_train, epochs=1, steps_per_epoch=1, callbacks=[hook])
33+
34+
trial = create_trial(path=out_dir)
35+
tnames = sorted(trial.tensor_names(collection=smd.CollectionKeys.LAYERS))
36+
assert "concatenate" in tnames[0]
37+
assert len(trial.tensor(tnames[0]).value(0)) == 2
38+
assert trial.tensor(tnames[0]).shape(0) == (2, 1000, 20)

0 commit comments

Comments
 (0)