Skip to content

Commit edf7cdf

Browse files
author
Enrico Lupi
committed
ADD paring for general bidirectional layer
1 parent d2c6cc0 commit edf7cdf

File tree

1 file changed

+28
-18
lines changed

1 file changed

+28
-18
lines changed

hls4ml/converters/keras/recurrent.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -133,12 +133,10 @@ def parse_bidirectional_layer(keras_layer, input_names, input_shapes, data_reade
133133

134134
layer = {}
135135
layer['name'] = keras_layer['config']['name']
136-
layer['forward_layer'] = parse_default_keras_layer(rnn_forward_layer, input_names)
137-
layer['backward_layer'] = parse_default_keras_layer(rnn_backward_layer, input_names)
138-
layer['class_name'] = (
139-
'Bidirectional' + layer['forward_layer']['class_name']
140-
) # TODO: to be changed if we ever implement different
141-
# architecture for forward and backward layer
136+
layer['class_name'] = keras_layer['class_name']
137+
if input_names is not None:
138+
layer['inputs'] = input_names
139+
142140
layer['direction'] = 'bidirectional'
143141
layer['return_sequences'] = rnn_forward_layer['config']['return_sequences']
144142
layer['return_state'] = rnn_forward_layer['config']['return_state']
@@ -151,40 +149,52 @@ def parse_bidirectional_layer(keras_layer, input_names, input_shapes, data_reade
151149
assert keras_layer['config']['merge_mode'] in merge_modes
152150
layer['merge_mode'] = keras_layer['config']['merge_mode']
153151

154-
for direction, rnn_layer in [('forward_layer', rnn_forward_layer), ('backward_layer', rnn_backward_layer)]:
152+
for direction, rnn_layer in [('forward', rnn_forward_layer), ('backward', rnn_backward_layer)]:
153+
154+
layer[f'{direction}_name'] = rnn_layer['config']['name']
155+
layer[f'{direction}_class_name'] = rnn_layer['class_name']
156+
157+
layer[f'{direction}_data_format'] = rnn_layer['config'].get('data_format', 'channels_last')
158+
159+
if 'activation' in rnn_layer['config']:
160+
layer[f'{direction}_activation'] = rnn_layer['config']['activation']
161+
if 'epsilon' in rnn_layer['config']:
162+
layer[f'{direction}_epsilon'] = rnn_layer['config']['epsilon']
163+
if 'use_bias' in rnn_layer['config']:
164+
layer[f'{direction}_use_bias'] = rnn_layer['config']['use_bias']
155165

156166
if 'SimpleRNN' not in rnn_layer['class_name']:
157-
layer[direction]['recurrent_activation'] = rnn_layer['config']['recurrent_activation']
167+
layer[f'{direction}_recurrent_activation'] = rnn_layer['config']['recurrent_activation']
158168

159169
rnn_layer_name = rnn_layer['config']['name']
160170
if 'SimpleRNN' in layer['class_name']:
161171
cell_name = 'simple_rnn'
162172
else:
163173
cell_name = rnn_layer['class_name'].lower()
164-
layer[direction]['weight_data'], layer[direction]['recurrent_weight_data'], layer[direction]['bias_data'] = (
174+
layer[f'{direction}_weight_data'], layer[f'{direction}_recurrent_weight_data'], layer[f'{direction}_bias_data'] = (
165175
get_weights_data(
166176
data_reader,
167177
layer['name'],
168178
[
169-
f'{direction[:-6]}_{rnn_layer_name}/{cell_name}_cell/kernel',
170-
f'{direction[:-6]}_{rnn_layer_name}/{cell_name}_cell/recurrent_kernel',
171-
f'{direction[:-6]}_{rnn_layer_name}/{cell_name}_cell/bias',
179+
f'{direction}_{rnn_layer_name}/{cell_name}_cell/kernel',
180+
f'{direction}_{rnn_layer_name}/{cell_name}_cell/recurrent_kernel',
181+
f'{direction}_{rnn_layer_name}/{cell_name}_cell/bias',
172182
],
173183
)
174184
)
175185

176186
if 'GRU' in rnn_layer['class_name']:
177-
layer[direction]['apply_reset_gate'] = 'after' if rnn_layer['config']['reset_after'] else 'before'
187+
layer[f'{direction}_apply_reset_gate'] = 'after' if rnn_layer['config']['reset_after'] else 'before'
178188

179189
# biases array is actually a 2-dim array of arrays (bias + recurrent bias)
180190
# both arrays have shape: n_units * 3 (z, r, h_cand)
181-
biases = layer[direction]['bias_data']
182-
layer[direction]['bias_data'] = biases[0]
183-
layer[direction]['recurrent_bias_data'] = biases[1]
191+
biases = layer[f'{direction}_bias_data']
192+
layer[f'{direction}_bias_data'] = biases[0]
193+
layer[f'{direction}_recurrent_bias_data'] = biases[1]
184194

185-
layer[direction]['n_states'] = rnn_layer['config']['units']
195+
layer[f'{direction}_n_states'] = rnn_layer['config']['units']
186196

187-
layer['n_out'] = layer['forward_layer']['n_states'] + layer['backward_layer']['n_states']
197+
layer['n_out'] = layer['forward_n_states'] + layer['backward_n_states']
188198

189199
if layer['return_sequences']:
190200
output_shape = [input_shapes[0][0], layer['n_timesteps'], layer['n_out']]

0 commit comments

Comments
 (0)