Skip to content

Commit 5eef679

Browse files
author
Enrico Lupi
committed
FIX bidirectional layers name
1 parent 4c3d26e commit 5eef679

File tree

3 files changed

+15
-21
lines changed

3 files changed

+15
-21
lines changed

hls4ml/backends/vivado/passes/recurrent_templates.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
from hls4ml.backends.backend import get_backend
22
from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate
3-
<<<<<<< HEAD
4-
from hls4ml.model.layers import GRU, LSTM, BLSTM, BGRU, TimeDistributed
5-
=======
6-
from hls4ml.model.layers import BGRU, BLSTM, GRU, LSTM
7-
>>>>>>> d2d3b452 (ADD fixes)
3+
from hls4ml.model.layers import GRU, LSTM, BidirectionalLSTM, BidirectionalGRU, TimeDistributed
84

95
# recurrent multiplication template
106

@@ -247,7 +243,7 @@ def format(self, node):
247243

248244
class BidirectionalRecurrentConfigTemplate(LayerConfigTemplate):
249245
def __init__(self):
250-
super().__init__((BLSTM, BGRU))
246+
super().__init__((BidirectionalLSTM, BidirectionalGRU))
251247
self.template = bidir_recr_config_template
252248
self.act_template = activ_config_template
253249
self.recr_act_template = recr_activ_config_template
@@ -275,11 +271,11 @@ def format(self, node):
275271
params['static'] = 'true' if node.attributes['static'] else 'false'
276272
params['pytorch'] = 'true' if node.get_attr('pytorch', False) else 'false'
277273
params['recr_type'] = node.class_name.lower()
278-
params['RECR_TYPE'] = node.class_name[1:]
274+
params['RECR_TYPE'] = node.class_name[13:]
279275

280-
if node.class_name == 'BLSTM':
276+
if node.class_name == 'BidirectionalLSTM':
281277
n_recr_mult = 4
282-
else: # BGRU
278+
else: # BidirectionalGRU
283279
n_recr_mult = 3
284280

285281
recr_config = self.template.format(**params)
@@ -458,7 +454,7 @@ def format(self, node):
458454

459455
class BidirectionalRecurrentFunctionTemplate(FunctionCallTemplate):
460456
def __init__(self):
461-
super().__init__((BLSTM, BGRU), include_header=recr_include_list)
457+
super().__init__((BidirectionalLSTM, BidirectionalGRU), include_header=recr_include_list)
462458

463459
def format(self, node):
464460
params = self._default_function_params(node)

hls4ml/backends/vivado/vivado_backend.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from hls4ml.model.layers import (
1313
GRU,
1414
LSTM,
15+
BidirectionalGRU,
16+
BidirectionalLSTM,
1517
Conv1D,
1618
Conv2D,
1719
Dense,
@@ -46,11 +48,7 @@ def __init__(self):
4648

4749
def _register_layer_attributes(self):
4850
# Add RNN-specific attributes, recurrent_reuse_factor and static implementation
49-
rnn_layers = [
50-
SimpleRNN,
51-
LSTM,
52-
GRU,
53-
]
51+
rnn_layers = [SimpleRNN, LSTM, GRU, BidirectionalLSTM, BidirectionalGRU]
5452

5553
for layer in rnn_layers:
5654
attrs = self.attribute_map.get(layer, [])

hls4ml/model/layers.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,8 +1402,8 @@ def initialize(self):
14021402
self.add_weights_variable(name='recurrent_bias', var_name='br{index}', data=recurrent_bias)
14031403

14041404

1405-
class BLSTM(LSTM):
1406-
_expected_attributes = LSTM._expected_attributes + [
1405+
class BidirectionalLSTM(LSTM):
1406+
_expected_attributes = [
14071407
WeightAttribute('weight_b'),
14081408
WeightAttribute('bias_b'),
14091409
WeightAttribute('recurrent_weight_b'),
@@ -1510,8 +1510,8 @@ def initialize(self):
15101510
self.add_output_variable(shape, dims)
15111511

15121512

1513-
class BGRU(GRU):
1514-
_expected_attributes = GRU._expected_attributes + [
1513+
class BidirectionalGRU(GRU):
1514+
_expected_attributes = [
15151515
WeightAttribute('weight_b'),
15161516
WeightAttribute('bias_b'),
15171517
WeightAttribute('recurrent_weight_b'),
@@ -1826,8 +1826,8 @@ def initialize(self):
18261826
'SimpleRNN': SimpleRNN,
18271827
'LSTM': LSTM,
18281828
'GRU': GRU,
1829-
'BLSTM': BLSTM,
1830-
'BGRU': BGRU,
1829+
'BidirectionalLSTM': BidirectionalLSTM,
1830+
'BidirectionalGRU': BidirectionalGRU,
18311831
'QSimpleRNN': SimpleRNN,
18321832
'QLSTM': LSTM,
18331833
'QGRU': GRU,

0 commit comments

Comments
 (0)