Skip to content

Commit 870012d

Browse files
Merge remote-tracking branch 'upstream/main' into conv_tr_parallel
2 parents 2156a93 + 563c84c commit 870012d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+3596
-823
lines changed

.github/PULL_REQUEST_TEMPLATE.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ Note: Please delete options that are not relevant.
3030

3131
## Checklist
3232

33-
- [ ] I have read the [guidelines for contributing](https://github.com/fastmachinelearning/hls4ml/blob/master/CONTRIBUTING.md).
33+
- [ ] I have read the [guidelines for contributing](https://github.com/fastmachinelearning/hls4ml/blob/main/CONTRIBUTING.md).
3434
- [ ] I have commented my code, particularly in hard-to-understand areas.
3535
- [ ] I have made corresponding changes to the documentation.
3636
- [ ] My changes generate no new warnings.

.github/workflows/build-sphinx.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ name: build-sphinx
22
on:
33
push:
44
branches:
5-
- master
5+
- main
66

77
jobs:
88
build:
@@ -30,4 +30,4 @@ jobs:
3030
with:
3131
branch: gh-pages
3232
directory: gh-pages
33-
github_token: ${{ secrets.PERSONAL_TOKEN }}
33+
github_token: ${{ secrets.PERSONAL_TOKEN }}

hls4ml/backends/fpga/fpga_types.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,13 @@ def definition_cpp(self, name_suffix='', as_reference=False):
258258
else: # Declaration
259259
return 'hls::stream<{type}> {name}{suffix}("{name}")'.format(type=self.type.name, name=self.name, suffix=name_suffix)
260260

261+
class QuartusStreamVariableDefinition(VariableDefinition):
262+
def definition_cpp(self, name_suffix='', as_reference=False):
263+
if as_reference: # Function parameter
264+
return 'stream<{type}> &{name}{suffix}'.format(type=self.type.name, name=self.name, suffix=name_suffix)
265+
else: # Declaration
266+
return 'stream<{type}> {name}{suffix}'.format(type=self.type.name, name=self.name, suffix=name_suffix)
267+
261268
class StreamVariableConverter(object):
262269
def __init__(self, type_converter, prefix, definition_cls):
263270
self.type_converter = type_converter
@@ -280,6 +287,10 @@ class VivadoStreamVariableConverter(StreamVariableConverter):
280287
def __init__(self, type_converter):
281288
super().__init__(type_converter=type_converter, prefix='Vivado', definition_cls=VivadoStreamVariableDefinition)
282289

290+
class QuartusStreamVariableConverter(StreamVariableConverter):
291+
def __init__(self, type_converter):
292+
super().__init__(type_converter=type_converter, prefix='Quartus', definition_cls=QuartusStreamVariableDefinition)
293+
283294
#endregion
284295

285296
#region InplaceVariable

hls4ml/backends/vivado/passes/clone.py renamed to hls4ml/backends/fpga/passes/clone.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from hls4ml.model.optimizer import OptimizerPass
44

55
from hls4ml.model.layers import Layer, register_layer
6-
from hls4ml.backends import get_backend
76
from hls4ml.backends.template import FunctionCallTemplate
87

98
class Clone(Layer):

hls4ml/backends/quartus/passes/core_templates.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636

3737
dense_function_template = 'nnet::dense_{strategy}<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});'
3838

39-
dense_include_list = ['nnet_utils/nnet_dense.h', 'nnet_utils/nnet_dense_compressed.h']
39+
dense_include_list = ['nnet_utils/nnet_dense.h', 'nnet_utils/nnet_dense_compressed.h', 'nnet_utils/nnet_dense_stream.h']
4040

4141
class DenseConfigTemplate(LayerConfigTemplate):
4242
def __init__(self):
@@ -80,7 +80,7 @@ def format(self, node):
8080

8181
batchnorm_function_template = 'nnet::normalize<{input_t}, {output_t}, {config}>({input}, {output}, {scale}, {bias});'
8282

83-
batchnorm_include_list = ['nnet_utils/nnet_batchnorm.h']
83+
batchnorm_include_list = ['nnet_utils/nnet_batchnorm.h', 'nnet_utils/nnet_batchnorm_stream.h']
8484

8585
class BatchNormalizationConfigTemplate(LayerConfigTemplate):
8686
def __init__(self):
@@ -130,7 +130,7 @@ def format(self, node):
130130
activ_function_template = 'nnet::{activation}<{input_t}, {output_t}, {config}>({input}, {output});'
131131
param_activ_function_template = 'nnet::{activation}<{input_t}, {output_t}, {config}>({input}, {param}, {output});'
132132

133-
activ_include_list = ['nnet_utils/nnet_activation.h']
133+
activ_include_list = ['nnet_utils/nnet_activation.h', 'nnet_utils/nnet_activation_stream.h']
134134

135135
class ActivationConfigTemplate(LayerConfigTemplate):
136136
def __init__(self):
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
from hls4ml.backends.backend import get_backend
2+
from hls4ml.model.layers import Concatenate, Dot, Merge
3+
from hls4ml.backends.template import LayerConfigTemplate, FunctionCallTemplate
4+
5+
# TODO - Very similar to vivado/merge_templates.py - only difference is on line 67: get_backend('vivado').product_type(inp1.type.precision, inp2.type.precision)
6+
# TODO - Look into ways of having passes similar accross many backends in a shared folder thorugh inheritance and overriding.
7+
8+
# Merge templates
9+
merge_config_template = """struct config{index} : nnet::merge_config {{
10+
static const unsigned n_elem = {n_elem};
11+
}};\n"""
12+
13+
merge_function_template = 'nnet::{merge}<{input1_t}, {input2_t}, {output_t}, {config}>({input1}, {input2}, {output});'
14+
merge_include_list = ['nnet_utils/nnet_merge.h', 'nnet_utils/nnet_merge_stream.h']
15+
16+
class MergeConfigTemplate(LayerConfigTemplate):
17+
def __init__(self):
18+
super().__init__(Merge)
19+
self.template = merge_config_template
20+
21+
def format(self, node):
22+
params = self._default_config_params(node)
23+
params['n_elem'] = node.get_input_variable(node.inputs[0]).size_cpp()
24+
25+
return self.template.format(**params)
26+
27+
class MergeFunctionTemplate(FunctionCallTemplate):
28+
def __init__(self):
29+
super().__init__((Merge, Concatenate, Dot), include_header=merge_include_list)
30+
self.template = merge_function_template
31+
32+
def format(self, node):
33+
params = {}
34+
params['merge'] = node.get_attr('op').lower()
35+
params['config'] = 'config{}'.format(node.index)
36+
params['input1_t'] = node.get_input_variable(node.inputs[0]).type.name
37+
params['input2_t'] = node.get_input_variable(node.inputs[1]).type.name
38+
params['output_t'] = node.get_output_variable().type.name
39+
params['input1'] = node.get_input_variable(node.inputs[0]).name
40+
params['input2'] = node.get_input_variable(node.inputs[1]).name
41+
params['output'] = node.get_output_variable().name
42+
43+
return self.template.format(**params)
44+
45+
46+
# Dot templates
47+
dot_config_template = """struct config{index} : nnet::dot_config {{
48+
static const unsigned n_in = {n_in};
49+
static const unsigned n_out = {n_out};
50+
51+
static const unsigned reuse_factor = {reuse};
52+
53+
typedef {accum_t.name} accum_t;
54+
55+
template<class x_T, class y_T>
56+
using product = nnet::product::{product_type}<x_T, y_T>;
57+
}};\n"""
58+
59+
class DotConfigTemplate(LayerConfigTemplate):
60+
def __init__(self):
61+
super().__init__(Dot)
62+
self.template = dot_config_template
63+
64+
def format(self, node):
65+
inp1 = node.get_input_variable(node.inputs[0])
66+
inp2 = node.get_input_variable(node.inputs[1])
67+
params = node._default_config_params()
68+
params['n_out'] = 1
69+
params['n_in'] = inp1.shape[0]
70+
params['product_type'] = get_backend('quartus').product_type(inp1.type.precision, inp2.type.precision)
71+
72+
return self.template.format(**params)
73+
74+
75+
# Concatenate templates
76+
concat_config_template = """struct config{index} : nnet::concat_config {{
77+
static const unsigned n_elem1_0 = {n_elem1_0};
78+
static const unsigned n_elem1_1 = {n_elem1_1};
79+
static const unsigned n_elem1_2 = {n_elem1_2};
80+
static const unsigned n_elem2_0 = {n_elem2_0};
81+
static const unsigned n_elem2_1 = {n_elem2_1};
82+
static const unsigned n_elem2_2 = {n_elem2_2};
83+
84+
static const int axis = {axis};
85+
}};\n"""
86+
87+
class ConcatenateConfigTemplate(LayerConfigTemplate):
88+
def __init__(self):
89+
super().__init__(Concatenate)
90+
self.template = concat_config_template
91+
92+
def format(self, node):
93+
params = self._default_config_params(node)
94+
for i in range(3):
95+
params.setdefault('n_elem1_{}'.format(i), 0)
96+
params.setdefault('n_elem2_{}'.format(i), 0)
97+
inp1 = node.get_input_variable(node.inputs[0])
98+
inp2 = node.get_input_variable(node.inputs[1])
99+
for i, (s1, s2) in enumerate(zip(inp1.shape, inp2.shape)):
100+
params['n_elem1_{}'.format(i)] = s1
101+
params['n_elem2_{}'.format(i)] = s2
102+
103+
return self.template.format(**params)
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
from hls4ml.backends.backend import get_backend
2+
from hls4ml.model.layers import GRU
3+
from hls4ml.backends.template import LayerConfigTemplate, FunctionCallTemplate
4+
5+
recurrent_include_list = ['nnet_utils/nnet_recurrent.h', 'nnet_utils/nnet_recurrent_stream.h']
6+
7+
# Shared Matrix Multiplication Template (Dense)
8+
recr_mult_config_template = '''struct config{index}_mult : nnet::dense_config {{
9+
static const unsigned n_in = {n_in};
10+
static const unsigned n_out = {n_out};
11+
12+
static const unsigned rf_pad = {rfpad};
13+
static const unsigned bf_pad = {bfpad};
14+
static const unsigned reuse_factor = {reuse};
15+
static const unsigned reuse_factor_rounded = reuse_factor + rf_pad;
16+
static const unsigned block_factor = DIV_ROUNDUP(n_in*n_out, reuse_factor);
17+
static const unsigned block_factor_rounded = block_factor + bf_pad;
18+
static const unsigned multiplier_factor = MIN(n_in, reuse_factor);
19+
static const unsigned multiplier_limit = DIV_ROUNDUP(n_in*n_out, multiplier_factor);
20+
static const unsigned multiplier_scale = multiplier_limit/n_out;
21+
typedef {accum_t.name} accum_t;
22+
typedef {bias_t.name} bias_t;
23+
typedef {weight_t.name} weight_t;
24+
25+
template<class x_T, class y_T>
26+
using product = nnet::product::{product_type}<x_T, y_T>;
27+
}};\n'''
28+
29+
# Activation Template
30+
activ_config_template = '''struct {type}_config{index} : nnet::activ_config {{
31+
static const unsigned n_in = {n_in};
32+
static const unsigned table_size = {table_size};
33+
static const unsigned io_type = nnet::{iotype};
34+
static const unsigned reuse_factor = {reuse};
35+
}};\n'''
36+
37+
# GRU Template
38+
gru_config_template = '''struct config{index} : nnet::gru_config {{
39+
static const unsigned n_in = {n_in};
40+
static const unsigned n_out = {n_out};
41+
static const unsigned n_units = {n_units};
42+
static const unsigned n_timesteps = {n_timesteps};
43+
static const unsigned n_outputs = {n_outputs};
44+
static const bool return_sequences = {return_sequences};
45+
46+
typedef {accum_t.name} accum_t;
47+
typedef {weight_t.name} weight_t;
48+
typedef {bias_t.name} bias_t;
49+
50+
typedef {config_mult_x} mult_config_x;
51+
typedef {config_mult_h} mult_config_h;
52+
53+
typedef {act_t} ACT_CONFIG_T;
54+
template<class x_T, class y_T, class config_T>
55+
using activation = nnet::activation::{activation}<x_T, y_T, config_T>;
56+
57+
typedef {act_recurrent_t} ACT_CONFIG_RECURRENT_T;
58+
template<class x_T, class y_T, class config_T>
59+
using activation_recr = nnet::activation::{recurrent_activation}<x_T, y_T, config_T>;
60+
61+
static const unsigned reuse_factor = {reuse};
62+
static const bool store_weights_in_bram = false;
63+
}};\n'''
64+
65+
gru_function_template = 'nnet::gru<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {wr}, {b}, {br});'
66+
67+
class GRUConfigTemplate(LayerConfigTemplate):
68+
def __init__(self):
69+
super().__init__(GRU)
70+
self.gru_template = gru_config_template
71+
self.act_template = activ_config_template
72+
self.recr_act_template = activ_config_template
73+
self.mult_x_template = recr_mult_config_template
74+
self.mult_h_template = recr_mult_config_template
75+
76+
def format(self, node):
77+
# Input has shape (n_timesteps, inp_dimensionality)
78+
# Output / hidden units has shape (1 if !return_sequences else n_timesteps , n_units)
79+
params = self._default_config_params(node)
80+
params['n_units'] = node.get_attr('n_out')
81+
params['n_outputs'] = node.get_attr('n_timesteps') if node.get_attr('return_sequences', False) else '1'
82+
params['return_sequences'] ='true' if node.get_attr('return_sequences', False) else 'false'
83+
params['config_mult_x'] = 'config{}_x_mult'.format(node.index)
84+
params['config_mult_h'] = 'config{}_h_mult'.format(node.index)
85+
params['act_t'] = '{}_config{}'.format(node.get_attr('activation'), str(node.index) + '_act')
86+
params['act_recurrent_t'] = '{}_config{}'.format(node.get_attr('recurrent_activation'), str(node.index) + '_rec_act')
87+
gru_config = self.gru_template.format(**params)
88+
89+
# Activation is on candidate hidden state, dimensionality (1, n_units)
90+
act_params = self._default_config_params(node)
91+
act_params['type'] = node.get_attr('activation')
92+
act_params['n_in'] = node.get_attr('n_out')
93+
act_params['index'] = str(node.index) + '_act'
94+
act_config = self.act_template.format(**act_params)
95+
96+
# Recurrent activation is on reset and update gates (therefore x2), dimensionality (1, n_units)
97+
recr_act_params = self._default_config_params(node)
98+
recr_act_params['type'] = node.get_attr('recurrent_activation')
99+
recr_act_params['n_in'] = str(node.get_attr('n_out')) + ' * 2'
100+
recr_act_params['index'] = str(node.index) + '_rec_act'
101+
recr_act_config = self.recr_act_template.format(**recr_act_params)
102+
103+
# Multiplication config for matrix multiplications of type Wx (reset, update and candidate states)
104+
mult_params_x = self._default_config_params(node)
105+
mult_params_x['n_in'] = node.get_attr('n_in')
106+
mult_params_x['n_out'] = str(node.get_attr('n_out')) + ' * 3'
107+
mult_params_x['product_type'] = get_backend('quartus').product_type(node.get_input_variable().type.precision, node.get_weights('weight').type.precision)
108+
mult_params_x['index'] = str(node.index) + '_x'
109+
mult_config_x = self.mult_x_template.format(**mult_params_x)
110+
111+
# Multiplication config for matrix multiplications of type Wh (reset, update and candidate states)
112+
mult_params_h = self._default_config_params(node)
113+
mult_params_h['n_in'] = node.get_attr('n_out')
114+
mult_params_h['n_out'] = str(node.get_attr('n_out')) + ' * 3'
115+
mult_params_h['reuse_factor'] = params['recurrent_reuse_factor']
116+
mult_params_h['product_type'] = get_backend('quartus').product_type(node.get_input_variable().type.precision, node.get_weights('recurrent_weight').type.precision)
117+
mult_params_h['index'] = str(node.index) + '_h'
118+
mult_config_h = self.mult_h_template.format(**mult_params_h)
119+
120+
return mult_config_x + '\n' + mult_config_h + '\n' + recr_act_config + '\n' + act_config + '\n' + gru_config
121+
122+
class GRUFunctionTemplate(FunctionCallTemplate):
123+
def __init__(self):
124+
super().__init__(GRU, include_header=recurrent_include_list)
125+
self.template = gru_function_template
126+
127+
def format(self, node):
128+
params = self._default_function_params(node)
129+
params['w'] = node.get_weights('weight').name
130+
params['b'] = node.get_weights('bias').name
131+
params['wr'] = node.get_weights('recurrent_weight').name
132+
params['br'] = node.get_weights('recurrent_bias').name
133+
return self.template.format(**params)
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import numpy as np
2+
from hls4ml.model.optimizer import OptimizerPass
3+
from hls4ml.model.layers import Dense, GRU
4+
5+
class ApplyResourceStrategy(OptimizerPass):
6+
''' Transposes the weights to use the dense_resource matrix multiply routine '''
7+
def match(self, node):
8+
node_matches = isinstance(node, (Dense, GRU))
9+
is_resource_strategy = True # node.get_attr('strategy', '').lower() == 'resource' ... Quartus only supports resource strategy
10+
already_transformed = node.get_attr('_weights_transposed', False) == True
11+
return node_matches and is_resource_strategy and not already_transformed
12+
13+
def transform(self, model, node):
14+
if isinstance(node, Dense) and not node.model.config.get_compression(node):
15+
rf = node.get_attr('reuse_factor')
16+
bf = int((node.attributes['n_in']*node.attributes['n_out'])/rf)
17+
bf_rounded = int(pow(2, np.ceil(np.log2(bf))))
18+
rf_rounded = int(pow(2, np.ceil(np.log2(rf))))
19+
20+
node.weights['weight'].data = np.transpose(node.weights['weight'].data).flatten()
21+
22+
if(node.attributes['n_in']*node.attributes['n_out'] > 2048 and rf_rounded != rf):
23+
node.set_attr('rfpad', rf_rounded-rf)
24+
node.set_attr('bfpad', bf_rounded-bf)
25+
26+
temp = np.empty([bf_rounded, rf_rounded])
27+
for i in range(rf_rounded):
28+
for j in range (bf_rounded):
29+
if (i < rf and j < bf):
30+
w_index = i + rf * j
31+
temp[j][i] = node.weights['weight'].data[w_index]
32+
else:
33+
temp[j][i] = 0
34+
node.weights['weight'].data = temp.flatten()
35+
node.weights['weight'].data_length = node.weights['weight'].data.size
36+
37+
elif isinstance(node, GRU):
38+
node.weights['weight'].data = np.transpose(node.weights['weight'].data)
39+
node.weights['recurrent_weight'].data = np.transpose(node.weights['recurrent_weight'].data)
40+
41+
else:
42+
raise Exception('Unexpected layer {} with resource strategy'.format(node.class_name))
43+
44+
node.set_attr('_weights_transposed', True)
45+
return False
46+

hls4ml/backends/quartus/passes/transform_types.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11

22
from hls4ml.model.optimizer import GlobalOptimizerPass
33
from hls4ml.model.types import InplaceVariable
4-
from hls4ml.backends.fpga.fpga_types import ACTypeConverter, QuartusArrayVariableConverter, HLSTypeConverter, QuartusInplaceVariableConverter, QuartusStructMemberVariableConverter, StaticWeightVariableConverter
5-
4+
from hls4ml.backends.fpga.fpga_types import ACTypeConverter, QuartusArrayVariableConverter, HLSTypeConverter, QuartusInplaceVariableConverter, QuartusStreamVariableConverter, QuartusStructMemberVariableConverter, StaticWeightVariableConverter
65

76
class TransformTypes(GlobalOptimizerPass):
87
def __init__(self):
98
self.type_converter = HLSTypeConverter(precision_converter=ACTypeConverter())
109
self.array_var_converter = QuartusArrayVariableConverter(type_converter=self.type_converter)
1110
self.struct_var_converter = QuartusStructMemberVariableConverter(type_converter=self.type_converter)
11+
self.stream_var_converter = QuartusStreamVariableConverter(type_converter=self.type_converter)
1212
self.weight_var_converter = StaticWeightVariableConverter(type_converter=self.type_converter)
1313
self.inplace_var_converter = QuartusInplaceVariableConverter(type_converter=self.type_converter)
1414

@@ -18,9 +18,8 @@ def transform(self, model, node):
1818
for out_name, var in node.variables.items():
1919
if isinstance(var, InplaceVariable):
2020
new_var = self.inplace_var_converter.convert(var, io_type)
21-
2221
if io_type == 'io_stream':
23-
raise Exception('Streaming IO is not supported in Quartus.')
22+
new_var = self.stream_var_converter.convert(var)
2423
elif io_type == 'io_parallel':
2524
if node.name in node.model.inputs:
2625
new_var = self.struct_var_converter.convert(var, pragma='hls_register', struct_name='inputs')

0 commit comments

Comments
 (0)