Skip to content

Commit 0246dae

Browse files
author
Enrico Lupi
committed
FIX weight name and ADD backward layer architecture check
1 parent a9546c7 commit 0246dae

File tree

1 file changed

+21
-8
lines changed

1 file changed

+21
-8
lines changed

hls4ml/converters/keras/recurrent.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -116,15 +116,19 @@ def parse_time_distributed_layer(keras_layer, input_names, input_shapes, data_re
116116
@keras_handler('Bidirectional')
117117
def parse_bidirectional_layer(keras_layer, input_names, input_shapes, data_reader):
118118
assert keras_layer['class_name'] == 'Bidirectional'
119-
119+
120120
rnn_layer = keras_layer['config']['layer']
121121
assert rnn_layer['class_name'] in rnn_layers or rnn_layer['class_name'][1:] in rnn_layers
122122

123123
layer = parse_default_keras_layer(rnn_layer, input_names)
124124
layer['name'] = keras_layer['config']['name']
125-
layer['class_name'] = 'B' + layer['class_name']
125+
layer['class_name'] = 'Bidirectional' + layer['class_name']
126126
layer['direction'] = 'bidirectional'
127127

128+
# TODO Should we handle different architectures for forward and backward layer?
129+
if keras_layer['config'].get('backward_layer'):
130+
raise Exception('Different architectures between forward and backward layers are not supported by hls4ml')
131+
128132
layer['return_sequences'] = rnn_layer['config']['return_sequences']
129133
layer['return_state'] = rnn_layer['config']['return_state']
130134

@@ -147,19 +151,28 @@ def parse_bidirectional_layer(keras_layer, input_names, input_shapes, data_reade
147151
if keras_layer['config']['merge_mode'] == 'concat':
148152
layer['n_out'] *= 2
149153

154+
rnn_layer_name = rnn_layer['config']['name']
150155
if 'SimpleRNN' in layer['class_name']:
151156
cell_name = 'simple_rnn'
152157
else:
153158
cell_name = rnn_layer['class_name'].lower()
154159
layer['weight_data'], layer['recurrent_weight_data'], layer['bias_data'] = get_weights_data(
155-
data_reader, layer['name'], [f'{cell_name}_cell/kernel',
156-
f'{cell_name}_cell/recurrent_kernel',
157-
f'{cell_name}_cell/bias']
160+
data_reader,
161+
layer['name'],
162+
[
163+
f'forward_{rnn_layer_name}/{cell_name}_cell/kernel',
164+
f'forward_{rnn_layer_name}/{cell_name}_cell/recurrent_kernel',
165+
f'forward_{rnn_layer_name}/{cell_name}_cell/bias',
166+
],
158167
)
159168
layer['weight_b_data'], layer['recurrent_weight_b_data'], layer['bias_b_data'] = get_weights_data(
160-
data_reader, layer['name'], [f'{cell_name}_cell/kernel',
161-
f'{cell_name}_cell/recurrent_kernel',
162-
f'{cell_name}_cell/bias']
169+
data_reader,
170+
layer['name'],
171+
[
172+
f'backward_{rnn_layer_name}/{cell_name}_cell/kernel',
173+
f'backward_{rnn_layer_name}/{cell_name}_cell/recurrent_kernel',
174+
f'backward_{rnn_layer_name}/{cell_name}_cell/bias',
175+
],
163176
)
164177

165178
if 'GRU' in layer['class_name']:

0 commit comments

Comments
 (0)