@@ -118,12 +118,14 @@ def parse_bidirectional_layer(keras_layer, input_names, input_shapes, data_reade
118
118
assert keras_layer ['class_name' ] == 'Bidirectional'
119
119
120
120
rnn_forward_layer = keras_layer ['config' ]['layer' ]
121
+ swapped_order = False
121
122
if keras_layer ['config' ].get ('backward_layer' ):
122
123
rnn_backward_layer = keras_layer ['config' ]['backward_layer' ]
123
124
if rnn_forward_layer ['config' ]['go_backwards' ]:
124
125
temp_layer = rnn_forward_layer .copy ()
125
126
rnn_forward_layer = rnn_backward_layer .copy ()
126
127
rnn_backward_layer = temp_layer
128
+ swapped_order = True
127
129
else :
128
130
rnn_backward_layer = rnn_forward_layer
129
131
@@ -138,6 +140,7 @@ def parse_bidirectional_layer(keras_layer, input_names, input_shapes, data_reade
138
140
layer ['inputs' ] = input_names
139
141
140
142
layer ['direction' ] = 'bidirectional'
143
+ layer ['swapped_order' ] = swapped_order
141
144
layer ['return_sequences' ] = rnn_forward_layer ['config' ]['return_sequences' ]
142
145
layer ['return_state' ] = rnn_forward_layer ['config' ]['return_state' ]
143
146
layer ['time_major' ] = rnn_forward_layer ['config' ]['time_major' ] if 'time_major' in rnn_forward_layer ['config' ] else False
@@ -171,14 +174,17 @@ def parse_bidirectional_layer(keras_layer, input_names, input_shapes, data_reade
171
174
cell_name = 'simple_rnn'
172
175
else :
173
176
cell_name = rnn_layer ['class_name' ].lower ()
177
+ temp_dir = direction
178
+ if swapped_order :
179
+ temp_dir = 'backward' if direction == 'forward' else 'forward'
174
180
layer [f'{ direction } _weight_data' ], layer [f'{ direction } _recurrent_weight_data' ], layer [f'{ direction } _bias_data' ] = (
175
181
get_weights_data (
176
182
data_reader ,
177
183
layer ['name' ],
178
184
[
179
- f'{ direction } _{ rnn_layer_name } /{ cell_name } _cell/kernel' ,
180
- f'{ direction } _{ rnn_layer_name } /{ cell_name } _cell/recurrent_kernel' ,
181
- f'{ direction } _{ rnn_layer_name } /{ cell_name } _cell/bias' ,
185
+ f'{ temp_dir } _{ rnn_layer_name } /{ cell_name } _cell/kernel' ,
186
+ f'{ temp_dir } _{ rnn_layer_name } /{ cell_name } _cell/recurrent_kernel' ,
187
+ f'{ temp_dir } _{ rnn_layer_name } /{ cell_name } _cell/bias' ,
182
188
],
183
189
)
184
190
)
0 commit comments