Skip to content

Commit b65c730

Browse files
author
Enrico Lupi
committed
FIX n_out in case of merge_mode != concat
1 parent 070fdc2 commit b65c730

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

hls4ml/converters/keras/recurrent.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,10 @@ def parse_bidirectional_layer(keras_layer, input_names, input_shapes, data_reade
200200

201201
layer[f'{direction}_n_states'] = rnn_layer['config']['units']
202202

203-
layer['n_out'] = layer['forward_n_states'] + layer['backward_n_states']
203+
if layer['merge_mode'] == 'concat':
204+
layer['n_out'] = layer['forward_n_states'] + layer['backward_n_states']
205+
else:
206+
layer['n_out'] = layer['forward_n_states']
204207

205208
if layer['return_sequences']:
206209
output_shape = [input_shapes[0][0], layer['n_timesteps'], layer['n_out']]

0 commit comments

Comments
 (0)