Skip to content

Bidirectional RNN layer support for Keras frontend and Vitis backend #1310

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
f975066
ADD parsing for bidirectional RNN layers
Apr 3, 2025
8121281
Implement bidirectional rnn layers
Apr 16, 2025
fc1e950
ADD fixes
Apr 25, 2025
f47bb5a
FIX resource strategy
May 16, 2025
c469793
FIX infer precision for bidirectional rnn
May 16, 2025
4c3d26e
FIX eliminate activation after bidirectional rnn
May 16, 2025
5eef679
FIX bidirectional layers name
May 16, 2025
a9546c7
ADD tests for bidirectional layer
May 16, 2025
0246dae
FIX weight name and ADD backward layer architecture check
May 16, 2025
7428af7
FIX static and non-static Bidirectional layers
May 16, 2025
d2c6cc0
ADD parse general bidirectional layer with possibly different archite…
May 16, 2025
edf7cdf
ADD paring for general bidirectional layer
May 19, 2025
d882310
ADD gnerale bidirectional wrapper
May 28, 2025
4ed22c9
ADD Bidirectional layers support
Jun 11, 2025
dd4f220
ADD support for reverse order layers
Jun 11, 2025
de803b7
ADD feature check for merge mode and layers order
Jun 11, 2025
070fdc2
ADD io type feature check
Jun 11, 2025
b65c730
FIX n_out in case of merge_mode != concat
Jun 12, 2025
b55cd04
ADD pytest for Bidirectional layer
Jun 12, 2025
e8fae54
FIX posible directions for LSTM and GRU
Jun 12, 2025
a1500e4
FIX spelling mistake
Jun 12, 2025
1c16616
FIX order
Jun 12, 2025
2fc981c
FIX remove unused import
Jun 12, 2025
48f4fe2
FIX blank space
Jun 13, 2025
64ab715
Merge branch 'main' into vivado_bidir_general
enlupi Jun 23, 2025
1b919c8
Merge branch 'main' into vivado_bidir_general
JanFSchulte Jun 27, 2025
20ff35e
Merge branch 'fastmachinelearning:main' into vivado_bidir_general
enlupi Jul 14, 2025
734d42f
RM old comments
enlupi Jul 15, 2025
7c128f7
MV check for out-of-order layers from passes to parsing
enlupi Jul 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions hls4ml/backends/fpga/fpga_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Activation,
BatchNormalization,
BatchNormOnnx,
Bidirectional,
Conv,
Conv1D,
Conv2D,
Expand Down Expand Up @@ -68,6 +69,7 @@ def __init__(self, name):
SimpleRNN,
LSTM,
GRU,
Bidirectional,
Dot,
Conv,
MatMul,
Expand Down Expand Up @@ -232,6 +234,16 @@ def get_layer_mult_size(self, layer):
n_out_recr = n_out
return n_in, n_out, n_in_recr, n_out_recr

if 'Bidirectional' in layer.class_name:
result = []
for d in ['forward', 'backward']:
n_in = layer.get_attr('n_in')
n_out = layer.get_attr(f'{d}_n_states') * 3
n_in_recr = layer.get_attr(f'{d}_n_states')
n_out_recr = n_out
result.append((n_in, n_out, n_in_recr, n_out_recr))
return result

raise Exception(f'Cannot get mult size for layer {layer.name} ({layer.class_name})')

def get_valid_reuse_factors(self, n_in, n_out):
Expand Down Expand Up @@ -282,6 +294,7 @@ def set_closest_reuse_factor(self, layer, n_in, n_out, attribute='reuse_factor',
if not include_max_rf:
valid_rf.pop()
chosen_rf = layer.get_attr(attribute)
print("\n\nREuse factor:", chosen_rf, "\n\n")
if chosen_rf not in valid_rf:
closest_rf = self.get_closest_reuse_factor(valid_rf, chosen_rf)
valid_rf_str = ','.join(map(str, valid_rf))
Expand Down
38 changes: 38 additions & 0 deletions hls4ml/backends/vitis/passes/feature_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,41 @@ def transform(self, model, node):
f'WARNING: "ResourceUnrolled" strategy in "{node.name}" ({node.class_name}) may have unexpected II in'
'Vitis backend.\nVerify that the final design satisfies the latency/II constraints.'
)


class ValidateBidirectionalMergeMode(OptimizerPass):
_unrolled_layer_cls = ['Bidirectional']

def match(self, node):
is_bidirectional_rnn_layer = (
len([layer_cls for layer_cls in self._unrolled_layer_cls if layer_cls in node.class_name]) > 0
)
is_merge_mode_not_concat = node.get_attr('merge_mode', 'concat') != 'concat'

return is_bidirectional_rnn_layer and is_merge_mode_not_concat

def transform(self, model, node):
merge_mode = node.get_attr('merge_mode', 'concat')
print(
f'WARNING: "{merge_mode}" merge mode in "{node.name}" ({node.class_name}) is not supported in Vitis backend. '
'Switching to "concat" merge mode.'
)
node.set_attr('merge_mode', 'concat')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we doing this here instead of just doing it during the parsing in the converter?



class ValidateBidirectionalIoType(OptimizerPass):
_unrolled_layer_cls = ['Bidirectional']

def match(self, node):
is_bidirectional_rnn_layer = (
len([layer_cls for layer_cls in self._unrolled_layer_cls if layer_cls in node.class_name]) > 0
)
is_layer_io_type_stream = node.model.config.config['IOType'] != 'io_parallel'

return is_bidirectional_rnn_layer and is_layer_io_type_stream

def transform(self, model, node):
raise Exception(
f'WARNING: "{node.model.config.config["IOType"]}" IO Type is not supported in Vitis backend '
f'for "{node.name}" ({node.class_name}). Please use "io_parallel".'
)
3 changes: 3 additions & 0 deletions hls4ml/backends/vitis/vitis_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ def _register_flows(self):
'vitis:validate_conv_implementation',
'vitis:validate_resource_strategy',
'vitis:validate_resource_unrolled_strategy',
'vitis:validate_bidirectional_merge_mode',
'vitis:validate_bidirectional_layer_order',
'vitis:validate_bidirectional_io_type',
]
validation_flow = register_flow('validation', validation_passes, requires=['vivado:init_layers'], backend=self.name)

Expand Down
214 changes: 213 additions & 1 deletion hls4ml/backends/vivado/passes/recurrent_templates.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from hls4ml.backends.backend import get_backend
from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate
from hls4ml.model.layers import GRU, LSTM, TimeDistributed
from hls4ml.model.layers import GRU, LSTM, Bidirectional, TimeDistributed

# recurrent multiplication template

Expand Down Expand Up @@ -86,10 +86,52 @@
static const bool pytorch_order = {pytorch};
}};\n"""

# Bidirectional templates

single_config_template = """struct config{index} : nnet::single_layer_config {{
typedef {accum_t.name} accum_t;
typedef {weight_t.name} weight_t; // Matrix
typedef {recurrent_weight_t.name} recurrent_weight_t; // Matrix
typedef {bias_t.name} bias_t; // Vector
typedef {recurrent_bias_t.name} recurrent_bias_t; // Vector
typedef {config_mult_t1} mult_config1;
typedef {config_mult_t2} mult_config2;
typedef {recr_act_t} ACT_CONFIG_{RECR_TYPE};
template<class x_T, class y_T, class config_T>
using activation_recr = nnet::activation::{recurrent_activation}<x_T, y_T, config_T>;
typedef {act_t} ACT_CONFIG_T;
template<class x_T, class y_T, class config_T>
using activation = nnet::activation::{activation}<x_T, y_T, config_T>;
static const unsigned n_in = {n_in};
static const unsigned n_state = {n_state};
static const unsigned n_mult = {n_mult};
static const bool pytorch_order = {pytorch};
}};\n"""

bidirectional_config_template = """struct config{index} : nnet::bidirectional_config {{
typedef {forward_t} FORWARD_CONFIG;
template<class x_T, class y_T, typename config_T, bool backward>
using RNNfunc_forward = nnet::{forward_layer}<x_T, y_T, config_T, backward>;
typedef {backward_t} BACKWARD_CONFIG;
template<class x_T, class y_T, typename config_T, bool backward>
using RNNfunc_backward = nnet::{backward_layer}<x_T, y_T, config_T, backward>;
static const unsigned n_in = {n_in};
static const unsigned n_out = {n_out};
static const unsigned n_sequence = {n_sequence};
static const unsigned n_sequence_out = {n_sequence_out};
static const unsigned io_type = nnet::{strategy};
static const unsigned reuse_factor = {reuse};
static const bool store_weights_in_bram = false;
static const bool use_static = {static};
static const bool pytorch_order = {pytorch};
}};\n"""

recr_function_template = 'nnet::{recr_type}_stack<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {wr}, {b}, {br});'
recr_function_template_initial_states_lstm = 'nnet::{recr_type}_stack<{input_t}, {input2_t}, {input3_t}, {output_t}, {config}>({input}, {input2}, {input3}, {output}, {w}, {wr}, {b}, {br});' # noqa: E501
recr_function_template_initial_states_gru = 'nnet::{recr_type}_stack<{input_t}, {input2_t}, {output_t}, {config}>({input}, {input2}, {output}, {w}, {wr}, {b}, {br});' # noqa: E501

bidirectional_function_template = 'nnet::bidirectional_stack<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {wr}, {b}, {br}, {w_b}, {wr_b}, {b_b}, {br_b});' # noqa: E501

recr_include_list = ['nnet_utils/nnet_recurrent.h']


Expand Down Expand Up @@ -207,6 +249,153 @@ def format(self, node):
return mult_config1 + '\n' + mult_config2 + '\n' + recr_act_config + '\n' + act_config + '\n' + recr_config


class BidirectionalConfigTemplate(LayerConfigTemplate):
def __init__(self):
super().__init__(Bidirectional)
self.template = bidirectional_config_template
self.layer_template = single_config_template
self.act_template = activ_config_template
self.recr_act_template = recr_activ_config_template
self.mult1_template = recr_mult_config_template_1
self.mult2_template = recr_mult_config_template_2

def format(self, node):

# ----- Bidirectional Layer Config -----#
params = self._default_config_params(node)

params['n_in'] = node.get_input_variable().dim_names[1]
params['n_sequence'] = node.get_input_variable().dim_names[0]
if node.get_attr('return_sequences'):
params['n_sequence_out'] = node.get_output_variable().dim_names[0]
else:
params['n_sequence_out'] = 1
params['n_out'] = node.get_attr('n_out')
params['strategy'] = node.get_attr('strategy')
params['static'] = 'true' if node.attributes['static'] else 'false'
params['pytorch'] = 'true' if node.get_attr('pytorch', False) else 'false'
params['forward_t'] = f'config{node.index}_forward'
params['backward_t'] = f'config{node.index}_backward'
params['forward_layer'] = node.get_attr('forward_class_name').lower() + '_class'
params['backward_layer'] = node.get_attr('backward_class_name').lower() + '_class'
if node.attributes['static']:
params['forward_layer'] += '_static'
params['backward_layer'] += '_static'

recr_config = self.template.format(**params)

# ----- Forward and Backward Layers Config -----#
result = ''
for d in ['forward', 'backward']:
if node.get_attr(f'{d}_class_name') == 'LSTM':
n_recr_mult = 4
else: # GRU
n_recr_mult = 3

# ----- Layer Config -----#
layer_params = self._default_config_params(node)
layer_params['n_in'] = params['n_in']
layer_params['pytorch'] = params['pytorch']
layer_params['n_state'] = node.get_attr(f'{d}_n_states')
layer_params['n_mult'] = 4
if node.get_attr(f'{d}_class_name').lower() == 'gru':
layer_params['n_mult'] = 3
layer_params['config_mult_t1'] = f'config{node.index}_1_{d[0]}'
layer_params['config_mult_t2'] = f'config{node.index}_2_{d[0]}'
layer_params['recr_act_t'] = '{}_config{}_recr'.format(
node.get_attr(f'{d}_recurrent_activation'), str(node.index) + f'_{d[0]}'
)
layer_params['act_t'] = '{}_config{}'.format(node.get_attr(f'{d}_activation'), str(node.index) + f'_{d[0]}')
layer_params['RECR_TYPE'] = node.get_attr(f'{d}_class_name')

layer_params['weight_t'] = layer_params[f'{d}_weight_t']
layer_params['recurrent_weight_t'] = layer_params[f'{d}_recurrent_weight_t']
layer_params['bias_t'] = layer_params[f'{d}_bias_t']
layer_params['recurrent_bias_t'] = layer_params[f'{d}_recurrent_bias_t']
layer_params['activation'] = layer_params[f'{d}_activation']
layer_params['recurrent_activation'] = layer_params[f'{d}_recurrent_activation']

layer_params['index'] = str(node.index) + f'_{d}'

layer_config = self.layer_template.format(**layer_params)

# ----- Activations Config -----#
act_params = self._default_config_params(node)
recr_act_params = self._default_config_params(node)

act_params['type'] = node.get_attr(f'{d}_activation')
recr_act_params['type'] = node.get_attr(f'{d}_recurrent_activation')
act_params['index'] = str(node.index) + f'_{d[0]}'
recr_act_params['index'] = str(node.index) + f'_{d[0]}'
act_params['n_in'] = node.get_attr(f'{d}_n_states')
recr_act_params['n_in'] = node.get_attr(f'{d}_n_states') * (n_recr_mult - 1)

act_config = self.act_template.format(**act_params)
recr_act_config = self.recr_act_template.format(**recr_act_params)

# ----- Mult Config -----#
mult_params1 = self._default_config_params(node)
mult_params2 = self._default_config_params(node)

mult_params1['n_in'] = node.get_input_variable().shape[1]
mult_params1['n_out'] = node.get_attr(f'{d}_n_states') * n_recr_mult
mult_params1['product_type'] = get_backend('vivado').product_type(
node.get_input_variable().type.precision, node.get_weights(f'{d}_weight').type.precision
)
mult_params1['reuse'] = params['reuse']
mult_params1['index'] = str(node.index) + f'_1_{d[0]}'
mult_params1['nzeros'] = node.get_weights(f'{d}_weight').nzeros
mult_params1['nonzeros'] = node.get_weights(f'{d}_weight').nonzeros

mult_params1['bias_t'] = mult_params1[f'{d}_bias_t']
mult_params1['weight_t'] = mult_params1[f'{d}_weight_t']
mult_params2['recurrent_bias_t'] = mult_params2[f'{d}_recurrent_bias_t']
mult_params2['recurrent_weight_t'] = mult_params2[f'{d}_recurrent_weight_t']

namespace = params['namespace']

if node.get_attr('strategy').lower() == 'latency':
mult_params1['dense_function'] = 'nnet::DenseLatency'
elif node.get_attr('strategy').lower() == 'resource':
if int(mult_params1[f'{d}_reuse_factor']) <= int(mult_params1['n_in']):
mult_params1['dense_function'] = 'nnet::DenseResource_rf_leq_nin'
else:
mult_params1['dense_function'] = 'nnet::DenseResource_rf_gt_nin_rem0'
# The 3rd case is never used
elif node.get_attr('strategy').lower() == 'resource_unrolled':
mult_params1['dense_function'] = f'{namespace}::dense_resource_unrolled_{node.index}_1'

mult_params2['n_in'] = node.get_attr(f'{d}_n_states')
mult_params2['n_out'] = node.get_attr(f'{d}_n_states') * n_recr_mult
mult_params2['product_type'] = get_backend('vivado').product_type(
node.get_input_variable().type.precision, node.get_weights(f'{d}_recurrent_weight').type.precision
)
mult_params2['reuse'] = node.attributes[f'{d}_recurrent_reuse_factor']
mult_params2['index'] = str(node.index) + f'_2_{d[0]}'
mult_params2['nzeros'] = node.get_weights(f'{d}_recurrent_weight').nzeros
mult_params2['nonzeros'] = node.get_weights(f'{d}_recurrent_weight').nonzeros

if node.get_attr('strategy').lower() == 'latency':
mult_params2['dense_function'] = 'nnet::DenseLatency'
elif node.get_attr('strategy').lower() == 'resource':
if int(mult_params2[f'{d}_reuse_factor']) <= int(mult_params2['n_in']):
mult_params2['dense_function'] = 'nnet::DenseResource_rf_leq_nin'
else:
mult_params2['dense_function'] = 'nnet::DenseResource_rf_gt_nin_rem0'
# The 3rd case is never used
elif node.get_attr('strategy').lower() == 'resource_unrolled':
mult_params2['dense_function'] = f'{namespace}::dense_resource_unrolled_{node.index}_2'

mult_config1 = self.mult1_template.format(**mult_params1)
mult_config2 = self.mult2_template.format(**mult_params2)

result += (
mult_config1 + '\n' + mult_config2 + '\n' + recr_act_config + '\n' + act_config + '\n' + layer_config + '\n'
)

return result + recr_config


class RecurrentFunctionTemplate(FunctionCallTemplate):
def __init__(self):
super().__init__((LSTM, GRU), include_header=recr_include_list)
Expand Down Expand Up @@ -239,6 +428,29 @@ def format(self, node):
return template.format(**params)


class BidirectionalFunctionTemplate(FunctionCallTemplate):
def __init__(self):
super().__init__((Bidirectional), include_header=recr_include_list)

def format(self, node):
params = self._default_function_params(node)

# TO DO: Add initial states functions for pytorch settings

params['w'] = node.get_weights('forward_weight').name
params['b'] = node.get_weights('forward_bias').name
params['wr'] = node.get_weights('forward_recurrent_weight').name
params['br'] = node.get_weights('forward_recurrent_bias').name
params['w_b'] = node.get_weights('backward_weight').name
params['b_b'] = node.get_weights('backward_bias').name
params['wr_b'] = node.get_weights('backward_recurrent_weight').name
params['br_b'] = node.get_weights('backward_recurrent_bias').name

template = bidirectional_function_template

return template.format(**params)


time_distributed_config_template = """struct config{index} : nnet::time_distributed_config {{
static const unsigned dim = {dim};

Expand Down
18 changes: 15 additions & 3 deletions hls4ml/backends/vivado/passes/resource_strategy.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
import numpy as np

from hls4ml.model.layers import GRU, LSTM, Conv1D, Conv2D, Dense, SeparableConv1D, SeparableConv2D
from hls4ml.model.layers import (
GRU,
LSTM,
Bidirectional,
Conv1D,
Conv2D,
Dense,
SeparableConv1D,
SeparableConv2D,
)
from hls4ml.model.optimizer import OptimizerPass


class ApplyResourceStrategy(OptimizerPass):
'''Transposes the weights to use the dense_resource matrix multiply routine'''

def match(self, node):
node_matches = isinstance(node, (Dense, Conv1D, SeparableConv1D, Conv2D, SeparableConv2D, LSTM, GRU))
node_matches = isinstance(node, (Dense, Conv1D, SeparableConv1D, Conv2D, SeparableConv2D, LSTM, GRU, Bidirectional))
is_resource_strategy = node.get_attr('strategy', '').lower() in ['resource', 'resource_unrolled']
already_transformed = node.get_attr('_weights_transposed', False) is True

return node_matches and is_resource_strategy and not already_transformed

def transform(self, model, node):
Expand All @@ -37,6 +45,10 @@ def transform(self, model, node):
node.weights['pointwise'].data = np.transpose(
node.weights['pointwise'].data, axes=[3, 0, 1, 2]
) # (H,W,C,F) => (F,H,W,C)
elif isinstance(node, (Bidirectional)):
for d in ['forward', 'backward']:
node.weights[f'{d}_weight'].data = np.transpose(node.weights[f'{d}_weight'].data)
node.weights[f'{d}_recurrent_weight'].data = np.transpose(node.weights[f'{d}_recurrent_weight'].data)
elif isinstance(node, (LSTM, GRU)):
node.weights['weight'].data = np.transpose(node.weights['weight'].data)
node.weights['recurrent_weight'].data = np.transpose(node.weights['recurrent_weight'].data)
Expand Down
Loading
Loading