Skip to content

Commit d882310

Browse files
author
Enrico Lupi
committed
ADD gnerale bidirectional wrapper
1 parent edf7cdf commit d882310

File tree

2 files changed

+380
-2
lines changed

2 files changed

+380
-2
lines changed

hls4ml/model/layers.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1538,6 +1538,82 @@ def initialize(self):
15381538
self.add_weights_variable(name='recurrent_bias_b', var_name='br_b{index}')
15391539

15401540

1541+
class Bidirectional(Layer):
1542+
_expected_attributes = [
1543+
Attribute('n_out'),
1544+
Attribute('return_sequences', value_type=bool, default=False),
1545+
Attribute('return_state', value_type=bool, default=False),
1546+
Attribute('pass_initial_states', value_type=bool, default=False),
1547+
Attribute('time_major', value_type=bool, default=False),
1548+
Attribute('forward_activation', value_type=str),
1549+
Attribute('forward_recurrent_activation', value_type=str),
1550+
WeightAttribute('forward_weight'),
1551+
WeightAttribute('forward_bias'),
1552+
WeightAttribute('forward_recurrent_weight'),
1553+
WeightAttribute('forward_recurrent_bias'),
1554+
TypeAttribute('forward_weight'),
1555+
TypeAttribute('forward_bias'),
1556+
TypeAttribute('forward_recurrent_weight'),
1557+
TypeAttribute('forward_recurrent_bias'),
1558+
Attribute('backward_activation', value_type=str),
1559+
Attribute('backward_recurrent_activation', value_type=str),
1560+
WeightAttribute('backward_weight'),
1561+
WeightAttribute('backward_bias'),
1562+
WeightAttribute('backward_recurrent_weight'),
1563+
WeightAttribute('backward_recurrent_bias'),
1564+
TypeAttribute('backward_weight'),
1565+
TypeAttribute('backward_bias'),
1566+
TypeAttribute('backward_recurrent_weight'),
1567+
TypeAttribute('backward_recurrent_bias'),
1568+
]
1569+
1570+
def initialize(self):
1571+
if self.attributes['return_sequences']:
1572+
shape = [self.attributes['n_timesteps'], self.attributes['n_out']]
1573+
dims = [f'N_TIME_STEPS_{self.index}', f'N_OUT_{self.index}']
1574+
else:
1575+
shape = [self.attributes['n_out']]
1576+
dims = [f'N_OUT_{self.index}']
1577+
1578+
self.add_output_variable(shape, dims)
1579+
1580+
if self.attributes['return_state']:
1581+
state_shape = [self.attributes['n_out']]
1582+
state_dims = [f'N_OUT_{self.index}']
1583+
self.add_output_variable(
1584+
state_shape, state_dims, out_name=self.outputs[1], var_name='layer{index}_h', type_name='layer{index}_h_t'
1585+
)
1586+
self.add_output_variable(
1587+
state_shape, state_dims, out_name=self.outputs[2], var_name='layer{index}_c', type_name='layer{index}_c_t'
1588+
)
1589+
1590+
for dir in ['forward', 'backward']:
1591+
# weights
1592+
self.add_weights_variable(name=f'{dir}_weight', var_name=(f'w_{dir[0]}_' + '{index}'))
1593+
1594+
# recurrent weights
1595+
recurrent_weight = self.get_attr(f'{dir}_recurrent_weight_data')
1596+
self.add_weights_variable(
1597+
name=f'{dir}_recurrent_weight', var_name=(f'wr_{dir[0]}_' + '{index}'), data=recurrent_weight
1598+
)
1599+
1600+
# biases
1601+
self.add_weights_variable(name=f'{dir}_bias', var_name=(f'b_{dir[0]}_' + '{index}'))
1602+
1603+
if self.attributes[f'{dir}_class_name'] == 'LSTM':
1604+
if "pytorch" in self.attributes.keys():
1605+
self.add_weights_variable(name=f'{dir}_recurrent_bias', var_name=(f'br_{dir[0]}_' + '{index}'))
1606+
else:
1607+
recurrent_bias = np.zeros(recurrent_weight.shape[1])
1608+
self.add_weights_variable(
1609+
name=f'{dir}_recurrent_bias', var_name=(f'br_{dir[0]}_' + '{index}'), data=recurrent_bias
1610+
)
1611+
else:
1612+
self.add_weights_variable(
1613+
name=f'{dir}_recurrent_bias', var_name=(f'br_{dir[0]}_' + '{index}'), data=recurrent_bias
1614+
)
1615+
1616+
15411617
class GarNet(Layer):
15421618
ref_impl = False
15431619

@@ -1828,6 +1904,7 @@ def initialize(self):
18281904
'GRU': GRU,
18291905
'BidirectionalLSTM': BidirectionalLSTM,
18301906
'BidirectionalGRU': BidirectionalGRU,
1907+
'Bidirectional': Bidirectional,
18311908
'QSimpleRNN': SimpleRNN,
18321909
'QLSTM': LSTM,
18331910
'QGRU': GRU,

0 commit comments

Comments
 (0)