Skip to content

Commit dd4f220

Browse files
author
Enrico Lupi
committed
ADD support for reverse order layers
1 parent 4ed22c9 commit dd4f220

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

hls4ml/converters/keras/recurrent.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,12 +118,14 @@ def parse_bidirectional_layer(keras_layer, input_names, input_shapes, data_reade
118118
assert keras_layer['class_name'] == 'Bidirectional'
119119

120120
rnn_forward_layer = keras_layer['config']['layer']
121+
swapped_order = False
121122
if keras_layer['config'].get('backward_layer'):
122123
rnn_backward_layer = keras_layer['config']['backward_layer']
123124
if rnn_forward_layer['config']['go_backwards']:
124125
temp_layer = rnn_forward_layer.copy()
125126
rnn_forward_layer = rnn_backward_layer.copy()
126127
rnn_backward_layer = temp_layer
128+
swapped_order = True
127129
else:
128130
rnn_backward_layer = rnn_forward_layer
129131

@@ -138,6 +140,7 @@ def parse_bidirectional_layer(keras_layer, input_names, input_shapes, data_reade
138140
layer['inputs'] = input_names
139141

140142
layer['direction'] = 'bidirectional'
143+
layer['swapped_order'] = swapped_order
141144
layer['return_sequences'] = rnn_forward_layer['config']['return_sequences']
142145
layer['return_state'] = rnn_forward_layer['config']['return_state']
143146
layer['time_major'] = rnn_forward_layer['config']['time_major'] if 'time_major' in rnn_forward_layer['config'] else False
@@ -171,14 +174,17 @@ def parse_bidirectional_layer(keras_layer, input_names, input_shapes, data_reade
171174
cell_name = 'simple_rnn'
172175
else:
173176
cell_name = rnn_layer['class_name'].lower()
177+
temp_dir = direction
178+
if swapped_order:
179+
temp_dir = 'backward' if direction == 'forward' else 'forward'
174180
layer[f'{direction}_weight_data'], layer[f'{direction}_recurrent_weight_data'], layer[f'{direction}_bias_data'] = (
175181
get_weights_data(
176182
data_reader,
177183
layer['name'],
178184
[
179-
f'{direction}_{rnn_layer_name}/{cell_name}_cell/kernel',
180-
f'{direction}_{rnn_layer_name}/{cell_name}_cell/recurrent_kernel',
181-
f'{direction}_{rnn_layer_name}/{cell_name}_cell/bias',
185+
f'{temp_dir}_{rnn_layer_name}/{cell_name}_cell/kernel',
186+
f'{temp_dir}_{rnn_layer_name}/{cell_name}_cell/recurrent_kernel',
187+
f'{temp_dir}_{rnn_layer_name}/{cell_name}_cell/bias',
182188
],
183189
)
184190
)

0 commit comments

Comments
 (0)