@@ -133,12 +133,10 @@ def parse_bidirectional_layer(keras_layer, input_names, input_shapes, data_reade
133
133
134
134
layer = {}
135
135
layer ['name' ] = keras_layer ['config' ]['name' ]
136
- layer ['forward_layer' ] = parse_default_keras_layer (rnn_forward_layer , input_names )
137
- layer ['backward_layer' ] = parse_default_keras_layer (rnn_backward_layer , input_names )
138
- layer ['class_name' ] = (
139
- 'Bidirectional' + layer ['forward_layer' ]['class_name' ]
140
- ) # TODO: to be changed if we ever implement different
141
- # architecture for forward and backward layer
136
+ layer ['class_name' ] = keras_layer ['class_name' ]
137
+ if input_names is not None :
138
+ layer ['inputs' ] = input_names
139
+
142
140
layer ['direction' ] = 'bidirectional'
143
141
layer ['return_sequences' ] = rnn_forward_layer ['config' ]['return_sequences' ]
144
142
layer ['return_state' ] = rnn_forward_layer ['config' ]['return_state' ]
@@ -151,40 +149,52 @@ def parse_bidirectional_layer(keras_layer, input_names, input_shapes, data_reade
151
149
assert keras_layer ['config' ]['merge_mode' ] in merge_modes
152
150
layer ['merge_mode' ] = keras_layer ['config' ]['merge_mode' ]
153
151
154
- for direction , rnn_layer in [('forward_layer' , rnn_forward_layer ), ('backward_layer' , rnn_backward_layer )]:
152
+ for direction , rnn_layer in [('forward' , rnn_forward_layer ), ('backward' , rnn_backward_layer )]:
153
+
154
+ layer [f'{ direction } _name' ] = rnn_layer ['config' ]['name' ]
155
+ layer [f'{ direction } _class_name' ] = rnn_layer ['class_name' ]
156
+
157
+ layer [f'{ direction } _data_format' ] = rnn_layer ['config' ].get ('data_format' , 'channels_last' )
158
+
159
+ if 'activation' in rnn_layer ['config' ]:
160
+ layer [f'{ direction } _activation' ] = rnn_layer ['config' ]['activation' ]
161
+ if 'epsilon' in rnn_layer ['config' ]:
162
+ layer [f'{ direction } _epsilon' ] = rnn_layer ['config' ]['epsilon' ]
163
+ if 'use_bias' in rnn_layer ['config' ]:
164
+ layer [f'{ direction } _use_bias' ] = rnn_layer ['config' ]['use_bias' ]
155
165
156
166
if 'SimpleRNN' not in rnn_layer ['class_name' ]:
157
- layer [direction ][ 'recurrent_activation ' ] = rnn_layer ['config' ]['recurrent_activation' ]
167
+ layer [f' { direction } _recurrent_activation ' ] = rnn_layer ['config' ]['recurrent_activation' ]
158
168
159
169
rnn_layer_name = rnn_layer ['config' ]['name' ]
160
170
if 'SimpleRNN' in layer ['class_name' ]:
161
171
cell_name = 'simple_rnn'
162
172
else :
163
173
cell_name = rnn_layer ['class_name' ].lower ()
164
- layer [direction ][ 'weight_data' ], layer [direction ][ 'recurrent_weight_data' ], layer [direction ][ 'bias_data ' ] = (
174
+ layer [f' { direction } _weight_data' ], layer [f' { direction } _recurrent_weight_data' ], layer [f' { direction } _bias_data ' ] = (
165
175
get_weights_data (
166
176
data_reader ,
167
177
layer ['name' ],
168
178
[
169
- f'{ direction [: - 6 ] } _{ rnn_layer_name } /{ cell_name } _cell/kernel' ,
170
- f'{ direction [: - 6 ] } _{ rnn_layer_name } /{ cell_name } _cell/recurrent_kernel' ,
171
- f'{ direction [: - 6 ] } _{ rnn_layer_name } /{ cell_name } _cell/bias' ,
179
+ f'{ direction } _{ rnn_layer_name } /{ cell_name } _cell/kernel' ,
180
+ f'{ direction } _{ rnn_layer_name } /{ cell_name } _cell/recurrent_kernel' ,
181
+ f'{ direction } _{ rnn_layer_name } /{ cell_name } _cell/bias' ,
172
182
],
173
183
)
174
184
)
175
185
176
186
if 'GRU' in rnn_layer ['class_name' ]:
177
- layer [direction ][ 'apply_reset_gate ' ] = 'after' if rnn_layer ['config' ]['reset_after' ] else 'before'
187
+ layer [f' { direction } _apply_reset_gate ' ] = 'after' if rnn_layer ['config' ]['reset_after' ] else 'before'
178
188
179
189
# biases array is actually a 2-dim array of arrays (bias + recurrent bias)
180
190
# both arrays have shape: n_units * 3 (z, r, h_cand)
181
- biases = layer [direction ][ 'bias_data ' ]
182
- layer [direction ][ 'bias_data ' ] = biases [0 ]
183
- layer [direction ][ 'recurrent_bias_data ' ] = biases [1 ]
191
+ biases = layer [f' { direction } _bias_data ' ]
192
+ layer [f' { direction } _bias_data ' ] = biases [0 ]
193
+ layer [f' { direction } _recurrent_bias_data ' ] = biases [1 ]
184
194
185
- layer [direction ][ 'n_states ' ] = rnn_layer ['config' ]['units' ]
195
+ layer [f' { direction } _n_states ' ] = rnn_layer ['config' ]['units' ]
186
196
187
- layer ['n_out' ] = layer ['forward_layer' ][ 'n_states' ] + layer ['backward_layer' ][ 'n_states ' ]
197
+ layer ['n_out' ] = layer ['forward_n_states' ] + layer ['backward_n_states ' ]
188
198
189
199
if layer ['return_sequences' ]:
190
200
output_shape = [input_shapes [0 ][0 ], layer ['n_timesteps' ], layer ['n_out' ]]
0 commit comments