Skip to content

Commit 329c2d8

Browse files
committed
Merge remote-tracking branch 'upstream/main' into sep_to_dw_point
2 parents 13b6dbb + 5bcba68 commit 329c2d8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+1673
-309
lines changed

.gitlab-ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ generator:
77
stage: generate
88
image: python:3.8-alpine
99
variables:
10-
N_TESTS_PER_YAML: 5
10+
N_TESTS_PER_YAML: 4
1111
tags:
1212
- k8s-default
1313
before_script:

.pre-commit-config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ exclude: (^hls4ml\/templates\/(vivado|quartus)\/(ap_types|ac_types)\/|^test/pyte
22

33
repos:
44
- repo: https://github.com/psf/black
5-
rev: 24.4.2
5+
rev: 24.8.0
66
hooks:
77
- id: black
88
language_version: python3
@@ -30,7 +30,7 @@ repos:
3030
args: ["--profile", "black", --line-length=125]
3131

3232
- repo: https://github.com/asottile/pyupgrade
33-
rev: v3.16.0
33+
rev: v3.17.0
3434
hooks:
3535
- id: pyupgrade
3636
args: ["--py36-plus"]
@@ -41,7 +41,7 @@ repos:
4141
- id: setup-cfg-fmt
4242

4343
- repo: https://github.com/pycqa/flake8
44-
rev: 7.1.0
44+
rev: 7.1.1
4545
hooks:
4646
- id: flake8
4747
exclude: docs/conf.py

hls4ml/backends/catapult/passes/pointwise.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from copy import copy
22

3-
import numpy as np
4-
53
from hls4ml.backends.catapult.passes.convolution_templates import (
64
Conv1DConfigTemplate,
75
Conv1DFunctionTemplate,
@@ -78,9 +76,6 @@ def match(self, node):
7876
def transform(self, model, node):
7977
dim = node.__class__.__name__[-2:] # '1D' or '2D'
8078
pw_node = model.make_node('PointwiseConv' + dim, node.name, copy(node.attributes), node.inputs.copy())
81-
if len(node.weights['weight'].data.shape) == 2: # This can happen if we assign weights of Dense layer to 1x1 Conv2D
82-
expand_axis = tuple(range(int(dim[0])))
83-
pw_node.weights['weight'].data = np.expand_dims(node.weights['weight'].data, axis=expand_axis)
8479
pw_node.weights['bias'].data = node.weights['bias'].data
8580
# Set strategy to ensure lowercase string is passed to the template
8681
if model.config.is_resource_strategy(pw_node):
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import numpy as np
2+
3+
from hls4ml.backends import Backend
4+
from hls4ml.backends.template import FunctionCallTemplate
5+
from hls4ml.model.layers import Layer
6+
from hls4ml.model.optimizer import OptimizerPass
7+
from hls4ml.model.optimizer.passes.hgq_proxy_model import FixedPointQuantizer, UnaryLUT
8+
from hls4ml.model.types import Source
9+
10+
11+
def to_apfixed(k, b, i, RND, SAT):
12+
u = 'u' if k == 0 else ''
13+
return f'ap_{u}fixed<{b},{i},AP_{RND},AP_{SAT}>'
14+
15+
16+
def to_acfixed(k, b, i, RND, SAT):
17+
k = 'false' if k == 0 else 'true'
18+
return f'ac_fixed<{b},{i},{k},AC_{RND},AC_{SAT}>'
19+
20+
21+
def generate_mask_fn(
22+
name: str, shape: tuple[int, ...], k: np.ndarray, b: np.ndarray, i: np.ndarray, RND: str, SAT: str, backend: str
23+
) -> str:
24+
"""Generate heterogenous quantization mask function, ONLY works for IOType=io_parallel"""
25+
assert k.shape[0] == b.shape[0] == i.shape[0] == 1
26+
assert backend.lower() in ('quartus', 'vivado', 'vitis'), f'Backend {backend} not tested'
27+
Ks, Bs, Is = k[0], b[0], i[0]
28+
Ks, Bs, Is = np.broadcast_to(Ks, shape), np.broadcast_to(Bs, shape), np.broadcast_to(Is, shape)
29+
Ks, Bs, Is = Ks.ravel(), Bs.ravel(), Is.ravel()
30+
masks = []
31+
to_fixed = to_acfixed if backend.lower() == 'quartus' else to_apfixed
32+
for idx, (k, b, i) in enumerate(zip(Ks, Bs, Is)):
33+
if b == 0:
34+
fn = f'out[{idx}] = 0;'
35+
else:
36+
fn = f'out[{idx}] = {to_fixed(k, b, i, RND, SAT)}(inp[{idx}]);'
37+
masks.append(f' {fn}')
38+
body = "\n".join(masks)
39+
mask_fn = f'''
40+
template<typename input_t, typename output_t>
41+
void {name}(input_t *inp, output_t *out) {{
42+
#pragma HLS INLINE
43+
44+
{body}
45+
}}
46+
'''
47+
return mask_fn
48+
49+
50+
class ProcessFixedPointQuantizerLayer(OptimizerPass):
51+
def match(self, node: Layer):
52+
return isinstance(node, FixedPointQuantizer)
53+
54+
def transform(self, model, node: FixedPointQuantizer):
55+
if node.fusible:
56+
model.remove_node(node, rewire=True)
57+
return True
58+
59+
if model.config.config['IOType'] != 'io_parallel':
60+
raise NotImplementedError('Heterogenous quantization for activations is only supported with IOType=io_parallel')
61+
62+
backend = model.config.config['Backend']
63+
64+
name = node.name
65+
66+
assert node.mask_kbi is not None
67+
k, b, i = node.mask_kbi
68+
RND = node.RND
69+
SAT = node.SAT
70+
mask_fn: str = generate_mask_fn(name, node.get_input_variable().shape, k, b, i, RND, SAT, backend)
71+
72+
node.set_attr('mask_fn_codegen', Source(mask_fn))
73+
74+
75+
class ProcessFixedPointQuantizerCall(FunctionCallTemplate):
76+
def __init__(self):
77+
super().__init__(FixedPointQuantizer, include_header=[])
78+
self.template = 'nnet::{name}<{input_t}, {output_t}>({input}, {output});'
79+
80+
def format(self, node):
81+
params = self._default_function_params(node)
82+
83+
return self.template.format(**params)
84+
85+
86+
class ProcessUnaryLUTCall(FunctionCallTemplate):
87+
def __init__(self):
88+
super().__init__(UnaryLUT, include_header=[])
89+
self.template = 'nnet::unary_lut<{input_t}, {output_t}, {config}>({input}, {output}, {table});'
90+
self.include_header = [
91+
'nnet_utils/nnet_activation.h',
92+
'nnet_utils/nnet_activation_stream.h',
93+
]
94+
95+
def format(self, node):
96+
params = self._default_function_params(node)
97+
node.attributes['result_t'].precision = node.attributes['table_t'].precision
98+
params['config'] = f'unary_lut_config{node.index}'
99+
params['table'] = node.get_weights('table').name
100+
101+
return self.template.format(**params)
102+
103+
104+
def register_hgq_proxy_model(backend: Backend):
105+
backend.register_pass('process_fixed_point_quantizer_layer', ProcessFixedPointQuantizerLayer)
106+
backend.register_template(ProcessFixedPointQuantizerCall)
107+
backend.register_template(ProcessUnaryLUTCall)

hls4ml/backends/quartus/passes/core_templates.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from hls4ml.backends.backend import get_backend
22
from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate
33
from hls4ml.model.layers import Activation, BatchNormalization, Dense, HardActivation, ParametrizedActivation, PReLU, Softmax
4+
from hls4ml.model.optimizer.passes.hgq_proxy_model import UnaryLUT
45

56
# Dense templates
67

@@ -152,7 +153,7 @@ def format(self, node):
152153

153154
class ActivationConfigTemplate(LayerConfigTemplate):
154155
def __init__(self):
155-
super().__init__((Activation, ParametrizedActivation, PReLU))
156+
super().__init__((Activation, ParametrizedActivation, PReLU, UnaryLUT))
156157
self.template = activ_config_template
157158

158159
def format(self, node):

hls4ml/backends/quartus/passes/pointwise.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from copy import copy
22

3-
import numpy as np
4-
53
from hls4ml.backends.fpga.fpga_layers import PointwiseConv1D, PointwiseConv2D
64
from hls4ml.backends.quartus.passes.convolution_templates import (
75
Conv1DConfigTemplate,
@@ -86,9 +84,6 @@ def transform(self, model, node):
8684
pw_node = model.make_node(
8785
'PointwiseConv' + dim, node.name, copy(node.attributes), node.inputs.copy(), outputs=node.outputs.copy()
8886
)
89-
if len(node.weights['weight'].data.shape) == 2: # This can happen if we assign weights of Dense layer to 1x1 Conv2D
90-
expand_axis = tuple(range(int(dim[0])))
91-
pw_node.weights['weight'].data = np.expand_dims(node.weights['weight'].data, axis=expand_axis)
9287
pw_node.weights['bias'].data = node.weights['bias'].data
9388
model.replace_node(node, pw_node)
9489

hls4ml/backends/quartus/passes/recurrent_templates.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
using activation_recr = nnet::activation::{recurrent_activation}<x_T, y_T, config_T>;
6767
6868
static const unsigned reuse_factor = {reuse};
69+
static const unsigned pytorch_order = {pytorch};
6970
static const bool store_weights_in_bram = false;
7071
}};\n'''
7172

@@ -92,6 +93,7 @@ def format(self, node):
9293
params['config_mult_h'] = f'config{node.index}_h_mult'
9394
params['act_t'] = '{}_config{}'.format(node.get_attr('activation'), str(node.index) + '_act')
9495
params['act_recurrent_t'] = '{}_config{}'.format(node.get_attr('recurrent_activation'), str(node.index) + '_rec_act')
96+
params['pytorch'] = 'true' if node.get_attr('pytorch', False) else 'false'
9597
gru_config = self.gru_template.format(**params)
9698

9799
# Activation is on candidate hidden state, dimensionality (1, n_units)
@@ -256,6 +258,9 @@ def format(self, node):
256258
}};\n"""
257259

258260
simple_rnn_function_template = 'nnet::simple_rnn<{input_t}, {output_t}, {config}>({input}, {output}, {weights});'
261+
simple_rnn_pytorch_function_template = (
262+
'nnet::simple_rnn_pytorch<{input_t}, {output_t}, {config}>({input}, {output}, {weights});'
263+
)
259264

260265

261266
class SimpleRNNConfigTemplate(LayerConfigTemplate):
@@ -301,5 +306,9 @@ def __init__(self):
301306

302307
def format(self, node):
303308
params = self._default_function_params(node)
304-
params['weights'] = 'w{0}, wr{0}, b{0}'.format(str(node.index))
309+
if node.get_attr('pytorch', False):
310+
self.template = simple_rnn_pytorch_function_template
311+
params['weights'] = 'w{0}, wr{0}, b{0}, br{0}'.format(str(node.index))
312+
else:
313+
params['weights'] = 'w{0}, wr{0}, b{0}'.format(str(node.index))
305314
return self.template.format(**params)

hls4ml/backends/quartus/quartus_backend.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
from contextlib import contextmanager
3+
from warnings import warn
34

45
import numpy as np
56

@@ -73,6 +74,7 @@ def _register_flows(self):
7374
'quartus:inplace_stream_flatten',
7475
'quartus:skip_softmax',
7576
'quartus:fix_softmax_table_size',
77+
'quartus:process_fixed_point_quantizer_layer',
7678
'infer_precision_types',
7779
]
7880
optimization_flow = register_flow('optimize', optimization_passes, requires=[init_flow], backend=self.name)
@@ -265,7 +267,17 @@ def init_conv1d(self, layer):
265267
n_in, n_out = self.get_layer_mult_size(layer)
266268
self.set_target_reuse_factor(layer)
267269
self.set_closest_reuse_factor(layer, n_in, n_out)
268-
layer.set_attr('parallelization', layer.model.config.get_layer_config_value(layer, 'ParallelizationFactor', 1))
270+
271+
# Not overriding user parallelization factor, if already set and user has not specified a value
272+
user_pf = layer.model.config.get_layer_config_value(layer, 'ParallelizationFactor', None)
273+
layer_pf = layer.get_attr('parallelization_factor', None)
274+
chosen_pf = user_pf or layer_pf or 1
275+
if user_pf is not None and layer_pf is not None:
276+
if user_pf != layer_pf:
277+
warn(
278+
f'For layer {layer.name}, parallelization factor of {layer_pf} is defined in the proxy-model, but is overridden by the user to {user_pf}.' # noqa: E501
279+
)
280+
layer.set_attr('parallelization', chosen_pf)
269281

270282
# impl_filt_width determines the filter size post-Winograd transformation
271283
layer.set_attr('impl_filt_width', layer.get_attr('filt_width'))
@@ -295,7 +307,17 @@ def init_conv2d(self, layer):
295307
n_in, n_out = self.get_layer_mult_size(layer)
296308
self.set_target_reuse_factor(layer)
297309
self.set_closest_reuse_factor(layer, n_in, n_out)
298-
layer.set_attr('parallelization', layer.model.config.get_layer_config_value(layer, 'ParallelizationFactor', 1))
310+
311+
# Not overriding user parallelization factor, if already set and user has not specified a value
312+
user_pf = layer.model.config.get_layer_config_value(layer, 'ParallelizationFactor', None)
313+
layer_pf = layer.get_attr('parallelization_factor', None)
314+
chosen_pf = user_pf or layer_pf or 1
315+
if user_pf is not None and layer_pf is not None:
316+
if user_pf != layer_pf:
317+
warn(
318+
f'For layer {layer.name}, parallelization factor of {layer_pf} is defined in the proxy-model, but is overridden by the user to {user_pf}.' # noqa: E501
319+
)
320+
layer.set_attr('parallelization', chosen_pf)
299321

300322
# impl_filt_width & impl_filt_height determine the filter size post-Winograd transformation
301323
layer.set_attr('impl_filt_height', layer.get_attr('filt_height'))

hls4ml/backends/vitis/vitis_backend.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,44 @@ def _register_flows(self):
3434
self._default_flow = register_flow('ip', None, requires=ip_flow_requirements, backend=self.name)
3535

3636
def create_initial_config(
37-
self, part='xcvu13p-flga2577-2-e', clock_period=5, clock_uncertainty='27%', io_type='io_parallel', **_
37+
self,
38+
part='xcvu13p-flga2577-2-e',
39+
clock_period=5,
40+
clock_uncertainty='27%',
41+
io_type='io_parallel',
42+
namespace=None,
43+
write_weights_txt=True,
44+
write_tar=False,
45+
**_,
3846
):
47+
"""Create initial configuration of the Vitis backend.
48+
49+
Args:
50+
part (str, optional): The FPGA part to be used. Defaults to 'xcvu13p-flga2577-2-e'.
51+
clock_period (int, optional): The clock period. Defaults to 5.
52+
clock_uncertainty (str, optional): The clock uncertainty. Defaults to 27%.
53+
io_type (str, optional): Type of implementation used. One of
54+
'io_parallel' or 'io_stream'. Defaults to 'io_parallel'.
55+
namespace (str, optional): If defined, place all generated code within a namespace. Defaults to None.
56+
write_weights_txt (bool, optional): If True, writes weights to .txt files which speeds up compilation.
57+
Defaults to True.
58+
write_tar (bool, optional): If True, compresses the output directory into a .tar.gz file. Defaults to False.
59+
60+
Returns:
61+
dict: initial configuration.
62+
"""
3963
config = {}
4064

4165
config['Part'] = part if part is not None else 'xcvu13p-flga2577-2-e'
4266
config['ClockPeriod'] = clock_period if clock_period is not None else 5
4367
config['ClockUncertainty'] = clock_uncertainty if clock_uncertainty is not None else '27%'
4468
config['IOType'] = io_type if io_type is not None else 'io_parallel'
4569
config['HLSConfig'] = {}
70+
config['WriterConfig'] = {
71+
'Namespace': namespace,
72+
'WriteWeightsTxt': write_weights_txt,
73+
'WriteTar': write_tar,
74+
}
4675

4776
return config
4877

hls4ml/backends/vivado/passes/core_templates.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from hls4ml.backends.backend import get_backend
22
from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate
33
from hls4ml.model.layers import Activation, BatchNormalization, Dense, HardActivation, ParametrizedActivation, PReLU, Softmax
4+
from hls4ml.model.optimizer.passes.hgq_proxy_model import UnaryLUT
45

56
# Dense templates
67

@@ -144,7 +145,7 @@ def format(self, node):
144145

145146
class ActivationConfigTemplate(LayerConfigTemplate):
146147
def __init__(self):
147-
super().__init__((Activation, ParametrizedActivation, PReLU))
148+
super().__init__((Activation, ParametrizedActivation, PReLU, UnaryLUT))
148149
self.template = activ_config_template
149150

150151
def format(self, node):

0 commit comments

Comments
 (0)