Skip to content

Commit d2c6cc0

Browse files
author
Enrico Lupi
committed
ADD parse general bidirectional layer with possibly different architectures
1 parent 7428af7 commit d2c6cc0

File tree

1 file changed

+55
-56
lines changed

1 file changed

+55
-56
lines changed

hls4ml/converters/keras/recurrent.py

Lines changed: 55 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -117,75 +117,74 @@ def parse_time_distributed_layer(keras_layer, input_names, input_shapes, data_re
117117
def parse_bidirectional_layer(keras_layer, input_names, input_shapes, data_reader):
118118
assert keras_layer['class_name'] == 'Bidirectional'
119119

120-
rnn_layer = keras_layer['config']['layer']
121-
assert rnn_layer['class_name'] in rnn_layers or rnn_layer['class_name'][1:] in rnn_layers
122-
123-
layer = parse_default_keras_layer(rnn_layer, input_names)
124-
layer['name'] = keras_layer['config']['name']
125-
layer['class_name'] = 'Bidirectional' + layer['class_name']
126-
layer['direction'] = 'bidirectional'
127-
128-
# TODO Should we handle different architectures for forward and backward layer?
120+
rnn_forward_layer = keras_layer['config']['layer']
129121
if keras_layer['config'].get('backward_layer'):
130-
raise Exception('Different architectures between forward and backward layers are not supported by hls4ml')
131-
132-
layer['return_sequences'] = rnn_layer['config']['return_sequences']
133-
layer['return_state'] = rnn_layer['config']['return_state']
134-
135-
if 'SimpleRNN' not in layer['class_name']:
136-
layer['recurrent_activation'] = rnn_layer['config']['recurrent_activation']
122+
rnn_backward_layer = keras_layer['config']['backward_layer']
123+
if rnn_forward_layer['config']['go_backwards']:
124+
temp_layer = rnn_forward_layer.copy()
125+
rnn_forward_layer = rnn_backward_layer.copy()
126+
rnn_backward_layer = temp_layer
127+
else:
128+
rnn_backward_layer = rnn_forward_layer
137129

138-
layer['time_major'] = rnn_layer['config']['time_major'] if 'time_major' in rnn_layer['config'] else False
130+
assert (rnn_forward_layer['class_name'] in rnn_layers or rnn_forward_layer['class_name'][1:] in rnn_layers) and (
131+
rnn_backward_layer['class_name'] in rnn_layers or rnn_backward_layer['class_name'][1:] in rnn_layers
132+
)
139133

134+
layer = {}
135+
layer['name'] = keras_layer['config']['name']
136+
layer['forward_layer'] = parse_default_keras_layer(rnn_forward_layer, input_names)
137+
layer['backward_layer'] = parse_default_keras_layer(rnn_backward_layer, input_names)
138+
layer['class_name'] = (
139+
'Bidirectional' + layer['forward_layer']['class_name']
140+
) # TODO: to be changed if we ever implement different
141+
# architecture for forward and backward layer
142+
layer['direction'] = 'bidirectional'
143+
layer['return_sequences'] = rnn_forward_layer['config']['return_sequences']
144+
layer['return_state'] = rnn_forward_layer['config']['return_state']
145+
layer['time_major'] = rnn_forward_layer['config']['time_major'] if 'time_major' in rnn_forward_layer['config'] else False
140146
# TODO Should we handle time_major?
141147
if layer['time_major']:
142148
raise Exception('Time-major format is not supported by hls4ml')
143-
144149
layer['n_timesteps'] = input_shapes[0][1]
145150
layer['n_in'] = input_shapes[0][2]
146-
147151
assert keras_layer['config']['merge_mode'] in merge_modes
148152
layer['merge_mode'] = keras_layer['config']['merge_mode']
149153

150-
layer['n_out'] = rnn_layer['config']['units']
151-
if keras_layer['config']['merge_mode'] == 'concat':
152-
layer['n_out'] *= 2
154+
for direction, rnn_layer in [('forward_layer', rnn_forward_layer), ('backward_layer', rnn_backward_layer)]:
153155

154-
rnn_layer_name = rnn_layer['config']['name']
155-
if 'SimpleRNN' in layer['class_name']:
156-
cell_name = 'simple_rnn'
157-
else:
158-
cell_name = rnn_layer['class_name'].lower()
159-
layer['weight_data'], layer['recurrent_weight_data'], layer['bias_data'] = get_weights_data(
160-
data_reader,
161-
layer['name'],
162-
[
163-
f'forward_{rnn_layer_name}/{cell_name}_cell/kernel',
164-
f'forward_{rnn_layer_name}/{cell_name}_cell/recurrent_kernel',
165-
f'forward_{rnn_layer_name}/{cell_name}_cell/bias',
166-
],
167-
)
168-
layer['weight_b_data'], layer['recurrent_weight_b_data'], layer['bias_b_data'] = get_weights_data(
169-
data_reader,
170-
layer['name'],
171-
[
172-
f'backward_{rnn_layer_name}/{cell_name}_cell/kernel',
173-
f'backward_{rnn_layer_name}/{cell_name}_cell/recurrent_kernel',
174-
f'backward_{rnn_layer_name}/{cell_name}_cell/bias',
175-
],
176-
)
177-
178-
if 'GRU' in layer['class_name']:
179-
layer['apply_reset_gate'] = 'after' if rnn_layer['config']['reset_after'] else 'before'
156+
if 'SimpleRNN' not in rnn_layer['class_name']:
157+
layer[direction]['recurrent_activation'] = rnn_layer['config']['recurrent_activation']
180158

181-
# biases array is actually a 2-dim array of arrays (bias + recurrent bias)
182-
# both arrays have shape: n_units * 3 (z, r, h_cand)
183-
biases = layer['bias_data']
184-
biases_b = layer['bias_b_data']
185-
layer['bias_data'] = biases[0]
186-
layer['recurrent_bias_data'] = biases[1]
187-
layer['bias_b_data'] = biases_b[0]
188-
layer['recurrent_bias_b_data'] = biases_b[1]
159+
rnn_layer_name = rnn_layer['config']['name']
160+
if 'SimpleRNN' in layer['class_name']:
161+
cell_name = 'simple_rnn'
162+
else:
163+
cell_name = rnn_layer['class_name'].lower()
164+
layer[direction]['weight_data'], layer[direction]['recurrent_weight_data'], layer[direction]['bias_data'] = (
165+
get_weights_data(
166+
data_reader,
167+
layer['name'],
168+
[
169+
f'{direction[:-6]}_{rnn_layer_name}/{cell_name}_cell/kernel',
170+
f'{direction[:-6]}_{rnn_layer_name}/{cell_name}_cell/recurrent_kernel',
171+
f'{direction[:-6]}_{rnn_layer_name}/{cell_name}_cell/bias',
172+
],
173+
)
174+
)
175+
176+
if 'GRU' in rnn_layer['class_name']:
177+
layer[direction]['apply_reset_gate'] = 'after' if rnn_layer['config']['reset_after'] else 'before'
178+
179+
# biases array is actually a 2-dim array of arrays (bias + recurrent bias)
180+
# both arrays have shape: n_units * 3 (z, r, h_cand)
181+
biases = layer[direction]['bias_data']
182+
layer[direction]['bias_data'] = biases[0]
183+
layer[direction]['recurrent_bias_data'] = biases[1]
184+
185+
layer[direction]['n_states'] = rnn_layer['config']['units']
186+
187+
layer['n_out'] = layer['forward_layer']['n_states'] + layer['backward_layer']['n_states']
189188

190189
if layer['return_sequences']:
191190
output_shape = [input_shapes[0][0], layer['n_timesteps'], layer['n_out']]

0 commit comments

Comments
 (0)