@@ -116,15 +116,19 @@ def parse_time_distributed_layer(keras_layer, input_names, input_shapes, data_re
116
116
@keras_handler ('Bidirectional' )
117
117
def parse_bidirectional_layer (keras_layer , input_names , input_shapes , data_reader ):
118
118
assert keras_layer ['class_name' ] == 'Bidirectional'
119
-
119
+
120
120
rnn_layer = keras_layer ['config' ]['layer' ]
121
121
assert rnn_layer ['class_name' ] in rnn_layers or rnn_layer ['class_name' ][1 :] in rnn_layers
122
122
123
123
layer = parse_default_keras_layer (rnn_layer , input_names )
124
124
layer ['name' ] = keras_layer ['config' ]['name' ]
125
- layer ['class_name' ] = 'B ' + layer ['class_name' ]
125
+ layer ['class_name' ] = 'Bidirectional ' + layer ['class_name' ]
126
126
layer ['direction' ] = 'bidirectional'
127
127
128
+ # TODO Should we handle different architectures for forward and backward layer?
129
+ if keras_layer ['config' ].get ('backward_layer' ):
130
+ raise Exception ('Different architectures between forward and backward layers are not supported by hls4ml' )
131
+
128
132
layer ['return_sequences' ] = rnn_layer ['config' ]['return_sequences' ]
129
133
layer ['return_state' ] = rnn_layer ['config' ]['return_state' ]
130
134
@@ -147,19 +151,28 @@ def parse_bidirectional_layer(keras_layer, input_names, input_shapes, data_reade
147
151
if keras_layer ['config' ]['merge_mode' ] == 'concat' :
148
152
layer ['n_out' ] *= 2
149
153
154
+ rnn_layer_name = rnn_layer ['config' ]['name' ]
150
155
if 'SimpleRNN' in layer ['class_name' ]:
151
156
cell_name = 'simple_rnn'
152
157
else :
153
158
cell_name = rnn_layer ['class_name' ].lower ()
154
159
layer ['weight_data' ], layer ['recurrent_weight_data' ], layer ['bias_data' ] = get_weights_data (
155
- data_reader , layer ['name' ], [f'{ cell_name } _cell/kernel' ,
156
- f'{ cell_name } _cell/recurrent_kernel' ,
157
- f'{ cell_name } _cell/bias' ]
160
+ data_reader ,
161
+ layer ['name' ],
162
+ [
163
+ f'forward_{ rnn_layer_name } /{ cell_name } _cell/kernel' ,
164
+ f'forward_{ rnn_layer_name } /{ cell_name } _cell/recurrent_kernel' ,
165
+ f'forward_{ rnn_layer_name } /{ cell_name } _cell/bias' ,
166
+ ],
158
167
)
159
168
layer ['weight_b_data' ], layer ['recurrent_weight_b_data' ], layer ['bias_b_data' ] = get_weights_data (
160
- data_reader , layer ['name' ], [f'{ cell_name } _cell/kernel' ,
161
- f'{ cell_name } _cell/recurrent_kernel' ,
162
- f'{ cell_name } _cell/bias' ]
169
+ data_reader ,
170
+ layer ['name' ],
171
+ [
172
+ f'backward_{ rnn_layer_name } /{ cell_name } _cell/kernel' ,
173
+ f'backward_{ rnn_layer_name } /{ cell_name } _cell/recurrent_kernel' ,
174
+ f'backward_{ rnn_layer_name } /{ cell_name } _cell/bias' ,
175
+ ],
163
176
)
164
177
165
178
if 'GRU' in layer ['class_name' ]:
0 commit comments