@@ -45,6 +45,28 @@ def get_model(type="sequential", input_shape=(10,), layer_list=None):
45
45
return models .Model (inputs = input , outputs = output )
46
46
elif type == "subclass" :
47
47
return CustomModel (layer_list )
48
+ elif type == "lstm" :
49
+ # https://github.com/keras-team/keras/issues/21390
50
+ inputs = layers .Input ((4 , 10 ))
51
+ x = layers .Bidirectional (
52
+ layers .LSTM (
53
+ 10 ,
54
+ kernel_initializer = "he_normal" ,
55
+ return_sequences = True ,
56
+ kernel_regularizer = None ,
57
+ ),
58
+ merge_mode = "sum" ,
59
+ )(inputs )
60
+ outputs = layers .Bidirectional (
61
+ layers .LSTM (
62
+ 10 ,
63
+ kernel_initializer = "he_normal" ,
64
+ return_sequences = True ,
65
+ kernel_regularizer = None ,
66
+ ),
67
+ merge_mode = "concat" ,
68
+ )(x )
69
+ return models .Model (inputs = inputs , outputs = outputs )
48
70
49
71
50
72
@pytest .mark .skipif (
@@ -57,13 +79,19 @@ def get_model(type="sequential", input_shape=(10,), layer_list=None):
57
79
@pytest .mark .skipif (testing .jax_uses_gpu (), reason = "Leads to core dumps on CI" )
58
80
class ExportONNXTest (testing .TestCase ):
59
81
@parameterized .named_parameters (
60
- named_product (model_type = ["sequential" , "functional" , "subclass" ])
82
+ named_product (
83
+ model_type = ["sequential" , "functional" , "subclass" , "lstm" ]
84
+ )
61
85
)
62
86
def test_standard_model_export (self , model_type ):
63
87
temp_filepath = os .path .join (self .get_temp_dir (), "exported_model" )
64
88
model = get_model (model_type )
65
89
batch_size = 3 if backend .backend () != "torch" else 1
66
- ref_input = np .random .normal (size = (batch_size , 10 )).astype ("float32" )
90
+ if model_type == "lstm" :
91
+ ref_input = np .random .normal (size = (batch_size , 4 , 10 ))
92
+ else :
93
+ ref_input = np .random .normal (size = (batch_size , 10 ))
94
+ ref_input = ref_input .astype ("float32" )
67
95
ref_output = model (ref_input )
68
96
69
97
onnx .export_onnx (model , temp_filepath )
0 commit comments