File tree Expand file tree Collapse file tree 2 files changed +23
-9
lines changed Expand file tree Collapse file tree 2 files changed +23
-9
lines changed Original file line number Diff line number Diff line change @@ -347,15 +347,15 @@ def shape_with_no_batch_size(x):
347
347
x [0 ] = None
348
348
return tuple (x )
349
349
350
- def make_spec_for_tensor (x ):
350
+ def make_spec_for_tensor (x , name = None ):
351
351
optional = False
352
352
if isinstance (x ._keras_history [0 ], InputLayer ):
353
353
if x ._keras_history [0 ].optional :
354
354
optional = True
355
355
return InputSpec (
356
356
shape = shape_with_no_batch_size (x .shape ),
357
357
allow_last_axis_squeeze = True ,
358
- name = x ._keras_history [0 ].name ,
358
+ name = x ._keras_history [0 ].name if name is None else name ,
359
359
optional = optional ,
360
360
)
361
361
@@ -367,13 +367,7 @@ def make_spec_for_tensor(x):
367
367
# Case where `_nested_inputs` is a plain dict of Inputs.
368
368
names = sorted (self ._inputs_struct .keys ())
369
369
return [
370
- InputSpec (
371
- shape = shape_with_no_batch_size (
372
- self ._inputs_struct [name ].shape
373
- ),
374
- allow_last_axis_squeeze = True ,
375
- name = name ,
376
- )
370
+ make_spec_for_tensor (self ._inputs_struct [name ], name = name )
377
371
for name in names
378
372
]
379
373
return None # Deeply nested dict: skip checks.
Original file line number Diff line number Diff line change @@ -574,6 +574,26 @@ def compute_output_shape(self, x_shape):
574
574
self .assertAllClose (out , np .ones ((2 , 2 )))
575
575
# Note: it's not intended to work in symbolic mode (yet).
576
576
577
+ def test_optional_dict_inputs (self ):
578
+ class OptionalInputLayer (layers .Layer ):
579
+ def call (self , x , y = None ):
580
+ if y is not None :
581
+ return x + y
582
+ return x
583
+
584
+ def compute_output_shape (self , x_shape ):
585
+ return x_shape
586
+
587
+ i1 = Input ((2 ,), name = "input1" )
588
+ i2 = Input ((2 ,), name = "input2" , optional = True )
589
+ outputs = OptionalInputLayer ()(i1 , i2 )
590
+ model = Model ({"input1" : i1 , "input2" : i2 }, outputs )
591
+
592
+ # Eager test
593
+ out = model ({"input1" : np .ones ((2 , 2 )), "input2" : None })
594
+ self .assertAllClose (out , np .ones ((2 , 2 )))
595
+ # Note: it's not intended to work in symbolic mode (yet).
596
+
577
597
def test_warning_for_mismatched_inputs_structure (self ):
578
598
def is_input_warning (w ):
579
599
return str (w .message ).startswith (
You can’t perform that action at this time.
0 commit comments