Skip to content

Commit 8333ef4

Browse files
authored
Fix functional dict inputs to support optional ones (#21030)
* Fix functional dict inputs to support optional ones * Add unit test for optional dict inputs * Fix unit test formatting
1 parent 4baba1e commit 8333ef4

File tree

2 files changed

+23
-9
lines changed

2 files changed

+23
-9
lines changed

keras/src/models/functional.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -347,15 +347,15 @@ def shape_with_no_batch_size(x):
347347
x[0] = None
348348
return tuple(x)
349349

350-
def make_spec_for_tensor(x):
350+
def make_spec_for_tensor(x, name=None):
351351
optional = False
352352
if isinstance(x._keras_history[0], InputLayer):
353353
if x._keras_history[0].optional:
354354
optional = True
355355
return InputSpec(
356356
shape=shape_with_no_batch_size(x.shape),
357357
allow_last_axis_squeeze=True,
358-
name=x._keras_history[0].name,
358+
name=x._keras_history[0].name if name is None else name,
359359
optional=optional,
360360
)
361361

@@ -367,13 +367,7 @@ def make_spec_for_tensor(x):
367367
# Case where `_nested_inputs` is a plain dict of Inputs.
368368
names = sorted(self._inputs_struct.keys())
369369
return [
370-
InputSpec(
371-
shape=shape_with_no_batch_size(
372-
self._inputs_struct[name].shape
373-
),
374-
allow_last_axis_squeeze=True,
375-
name=name,
376-
)
370+
make_spec_for_tensor(self._inputs_struct[name], name=name)
377371
for name in names
378372
]
379373
return None # Deeply nested dict: skip checks.

keras/src/models/functional_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,26 @@ def compute_output_shape(self, x_shape):
574574
self.assertAllClose(out, np.ones((2, 2)))
575575
# Note: it's not intended to work in symbolic mode (yet).
576576

577+
def test_optional_dict_inputs(self):
578+
class OptionalInputLayer(layers.Layer):
579+
def call(self, x, y=None):
580+
if y is not None:
581+
return x + y
582+
return x
583+
584+
def compute_output_shape(self, x_shape):
585+
return x_shape
586+
587+
i1 = Input((2,), name="input1")
588+
i2 = Input((2,), name="input2", optional=True)
589+
outputs = OptionalInputLayer()(i1, i2)
590+
model = Model({"input1": i1, "input2": i2}, outputs)
591+
592+
# Eager test
593+
out = model({"input1": np.ones((2, 2)), "input2": None})
594+
self.assertAllClose(out, np.ones((2, 2)))
595+
# Note: it's not intended to work in symbolic mode (yet).
596+
577597
def test_warning_for_mismatched_inputs_structure(self):
578598
def is_input_warning(w):
579599
return str(w.message).startswith(

0 commit comments

Comments
 (0)