Skip to content

Commit 9775fa5

Browse files
author
Yu Cong
committed
Add masked LSTM support
Signed-off-by: Yu Cong <congyc@amazon.com>
1 parent bc677a1 commit 9775fa5

File tree

3 files changed

+320
-26
lines changed

3 files changed

+320
-26
lines changed

tests/test_lstm.py

+207
Original file line numberDiff line numberDiff line change
@@ -793,5 +793,212 @@ def func(x):
793793
return tf.identity(y[0], name="output")
794794
self.run_test_case(func, {"input:0": x_val}, [], ["output:0"], rtol=1e-05, atol=1e-06)
795795

796+
@check_tf_min_version("2.0")
797+
@skip_tf_versions("2.1", "Bug in TF 2.1")
798+
def test_keras_masked_lstm_embedding_unidirectional(self):
799+
for go_backwards in [True, False]:
800+
for return_sequences in [True, False]:
801+
timesteps = 4
802+
# Note: masked LSTM only support post-padded input after conversion
803+
# test case sequence_lens = [4, 2, 0]
804+
x_val = np.array([
805+
[1, 2, 3, 4],
806+
[5, 6, 0, 0],
807+
[0, 0, 0, 0]
808+
], dtype=np.int32)
809+
810+
model_in = tf.keras.layers.Input((timesteps,), dtype="int32")
811+
x_embedding = tf.keras.layers.Embedding(
812+
input_dim=10,
813+
output_dim=5,
814+
mask_zero=True,
815+
embeddings_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=41),
816+
)(model_in)
817+
818+
# RNN layer inherits the mask propagated from above embedding layer
819+
model_out = tf.keras.layers.LSTM(
820+
units=5,
821+
go_backwards=go_backwards,
822+
return_sequences=return_sequences,
823+
return_state=True,
824+
kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=42),
825+
bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=43),
826+
recurrent_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=44),
827+
)(x_embedding)
828+
model = tf.keras.models.Model(inputs=model_in, outputs=model_out)
829+
830+
def func(x):
831+
y = model(x)
832+
if return_sequences:
833+
return (
834+
# skipping output Y when return_sequences=True due to inconsistent
835+
# ORT and TF behaviors: https://sim.amazon.com/issues/NEMORT-1712
836+
tf.identity(y[1], name="output_yh"),
837+
tf.identity(y[2], name="output_yc"))
838+
return(
839+
tf.identity(y[0], name="output_y"),
840+
tf.identity(y[1], name="output_yh"),
841+
tf.identity(y[2], name="output_yc"))
842+
843+
output_list = ["output_yh:0", "output_yc:0"] if return_sequences \
844+
else ["output_y:0", "output_yh:0", "output_yc:0"]
845+
self.run_test_case(func, {"input:0": x_val}, [], output_list, rtol=1e-05, atol=1e-06)
846+
847+
@check_tf_min_version("2.0")
848+
@skip_tf_versions("2.1", "Bug in TF 2.1")
849+
def test_keras_masked_lstm_embedding_bidirectional(self):
850+
for return_sequences in [False, True]:
851+
timesteps = 4
852+
# Note: masked LSTM only support post-padded input after conversion
853+
# test case sequence_lens = [4, 2, 0]
854+
x_val = np.array([
855+
[1, 2, 3, 4],
856+
[5, 6, 0, 0],
857+
[0, 0, 0, 0]
858+
], dtype=np.int32)
859+
860+
model_in = tf.keras.layers.Input((timesteps,), dtype="int32")
861+
x_embedding = tf.keras.layers.Embedding(
862+
input_dim=10,
863+
output_dim=5,
864+
mask_zero=True,
865+
embeddings_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=41),
866+
)(model_in)
867+
868+
# RNN layer inherits the mask propagated from above embedding layer
869+
lstm_layer = tf.keras.layers.LSTM(
870+
units=5,
871+
go_backwards=False,
872+
return_sequences=return_sequences,
873+
return_state=True,
874+
kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=42),
875+
bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=43),
876+
recurrent_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=44),
877+
)
878+
model_out = tf.keras.layers.Bidirectional(lstm_layer)(x_embedding)
879+
model = tf.keras.models.Model(inputs=model_in, outputs=model_out)
880+
881+
def func(x):
882+
y = model(x)
883+
if return_sequences:
884+
return (
885+
# skipping output Y when return_sequences=True due to inconsistent
886+
# ORT and TF behaviors: https://sim.amazon.com/issues/NEMORT-1712
887+
tf.identity(y[1], name="output_yh_f"),
888+
tf.identity(y[2], name="output_yc_f"),
889+
tf.identity(y[3], name="output_yh_r"),
890+
tf.identity(y[4], name="output_yc_r"))
891+
return(
892+
tf.identity(y[0], name="output_y_concat"),
893+
tf.identity(y[1], name="output_yh_f"),
894+
tf.identity(y[2], name="output_yc_f"),
895+
tf.identity(y[3], name="output_yh_r"),
896+
tf.identity(y[4], name="output_yc_r"))
897+
898+
output_list = ["output_yh_f:0", "output_yc_f:0", "output_yh_r:0", "output_yc_r:0"] if return_sequences \
899+
else ["output_y_concat:0", "output_yh_f:0", "output_yc_f:0", "output_yh_r:0", "output_yc_r:0"]
900+
901+
# translate single BiLSTM to two forward LSTMs
902+
self.run_test_case(func, {"input:0": x_val}, [], output_list, rtol=1e-05, atol=1e-06,
903+
require_lstm_count=2)
904+
905+
@check_tf_min_version("2.0")
906+
@skip_tf_versions("2.1", "Bug in TF 2.1")
907+
def test_keras_masked_lstm_unidirectional(self):
908+
for go_backwards in [True, False]:
909+
for return_sequences in [True, False]:
910+
batch_size, timesteps, feat = 3, 4, 5
911+
in_shape = (timesteps, feat)
912+
x_val = np.random.uniform(size=[batch_size, timesteps, feat]).astype(np.float32)
913+
# Note: masked LSTM only support post-padded input after conversion
914+
# test case sequence_lens = [4, 2, 0]
915+
x_val[1, 2:, :] = 0.
916+
x_val[2, :, :] = 0.
917+
918+
model_in = tf.keras.layers.Input(shape=in_shape, dtype="float32")
919+
x_masked = tf.keras.layers.Masking(mask_value=0.)(model_in)
920+
921+
# RNN layer inherits the mask propagated from above mask layer
922+
model_out = tf.keras.layers.LSTM(
923+
units=5,
924+
go_backwards=go_backwards,
925+
return_sequences=return_sequences,
926+
return_state=True,
927+
kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=42),
928+
bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=43),
929+
recurrent_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=44),
930+
)(x_masked)
931+
model = tf.keras.models.Model(inputs=model_in, outputs=model_out)
932+
933+
def func(x):
934+
y = model(x)
935+
if return_sequences:
936+
return (
937+
# skipping output Y when return_sequences=True due to inconsistent
938+
# ORT and TF behaviors: https://sim.amazon.com/issues/NEMORT-1712
939+
tf.identity(y[1], name="output_yh"),
940+
tf.identity(y[2], name="output_yc"))
941+
return(
942+
tf.identity(y[0], name="output_y"),
943+
tf.identity(y[1], name="output_yh"),
944+
tf.identity(y[2], name="output_yc"))
945+
946+
output_list = ["output_yh:0", "output_yc:0"] if return_sequences \
947+
else ["output_y:0", "output_yh:0", "output_yc:0"]
948+
self.run_test_case(func, {"input:0": x_val}, [], output_list, rtol=1e-05, atol=1e-06)
949+
950+
@check_tf_min_version("2.0")
951+
@skip_tf_versions("2.1", "Bug in TF 2.1")
952+
def test_keras_masked_lstm_bidirectional(self):
953+
for return_sequences in [False, True]:
954+
batch_size, timesteps, feat = 3, 4, 5
955+
in_shape = (timesteps, feat)
956+
x_val = np.random.uniform(size=[batch_size, timesteps, feat]).astype(np.float32)
957+
# Note: masked LSTM only support post-padded input after conversion
958+
# test case sequence_lens = [4, 2, 0]
959+
x_val[1, 2:, :] = 0.
960+
x_val[2, :, :] = 0.
961+
962+
model_in = tf.keras.layers.Input(shape=in_shape, dtype="float32")
963+
x_masked = tf.keras.layers.Masking(mask_value=0.)(model_in)
964+
965+
# RNN layer inherits the mask propagated from above mask layer
966+
lstm_layer = tf.keras.layers.LSTM(
967+
units=5,
968+
go_backwards=False,
969+
return_sequences=return_sequences,
970+
return_state=True,
971+
kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=42),
972+
bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=43),
973+
recurrent_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=44),
974+
)
975+
model_out = tf.keras.layers.Bidirectional(lstm_layer)(x_masked)
976+
model = tf.keras.models.Model(inputs=model_in, outputs=model_out)
977+
978+
def func(x):
979+
y = model(x)
980+
if return_sequences:
981+
return (
982+
# skipping output Y when return_sequences=True due to inconsistent
983+
# ORT and TF behaviors: https://sim.amazon.com/issues/NEMORT-1712
984+
tf.identity(y[1], name="output_yh_f"),
985+
tf.identity(y[2], name="output_yc_f"),
986+
tf.identity(y[3], name="output_yh_r"),
987+
tf.identity(y[4], name="output_yc_r"))
988+
return(
989+
tf.identity(y[0], name="output_y_concat"),
990+
tf.identity(y[1], name="output_yh_f"),
991+
tf.identity(y[2], name="output_yc_f"),
992+
tf.identity(y[3], name="output_yh_r"),
993+
tf.identity(y[4], name="output_yc_r"))
994+
995+
output_list = ["output_yh_f:0", "output_yc_f:0", "output_yh_r:0", "output_yc_r:0"] if return_sequences \
996+
else ["output_y_concat:0", "output_yh_f:0", "output_yc_f:0", "output_yh_r:0", "output_yc_r:0"]
997+
998+
# translate single BiLSTM to two forward LSTMs
999+
self.run_test_case(func, {"input:0": x_val}, [], output_list, rtol=1e-05, atol=1e-06,
1000+
require_lstm_count=2)
1001+
1002+
7961003
if __name__ == '__main__':
7971004
unittest_main()

tf2onnx/onnx_opset/tensor.py

+19-9
Original file line numberDiff line numberDiff line change
@@ -2260,15 +2260,25 @@ def version_10(cls, ctx, node, **kwargs):
22602260
const_axis_name = utils.make_name(f'const_{axis}')
22612261
const_axis = ctx.make_const(name=const_axis_name, np_val=np.array([axis], dtype=np.int64))
22622262

2263-
# Add a Constant node (seq_len) for ReverseSequence.
2264-
# Index 1 for the shape should not return 0, since rank(input) >=2
2265-
input_shape = ctx.make_node("Shape", [inputs[-1]], op_name_scope=rv2_node_name)
2266-
batch_size = ctx.make_node("Gather", [input_shape.output[0], const_one.output[0]],
2267-
op_name_scope=rv2_node_name)
2268-
axis_dim = ctx.make_node("Gather", [input_shape_node.output[0], const_axis.output[0]],
2269-
op_name_scope=rv2_node_name)
2270-
seq_array = ctx.make_node("Expand", [axis_dim.output[0], batch_size.output[0]])
2271-
inputs.append(seq_array.output[0])
2263+
# Add sequence_lens as ReverseSequence input
2264+
has_sequence_lens = node.get_attr_value("has_sequence_lens", False)
2265+
if not has_sequence_lens:
2266+
# open-source impl: fill in dummy sequence_lens based on input shape
2267+
# Add a Constant node (seq_len) for ReverseSequence.
2268+
# Index 1 for the shape should not return 0, since rank(input) >=2
2269+
input_shape = ctx.make_node("Shape", [inputs[-1]], op_name_scope=rv2_node_name)
2270+
batch_size = ctx.make_node("Gather", [input_shape.output[0], const_one.output[0]],
2271+
op_name_scope=rv2_node_name)
2272+
axis_dim = ctx.make_node("Gather", [input_shape_node.output[0], const_axis.output[0]],
2273+
op_name_scope=rv2_node_name)
2274+
seq_array = ctx.make_node("Expand", [axis_dim.output[0], batch_size.output[0]])
2275+
inputs.append(seq_array.output[0])
2276+
else:
2277+
# masked backward LSTM:
2278+
# sequence_lens is appended to ReverseV2's input by lstm_tf2_rewriter
2279+
# to keep tensor post-padded after reverse
2280+
seq_lens_casted = ctx.make_node("Cast", [node.input[-1]], attr={'to': TensorProto.INT64}).output[0]
2281+
inputs.append(seq_lens_casted)
22722282

22732283
# Add a ReverseSequence node.
22742284

0 commit comments

Comments
 (0)