diff --git a/hls4ml/backends/fpga/fpga_backend.py b/hls4ml/backends/fpga/fpga_backend.py index f0b603ab24..f196ce30ce 100644 --- a/hls4ml/backends/fpga/fpga_backend.py +++ b/hls4ml/backends/fpga/fpga_backend.py @@ -14,6 +14,7 @@ Activation, BatchNormalization, BatchNormOnnx, + Bidirectional, Conv, Conv1D, Conv2D, @@ -68,6 +69,7 @@ def __init__(self, name): SimpleRNN, LSTM, GRU, + Bidirectional, Dot, Conv, MatMul, @@ -232,6 +234,16 @@ def get_layer_mult_size(self, layer): n_out_recr = n_out return n_in, n_out, n_in_recr, n_out_recr + if 'Bidirectional' in layer.class_name: + result = [] + for d in ['forward', 'backward']: + n_in = layer.get_attr('n_in') + n_out = layer.get_attr(f'{d}_n_states') * 3 + n_in_recr = layer.get_attr(f'{d}_n_states') + n_out_recr = n_out + result.append((n_in, n_out, n_in_recr, n_out_recr)) + return result + raise Exception(f'Cannot get mult size for layer {layer.name} ({layer.class_name})') def get_valid_reuse_factors(self, n_in, n_out): @@ -282,6 +294,7 @@ def set_closest_reuse_factor(self, layer, n_in, n_out, attribute='reuse_factor', if not include_max_rf: valid_rf.pop() chosen_rf = layer.get_attr(attribute) + print("\n\nREuse factor:", chosen_rf, "\n\n") if chosen_rf not in valid_rf: closest_rf = self.get_closest_reuse_factor(valid_rf, chosen_rf) valid_rf_str = ','.join(map(str, valid_rf)) diff --git a/hls4ml/backends/vitis/passes/feature_check.py b/hls4ml/backends/vitis/passes/feature_check.py index a38f6581f6..48f87168bc 100644 --- a/hls4ml/backends/vitis/passes/feature_check.py +++ b/hls4ml/backends/vitis/passes/feature_check.py @@ -49,3 +49,41 @@ def transform(self, model, node): f'WARNING: "ResourceUnrolled" strategy in "{node.name}" ({node.class_name}) may have unexpected II in' 'Vitis backend.\nVerify that the final design satisfies the latency/II constraints.' ) + + +class ValidateBidirectionalMergeMode(OptimizerPass): + _unrolled_layer_cls = ['Bidirectional'] + + def match(self, node): + is_bidirectional_rnn_layer = ( + len([layer_cls for layer_cls in self._unrolled_layer_cls if layer_cls in node.class_name]) > 0 + ) + is_merge_mode_not_concat = node.get_attr('merge_mode', 'concat') != 'concat' + + return is_bidirectional_rnn_layer and is_merge_mode_not_concat + + def transform(self, model, node): + merge_mode = node.get_attr('merge_mode', 'concat') + print( + f'WARNING: "{merge_mode}" merge mode in "{node.name}" ({node.class_name}) is not supported in Vitis backend. ' + 'Switching to "concat" merge mode.' + ) + node.set_attr('merge_mode', 'concat') + + +class ValidateBidirectionalIoType(OptimizerPass): + _unrolled_layer_cls = ['Bidirectional'] + + def match(self, node): + is_bidirectional_rnn_layer = ( + len([layer_cls for layer_cls in self._unrolled_layer_cls if layer_cls in node.class_name]) > 0 + ) + is_layer_io_type_stream = node.model.config.config['IOType'] != 'io_parallel' + + return is_bidirectional_rnn_layer and is_layer_io_type_stream + + def transform(self, model, node): + raise Exception( + f'WARNING: "{node.model.config.config["IOType"]}" IO Type is not supported in Vitis backend ' + f'for "{node.name}" ({node.class_name}). Please use "io_parallel".' + ) diff --git a/hls4ml/backends/vitis/vitis_backend.py b/hls4ml/backends/vitis/vitis_backend.py index f72818a279..58fa7f2bc6 100644 --- a/hls4ml/backends/vitis/vitis_backend.py +++ b/hls4ml/backends/vitis/vitis_backend.py @@ -28,6 +28,9 @@ def _register_flows(self): 'vitis:validate_conv_implementation', 'vitis:validate_resource_strategy', 'vitis:validate_resource_unrolled_strategy', + 'vitis:validate_bidirectional_merge_mode', + 'vitis:validate_bidirectional_layer_order', + 'vitis:validate_bidirectional_io_type', ] validation_flow = register_flow('validation', validation_passes, requires=['vivado:init_layers'], backend=self.name) diff --git a/hls4ml/backends/vivado/passes/recurrent_templates.py b/hls4ml/backends/vivado/passes/recurrent_templates.py index 6934e82e4e..6f03d674ad 100644 --- a/hls4ml/backends/vivado/passes/recurrent_templates.py +++ b/hls4ml/backends/vivado/passes/recurrent_templates.py @@ -1,6 +1,6 @@ from hls4ml.backends.backend import get_backend from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate -from hls4ml.model.layers import GRU, LSTM, TimeDistributed +from hls4ml.model.layers import GRU, LSTM, Bidirectional, TimeDistributed # recurrent multiplication template @@ -86,10 +86,52 @@ static const bool pytorch_order = {pytorch}; }};\n""" +# Bidirectional templates + +single_config_template = """struct config{index} : nnet::single_layer_config {{ + typedef {accum_t.name} accum_t; + typedef {weight_t.name} weight_t; // Matrix + typedef {recurrent_weight_t.name} recurrent_weight_t; // Matrix + typedef {bias_t.name} bias_t; // Vector + typedef {recurrent_bias_t.name} recurrent_bias_t; // Vector + typedef {config_mult_t1} mult_config1; + typedef {config_mult_t2} mult_config2; + typedef {recr_act_t} ACT_CONFIG_{RECR_TYPE}; + template + using activation_recr = nnet::activation::{recurrent_activation}; + typedef {act_t} ACT_CONFIG_T; + template + using activation = nnet::activation::{activation}; + static const unsigned n_in = {n_in}; + static const unsigned n_state = {n_state}; + static const unsigned n_mult = {n_mult}; + static const bool pytorch_order = {pytorch}; +}};\n""" + +bidirectional_config_template = """struct config{index} : nnet::bidirectional_config {{ + typedef {forward_t} FORWARD_CONFIG; + template + using RNNfunc_forward = nnet::{forward_layer}; + typedef {backward_t} BACKWARD_CONFIG; + template + using RNNfunc_backward = nnet::{backward_layer}; + static const unsigned n_in = {n_in}; + static const unsigned n_out = {n_out}; + static const unsigned n_sequence = {n_sequence}; + static const unsigned n_sequence_out = {n_sequence_out}; + static const unsigned io_type = nnet::{strategy}; + static const unsigned reuse_factor = {reuse}; + static const bool store_weights_in_bram = false; + static const bool use_static = {static}; + static const bool pytorch_order = {pytorch}; +}};\n""" + recr_function_template = 'nnet::{recr_type}_stack<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {wr}, {b}, {br});' recr_function_template_initial_states_lstm = 'nnet::{recr_type}_stack<{input_t}, {input2_t}, {input3_t}, {output_t}, {config}>({input}, {input2}, {input3}, {output}, {w}, {wr}, {b}, {br});' # noqa: E501 recr_function_template_initial_states_gru = 'nnet::{recr_type}_stack<{input_t}, {input2_t}, {output_t}, {config}>({input}, {input2}, {output}, {w}, {wr}, {b}, {br});' # noqa: E501 +bidirectional_function_template = 'nnet::bidirectional_stack<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {wr}, {b}, {br}, {w_b}, {wr_b}, {b_b}, {br_b});' # noqa: E501 + recr_include_list = ['nnet_utils/nnet_recurrent.h'] @@ -207,6 +249,153 @@ def format(self, node): return mult_config1 + '\n' + mult_config2 + '\n' + recr_act_config + '\n' + act_config + '\n' + recr_config +class BidirectionalConfigTemplate(LayerConfigTemplate): + def __init__(self): + super().__init__(Bidirectional) + self.template = bidirectional_config_template + self.layer_template = single_config_template + self.act_template = activ_config_template + self.recr_act_template = recr_activ_config_template + self.mult1_template = recr_mult_config_template_1 + self.mult2_template = recr_mult_config_template_2 + + def format(self, node): + + # ----- Bidirectional Layer Config -----# + params = self._default_config_params(node) + + params['n_in'] = node.get_input_variable().dim_names[1] + params['n_sequence'] = node.get_input_variable().dim_names[0] + if node.get_attr('return_sequences'): + params['n_sequence_out'] = node.get_output_variable().dim_names[0] + else: + params['n_sequence_out'] = 1 + params['n_out'] = node.get_attr('n_out') + params['strategy'] = node.get_attr('strategy') + params['static'] = 'true' if node.attributes['static'] else 'false' + params['pytorch'] = 'true' if node.get_attr('pytorch', False) else 'false' + params['forward_t'] = f'config{node.index}_forward' + params['backward_t'] = f'config{node.index}_backward' + params['forward_layer'] = node.get_attr('forward_class_name').lower() + '_class' + params['backward_layer'] = node.get_attr('backward_class_name').lower() + '_class' + if node.attributes['static']: + params['forward_layer'] += '_static' + params['backward_layer'] += '_static' + + recr_config = self.template.format(**params) + + # ----- Forward and Backward Layers Config -----# + result = '' + for d in ['forward', 'backward']: + if node.get_attr(f'{d}_class_name') == 'LSTM': + n_recr_mult = 4 + else: # GRU + n_recr_mult = 3 + + # ----- Layer Config -----# + layer_params = self._default_config_params(node) + layer_params['n_in'] = params['n_in'] + layer_params['pytorch'] = params['pytorch'] + layer_params['n_state'] = node.get_attr(f'{d}_n_states') + layer_params['n_mult'] = 4 + if node.get_attr(f'{d}_class_name').lower() == 'gru': + layer_params['n_mult'] = 3 + layer_params['config_mult_t1'] = f'config{node.index}_1_{d[0]}' + layer_params['config_mult_t2'] = f'config{node.index}_2_{d[0]}' + layer_params['recr_act_t'] = '{}_config{}_recr'.format( + node.get_attr(f'{d}_recurrent_activation'), str(node.index) + f'_{d[0]}' + ) + layer_params['act_t'] = '{}_config{}'.format(node.get_attr(f'{d}_activation'), str(node.index) + f'_{d[0]}') + layer_params['RECR_TYPE'] = node.get_attr(f'{d}_class_name') + + layer_params['weight_t'] = layer_params[f'{d}_weight_t'] + layer_params['recurrent_weight_t'] = layer_params[f'{d}_recurrent_weight_t'] + layer_params['bias_t'] = layer_params[f'{d}_bias_t'] + layer_params['recurrent_bias_t'] = layer_params[f'{d}_recurrent_bias_t'] + layer_params['activation'] = layer_params[f'{d}_activation'] + layer_params['recurrent_activation'] = layer_params[f'{d}_recurrent_activation'] + + layer_params['index'] = str(node.index) + f'_{d}' + + layer_config = self.layer_template.format(**layer_params) + + # ----- Activations Config -----# + act_params = self._default_config_params(node) + recr_act_params = self._default_config_params(node) + + act_params['type'] = node.get_attr(f'{d}_activation') + recr_act_params['type'] = node.get_attr(f'{d}_recurrent_activation') + act_params['index'] = str(node.index) + f'_{d[0]}' + recr_act_params['index'] = str(node.index) + f'_{d[0]}' + act_params['n_in'] = node.get_attr(f'{d}_n_states') + recr_act_params['n_in'] = node.get_attr(f'{d}_n_states') * (n_recr_mult - 1) + + act_config = self.act_template.format(**act_params) + recr_act_config = self.recr_act_template.format(**recr_act_params) + + # ----- Mult Config -----# + mult_params1 = self._default_config_params(node) + mult_params2 = self._default_config_params(node) + + mult_params1['n_in'] = node.get_input_variable().shape[1] + mult_params1['n_out'] = node.get_attr(f'{d}_n_states') * n_recr_mult + mult_params1['product_type'] = get_backend('vivado').product_type( + node.get_input_variable().type.precision, node.get_weights(f'{d}_weight').type.precision + ) + mult_params1['reuse'] = params['reuse'] + mult_params1['index'] = str(node.index) + f'_1_{d[0]}' + mult_params1['nzeros'] = node.get_weights(f'{d}_weight').nzeros + mult_params1['nonzeros'] = node.get_weights(f'{d}_weight').nonzeros + + mult_params1['bias_t'] = mult_params1[f'{d}_bias_t'] + mult_params1['weight_t'] = mult_params1[f'{d}_weight_t'] + mult_params2['recurrent_bias_t'] = mult_params2[f'{d}_recurrent_bias_t'] + mult_params2['recurrent_weight_t'] = mult_params2[f'{d}_recurrent_weight_t'] + + namespace = params['namespace'] + + if node.get_attr('strategy').lower() == 'latency': + mult_params1['dense_function'] = 'nnet::DenseLatency' + elif node.get_attr('strategy').lower() == 'resource': + if int(mult_params1[f'{d}_reuse_factor']) <= int(mult_params1['n_in']): + mult_params1['dense_function'] = 'nnet::DenseResource_rf_leq_nin' + else: + mult_params1['dense_function'] = 'nnet::DenseResource_rf_gt_nin_rem0' + # The 3rd case is never used + elif node.get_attr('strategy').lower() == 'resource_unrolled': + mult_params1['dense_function'] = f'{namespace}::dense_resource_unrolled_{node.index}_1' + + mult_params2['n_in'] = node.get_attr(f'{d}_n_states') + mult_params2['n_out'] = node.get_attr(f'{d}_n_states') * n_recr_mult + mult_params2['product_type'] = get_backend('vivado').product_type( + node.get_input_variable().type.precision, node.get_weights(f'{d}_recurrent_weight').type.precision + ) + mult_params2['reuse'] = node.attributes[f'{d}_recurrent_reuse_factor'] + mult_params2['index'] = str(node.index) + f'_2_{d[0]}' + mult_params2['nzeros'] = node.get_weights(f'{d}_recurrent_weight').nzeros + mult_params2['nonzeros'] = node.get_weights(f'{d}_recurrent_weight').nonzeros + + if node.get_attr('strategy').lower() == 'latency': + mult_params2['dense_function'] = 'nnet::DenseLatency' + elif node.get_attr('strategy').lower() == 'resource': + if int(mult_params2[f'{d}_reuse_factor']) <= int(mult_params2['n_in']): + mult_params2['dense_function'] = 'nnet::DenseResource_rf_leq_nin' + else: + mult_params2['dense_function'] = 'nnet::DenseResource_rf_gt_nin_rem0' + # The 3rd case is never used + elif node.get_attr('strategy').lower() == 'resource_unrolled': + mult_params2['dense_function'] = f'{namespace}::dense_resource_unrolled_{node.index}_2' + + mult_config1 = self.mult1_template.format(**mult_params1) + mult_config2 = self.mult2_template.format(**mult_params2) + + result += ( + mult_config1 + '\n' + mult_config2 + '\n' + recr_act_config + '\n' + act_config + '\n' + layer_config + '\n' + ) + + return result + recr_config + + class RecurrentFunctionTemplate(FunctionCallTemplate): def __init__(self): super().__init__((LSTM, GRU), include_header=recr_include_list) @@ -239,6 +428,29 @@ def format(self, node): return template.format(**params) +class BidirectionalFunctionTemplate(FunctionCallTemplate): + def __init__(self): + super().__init__((Bidirectional), include_header=recr_include_list) + + def format(self, node): + params = self._default_function_params(node) + + # TO DO: Add initial states functions for pytorch settings + + params['w'] = node.get_weights('forward_weight').name + params['b'] = node.get_weights('forward_bias').name + params['wr'] = node.get_weights('forward_recurrent_weight').name + params['br'] = node.get_weights('forward_recurrent_bias').name + params['w_b'] = node.get_weights('backward_weight').name + params['b_b'] = node.get_weights('backward_bias').name + params['wr_b'] = node.get_weights('backward_recurrent_weight').name + params['br_b'] = node.get_weights('backward_recurrent_bias').name + + template = bidirectional_function_template + + return template.format(**params) + + time_distributed_config_template = """struct config{index} : nnet::time_distributed_config {{ static const unsigned dim = {dim}; diff --git a/hls4ml/backends/vivado/passes/resource_strategy.py b/hls4ml/backends/vivado/passes/resource_strategy.py index 0c06190f30..407adc9d03 100644 --- a/hls4ml/backends/vivado/passes/resource_strategy.py +++ b/hls4ml/backends/vivado/passes/resource_strategy.py @@ -1,6 +1,15 @@ import numpy as np -from hls4ml.model.layers import GRU, LSTM, Conv1D, Conv2D, Dense, SeparableConv1D, SeparableConv2D +from hls4ml.model.layers import ( + GRU, + LSTM, + Bidirectional, + Conv1D, + Conv2D, + Dense, + SeparableConv1D, + SeparableConv2D, +) from hls4ml.model.optimizer import OptimizerPass @@ -8,10 +17,9 @@ class ApplyResourceStrategy(OptimizerPass): '''Transposes the weights to use the dense_resource matrix multiply routine''' def match(self, node): - node_matches = isinstance(node, (Dense, Conv1D, SeparableConv1D, Conv2D, SeparableConv2D, LSTM, GRU)) + node_matches = isinstance(node, (Dense, Conv1D, SeparableConv1D, Conv2D, SeparableConv2D, LSTM, GRU, Bidirectional)) is_resource_strategy = node.get_attr('strategy', '').lower() in ['resource', 'resource_unrolled'] already_transformed = node.get_attr('_weights_transposed', False) is True - return node_matches and is_resource_strategy and not already_transformed def transform(self, model, node): @@ -37,6 +45,10 @@ def transform(self, model, node): node.weights['pointwise'].data = np.transpose( node.weights['pointwise'].data, axes=[3, 0, 1, 2] ) # (H,W,C,F) => (F,H,W,C) + elif isinstance(node, (Bidirectional)): + for d in ['forward', 'backward']: + node.weights[f'{d}_weight'].data = np.transpose(node.weights[f'{d}_weight'].data) + node.weights[f'{d}_recurrent_weight'].data = np.transpose(node.weights[f'{d}_recurrent_weight'].data) elif isinstance(node, (LSTM, GRU)): node.weights['weight'].data = np.transpose(node.weights['weight'].data) node.weights['recurrent_weight'].data = np.transpose(node.weights['recurrent_weight'].data) diff --git a/hls4ml/backends/vivado/vivado_backend.py b/hls4ml/backends/vivado/vivado_backend.py index 43280ac934..285eab06cc 100644 --- a/hls4ml/backends/vivado/vivado_backend.py +++ b/hls4ml/backends/vivado/vivado_backend.py @@ -12,6 +12,7 @@ from hls4ml.model.layers import ( GRU, LSTM, + Bidirectional, Conv1D, Conv2D, Dense, @@ -45,11 +46,7 @@ def __init__(self): def _register_layer_attributes(self): # Add RNN-specific attributes, recurrent_reuse_factor and static implementation - rnn_layers = [ - SimpleRNN, - LSTM, - GRU, - ] + rnn_layers = [SimpleRNN, LSTM, GRU] for layer in rnn_layers: attrs = self.attribute_map.get(layer, []) @@ -61,6 +58,24 @@ def _register_layer_attributes(self): attrs.append(TypeAttribute('table', default=FixedPrecisionType(18, 8), description=descriptions.table_type)) self.attribute_map[layer] = attrs + bidir_rnn_layers = [Bidirectional] + for layer in bidir_rnn_layers: + attrs = self.attribute_map.get(layer, []) + attrs.append(ConfigurableAttribute('forward_reuse_factor', default=1, description=descriptions.reuse_factor)) + attrs.append(ConfigurableAttribute('backward_reuse_factor', default=1, description=descriptions.reuse_factor)) + attrs.append( + ConfigurableAttribute('forward_recurrent_reuse_factor', default=1, description=descriptions.reuse_factor) + ) + attrs.append( + ConfigurableAttribute('backward_recurrent_reuse_factor', default=1, description=descriptions.reuse_factor) + ) + attrs.append( + ConfigurableAttribute('static', value_type=bool, default=True, description=descriptions.recurrent_static) + ) + attrs.append(ConfigurableAttribute('table_size', default=1024, description=descriptions.table_size)) + attrs.append(TypeAttribute('table', default=FixedPrecisionType(18, 8), description=descriptions.table_type)) + self.attribute_map[layer] = attrs + # Add ParallelizationFactor to Conv1D/2D pf_layers = [ Conv1D, @@ -672,6 +687,45 @@ def init_time_distributed(self, layer): loop_mode = 'off' layer.set_attr('time_step_loop_parallelism', loop_mode) + @layer_optimizer(Bidirectional) + def init_bidirectional(self, layer): + reuse_factor = layer.model.config.get_reuse_factor(layer) + + for i, d in enumerate(['forward', 'backward']): + layer.set_attr(f'{d}_reuse_factor', reuse_factor) + layer.set_attr(f'{d}_recurrent_reuse_factor', reuse_factor) + + if layer.model.config.is_resource_strategy(layer): + n_in, n_out, n_in_recr, n_out_recr = self.get_layer_mult_size(layer)[i] + self.set_closest_reuse_factor(layer, n_in, n_out, attribute=f'{d}_reuse_factor') + self.set_closest_reuse_factor(layer, n_in_recr, n_out_recr, attribute=f'{d}_recurrent_reuse_factor') + layer.set_attr('strategy', 'resource') + + elif layer.model.config.get_strategy(layer).lower() == 'resource_unrolled': + use_resource_instead = False + if layer.get_attr('reuse_factor', 1) == 1: + print( + f'Unrolled resource strategy cannot be combined with reuse factor 1 in layer "{layer.name} ({d})". ' + 'Using "resource" strategy instead.' + ) + use_resource_instead = True + + n_in, n_out, n_in_recr, n_out_recr = self.get_layer_mult_size(layer)[i] + if use_resource_instead: + self.set_closest_reuse_factor(layer, n_in, n_out, attribute=f'{d}_reuse_factor') + self.set_closest_reuse_factor(layer, n_in_recr, n_out_recr, attribute=f'{d}_recurrent_reuse_factor') + layer.set_attr('strategy', 'resource') + else: + self.set_closest_reuse_factor(layer, n_in, n_out, attribute=f'{d}_reuse_factor', include_max_rf=False) + self.set_closest_reuse_factor( + layer, n_in_recr, n_out_recr, attribute=f'{d}_recurrent_reuse_factor', include_max_rf=False + ) + layer.set_attr('strategy', 'resource_unrolled') + else: + layer.set_attr('strategy', 'latency') + + layer.set_attr('index_t', NamedType(f'layer{layer.index}_index', IntegerPrecisionType(width=1, signed=False))) + @layer_optimizer(GarNet) def init_garnet(self, layer): reuse_factor = layer.attributes['reuse_factor'] diff --git a/hls4ml/converters/keras/recurrent.py b/hls4ml/converters/keras/recurrent.py index 9f98b33f76..f27a970c54 100644 --- a/hls4ml/converters/keras/recurrent.py +++ b/hls4ml/converters/keras/recurrent.py @@ -18,6 +18,7 @@ def parse_rnn_layer(keras_layer, input_names, input_shapes, data_reader): assert keras_layer['class_name'] in rnn_layers or keras_layer['class_name'][1:] in rnn_layers layer = parse_default_keras_layer(keras_layer, input_names) + layer['direction'] = 'forward' layer['return_sequences'] = keras_layer['config']['return_sequences'] layer['return_state'] = keras_layer['config']['return_state'] @@ -109,4 +110,107 @@ def parse_time_distributed_layer(keras_layer, input_names, input_shapes, data_re layer['output_shape'] = output_shape[1:] # Remove the batch dimension layer['n_time_steps'] = output_shape[1] + +@keras_handler('Bidirectional') +def parse_bidirectional_layer(keras_layer, input_names, input_shapes, data_reader): + assert keras_layer['class_name'] == 'Bidirectional' + + rnn_forward_layer = keras_layer['config']['layer'] + swapped_order = False + if keras_layer['config'].get('backward_layer'): + rnn_backward_layer = keras_layer['config']['backward_layer'] + if rnn_forward_layer['config']['go_backwards']: + temp_layer = rnn_forward_layer.copy() + rnn_forward_layer = rnn_backward_layer.copy() + rnn_backward_layer = temp_layer + swapped_order = True + print( + f'WARNING: The selected order for forward and backward layers in "{keras_layer['config']['name']}" ' + f'({keras_layer['class_name']}) is not supported in Vitis backend. Switching to forward layer first, backward layer last.' + ) + else: + rnn_backward_layer = rnn_forward_layer + + assert (rnn_forward_layer['class_name'] in rnn_layers or rnn_forward_layer['class_name'][1:] in rnn_layers) and ( + rnn_backward_layer['class_name'] in rnn_layers or rnn_backward_layer['class_name'][1:] in rnn_layers + ) + + layer = {} + layer['name'] = keras_layer['config']['name'] + layer['class_name'] = keras_layer['class_name'] + if input_names is not None: + layer['inputs'] = input_names + + layer['direction'] = 'bidirectional' + layer['return_sequences'] = rnn_forward_layer['config']['return_sequences'] + layer['return_state'] = rnn_forward_layer['config']['return_state'] + layer['time_major'] = rnn_forward_layer['config']['time_major'] if 'time_major' in rnn_forward_layer['config'] else False + # TODO Should we handle time_major? + if layer['time_major']: + raise Exception('Time-major format is not supported by hls4ml') + layer['n_timesteps'] = input_shapes[0][1] + layer['n_in'] = input_shapes[0][2] + layer['merge_mode'] = keras_layer['config']['merge_mode'] + + for direction, rnn_layer in [('forward', rnn_forward_layer), ('backward', rnn_backward_layer)]: + + layer[f'{direction}_name'] = rnn_layer['config']['name'] + layer[f'{direction}_class_name'] = rnn_layer['class_name'] + + layer[f'{direction}_data_format'] = rnn_layer['config'].get('data_format', 'channels_last') + + if 'activation' in rnn_layer['config']: + layer[f'{direction}_activation'] = rnn_layer['config']['activation'] + if 'epsilon' in rnn_layer['config']: + layer[f'{direction}_epsilon'] = rnn_layer['config']['epsilon'] + if 'use_bias' in rnn_layer['config']: + layer[f'{direction}_use_bias'] = rnn_layer['config']['use_bias'] + + if 'SimpleRNN' not in rnn_layer['class_name']: + layer[f'{direction}_recurrent_activation'] = rnn_layer['config']['recurrent_activation'] + + rnn_layer_name = rnn_layer['config']['name'] + if 'SimpleRNN' in layer['class_name']: + cell_name = 'simple_rnn' + else: + cell_name = rnn_layer['class_name'].lower() + temp_dir = direction + if swapped_order: + temp_dir = 'backward' if direction == 'forward' else 'forward' + layer[f'{direction}_weight_data'], layer[f'{direction}_recurrent_weight_data'], layer[f'{direction}_bias_data'] = ( + get_weights_data( + data_reader, + layer['name'], + [ + f'{temp_dir}_{rnn_layer_name}/{cell_name}_cell/kernel', + f'{temp_dir}_{rnn_layer_name}/{cell_name}_cell/recurrent_kernel', + f'{temp_dir}_{rnn_layer_name}/{cell_name}_cell/bias', + ], + ) + ) + + if 'GRU' in rnn_layer['class_name']: + layer[f'{direction}_apply_reset_gate'] = 'after' if rnn_layer['config']['reset_after'] else 'before' + + # biases array is actually a 2-dim array of arrays (bias + recurrent bias) + # both arrays have shape: n_units * 3 (z, r, h_cand) + biases = layer[f'{direction}_bias_data'] + layer[f'{direction}_bias_data'] = biases[0] + layer[f'{direction}_recurrent_bias_data'] = biases[1] + + layer[f'{direction}_n_states'] = rnn_layer['config']['units'] + + if layer['merge_mode'] == 'concat': + layer['n_out'] = layer['forward_n_states'] + layer['backward_n_states'] + else: + layer['n_out'] = layer['forward_n_states'] + + if layer['return_sequences']: + output_shape = [input_shapes[0][0], layer['n_timesteps'], layer['n_out']] + else: + output_shape = [input_shapes[0][0], layer['n_out']] + + if layer['return_state']: + raise Exception('"return_state" of {} layer is not yet supported.') + return layer, output_shape diff --git a/hls4ml/converters/keras_v2_to_hls.py b/hls4ml/converters/keras_v2_to_hls.py index b042e64f14..daa9fc5575 100644 --- a/hls4ml/converters/keras_v2_to_hls.py +++ b/hls4ml/converters/keras_v2_to_hls.py @@ -241,7 +241,7 @@ def parse_keras_model(model_arch, reader): 'HGQ>UnaryLUT', ] # Recurrent layers - recurrent_layers = ['SimpleRNN', 'LSTM', 'GRU', 'QSimpleRNN', 'QLSTM', 'QGRU'] + recurrent_layers = ['SimpleRNN', 'LSTM', 'GRU', 'QSimpleRNN', 'QLSTM', 'QGRU', 'Bidirectional'] # All supported layers supported_layers = get_supported_keras_layers() + skip_layers diff --git a/hls4ml/model/layers.py b/hls4ml/model/layers.py index b6cd446e58..f928fa8dbf 100644 --- a/hls4ml/model/layers.py +++ b/hls4ml/model/layers.py @@ -1403,7 +1403,7 @@ class LSTM(Layer): Attribute('return_sequences', value_type=bool, default=False), Attribute('return_state', value_type=bool, default=False), Attribute('pass_initial_states', value_type=bool, default=False), - ChoiceAttribute('direction', ['forward', 'backward'], default='forward'), + ChoiceAttribute('direction', ['forward', 'backward'], configurable=False, default='forward'), Attribute('time_major', value_type=bool, default=False), WeightAttribute('weight'), WeightAttribute('bias'), @@ -1460,9 +1460,9 @@ class GRU(Layer): Attribute('return_sequences', value_type=bool, default=False), Attribute('return_state', value_type=bool, default=False), Attribute('pass_initial_states', value_type=bool, default=False), - ChoiceAttribute('direction', ['forward', 'backward'], default='forward'), + ChoiceAttribute('direction', ['forward', 'backward'], configurable=False, default='forward'), Attribute('time_major', value_type=bool, default=False), - ChoiceAttribute('apply_reset_gate', ['before', 'after'], default='after'), + ChoiceAttribute('apply_reset_gate', ['before', 'after'], configurable=False, default='after'), WeightAttribute('weight'), WeightAttribute('bias'), WeightAttribute('recurrent_weight'), @@ -1526,6 +1526,80 @@ def initialize(self): self.add_output_variable(shape, dims) +class Bidirectional(Layer): + _expected_attributes = [ + Attribute('n_out'), + Attribute('return_sequences', value_type=bool, default=False), + Attribute('return_state', value_type=bool, default=False), + Attribute('pass_initial_states', value_type=bool, default=False), + Attribute('time_major', value_type=bool, default=False), + Attribute('forward_activation', value_type=str), + Attribute('forward_recurrent_activation', value_type=str), + WeightAttribute('forward_weight'), + WeightAttribute('forward_bias'), + WeightAttribute('forward_recurrent_weight'), + WeightAttribute('forward_recurrent_bias'), + TypeAttribute('forward_weight'), + TypeAttribute('forward_bias'), + TypeAttribute('forward_recurrent_weight'), + TypeAttribute('forward_recurrent_bias'), + Attribute('backward_activation', value_type=str), + Attribute('backward_recurrent_activation', value_type=str), + WeightAttribute('backward_weight'), + WeightAttribute('backward_bias'), + WeightAttribute('backward_recurrent_weight'), + WeightAttribute('backward_recurrent_bias'), + TypeAttribute('backward_weight'), + TypeAttribute('backward_bias'), + TypeAttribute('backward_recurrent_weight'), + TypeAttribute('backward_recurrent_bias'), + ] + + def initialize(self): + if self.attributes['return_sequences']: + shape = [self.attributes['n_timesteps'], self.attributes['n_out']] + dims = [f'N_TIME_STEPS_{self.index}', f'N_OUT_{self.index}'] + else: + shape = [self.attributes['n_out']] + dims = [f'N_OUT_{self.index}'] + + self.add_output_variable(shape, dims) + + if self.attributes['return_state']: + state_shape = [self.attributes['n_out']] + state_dims = [f'N_OUT_{self.index}'] + self.add_output_variable( + state_shape, state_dims, out_name=self.outputs[1], var_name='layer{index}_h', type_name='layer{index}_h_t' + ) + self.add_output_variable( + state_shape, state_dims, out_name=self.outputs[2], var_name='layer{index}_c', type_name='layer{index}_c_t' + ) + + for dir in ['forward', 'backward']: + # weights + self.add_weights_variable(name=f'{dir}_weight', var_name=(f'w_{dir[0]}_' + '{index}')) + + # recurrent weights + recurrent_weight = self.get_attr(f'{dir}_recurrent_weight_data') + self.add_weights_variable( + name=f'{dir}_recurrent_weight', var_name=(f'wr_{dir[0]}_' + '{index}'), data=recurrent_weight + ) + + # biases + self.add_weights_variable(name=f'{dir}_bias', var_name=(f'b_{dir[0]}_' + '{index}')) + + if self.attributes[f'{dir}_class_name'] == 'LSTM': + if "pytorch" in self.attributes.keys(): + self.add_weights_variable(name=f'{dir}_recurrent_bias', var_name=(f'br_{dir[0]}_' + '{index}')) + else: + recurrent_bias = np.zeros(recurrent_weight.shape[1]) + self.add_weights_variable( + name=f'{dir}_recurrent_bias', var_name=(f'br_{dir[0]}_' + '{index}'), data=recurrent_bias + ) + else: + self.add_weights_variable(name=f'{dir}_recurrent_bias', var_name=(f'br_{dir[0]}_' + '{index}')) + + class GarNet(Layer): ref_impl = False @@ -1816,6 +1890,7 @@ def initialize(self): 'SimpleRNN': SimpleRNN, 'LSTM': LSTM, 'GRU': GRU, + 'Bidirectional': Bidirectional, 'QSimpleRNN': SimpleRNN, 'QLSTM': LSTM, 'QGRU': GRU, diff --git a/hls4ml/model/optimizer/passes/infer_precision.py b/hls4ml/model/optimizer/passes/infer_precision.py index 919bc0c3c2..c048be99d4 100644 --- a/hls4ml/model/optimizer/passes/infer_precision.py +++ b/hls4ml/model/optimizer/passes/infer_precision.py @@ -81,7 +81,7 @@ def _infer_precision(self, node, types_to_infer): if node_class in ['Embedding']: return self._infer_embedding_precision(node, types_to_infer) - if node_class in ['SimpleRNN', 'LSTM', 'GRU']: + if node_class in ['SimpleRNN', 'LSTM', 'GRU', 'Bidirectional']: return self._infer_rnn_precision(node, types_to_infer) if node_class in ['ParametrizedActivation']: @@ -553,7 +553,11 @@ def _infer_rnn_precision(self, node, types_to_infer): inferred_types = [] # for now just do the weights and leave the rest for the default catch - for weightvar in ('weight', 'bias', 'recurrent_weight', 'recurrent_bias'): + rnn_weights = ('weight', 'bias', 'recurrent_weight', 'recurrent_bias') + if node.class_name == 'Bidirectional': + rnn_weights = [direction + '_' + weight for direction in ['forward', 'backward'] for weight in rnn_weights] + + for weightvar in rnn_weights: if f'{weightvar}_t' in types_to_infer: self._infer_default_type(node, f'{weightvar}_t') node.weights[weightvar].update_precision(node.types[f'{weightvar}_t'].precision) diff --git a/hls4ml/templates/vivado/nnet_utils/nnet_recurrent.h b/hls4ml/templates/vivado/nnet_utils/nnet_recurrent.h index 618767dcb5..042a0325ee 100644 --- a/hls4ml/templates/vivado/nnet_utils/nnet_recurrent.h +++ b/hls4ml/templates/vivado/nnet_utils/nnet_recurrent.h @@ -9,6 +9,8 @@ namespace nnet { +// Struct for the LSTM template + struct lstm_config { // Internal data type definitions typedef float weight_t; @@ -35,6 +37,7 @@ struct lstm_config { template using activation_recr = nnet::activation::relu; template using activation = nnet::activation::relu; }; + // Long Short term Memory NN (LSTM) // Resources: // https://github.com/nicodjimenez/lstm/blob/master/lstm.py @@ -111,7 +114,7 @@ void lstm(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[CONFIG } } -template +template void lstm_static(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[CONFIG_T::n_state], res_T s_newstate[CONFIG_T::n_state], typename CONFIG_T::weight_t param[CONFIG_T::n_state * 4 * CONFIG_T::n_in], @@ -189,6 +192,33 @@ void lstm_static(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate } } +template class lstm_class { + public: + static void apply(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_total[2 * CONFIG_T::n_state], + typename CONFIG_T::weight_t param[CONFIG_T::n_state * 4 * CONFIG_T::n_in], + typename CONFIG_T::recurrent_weight_t param_r[CONFIG_T::n_state * 4 * CONFIG_T::n_state], + typename CONFIG_T::bias_t param_b[CONFIG_T::n_state * 4], + typename CONFIG_T::recurrent_bias_t param_br[CONFIG_T::n_state * 4]) { + res_T *h_newstate = h_total; + res_T *s_newstate = h_newstate + CONFIG_T::n_state; + nnet::lstm(reset_state, data, h_newstate, s_newstate, param, param_r, param_b, param_br); + }; +}; + +template class lstm_class_static { + public: + static void apply(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_total[2 * CONFIG_T::n_state], + typename CONFIG_T::weight_t param[CONFIG_T::n_state * 4 * CONFIG_T::n_in], + typename CONFIG_T::recurrent_weight_t param_r[CONFIG_T::n_state * 4 * CONFIG_T::n_state], + typename CONFIG_T::bias_t param_b[CONFIG_T::n_state * 4], + typename CONFIG_T::recurrent_bias_t param_br[CONFIG_T::n_state * 4]) { + res_T *h_newstate = h_total; + res_T *s_newstate = h_newstate + CONFIG_T::n_state; + nnet::lstm_static(reset_state, data, h_newstate, s_newstate, param, param_r, + param_b, param_br); + }; +}; + template void lstm_stack(data_T data[CONFIG_T::n_sequence * CONFIG_T::n_in], res_T res[CONFIG_T::n_sequence_out * CONFIG_T::n_state], typename CONFIG_T::weight_t param[CONFIG_T::n_state * 4 * CONFIG_T::n_in], @@ -435,15 +465,14 @@ void gru(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[CONFIG_ } } -template +template void gru_static(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[CONFIG_T::n_state], typename CONFIG_T::weight_t param[CONFIG_T::n_state * 3 * CONFIG_T::n_in], typename CONFIG_T::recurrent_weight_t param_zr[CONFIG_T::n_state * 3 * CONFIG_T::n_state], typename CONFIG_T::bias_t param_b[CONFIG_T::n_state * 3], typename CONFIG_T::recurrent_bias_t param_br[CONFIG_T::n_state * 3]) { - // Initialize the state variable -- will maintain state between function calls - static res_T h_state[CONFIG_T::n_state]; + // Initialize the state variable -- will maintain state between function calls typename CONFIG_T::accum_t tmpres[CONFIG_T::n_state * 3]; typename CONFIG_T::accum_t tmpres_state_zr[CONFIG_T::n_state * 3]; typename CONFIG_T::accum_t tmpres_state_h[CONFIG_T::n_state]; @@ -519,6 +548,26 @@ void gru_static(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_newstate[ } } +template struct gru_class { + static void apply(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_state[CONFIG_T::n_state], + typename CONFIG_T::weight_t param[CONFIG_T::n_state * 3 * CONFIG_T::n_in], + typename CONFIG_T::recurrent_weight_t param_zr[CONFIG_T::n_state * 3 * CONFIG_T::n_state], + typename CONFIG_T::bias_t param_b[CONFIG_T::n_state * 3], + typename CONFIG_T::recurrent_bias_t param_br[CONFIG_T::n_state * 3]) { + nnet::gru(reset_state, data, h_state, param, param_zr, param_b, param_br); + }; +}; + +template struct gru_class_static { + static void apply(bool reset_state, data_T data[CONFIG_T::n_in], res_T h_state[CONFIG_T::n_state], + typename CONFIG_T::weight_t param[CONFIG_T::n_state * 3 * CONFIG_T::n_in], + typename CONFIG_T::recurrent_weight_t param_zr[CONFIG_T::n_state * 3 * CONFIG_T::n_state], + typename CONFIG_T::bias_t param_b[CONFIG_T::n_state * 3], + typename CONFIG_T::recurrent_bias_t param_br[CONFIG_T::n_state * 3]) { + nnet::gru_static(reset_state, data, h_state, param, param_zr, param_b, param_br); + }; +}; + template void gru_stack(data_T data[CONFIG_T::n_sequence * CONFIG_T::n_in], res_T res[CONFIG_T::n_sequence_out * CONFIG_T::n_state], typename CONFIG_T::weight_t param[CONFIG_T::n_state * 3 * CONFIG_T::n_in], @@ -654,6 +703,132 @@ void gru_stack(hls::stream &data_stream, hls::stream &res_stream, } } +// Struct for the Bidirectional template + +struct single_layer_config { + // Internal data type definitions + typedef float weight_t; + typedef float recurrent_weight_t; + typedef float bias_t; + typedef float recurrent_bias_t; + typedef float accum_t; + + // Layer Sizes + static const unsigned n_in = 2; + static const unsigned n_state = 2; + static const unsigned n_mult = 3; + static const unsigned table_size = 1024; + + template using activation_recr = nnet::activation::relu; + template using activation = nnet::activation::relu; +}; + +struct bidirectional_config { + // Layer Sizes + static const unsigned n_in = 2; + static const unsigned n_parts = 20; + static const unsigned n_out = 2; + static const unsigned table_size = 1024; + + // Resource reuse info + static const unsigned io_type = io_parallel; + static const unsigned reuse_factor = 1; + static const unsigned n_zeros = 0; + static const bool store_weights_in_bram = false; + static const bool use_static = true; + + // Layers info + + template + using RNNfunc_forward = nnet::lstm_class; + template + using RNNfunc_backward = nnet::lstm_class; +}; + +template +void bidirectional_stack( + data_T data[CONFIG_T::n_sequence * CONFIG_T::n_in], res_T res[CONFIG_T::n_sequence_out * CONFIG_T::n_out], + typename CONFIG_T::FORWARD_CONFIG::weight_t + param[CONFIG_T::FORWARD_CONFIG::n_state * CONFIG_T::FORWARD_CONFIG::n_mult * CONFIG_T::n_in], + typename CONFIG_T::FORWARD_CONFIG::recurrent_weight_t + param_r[CONFIG_T::FORWARD_CONFIG::n_state * CONFIG_T::FORWARD_CONFIG::n_mult * CONFIG_T::FORWARD_CONFIG::n_state], + typename CONFIG_T::FORWARD_CONFIG::bias_t param_b[CONFIG_T::FORWARD_CONFIG::n_state * CONFIG_T::FORWARD_CONFIG::n_mult], + typename CONFIG_T::FORWARD_CONFIG::recurrent_bias_t + param_br[CONFIG_T::FORWARD_CONFIG::n_state * CONFIG_T::FORWARD_CONFIG::n_mult], + typename CONFIG_T::BACKWARD_CONFIG::weight_t + param_back[CONFIG_T::BACKWARD_CONFIG::n_state * CONFIG_T::BACKWARD_CONFIG::n_mult * CONFIG_T::n_in], + typename CONFIG_T::BACKWARD_CONFIG::recurrent_weight_t + param_r_back[CONFIG_T::BACKWARD_CONFIG::n_state * CONFIG_T::BACKWARD_CONFIG::n_mult * + CONFIG_T::BACKWARD_CONFIG::n_state], + typename CONFIG_T::BACKWARD_CONFIG::bias_t + param_b_back[CONFIG_T::BACKWARD_CONFIG::n_state * CONFIG_T::BACKWARD_CONFIG::n_mult], + typename CONFIG_T::BACKWARD_CONFIG::recurrent_bias_t + param_br_back[CONFIG_T::BACKWARD_CONFIG::n_state * CONFIG_T::BACKWARD_CONFIG::n_mult]) { + + res_T h_newstate[(CONFIG_T::FORWARD_CONFIG::n_mult - 2) * CONFIG_T::FORWARD_CONFIG::n_state]; + res_T h_newstate_back[(CONFIG_T::BACKWARD_CONFIG::n_mult - 2) * CONFIG_T::BACKWARD_CONFIG::n_state]; + data_T data_in[CONFIG_T::n_in]; + data_T data_in_back[CONFIG_T::n_in]; + bool reset_state = true; + + #pragma HLS ARRAY_PARTITION variable=h_newstate complete + #pragma HLS ARRAY_PARTITION variable=h_newstate_back complete + + for (int ii = 0; ii < (CONFIG_T::FORWARD_CONFIG::n_mult - 2) * CONFIG_T::FORWARD_CONFIG::n_state; ii++) { + #pragma HLS UNROLL + h_newstate[ii] = 0; + } + for (int ii = 0; ii < (CONFIG_T::BACKWARD_CONFIG::n_mult - 2) * CONFIG_T::BACKWARD_CONFIG::n_state; ii++) { + #pragma HLS UNROLL + h_newstate_back[ii] = 0; + } + + for (int iloop = 0; iloop < CONFIG_T::n_sequence; iloop++) { + for (int j = 0; j < CONFIG_T::n_in; j++) { + #pragma HLS UNROLL + data_in[j] = data[j + iloop * CONFIG_T::n_in]; + data_in_back[j] = data[j + (CONFIG_T::n_sequence - iloop - 1) * CONFIG_T::n_in]; + } + + CONFIG_T::template RNNfunc_forward::apply( + reset_state, data_in, h_newstate, param, param_r, param_b, param_br); + CONFIG_T::template RNNfunc_backward::apply( + reset_state, data_in_back, h_newstate_back, param_back, param_r_back, param_b_back, param_br_back); + + if (CONFIG_T::n_sequence_out > 1) { + for (int i = (CONFIG_T::FORWARD_CONFIG::n_state + CONFIG_T::BACKWARD_CONFIG::n_state) * iloop, j = 0; + i < (CONFIG_T::FORWARD_CONFIG::n_state + CONFIG_T::BACKWARD_CONFIG::n_state) * iloop + + CONFIG_T::FORWARD_CONFIG::n_state; + i++, j++) { + #pragma HLS UNROLL + res[i] = h_newstate[j]; + } + for (int i = (CONFIG_T::FORWARD_CONFIG::n_state + CONFIG_T::BACKWARD_CONFIG::n_state) * + (CONFIG_T::n_sequence - iloop) - + CONFIG_T::BACKWARD_CONFIG::n_state, + j = 0; + i < + (CONFIG_T::FORWARD_CONFIG::n_state + CONFIG_T::BACKWARD_CONFIG::n_state) * (CONFIG_T::n_sequence - iloop); + i++, j++) { + #pragma HLS UNROLL + res[i] = h_newstate_back[j]; + } + } + reset_state = false; + } + + if (CONFIG_T::n_sequence_out == 1) { + for (int i = 0; i < (CONFIG_T::FORWARD_CONFIG::n_state); i++) { + #pragma HLS UNROLL + res[i] = h_newstate[i]; + } + for (int i = 0; i < (CONFIG_T::BACKWARD_CONFIG::n_state); i++) { + #pragma HLS UNROLL + res[i + CONFIG_T::FORWARD_CONFIG::n_state] = h_newstate_back[i]; + } + } +} + } // namespace nnet #endif diff --git a/test/pytest/test_rnn.py b/test/pytest/test_rnn.py index d2303669fe..daa95ac0b6 100644 --- a/test/pytest/test_rnn.py +++ b/test/pytest/test_rnn.py @@ -2,29 +2,92 @@ import numpy as np import pytest -from tensorflow.keras.layers import GRU, LSTM, Input, SimpleRNN +from tensorflow.keras.layers import GRU, LSTM, Bidirectional, Input, SimpleRNN from tensorflow.keras.models import Model, Sequential import hls4ml test_root_path = Path(__file__).parent -rnn_layers = [SimpleRNN, LSTM, GRU] +rnn_layers = [SimpleRNN, LSTM, GRU, Bidirectional] -@pytest.mark.parametrize('rnn_layer', rnn_layers) -@pytest.mark.parametrize('return_sequences', [True, False]) -def test_rnn_parsing(rnn_layer, return_sequences): +def create_model_parsing(rnn_layer, return_sequences): time_steps = 3 input_size = 8 input_shape = (time_steps, input_size) model_input = Input(shape=input_shape) - model_output = rnn_layer(64, return_sequences=return_sequences)(model_input) + if rnn_layer.__name__ != 'Bidirectional': + model_output = rnn_layer(64, return_sequences=return_sequences)(model_input) + else: + forward_layer = LSTM(37, return_sequences=return_sequences) + bacwkard_layer = GRU(27, return_sequences=return_sequences, go_backwards=True) + model_output = rnn_layer(forward_layer, backward_layer=bacwkard_layer)(model_input) model = Model(model_input, model_output) model.compile(optimizer='adam', loss='mse') + return model + + +def compare_attributes(hls_layer, keras_layer): + assert hls_layer.class_name == keras_layer.__class__.__name__ + assert hls_layer.get_input_variable().shape == list(keras_layer.input_shape)[1:] # Ignore the batch size + assert hls_layer.get_output_variable().shape == list(keras_layer.output_shape)[1:] # Ignore the batch size + if keras_layer.__class__.__name__ != 'Bidirectional': + assert hls_layer.attributes['n_out'] == keras_layer.units + assert hls_layer.attributes['activation'] == keras_layer.activation.__name__ + if 'recurrent_activation' in hls_layer.attributes: # SimpleRNN doesn't have this + assert hls_layer.attributes['recurrent_activation'] == keras_layer.recurrent_activation.__name__ + else: + assert hls_layer.attributes['merge_mode'] == keras_layer.merge_mode + n_out = 0 + for inner_layer, direction in [(keras_layer.forward_layer, 'forward'), (keras_layer.backward_layer, 'backward')]: + assert hls_layer.attributes[f'{direction}_n_states'] == inner_layer.units + if hls_layer.attributes['merge_mode'] == 'concat': + n_out += inner_layer.units + else: + n_out = inner_layer.units + assert hls_layer.attributes[f'{direction}_activation'] == inner_layer.activation.__name__ + if f'{direction}_recurrent_activation' in hls_layer.attributes: # SimpleRNN doesn't have this + assert hls_layer.attributes[f'{direction}_recurrent_activation'] == inner_layer.recurrent_activation.__name__ + assert hls_layer.attributes['n_out'] == n_out + + +def compare_weights(hls_weights, keras_weights, keras_layer): + def comparison(hls_weights, keras_weights, class_name): + assert hls_weights[0].data.shape == keras_weights[0].shape + assert hls_weights[1].data.shape == keras_weights[1].shape + if class_name == 'GRU': + # GRU has both bias and recurrent bias + assert hls_weights[2].data.shape == keras_weights[2][0].shape + assert hls_weights[3].data.shape == keras_weights[2][1].shape + else: + # LSTM and SimpleRNN only have bias + assert hls_weights[2].data.shape == keras_weights[2].shape + + np.testing.assert_array_equal(hls_weights[0].data, keras_weights[0]) + np.testing.assert_array_equal(hls_weights[1].data, keras_weights[1]) + if class_name == 'GRU': + np.testing.assert_array_equal(hls_weights[2].data, keras_weights[2][0]) + np.testing.assert_array_equal(hls_weights[3].data, keras_weights[2][1]) + else: + np.testing.assert_array_equal(hls_weights[2].data, keras_weights[2]) + + if keras_layer.__class__.__name__ != 'Bidirectional': + comparison(hls_weights, keras_weights, keras_layer.__class__.__name__) + else: + for i, inner_layer in enumerate([keras_layer.forward_layer, keras_layer.backward_layer]): + comparison(hls_weights[4 * i : 4 * (i + 1)], keras_weights[3 * i : 3 * (i + 1)], inner_layer.__class__.__name__) + + +@pytest.mark.parametrize('rnn_layer', rnn_layers) +@pytest.mark.parametrize('return_sequences', [True, False]) +def test_rnn_parsing(rnn_layer, return_sequences): + + model = create_model_parsing(rnn_layer, return_sequences) + config = hls4ml.utils.config_from_keras_model(model, granularity='name', backend='Vivado') prj_name = f'hls4mlprj_rnn_{rnn_layer.__class__.__name__.lower()}_seq_{int(return_sequences)}' output_dir = str(test_root_path / prj_name) @@ -34,35 +97,56 @@ def test_rnn_parsing(rnn_layer, return_sequences): keras_layer = model.layers[1] # Basic sanity check, I/O, activations - assert hls_layer.class_name == rnn_layer.__name__ - assert hls_layer.attributes['n_out'] == keras_layer.units - assert hls_layer.attributes['activation'] == keras_layer.activation.__name__ - if 'recurrent_activation' in hls_layer.attributes: # SimpleRNN doesn't have this - assert hls_layer.attributes['recurrent_activation'] == keras_layer.recurrent_activation.__name__ - assert hls_layer.get_input_variable().shape == list(input_shape) - assert hls_layer.get_output_variable().shape == model_output.shape.as_list()[1:] # Ignore the batch size + compare_attributes(hls_layer, keras_layer) # Compare weights hls_weights = list(hls_layer.get_weights()) # [weights, recurrent_weights, bias, recurrent_bias] - rnn_weights = keras_layer.get_weights() # [weights, recurrent_weights, bias] - - assert hls_weights[0].data.shape == rnn_weights[0].shape - assert hls_weights[1].data.shape == rnn_weights[1].shape - if 'gru' in rnn_layer.__name__.lower(): - # GRU has both bias and recurrent bias - assert hls_weights[2].data.shape == rnn_weights[2][0].shape - assert hls_weights[3].data.shape == rnn_weights[2][1].shape - else: - # LSTM and SimpleRNN only have bias - assert hls_weights[2].data.shape == rnn_weights[2].shape - - np.testing.assert_array_equal(hls_weights[0].data, rnn_weights[0]) - np.testing.assert_array_equal(hls_weights[1].data, rnn_weights[1]) - if 'gru' in rnn_layer.__name__.lower(): - np.testing.assert_array_equal(hls_weights[2].data, rnn_weights[2][0]) - np.testing.assert_array_equal(hls_weights[3].data, rnn_weights[2][1]) + keras_weights = keras_layer.get_weights() # [weights, recurrent_weights, bias] + compare_weights(hls_weights, keras_weights, keras_layer) + + +def create_model_accuracy(rnn_layer, return_sequences): + # Subtract 0.5 to include negative values + input_shape = (12, 8) + X = np.random.rand(50, *input_shape) - 0.5 + + layer_name = rnn_layer.__name__ + model = Sequential() + model.add(Input(shape=input_shape)) + if layer_name != 'Bidirectional': + test_layer = rnn_layer( + units=32, + input_shape=input_shape, + kernel_initializer='lecun_uniform', + recurrent_initializer='lecun_uniform', + bias_initializer='lecun_uniform', + return_sequences=return_sequences, + name=layer_name, + ) else: - np.testing.assert_array_equal(hls_weights[2].data, rnn_weights[2]) + test_layer = Bidirectional( + LSTM( + units=15, + input_shape=input_shape, + kernel_initializer='lecun_uniform', + recurrent_initializer='lecun_uniform', + bias_initializer='lecun_uniform', + return_sequences=return_sequences, + ), + backward_layer=GRU( + units=17, + input_shape=input_shape, + kernel_initializer='lecun_uniform', + recurrent_initializer='lecun_uniform', + bias_initializer='lecun_uniform', + return_sequences=return_sequences, + go_backwards=True, + ), + name=layer_name, + ) + model.add(test_layer) + model.compile() + return model, X @pytest.mark.parametrize( @@ -92,47 +176,37 @@ def test_rnn_parsing(rnn_layer, return_sequences): (GRU, 'Vitis', 'io_stream', 'latency'), (GRU, 'Quartus', 'io_stream', 'resource'), (GRU, 'oneAPI', 'io_stream', 'resource'), + (Bidirectional, 'Vivado', 'io_parallel', 'resource'), + (Bidirectional, 'Vivado', 'io_parallel', 'latency'), + (Bidirectional, 'Vitis', 'io_parallel', 'resource'), + (Bidirectional, 'Vitis', 'io_parallel', 'latency'), ], ) @pytest.mark.parametrize('return_sequences', [True, False]) @pytest.mark.parametrize('static', [True, False]) def test_rnn_accuracy(rnn_layer, return_sequences, backend, io_type, strategy, static): - # Subtract 0.5 to include negative values - input_shape = (12, 8) - X = np.random.rand(50, *input_shape) - 0.5 - layer_name = rnn_layer.__name__ - keras_model = Sequential() - keras_model.add( - rnn_layer( - units=32, - input_shape=input_shape, - kernel_initializer='lecun_uniform', - recurrent_initializer='lecun_uniform', - bias_initializer='lecun_uniform', - return_sequences=return_sequences, - name=layer_name, - ) - ) - keras_model.compile() + + model, X = create_model_accuracy(rnn_layer, return_sequences) default_precision = 'ap_fixed<32, 16>' if backend in ['Vivado', 'Vitis'] else 'ac_fixed<32, 16, true>' hls_config = hls4ml.utils.config_from_keras_model( - keras_model, granularity='name', default_precision=default_precision, backend=backend + model, granularity='name', default_precision=default_precision, backend=backend ) hls_config['LayerName'][layer_name]['static'] = static hls_config['LayerName'][layer_name]['Strategy'] = strategy prj_name = ( - f'hls4mlprj_rnn_accuracy_{layer_name}_static_{int(static)}_ret_seq_{int(return_sequences)}_' - f'{backend}_{io_type}_{strategy}' + 'hls4mlprj_rnn_accuracy_' + + f'{layer_name}_static_{int(static)}_ret_seq_{int(return_sequences)}_' + + f'{backend}_{io_type}_{strategy}' ) output_dir = str(test_root_path / prj_name) hls_model = hls4ml.converters.convert_from_keras_model( - keras_model, hls_config=hls_config, output_dir=output_dir, backend=backend, io_type=io_type + model, hls_config=hls_config, output_dir=output_dir, backend=backend, io_type=io_type ) hls_model.compile() - keras_prediction = keras_model.predict(X) + keras_prediction = model.predict(X) hls_prediction = hls_model.predict(X) np.testing.assert_allclose(hls_prediction.flatten(), keras_prediction.flatten(), rtol=0.0, atol=5e-2)