@@ -117,75 +117,74 @@ def parse_time_distributed_layer(keras_layer, input_names, input_shapes, data_re
117
117
def parse_bidirectional_layer (keras_layer , input_names , input_shapes , data_reader ):
118
118
assert keras_layer ['class_name' ] == 'Bidirectional'
119
119
120
- rnn_layer = keras_layer ['config' ]['layer' ]
121
- assert rnn_layer ['class_name' ] in rnn_layers or rnn_layer ['class_name' ][1 :] in rnn_layers
122
-
123
- layer = parse_default_keras_layer (rnn_layer , input_names )
124
- layer ['name' ] = keras_layer ['config' ]['name' ]
125
- layer ['class_name' ] = 'Bidirectional' + layer ['class_name' ]
126
- layer ['direction' ] = 'bidirectional'
127
-
128
- # TODO Should we handle different architectures for forward and backward layer?
120
+ rnn_forward_layer = keras_layer ['config' ]['layer' ]
129
121
if keras_layer ['config' ].get ('backward_layer' ):
130
- raise Exception ( 'Different architectures between forward and backward layers are not supported by hls4ml' )
131
-
132
- layer [ 'return_sequences' ] = rnn_layer [ 'config' ][ 'return_sequences' ]
133
- layer [ 'return_state' ] = rnn_layer [ 'config' ][ 'return_state' ]
134
-
135
- if 'SimpleRNN' not in layer [ 'class_name' ] :
136
- layer [ 'recurrent_activation' ] = rnn_layer [ 'config' ][ 'recurrent_activation' ]
122
+ rnn_backward_layer = keras_layer [ 'config' ][ 'backward_layer' ]
123
+ if rnn_forward_layer [ 'config' ][ 'go_backwards' ]:
124
+ temp_layer = rnn_forward_layer . copy ()
125
+ rnn_forward_layer = rnn_backward_layer . copy ()
126
+ rnn_backward_layer = temp_layer
127
+ else :
128
+ rnn_backward_layer = rnn_forward_layer
137
129
138
- layer ['time_major' ] = rnn_layer ['config' ]['time_major' ] if 'time_major' in rnn_layer ['config' ] else False
130
+ assert (rnn_forward_layer ['class_name' ] in rnn_layers or rnn_forward_layer ['class_name' ][1 :] in rnn_layers ) and (
131
+ rnn_backward_layer ['class_name' ] in rnn_layers or rnn_backward_layer ['class_name' ][1 :] in rnn_layers
132
+ )
139
133
134
+ layer = {}
135
+ layer ['name' ] = keras_layer ['config' ]['name' ]
136
+ layer ['forward_layer' ] = parse_default_keras_layer (rnn_forward_layer , input_names )
137
+ layer ['backward_layer' ] = parse_default_keras_layer (rnn_backward_layer , input_names )
138
+ layer ['class_name' ] = (
139
+ 'Bidirectional' + layer ['forward_layer' ]['class_name' ]
140
+ ) # TODO: to be changed if we ever implement different
141
+ # architecture for forward and backward layer
142
+ layer ['direction' ] = 'bidirectional'
143
+ layer ['return_sequences' ] = rnn_forward_layer ['config' ]['return_sequences' ]
144
+ layer ['return_state' ] = rnn_forward_layer ['config' ]['return_state' ]
145
+ layer ['time_major' ] = rnn_forward_layer ['config' ]['time_major' ] if 'time_major' in rnn_forward_layer ['config' ] else False
140
146
# TODO Should we handle time_major?
141
147
if layer ['time_major' ]:
142
148
raise Exception ('Time-major format is not supported by hls4ml' )
143
-
144
149
layer ['n_timesteps' ] = input_shapes [0 ][1 ]
145
150
layer ['n_in' ] = input_shapes [0 ][2 ]
146
-
147
151
assert keras_layer ['config' ]['merge_mode' ] in merge_modes
148
152
layer ['merge_mode' ] = keras_layer ['config' ]['merge_mode' ]
149
153
150
- layer ['n_out' ] = rnn_layer ['config' ]['units' ]
151
- if keras_layer ['config' ]['merge_mode' ] == 'concat' :
152
- layer ['n_out' ] *= 2
154
+ for direction , rnn_layer in [('forward_layer' , rnn_forward_layer ), ('backward_layer' , rnn_backward_layer )]:
153
155
154
- rnn_layer_name = rnn_layer ['config' ]['name' ]
155
- if 'SimpleRNN' in layer ['class_name' ]:
156
- cell_name = 'simple_rnn'
157
- else :
158
- cell_name = rnn_layer ['class_name' ].lower ()
159
- layer ['weight_data' ], layer ['recurrent_weight_data' ], layer ['bias_data' ] = get_weights_data (
160
- data_reader ,
161
- layer ['name' ],
162
- [
163
- f'forward_{ rnn_layer_name } /{ cell_name } _cell/kernel' ,
164
- f'forward_{ rnn_layer_name } /{ cell_name } _cell/recurrent_kernel' ,
165
- f'forward_{ rnn_layer_name } /{ cell_name } _cell/bias' ,
166
- ],
167
- )
168
- layer ['weight_b_data' ], layer ['recurrent_weight_b_data' ], layer ['bias_b_data' ] = get_weights_data (
169
- data_reader ,
170
- layer ['name' ],
171
- [
172
- f'backward_{ rnn_layer_name } /{ cell_name } _cell/kernel' ,
173
- f'backward_{ rnn_layer_name } /{ cell_name } _cell/recurrent_kernel' ,
174
- f'backward_{ rnn_layer_name } /{ cell_name } _cell/bias' ,
175
- ],
176
- )
177
-
178
- if 'GRU' in layer ['class_name' ]:
179
- layer ['apply_reset_gate' ] = 'after' if rnn_layer ['config' ]['reset_after' ] else 'before'
156
+ if 'SimpleRNN' not in rnn_layer ['class_name' ]:
157
+ layer [direction ]['recurrent_activation' ] = rnn_layer ['config' ]['recurrent_activation' ]
180
158
181
- # biases array is actually a 2-dim array of arrays (bias + recurrent bias)
182
- # both arrays have shape: n_units * 3 (z, r, h_cand)
183
- biases = layer ['bias_data' ]
184
- biases_b = layer ['bias_b_data' ]
185
- layer ['bias_data' ] = biases [0 ]
186
- layer ['recurrent_bias_data' ] = biases [1 ]
187
- layer ['bias_b_data' ] = biases_b [0 ]
188
- layer ['recurrent_bias_b_data' ] = biases_b [1 ]
159
+ rnn_layer_name = rnn_layer ['config' ]['name' ]
160
+ if 'SimpleRNN' in layer ['class_name' ]:
161
+ cell_name = 'simple_rnn'
162
+ else :
163
+ cell_name = rnn_layer ['class_name' ].lower ()
164
+ layer [direction ]['weight_data' ], layer [direction ]['recurrent_weight_data' ], layer [direction ]['bias_data' ] = (
165
+ get_weights_data (
166
+ data_reader ,
167
+ layer ['name' ],
168
+ [
169
+ f'{ direction [:- 6 ]} _{ rnn_layer_name } /{ cell_name } _cell/kernel' ,
170
+ f'{ direction [:- 6 ]} _{ rnn_layer_name } /{ cell_name } _cell/recurrent_kernel' ,
171
+ f'{ direction [:- 6 ]} _{ rnn_layer_name } /{ cell_name } _cell/bias' ,
172
+ ],
173
+ )
174
+ )
175
+
176
+ if 'GRU' in rnn_layer ['class_name' ]:
177
+ layer [direction ]['apply_reset_gate' ] = 'after' if rnn_layer ['config' ]['reset_after' ] else 'before'
178
+
179
+ # biases array is actually a 2-dim array of arrays (bias + recurrent bias)
180
+ # both arrays have shape: n_units * 3 (z, r, h_cand)
181
+ biases = layer [direction ]['bias_data' ]
182
+ layer [direction ]['bias_data' ] = biases [0 ]
183
+ layer [direction ]['recurrent_bias_data' ] = biases [1 ]
184
+
185
+ layer [direction ]['n_states' ] = rnn_layer ['config' ]['units' ]
186
+
187
+ layer ['n_out' ] = layer ['forward_layer' ]['n_states' ] + layer ['backward_layer' ]['n_states' ]
189
188
190
189
if layer ['return_sequences' ]:
191
190
output_shape = [input_shapes [0 ][0 ], layer ['n_timesteps' ], layer ['n_out' ]]
0 commit comments