Skip to content

Commit 0ea246c

Browse files
committed
Refactor matrix-multiplication kernel as a function pointer
1 parent f1a238d commit 0ea246c

15 files changed

+362
-239
lines changed

hls4ml/backends/fpga/passes/codegen.py

Lines changed: 97 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import numpy as np
44

5-
from hls4ml.model.layers import Conv1D, Conv2D, Dense
5+
from hls4ml.model.layers import GRU, LSTM, Conv1D, Conv2D, Dense
66
from hls4ml.model.optimizer import OptimizerPass
77
from hls4ml.model.types import Source
88

@@ -60,8 +60,8 @@ class GenerateUnrolledDenseResource(OptimizerPass):
6060

6161
def match(self, node):
6262
# Only apply to layers use that use Dense Matrix Multiplication
63-
# TODO - Extend (& test) for Conv1D / Separable Conv / Depthwise Conv / Recurrent layers
64-
layers_with_dense = (Dense, Conv2D)
63+
# TODO - Extend (& test) for Separable Conv / Depthwise Conv / Recurrent layers
64+
layers_with_dense = (Dense, Conv1D, Conv2D, LSTM, GRU)
6565

6666
# Unrolled Dense mimicks the hardware implementation of Resource strategy -> apply after Resource optimizer
6767
weights_transposed = node.get_attr('_weights_transposed', False)
@@ -70,23 +70,43 @@ def match(self, node):
7070
rf_gt_one = node.get_attr('reuse_factor', 1) > 1
7171

7272
# User requested unrolled implementation of Dense
73-
is_unrolled = node.get_attr('dense_resource_implementation', 'standard') == 'unrolled'
73+
is_unrolled = node.get_attr('strategy', 'latency') == 'unrolled'
7474

7575
return isinstance(node, layers_with_dense) and weights_transposed and rf_gt_one and is_unrolled
7676

7777
def transform(self, model, node):
78-
code_str = self.__generate_unrolled_dense_resource(model, node)
79-
node.set_attr('unrolled_dense_resource_codegen', Source(code_str))
78+
if isinstance(node, (LSTM, GRU)):
79+
n_in, n_out, n_in_recr, n_out_recr = node.model.config.backend.get_layer_mult_size(node)
8080

81-
def __generate_unrolled_dense_resource(self, model, node):
81+
reuse_factor = node.get_attr('reuse_factor')
82+
weights = node.weights['weight']
83+
code_str = self._generate_unrolled_function(n_in, n_out, reuse_factor, weights, str(node.index) + '_1')
84+
node.set_attr('unrolled_dense_resource_codegen_1', Source(code_str))
85+
86+
recr_reuse_factor = node.get_attr('recurrent_reuse_factor')
87+
recr_weights = node.weights['recurrent_weight']
88+
code_str = self._generate_unrolled_function(
89+
n_in_recr, n_out_recr, recr_reuse_factor, recr_weights, str(node.index) + '_2'
90+
)
91+
node.set_attr('unrolled_dense_resource_codegen_2', Source(code_str))
92+
93+
else:
94+
n_in, n_out = node.model.config.backend.get_layer_mult_size(node)
95+
reuse_factor = node.get_attr('reuse_factor')
96+
weights = node.weights['weight']
97+
98+
code_str = self._generate_unrolled_function(n_in, n_out, reuse_factor, weights, node.index)
99+
node.set_attr('unrolled_dense_resource_codegen', Source(code_str))
100+
101+
def _generate_unrolled_function(self, n_in, n_out, reuse_factor, weights, function_suffix):
82102
"""
83103
Generate a C++ function that mimics the Dense Resource implementation.
84104
85105
The HLS compiler produces suboptimal designs for Dense Resource when the weights processed by the same DSP are zero.
86-
Latency strategy can optimize zero mutiplications
106+
Latency strategy can optimize zero multiplications
87107
Resource strategy, on the other hand, cannot.
88108
When all the weights in the same BRAM block are zero, Vivado is unable to optimize it
89-
With this (and additional TCL scripts) zero BRAM are optimised
109+
With this (and additional TCL scripts) zero BRAM are optimized
90110
91111
Args:
92112
node: Layer to generate code for
@@ -96,61 +116,58 @@ def __generate_unrolled_dense_resource(self, model, node):
96116

97117
# Variable instantiation and function pragmas
98118
generated_code = (
99-
"template<class data_T, class res_T, typename CONFIG_T>\n"
100-
"class dense_unrolled_{index} : public DenseResourceUnrolled<data_T, res_T, CONFIG_T> {{\n"
101-
" public:\n"
102-
" static void dense_unrolled(\n"
103-
" data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_out],\n"
104-
" typename CONFIG_T::weight_t weights[CONFIG_T::n_in * CONFIG_T::n_out],\n"
105-
" typename CONFIG_T::bias_t biases[CONFIG_T::n_out]\n"
106-
" ) {{\n"
107-
" #pragma HLS pipeline II=CONFIG_T::reuse_factor\n"
108-
"\n"
109-
" constexpr int block_factor = DIV_ROUNDUP(CONFIG_T::n_in * CONFIG_T::n_out, CONFIG_T::reuse_factor);\n"
110-
" #pragma HLS function_instantiate variable=weights,biases\n"
111-
" #pragma HLS ARRAY_RESHAPE variable=weights block factor=block_factor\n"
112-
" #pragma HLS RESOURCE variable=weights core=ROM_nP_BRAM\n"
113-
" #pragma HLS ARRAY_PARTITION variable=biases complete\n"
114-
"\n"
115-
" typename CONFIG_T::accum_t acc[CONFIG_T::n_out];\n"
116-
" #pragma HLS ARRAY_PARTITION variable=acc complete\n"
117-
"\n"
118-
" InitAccum:\n"
119-
" for (int i = 0; i < CONFIG_T::n_out; i++) {{\n"
120-
" #pragma HLS UNROLL\n"
121-
" acc[i] = (typename CONFIG_T::accum_t) biases[i];\n"
122-
" }}\n"
123-
"\n"
124-
).format(index=node.index)
119+
'template<class data_T, class res_T, typename CONFIG_T>\n'
120+
'class dense_unrolled_{suffix} : public DenseKernel<data_T, res_T, CONFIG_T> {{\n'
121+
' public:\n'
122+
' static void dense(\n'
123+
' data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_out],\n'
124+
' typename CONFIG_T::weight_t weights[CONFIG_T::n_in * CONFIG_T::n_out],\n'
125+
' typename CONFIG_T::bias_t biases[CONFIG_T::n_out]\n'
126+
' ) {{\n'
127+
' #pragma HLS pipeline II=CONFIG_T::reuse_factor\n'
128+
'\n'
129+
' constexpr int block_factor = DIV_ROUNDUP(CONFIG_T::n_in * CONFIG_T::n_out, CONFIG_T::reuse_factor);\n'
130+
' #pragma HLS function_instantiate variable=weights,biases\n'
131+
' #pragma HLS ARRAY_RESHAPE variable=weights block factor=block_factor\n'
132+
' #pragma HLS RESOURCE variable=weights core=ROM_nP_BRAM\n'
133+
' #pragma HLS ARRAY_PARTITION variable=biases complete\n'
134+
'\n'
135+
' typename CONFIG_T::accum_t acc[CONFIG_T::n_out];\n'
136+
' #pragma HLS ARRAY_PARTITION variable=acc complete\n'
137+
'\n'
138+
' InitAccum:\n'
139+
' for (int i = 0; i < CONFIG_T::n_out; i++) {{\n'
140+
' #pragma HLS UNROLL\n'
141+
' acc[i] = (typename CONFIG_T::accum_t) biases[i];\n'
142+
' }}\n'
143+
'\n'
144+
).format(suffix=function_suffix)
125145

126146
# Unrolled multiplication, according to the three cases
127-
n_in, n_out = node.model.config.backend.get_layer_mult_size(node)
128-
reuse_factor = node.get_attr('reuse_factor')
129-
weights = node.weights['weight']
130147
if reuse_factor <= n_in:
131-
mult_code = self.__generate_unrolled_mult_code_rf_leq_nin(n_in, n_out, reuse_factor, weights)
148+
mult_code = self._generate_unrolled_mult_code_rf_leq_nin(n_in, n_out, reuse_factor, weights)
132149
elif reuse_factor > n_in and reuse_factor % n_in == 0:
133-
mult_code = self.__generate_unrolled_mult_code_rf_gt_nin_rem0(n_in, n_out, reuse_factor, weights)
150+
mult_code = self._generate_unrolled_mult_code_rf_gt_nin_rem0(n_in, n_out, reuse_factor, weights)
134151
else:
135152
# This case shouldn't happen if my understanding of RF is correct
136153
# The function fpga_backend._validate_reuse_factor() has assertion rf % n_in == 0 or rf < n_in
137154
raise Exception('Not implemented...')
138155

139156
# Write output
140-
generated_code += mult_code + "\n"
157+
generated_code += mult_code + '\n'
141158
generated_code += (
142-
" Result:\n"
143-
" for (int i = 0; i < CONFIG_T::n_out; i++) {\n"
144-
" #pragma HLS UNROLL\n"
145-
" res[i] = cast<data_T, res_T, CONFIG_T>(acc[i]);\n"
146-
" }\n"
147-
" }\n"
148-
"};\n"
159+
' Result:\n'
160+
' for (int i = 0; i < CONFIG_T::n_out; i++) {\n'
161+
' #pragma HLS UNROLL\n'
162+
' res[i] = cast<data_T, res_T, CONFIG_T>(acc[i]);\n'
163+
' }\n'
164+
' }\n'
165+
'};\n'
149166
)
150167

151168
return generated_code
152169

153-
def __generate_unrolled_mult_code_rf_leq_nin(self, n_in, n_out, reuse_factor, weights):
170+
def _generate_unrolled_mult_code_rf_leq_nin(self, n_in, n_out, reuse_factor, weights):
154171
# Function constants
155172
mult_factor = min(n_in, reuse_factor)
156173
block_factor = int(math.ceil(n_in * n_out / reuse_factor))
@@ -162,24 +179,29 @@ def __generate_unrolled_mult_code_rf_leq_nin(self, n_in, n_out, reuse_factor, we
162179
# The new shape is (parallel_mult, reuse_factor)
163180
zeros = np.sum(~weights.data.reshape(block_factor, reuse_factor).any(1))
164181

182+
# Used to pad the code to make it human-readable
183+
indent = ' '
184+
165185
# Generate unrolled multiplications
166-
mult_code = f"\t\t#pragma HLS ALLOCATION operation instances=mul limit={mult_limit - zeros}\n"
167-
mult_code += "\t\tMULT: {\n"
168-
mult_code += "\t\t\t#pragma HLS protocol\n"
186+
mult_code = f'{indent*2}#pragma HLS ALLOCATION operation instances=mul limit={mult_limit - zeros}\n'
187+
mult_code += f'{indent*2}MULT: {{\n'
188+
mult_code += f'{indent*3}#pragma HLS protocol\n'
169189

170190
for ir in range(reuse_factor):
171191
acc_step = 0
172192
out_index = 0
173193
w_index = ir
174194
in_index = ir
175195

176-
mult_code += f"\t\t\tM{ir}: {{\n"
196+
mult_code += f'{indent*3}M{ir}: {{\n'
177197
for _ in range(block_factor):
178198
if weights.data.flatten()[w_index] != 0:
179-
mult_code += f"\t\t\t\tacc[{out_index}] += \
180-
static_cast<typename CONFIG_T::accum_t>\
181-
(CONFIG_T::template product<data_T, typename CONFIG_T::weight_t>::\
182-
product(data[{in_index}], weights[{w_index}]));\n"
199+
mult_code += (
200+
f'{indent*4}acc[{out_index}] += '
201+
'static_cast<typename CONFIG_T::accum_t>'
202+
'(CONFIG_T::template product<data_T, typename CONFIG_T::weight_t>::'
203+
f'product(data[{in_index}], weights[{w_index}]));\n'
204+
)
183205

184206
w_index += reuse_factor
185207
in_index += reuse_factor
@@ -191,13 +213,13 @@ def __generate_unrolled_mult_code_rf_leq_nin(self, n_in, n_out, reuse_factor, we
191213
else:
192214
acc_step += 1
193215

194-
mult_code += "\t\t\t}\n"
216+
mult_code += f'{indent*3}}}\n'
195217

196-
mult_code += "\t\t}\n"
218+
mult_code += f'{indent*2}}}\n'
197219

198220
return mult_code
199221

200-
def __generate_unrolled_mult_code_rf_gt_nin_rem0(self, n_in, n_out, reuse_factor, weights):
222+
def _generate_unrolled_mult_code_rf_gt_nin_rem0(self, n_in, n_out, reuse_factor, weights):
201223
# Function constants
202224
mult_factor = min(n_in, reuse_factor)
203225
block_factor = int(math.ceil(n_in * n_out / reuse_factor))
@@ -208,6 +230,9 @@ def __generate_unrolled_mult_code_rf_gt_nin_rem0(self, n_in, n_out, reuse_factor
208230
# The new shape is (parallel_mult, reuse_factor)
209231
zeros = np.sum(~weights.data.reshape(block_factor, reuse_factor).any(1))
210232

233+
# Used to pad the code to make it human-readable
234+
indent = ' '
235+
211236
# Generate out indices
212237
outidx = [0] * reuse_factor
213238
outstep = 0
@@ -221,32 +246,34 @@ def __generate_unrolled_mult_code_rf_gt_nin_rem0(self, n_in, n_out, reuse_factor
221246
in_index = 0
222247

223248
# Generate unrolled multiplications
224-
mult_code = f"\t\t#pragma HLS ALLOCATION operation instances=mul limit={mult_limit - zeros}\n"
225-
mult_code += "\t\tMULT: {\n"
226-
mult_code += "\t\t\t#pragma HLS protocol\n"
249+
mult_code = f'{indent*2}#pragma HLS ALLOCATION operation instances=mul limit={mult_limit - zeros}\n'
250+
mult_code += f'{indent*2}MULT: {{\n'
251+
mult_code += f'{indent*3}#pragma HLS protocol\n'
227252

228253
for ir in range(reuse_factor):
229254
w_index = ir
230255
out_index = outidx[ir]
231256

232-
mult_code += f"\t\t\tM{ir}: {{\n"
257+
mult_code += f'{indent*3}M{ir}: {{\n'
233258
for _ in range(block_factor):
234259
if weights.data.flatten()[w_index] != 0:
235-
mult_code += f"\t\t\t\tacc[{int(out_index)}] += \
236-
static_cast<typename CONFIG_T::accum_t>\
237-
(CONFIG_T::template product<data_T, typename CONFIG_T::weight_t>::\
238-
product(data[{in_index}], weights[{w_index}]));\n"
260+
mult_code += (
261+
f'{indent*4}acc[{int(out_index)}] += '
262+
'static_cast<typename CONFIG_T::accum_t>'
263+
'(CONFIG_T::template product<data_T, typename CONFIG_T::weight_t>::'
264+
f'product(data[{in_index}], weights[{w_index}]));\n'
265+
)
239266

240267
w_index += reuse_factor
241268
if w_index > n_in * n_out:
242269
break
243270
out_index += outscale
244-
mult_code += "\t\t\t}\n"
271+
mult_code += f'{indent*3}}}\n'
245272

246273
in_index += 1
247274
if in_index >= n_in:
248275
in_index = 0
249276

250-
mult_code += "\t\t}\n"
277+
mult_code += f'{indent*2}}}\n'
251278

252279
return mult_code

hls4ml/backends/vivado/passes/convolution_templates.py

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,13 @@
1717
static const unsigned n_out = {n_out};
1818
static const unsigned reuse_factor = {reuse};
1919
static const unsigned strategy = nnet::{strategy};
20-
static const unsigned resource_implementation = nnet::{dense_resource_implementation};
21-
template<class data_T, class res_T, class CONFIG_T>
22-
using dense_unrolled = nnet::{unrolled_function}<data_T, res_T, CONFIG_T>;
2320
static const unsigned n_zeros = {nzeros};
2421
static const unsigned multiplier_limit = DIV_ROUNDUP(n_in * n_out, reuse_factor) - n_zeros / reuse_factor;
2522
typedef {accum_t.name} accum_t;
2623
typedef {bias_t.name} bias_t;
2724
typedef {weight_t.name} weight_t;
25+
template<class data_T, class res_T, class CONFIG_T>
26+
using kernel = nnet::{dense_function}<data_T, res_T, CONFIG_T>;
2827
template<class x_T, class y_T>
2928
using product = nnet::product::{product_type}<x_T, y_T>;
3029
}};\n"""
@@ -49,9 +48,6 @@
4948
static const bool store_weights_in_bram = false;
5049
static const unsigned strategy = nnet::{strategy};
5150
static const nnet::conv_implementation implementation = nnet::conv_implementation::{implementation};
52-
static const unsigned resource_implementation = nnet::{dense_resource_implementation};
53-
template<class data_T, class res_T, class CONFIG_T>
54-
using dense_unrolled = nnet::{unrolled_function}<data_T, res_T, CONFIG_T>;
5551
static const unsigned min_width = {min_width};
5652
static const ap_uint<filt_width> pixels[min_width];
5753
static const unsigned n_partitions = {n_partitions};
@@ -96,8 +92,6 @@ def format(self, node):
9692
params['fill_fn'] = f'fill_buffer_{node.index}'
9793
else:
9894
params['fill_fn'] = 'FillConv1DBuffer'
99-
# TODO - Extend unrolled Dense Resource to Conv1D
100-
params['unrolled_function'] = 'DenseResourceUnrolled'
10195

10296
conv_config = self.template.format(**params)
10397

@@ -108,8 +102,18 @@ def format(self, node):
108102
mult_params['product_type'] = get_backend('vivado').product_type(
109103
node.get_input_variable().type.precision, node.get_weights('weight').type.precision
110104
)
111-
# TODO - Extend unrolled Dense Resource to Conv1D
112-
mult_params['unrolled_function'] = 'DenseResourceUnrolled'
105+
106+
if node.get_attr('strategy').lower() == 'latency':
107+
mult_params['dense_function'] = 'DenseLatency'
108+
elif node.get_attr('strategy').lower() == 'resource':
109+
if int(mult_params['reuse_factor']) <= int(mult_params['n_in']):
110+
mult_params['dense_function'] = 'DenseResource_rf_leq_nin'
111+
else:
112+
mult_params['dense_function'] = 'DenseResource_rf_gt_nin_rem0'
113+
# The 3rd case is never used
114+
elif node.get_attr('strategy').lower() == 'unrolled':
115+
mult_params['dense_function'] = f'dense_unrolled_{node.index}'
116+
113117
mult_config = self.mult_template.format(**mult_params)
114118

115119
return mult_config + '\n' + conv_config
@@ -160,9 +164,6 @@ def __init__(self):
160164
static const bool store_weights_in_bram = false;
161165
static const unsigned strategy = nnet::{strategy};
162166
static const nnet::conv_implementation implementation = nnet::conv_implementation::{implementation};
163-
static const unsigned resource_implementation = nnet::{dense_resource_implementation};
164-
template<class data_T, class res_T, class CONFIG_T>
165-
using dense_unrolled = nnet::{unrolled_function}<data_T, res_T, CONFIG_T>;
166167
static const unsigned min_height = {min_height};
167168
static const unsigned min_width = {min_width};
168169
static const ap_uint<filt_height * filt_width> pixels[min_height * min_width];
@@ -217,15 +218,6 @@ def format(self, node):
217218
else:
218219
params['fill_fn'] = 'FillConv2DBuffer'
219220

220-
if (
221-
node.get_attr('dense_resource_implementation', 'standard') == 'unrolled'
222-
and node.get_attr('strategy').lower() == 'resource'
223-
and node.get_attr('reuse_factor') > 1
224-
):
225-
params['unrolled_function'] = f'dense_unrolled_{node.index}'
226-
else:
227-
params['unrolled_function'] = 'DenseResourceUnrolled'
228-
229221
conv_config = self.template.format(**params)
230222

231223
mult_params = self._default_config_params(node)
@@ -235,14 +227,18 @@ def format(self, node):
235227
mult_params['product_type'] = get_backend('vivado').product_type(
236228
node.get_input_variable().type.precision, node.get_weights('weight').type.precision
237229
)
238-
if (
239-
node.get_attr('dense_resource_implementation', 'standard') == 'unrolled'
240-
and node.get_attr('strategy').lower() == 'resource'
241-
and node.get_attr('reuse_factor') > 1
242-
):
243-
mult_params['unrolled_function'] = f'dense_unrolled_{node.index}'
244-
else:
245-
mult_params['unrolled_function'] = 'DenseResourceUnrolled'
230+
231+
if node.get_attr('strategy').lower() == 'latency':
232+
mult_params['dense_function'] = 'DenseLatency'
233+
elif node.get_attr('strategy').lower() == 'resource':
234+
if int(mult_params['reuse_factor']) <= int(mult_params['n_in']):
235+
mult_params['dense_function'] = 'DenseResource_rf_leq_nin'
236+
else:
237+
mult_params['dense_function'] = 'DenseResource_rf_gt_nin_rem0'
238+
# The 3rd case is never used
239+
elif node.get_attr('strategy').lower() == 'unrolled':
240+
mult_params['dense_function'] = f'dense_unrolled_{node.index}'
241+
246242
mult_config = self.mult_template.format(**mult_params)
247243

248244
return mult_config + '\n' + conv_config

0 commit comments

Comments
 (0)