Skip to content

Commit 4ed22c9

Browse files
author
Enrico Lupi
committed
ADD Bidirectional layers support
1 parent d882310 commit 4ed22c9

File tree

8 files changed

+380
-813
lines changed

8 files changed

+380
-813
lines changed

hls4ml/backends/fpga/fpga_backend.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
Activation,
1515
BatchNormalization,
1616
BatchNormOnnx,
17-
BidirectionalGRU,
18-
BidirectionalLSTM,
17+
Bidirectional,
1918
Conv,
2019
Conv1D,
2120
Conv2D,
@@ -70,8 +69,7 @@ def __init__(self, name):
7069
SimpleRNN,
7170
LSTM,
7271
GRU,
73-
BidirectionalLSTM,
74-
BidirectionalGRU,
72+
Bidirectional,
7573
Dot,
7674
Conv,
7775
MatMul,
@@ -217,34 +215,30 @@ def get_layer_mult_size(self, layer):
217215
n_out = layer.get_attr('n_filt')
218216
return n_in, n_out
219217

220-
if 'BidirectionalLSTM' in layer.class_name:
221-
n_in = layer.get_attr('n_in')
222-
n_out = layer.get_attr('n_out') * 2 # /2*4
223-
n_in_recr = layer.get_attr('n_out') // 2
224-
n_out_recr = n_out
225-
return n_in, n_out, n_in_recr, n_out_recr
226-
227218
if 'LSTM' in layer.class_name:
228219
n_in = layer.get_attr('n_in')
229220
n_out = layer.get_attr('n_out') * 4
230221
n_in_recr = layer.get_attr('n_out')
231222
n_out_recr = n_out
232223
return n_in, n_out, n_in_recr, n_out_recr
233224

234-
if 'BidirectionalGRU' in layer.class_name:
235-
n_in = layer.get_attr('n_in')
236-
n_out = layer.get_attr('n_out') // 2 * 3
237-
n_in_recr = layer.get_attr('n_out') // 2
238-
n_out_recr = n_out
239-
return n_in, n_out, n_in_recr, n_out_recr
240-
241225
if 'GRU' in layer.class_name:
242226
n_in = layer.get_attr('n_in')
243227
n_out = layer.get_attr('n_out') * 3
244228
n_in_recr = layer.get_attr('n_out')
245229
n_out_recr = n_out
246230
return n_in, n_out, n_in_recr, n_out_recr
247231

232+
if 'Bidirectional' in layer.class_name:
233+
result = []
234+
for d in ['forward', 'backward']:
235+
n_in = layer.get_attr('n_in')
236+
n_out = layer.get_attr(f'{d}_n_states') * 3
237+
n_in_recr = layer.get_attr(f'{d}_n_states')
238+
n_out_recr = n_out
239+
result.append((n_in, n_out, n_in_recr, n_out_recr))
240+
return result
241+
248242
raise Exception(f'Cannot get mult size for layer {layer.name} ({layer.class_name})')
249243

250244
def get_valid_reuse_factors(self, n_in, n_out):
@@ -295,6 +289,7 @@ def set_closest_reuse_factor(self, layer, n_in, n_out, attribute='reuse_factor',
295289
if not include_max_rf:
296290
valid_rf.pop()
297291
chosen_rf = layer.get_attr(attribute)
292+
print("\n\nREuse factor:", chosen_rf, "\n\n")
298293
if chosen_rf not in valid_rf:
299294
closest_rf = self.get_closest_reuse_factor(valid_rf, chosen_rf)
300295
valid_rf_str = ','.join(map(str, valid_rf))

hls4ml/backends/vivado/passes/recurrent_templates.py

Lines changed: 159 additions & 123 deletions
Large diffs are not rendered by default.

hls4ml/backends/vivado/passes/resource_strategy.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
from hls4ml.model.layers import (
44
GRU,
55
LSTM,
6-
BidirectionalGRU,
7-
BidirectionalLSTM,
6+
Bidirectional,
87
Conv1D,
98
Conv2D,
109
Dense,
@@ -18,9 +17,7 @@ class ApplyResourceStrategy(OptimizerPass):
1817
'''Transposes the weights to use the dense_resource matrix multiply routine'''
1918

2019
def match(self, node):
21-
node_matches = isinstance(
22-
node, (Dense, Conv1D, SeparableConv1D, Conv2D, SeparableConv2D, LSTM, GRU, BidirectionalLSTM, BidirectionalGRU)
23-
)
20+
node_matches = isinstance(node, (Dense, Conv1D, SeparableConv1D, Conv2D, SeparableConv2D, LSTM, GRU, Bidirectional))
2421
is_resource_strategy = node.get_attr('strategy', '').lower() in ['resource', 'resource_unrolled']
2522
already_transformed = node.get_attr('_weights_transposed', False) is True
2623
return node_matches and is_resource_strategy and not already_transformed
@@ -48,11 +45,10 @@ def transform(self, model, node):
4845
node.weights['pointwise'].data = np.transpose(
4946
node.weights['pointwise'].data, axes=[3, 0, 1, 2]
5047
) # (H,W,C,F) => (F,H,W,C)
51-
elif isinstance(node, (BidirectionalLSTM, BidirectionalGRU)):
52-
node.weights['weight'].data = np.transpose(node.weights['weight'].data)
53-
node.weights['recurrent_weight'].data = np.transpose(node.weights['recurrent_weight'].data)
54-
node.weights['weight_b'].data = np.transpose(node.weights['weight_b'].data)
55-
node.weights['recurrent_weight_b'].data = np.transpose(node.weights['recurrent_weight_b'].data)
48+
elif isinstance(node, (Bidirectional)):
49+
for d in ['forward', 'backward']:
50+
node.weights[f'{d}_weight'].data = np.transpose(node.weights[f'{d}_weight'].data)
51+
node.weights[f'{d}_recurrent_weight'].data = np.transpose(node.weights[f'{d}_recurrent_weight'].data)
5652
elif isinstance(node, (LSTM, GRU)):
5753
node.weights['weight'].data = np.transpose(node.weights['weight'].data)
5854
node.weights['recurrent_weight'].data = np.transpose(node.weights['recurrent_weight'].data)

hls4ml/backends/vivado/vivado_backend.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
from hls4ml.model.layers import (
1313
GRU,
1414
LSTM,
15-
BidirectionalGRU,
16-
BidirectionalLSTM,
15+
Bidirectional,
1716
Conv1D,
1817
Conv2D,
1918
Dense,
@@ -48,7 +47,7 @@ def __init__(self):
4847

4948
def _register_layer_attributes(self):
5049
# Add RNN-specific attributes, recurrent_reuse_factor and static implementation
51-
rnn_layers = [SimpleRNN, LSTM, GRU, BidirectionalLSTM, BidirectionalGRU]
50+
rnn_layers = [SimpleRNN, LSTM, GRU]
5251

5352
for layer in rnn_layers:
5453
attrs = self.attribute_map.get(layer, [])
@@ -60,6 +59,24 @@ def _register_layer_attributes(self):
6059
attrs.append(TypeAttribute('table', default=FixedPrecisionType(18, 8), description=descriptions.table_type))
6160
self.attribute_map[layer] = attrs
6261

62+
bidir_rnn_layers = [Bidirectional]
63+
for layer in bidir_rnn_layers:
64+
attrs = self.attribute_map.get(layer, [])
65+
attrs.append(ConfigurableAttribute('forward_reuse_factor', default=1, description=descriptions.reuse_factor))
66+
attrs.append(ConfigurableAttribute('backward_reuse_factor', default=1, description=descriptions.reuse_factor))
67+
attrs.append(
68+
ConfigurableAttribute('forward_recurrent_reuse_factor', default=1, description=descriptions.reuse_factor)
69+
)
70+
attrs.append(
71+
ConfigurableAttribute('backward_recurrent_reuse_factor', default=1, description=descriptions.reuse_factor)
72+
)
73+
attrs.append(
74+
ConfigurableAttribute('static', value_type=bool, default=True, description=descriptions.recurrent_static)
75+
)
76+
attrs.append(ConfigurableAttribute('table_size', default=1024, description=descriptions.table_size))
77+
attrs.append(TypeAttribute('table', default=FixedPrecisionType(18, 8), description=descriptions.table_type))
78+
self.attribute_map[layer] = attrs
79+
6380
# Add ParallelizationFactor to Conv1D/2D
6481
pf_layers = [
6582
Conv1D,
@@ -657,6 +674,45 @@ def init_time_distributed(self, layer):
657674
warn(f'Cannot unroll time step loop in layer "{layer.name}" while using "io_stream".')
658675
loop_mode = 'off'
659676
layer.set_attr('time_step_loop_parallelism', loop_mode)
677+
678+
@layer_optimizer(Bidirectional)
679+
def init_bidirectional(self, layer):
680+
reuse_factor = layer.model.config.get_reuse_factor(layer)
681+
682+
for i, d in enumerate(['forward', 'backward']):
683+
layer.set_attr(f'{d}_reuse_factor', reuse_factor)
684+
layer.set_attr(f'{d}_recurrent_reuse_factor', reuse_factor)
685+
686+
if layer.model.config.is_resource_strategy(layer):
687+
n_in, n_out, n_in_recr, n_out_recr = self.get_layer_mult_size(layer)[i]
688+
self.set_closest_reuse_factor(layer, n_in, n_out, attribute=f'{d}_reuse_factor')
689+
self.set_closest_reuse_factor(layer, n_in_recr, n_out_recr, attribute=f'{d}_recurrent_reuse_factor')
690+
layer.set_attr('strategy', 'resource')
691+
692+
elif layer.model.config.get_strategy(layer).lower() == 'resource_unrolled':
693+
use_resource_instead = False
694+
if layer.get_attr('reuse_factor', 1) == 1:
695+
print(
696+
f'Unrolled resource strategy cannot be combined with reuse factor 1 in layer "{layer.name} ({d})". '
697+
'Using "resource" strategy instead.'
698+
)
699+
use_resource_instead = True
700+
701+
n_in, n_out, n_in_recr, n_out_recr = self.get_layer_mult_size(layer)[i]
702+
if use_resource_instead:
703+
self.set_closest_reuse_factor(layer, n_in, n_out, attribute=f'{d}_reuse_factor')
704+
self.set_closest_reuse_factor(layer, n_in_recr, n_out_recr, attribute=f'{d}_recurrent_reuse_factor')
705+
layer.set_attr('strategy', 'resource')
706+
else:
707+
self.set_closest_reuse_factor(layer, n_in, n_out, attribute=f'{d}_reuse_factor', include_max_rf=False)
708+
self.set_closest_reuse_factor(
709+
layer, n_in_recr, n_out_recr, attribute=f'{d}_recurrent_reuse_factor', include_max_rf=False
710+
)
711+
layer.set_attr('strategy', 'resource_unrolled')
712+
else:
713+
layer.set_attr('strategy', 'latency')
714+
715+
layer.set_attr('index_t', NamedType(f'layer{layer.index}_index', IntegerPrecisionType(width=1, signed=False)))
660716

661717
@layer_optimizer(GarNet)
662718
def init_garnet(self, layer):

hls4ml/converters/keras_v2_to_hls.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def parse_keras_model(model_arch, reader):
241241
'HGQ>UnaryLUT',
242242
]
243243
# Recurrent layers
244-
recurrent_layers = ['SimpleRNN', 'LSTM', 'GRU', 'QSimpleRNN', 'QLSTM', 'QGRU', 'BidirectionalLSTM', 'BidirectionalGRU']
244+
recurrent_layers = ['SimpleRNN', 'LSTM', 'GRU', 'QSimpleRNN', 'QLSTM', 'QGRU', 'Bidirectional']
245245
# All supported layers
246246
supported_layers = get_supported_keras_layers() + skip_layers
247247

hls4ml/model/layers.py

Lines changed: 1 addition & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,40 +1402,6 @@ def initialize(self):
14021402
self.add_weights_variable(name='recurrent_bias', var_name='br{index}', data=recurrent_bias)
14031403

14041404

1405-
class BidirectionalLSTM(LSTM):
1406-
_expected_attributes = [
1407-
WeightAttribute('weight_b'),
1408-
WeightAttribute('bias_b'),
1409-
WeightAttribute('recurrent_weight_b'),
1410-
WeightAttribute('recurrent_bias_b'),
1411-
TypeAttribute('weight_b'),
1412-
TypeAttribute('bias_b'),
1413-
TypeAttribute('recurrent_weight_b'),
1414-
TypeAttribute('recurrent_bias_b'),
1415-
ChoiceAttribute('merge_mode', ['sum', 'mul', 'concat', 'ave'], configurable=False, default='concat'),
1416-
]
1417-
1418-
def initialize(self):
1419-
super().initialize()
1420-
1421-
# Add backward layer parameters
1422-
# weights
1423-
self.add_weights_variable(name='weight_b', var_name='w_b{index}')
1424-
1425-
# recurrent weights
1426-
self.add_weights_variable(name='recurrent_weight_b', var_name='wr_b{index}')
1427-
1428-
# biases
1429-
self.add_weights_variable(name='bias_b', var_name='b_b{index}')
1430-
1431-
if "pytorch" in self.attributes.keys():
1432-
self.add_weights_variable(name='recurrent_bias_b', var_name='br_b{index}')
1433-
else:
1434-
recurrent_weight_b = self.get_attr('recurrent_weight_b_data')
1435-
recurrent_bias_b = np.zeros(recurrent_weight_b.shape[1])
1436-
self.add_weights_variable(name='recurrent_bias_b', var_name='br_b{index}', data=recurrent_bias_b)
1437-
1438-
14391405
class GRU(Layer):
14401406
_expected_attributes = [
14411407
Attribute('n_out'),
@@ -1509,34 +1475,6 @@ def initialize(self):
15091475

15101476
self.add_output_variable(shape, dims)
15111477

1512-
1513-
class BidirectionalGRU(GRU):
1514-
_expected_attributes = [
1515-
WeightAttribute('weight_b'),
1516-
WeightAttribute('bias_b'),
1517-
WeightAttribute('recurrent_weight_b'),
1518-
WeightAttribute('recurrent_bias_b'),
1519-
TypeAttribute('weight_b'),
1520-
TypeAttribute('bias_b'),
1521-
TypeAttribute('recurrent_weight_b'),
1522-
TypeAttribute('recurrent_bias_b'),
1523-
ChoiceAttribute('merge_mode', ['sum', 'mul', 'concat', 'ave'], configurable=False, default='concat'),
1524-
]
1525-
1526-
def initialize(self):
1527-
super().initialize()
1528-
1529-
# Add backward layer parameters
1530-
# weights
1531-
self.add_weights_variable(name='weight_b', var_name='w_b{index}')
1532-
1533-
# recurrent weights
1534-
self.add_weights_variable(name='recurrent_weight_b', var_name='wr_b{index}')
1535-
1536-
# biases
1537-
self.add_weights_variable(name='bias_b', var_name='b_b{index}')
1538-
self.add_weights_variable(name='recurrent_bias_b', var_name='br_b{index}')
1539-
15401478

15411479
class Bidirectional(Layer):
15421480
_expected_attributes = [
@@ -1609,9 +1547,7 @@ def initialize(self):
16091547
name=f'{dir}_recurrent_bias', var_name=(f'br_{dir[0]}_' + '{index}'), data=recurrent_bias
16101548
)
16111549
else:
1612-
self.add_weights_variable(
1613-
name=f'{dir}_recurrent_bias', var_name=(f'br_{dir[0]}_' + '{index}'), data=recurrent_bias
1614-
)
1550+
self.add_weights_variable(name=f'{dir}_recurrent_bias', var_name=(f'br_{dir[0]}_' + '{index}'))
16151551

16161552

16171553
class GarNet(Layer):
@@ -1902,8 +1838,6 @@ def initialize(self):
19021838
'SimpleRNN': SimpleRNN,
19031839
'LSTM': LSTM,
19041840
'GRU': GRU,
1905-
'BidirectionalLSTM': BidirectionalLSTM,
1906-
'BidirectionalGRU': BidirectionalGRU,
19071841
'Bidirectional': Bidirectional,
19081842
'QSimpleRNN': SimpleRNN,
19091843
'QLSTM': LSTM,

hls4ml/model/optimizer/passes/infer_precision.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def _infer_precision(self, node, types_to_infer):
8181
if node_class in ['Embedding']:
8282
return self._infer_embedding_precision(node, types_to_infer)
8383

84-
if node_class in ['SimpleRNN', 'LSTM', 'GRU', 'BidirectionalLSTM', 'BidirectionalGRU']:
84+
if node_class in ['SimpleRNN', 'LSTM', 'GRU', 'Bidirectional']:
8585
return self._infer_rnn_precision(node, types_to_infer)
8686

8787
if node_class in ['ParametrizedActivation']:
@@ -554,8 +554,9 @@ def _infer_rnn_precision(self, node, types_to_infer):
554554

555555
# for now just do the weights and leave the rest for the default catch
556556
rnn_weights = ('weight', 'bias', 'recurrent_weight', 'recurrent_bias')
557-
if node.attributes['direction'] == 'bidirectional':
558-
rnn_weights += ('weight_b', 'bias_b', 'recurrent_weight_b', 'recurrent_bias_b')
557+
if node.class_name == 'Bidirectional':
558+
rnn_weights = [direction + '_' + weight for direction in ['forward', 'backward'] for weight in rnn_weights]
559+
559560
for weightvar in rnn_weights:
560561
if f'{weightvar}_t' in types_to_infer:
561562
self._infer_default_type(node, f'{weightvar}_t')

0 commit comments

Comments
 (0)