Skip to content

Commit 4baba1e

Browse files
authored
Raise error if inputs are not connected with output in functional model (#20705)
* Raise error if inputs are not connected with output in functional model * Fix Failing test case for unconnected inputs/outputs * fix formatting issue
1 parent a69952e commit 4baba1e

File tree

3 files changed

+51
-8
lines changed

3 files changed

+51
-8
lines changed

keras/src/backend/tensorflow/saved_model_test.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -191,24 +191,37 @@ def call(self, inputs):
191191
def test_multi_input_model(self):
192192
input_1 = layers.Input(shape=(3,))
193193
input_2 = layers.Input(shape=(5,))
194-
model = models.Model([input_1, input_2], [input_1, input_2])
195-
path = os.path.join(self.get_temp_dir(), "my_keras_model")
196194

197-
tf.saved_model.save(model, path)
198-
restored_model = tf.saved_model.load(path)
195+
y1 = layers.Dense(1)(input_1)
196+
y2 = layers.Dense(1)(input_2)
197+
layer_2 = layers.Dense(1, activation="relu")
198+
output_1 = layer_2(y1)
199+
output_2 = layer_2(y2)
200+
model = models.Model([input_1, input_2], [output_1, output_2])
201+
199202
input_arr_1 = np.random.random((1, 3)).astype("float32")
200203
input_arr_2 = np.random.random((1, 5)).astype("float32")
201204

202-
outputs = restored_model.signatures["serving_default"](
205+
model = models.Model([input_1, input_2], [output_1, output_2])
206+
path = os.path.join(self.get_temp_dir(), "my_keras_model")
207+
outputs_1 = model(
208+
inputs=[
209+
tf.convert_to_tensor(input_arr_1, dtype=tf.float32),
210+
tf.convert_to_tensor(input_arr_2, dtype=tf.float32),
211+
],
212+
)
213+
tf.saved_model.save(model, path)
214+
restored_model = tf.saved_model.load(path)
215+
216+
outputs_2 = restored_model.signatures["serving_default"](
203217
inputs=tf.convert_to_tensor(input_arr_1, dtype=tf.float32),
204218
inputs_1=tf.convert_to_tensor(input_arr_2, dtype=tf.float32),
205219
)
206-
207220
self.assertAllClose(
208-
input_arr_1, outputs["output_0"], rtol=1e-4, atol=1e-4
221+
outputs_1[0], outputs_2["output_0"], rtol=1e-4, atol=1e-4
209222
)
210223
self.assertAllClose(
211-
input_arr_2, outputs["output_1"], rtol=1e-4, atol=1e-4
224+
outputs_1[1], outputs_2["output_1"], rtol=1e-4, atol=1e-4
212225
)
213226

214227
def test_multi_input_custom_model_and_layer(self):

keras/src/ops/function.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,12 @@ def __init__(self, inputs, outputs, name=None):
8181
self._nodes_by_depth = nodes_by_depth
8282
self._operations = operations
8383
self._operations_by_depth = operations_by_depth
84+
for input in self._inputs:
85+
if (
86+
input._keras_history.operation
87+
and not input._keras_history.operation._outbound_nodes
88+
):
89+
raise ValueError("`inputs` not connected to `outputs`")
8490

8591
@property
8692
def operations(self):

keras/src/ops/function_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from keras.src.layers import Dense
88
from keras.src.layers import Input
99
from keras.src.models import Model
10+
from keras.src.models import Sequential
1011
from keras.src.ops import function
1112
from keras.src.ops import numpy as knp
1213

@@ -142,3 +143,26 @@ def test_function_with_empty_inputs(self):
142143
ValueError, "`inputs` argument cannot be empty"
143144
):
144145
_ = function.Function(inputs=[], outputs=x)
146+
147+
def test_function_with_unconnected_inputs(self):
148+
model_1 = Sequential(
149+
[
150+
Input(shape=(6,)),
151+
Dense(3, activation="sigmoid"),
152+
]
153+
)
154+
model_2 = Sequential(
155+
[
156+
Input(shape=(3,)),
157+
Dense(2, activation="sigmoid"),
158+
],
159+
)
160+
with self.assertRaisesRegex(
161+
ValueError, "`inputs` not connected to `outputs`"
162+
):
163+
_ = Model(Input(shape=(6,)), model_2(model_1(Input(shape=(6,)))))
164+
165+
with self.assertRaisesRegex(
166+
ValueError, "`inputs` not connected to `outputs`"
167+
):
168+
_ = Model(model_1(Input(shape=(6,))), model_2(Input(shape=(3,))))

0 commit comments

Comments
 (0)