Skip to content

Commit 6c41fc8

Browse files
Fix torch LSTM onnx export issue. (#21434)
1 parent cdcd5b3 commit 6c41fc8

File tree

3 files changed

+33
-2
lines changed

3 files changed

+33
-2
lines changed

keras/src/backend/torch/rnn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,8 @@ def compute_masked_output(mask_t, flat_out, flat_mask):
228228
elif isinstance(input_length, torch.Tensor):
229229
if go_backwards:
230230
max_len = torch.max(input_length, dim=0)
231+
if isinstance(max_len, torch.return_types.max):
232+
max_len = max_len[0]
231233
rev_input_length = torch.subtract(max_len - 1, input_length)
232234

233235
def masking_fn(time):

keras/src/export/onnx_test.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,28 @@ def get_model(type="sequential", input_shape=(10,), layer_list=None):
4545
return models.Model(inputs=input, outputs=output)
4646
elif type == "subclass":
4747
return CustomModel(layer_list)
48+
elif type == "lstm":
49+
# https://github.com/keras-team/keras/issues/21390
50+
inputs = layers.Input((4, 10))
51+
x = layers.Bidirectional(
52+
layers.LSTM(
53+
10,
54+
kernel_initializer="he_normal",
55+
return_sequences=True,
56+
kernel_regularizer=None,
57+
),
58+
merge_mode="sum",
59+
)(inputs)
60+
outputs = layers.Bidirectional(
61+
layers.LSTM(
62+
10,
63+
kernel_initializer="he_normal",
64+
return_sequences=True,
65+
kernel_regularizer=None,
66+
),
67+
merge_mode="concat",
68+
)(x)
69+
return models.Model(inputs=inputs, outputs=outputs)
4870

4971

5072
@pytest.mark.skipif(
@@ -57,13 +79,19 @@ def get_model(type="sequential", input_shape=(10,), layer_list=None):
5779
@pytest.mark.skipif(testing.jax_uses_gpu(), reason="Leads to core dumps on CI")
5880
class ExportONNXTest(testing.TestCase):
5981
@parameterized.named_parameters(
60-
named_product(model_type=["sequential", "functional", "subclass"])
82+
named_product(
83+
model_type=["sequential", "functional", "subclass", "lstm"]
84+
)
6185
)
6286
def test_standard_model_export(self, model_type):
6387
temp_filepath = os.path.join(self.get_temp_dir(), "exported_model")
6488
model = get_model(model_type)
6589
batch_size = 3 if backend.backend() != "torch" else 1
66-
ref_input = np.random.normal(size=(batch_size, 10)).astype("float32")
90+
if model_type == "lstm":
91+
ref_input = np.random.normal(size=(batch_size, 4, 10))
92+
else:
93+
ref_input = np.random.normal(size=(batch_size, 10))
94+
ref_input = ref_input.astype("float32")
6795
ref_output = model(ref_input)
6896

6997
onnx.export_onnx(model, temp_filepath)

requirements-common.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,5 @@ dm_tree
2323
coverage!=7.6.5 # 7.6.5 breaks CI
2424
# for onnx_test.py
2525
onnxruntime
26+
onnxscript # Needed by TorchDynamo-based ONNX exporter
2627
openvino

0 commit comments

Comments
 (0)