Skip to content

Commit 10f648c

Browse files
committed
Merge remote-tracking branch 'upstream/main' into hw_opt_p2
2 parents 2ed0865 + 2898ab2 commit 10f648c

File tree

76 files changed

+2743
-657
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

76 files changed

+2743
-657
lines changed

.gitlab-ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ generator:
77
stage: generate
88
image: python:3.8-alpine
99
variables:
10-
N_TESTS_PER_YAML: 5
10+
N_TESTS_PER_YAML: 4
1111
tags:
1212
- k8s-default
1313
before_script:

.pre-commit-config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ exclude: (^hls4ml\/templates\/(vivado|quartus)\/(ap_types|ac_types)\/|^test/pyte
22

33
repos:
44
- repo: https://github.com/psf/black
5-
rev: 24.4.2
5+
rev: 24.8.0
66
hooks:
77
- id: black
88
language_version: python3
@@ -30,7 +30,7 @@ repos:
3030
args: ["--profile", "black", --line-length=125]
3131

3232
- repo: https://github.com/asottile/pyupgrade
33-
rev: v3.15.2
33+
rev: v3.17.0
3434
hooks:
3535
- id: pyupgrade
3636
args: ["--py36-plus"]
@@ -41,7 +41,7 @@ repos:
4141
- id: setup-cfg-fmt
4242

4343
- repo: https://github.com/pycqa/flake8
44-
rev: 7.0.0
44+
rev: 7.1.1
4545
hooks:
4646
- id: flake8
4747
exclude: docs/conf.py

Jenkinsfile

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
pipeline {
22
agent {
33
docker {
4-
image 'vivado-el7:3'
4+
image 'vivado-alma9:1'
55
args '-v /data/Xilinx:/data/Xilinx'
66
}
77
}
@@ -14,8 +14,9 @@ pipeline {
1414
steps {
1515
dir(path: 'test') {
1616
sh '''#!/bin/bash --login
17-
conda activate hls4ml-py38
18-
pip install tensorflow pyparsing
17+
conda activate hls4ml-py310
18+
conda install -y jupyterhub pydot graphviz pytest pytest-cov
19+
pip install pytest-randomly jupyter onnx>=1.4.0 matplotlib pandas seaborn pydigitalwavetools==1.1 pyyaml tensorflow==2.14 qonnx torch git+https://github.com/google/qkeras.git pyparsing
1920
pip install -U ../ --user
2021
./convert-keras-models.sh -x -f keras-models.txt
2122
pip uninstall hls4ml -y'''

hls4ml/backends/catapult/passes/pointwise.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from copy import copy
22

3-
import numpy as np
4-
53
from hls4ml.backends.catapult.passes.convolution_templates import (
64
Conv1DConfigTemplate,
75
Conv1DFunctionTemplate,
@@ -78,9 +76,6 @@ def match(self, node):
7876
def transform(self, model, node):
7977
dim = node.__class__.__name__[-2:] # '1D' or '2D'
8078
pw_node = model.make_node('PointwiseConv' + dim, node.name, copy(node.attributes), node.inputs.copy())
81-
if len(node.weights['weight'].data.shape) == 2: # This can happen if we assign weights of Dense layer to 1x1 Conv2D
82-
expand_axis = tuple(range(int(dim[0])))
83-
pw_node.weights['weight'].data = np.expand_dims(node.weights['weight'].data, axis=expand_axis)
8479
pw_node.weights['bias'].data = node.weights['bias'].data
8580
# Set strategy to ensure lowercase string is passed to the template
8681
if model.config.is_resource_strategy(pw_node):

hls4ml/backends/fpga/fpga_backend.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,6 @@ def __init__(self, name):
5555
Dense,
5656
Conv1D,
5757
Conv2D,
58-
SeparableConv1D,
59-
SeparableConv2D,
6058
Pooling1D,
6159
Pooling2D,
6260
GlobalPooling1D,
@@ -79,6 +77,16 @@ def __init__(self, name):
7977
attrs.append(ConfigurableAttribute('reuse_factor', default=1))
8078
self.attribute_map[layer] = attrs
8179

80+
# seperable is kind of special because it is effectively two layers that will be split
81+
for layer in (SeparableConv1D, SeparableConv2D):
82+
attrs = self.attribute_map.get(layer, [])
83+
attrs.append(TypeAttribute('depthwise_accum'))
84+
attrs.append(TypeAttribute('pointwise_accum'))
85+
attrs.append(TypeAttribute('depthwise_result'))
86+
attrs.append(ConfigurableAttribute('depthwise_reuse_factor', default=1))
87+
attrs.append(ConfigurableAttribute('pointwise_reuse_factor', default=1))
88+
self.attribute_map[layer] = attrs
89+
8290
act_attrs = self.attribute_map.get(Activation, [])
8391
act_attrs.append(ConfigurableAttribute('table_size', default=1024))
8492
act_attrs.append(TypeAttribute('table', default=FixedPrecisionType(18, 8)))
@@ -687,7 +695,7 @@ def generate_conv1d_line_buffer_fn(self, layer_idx, n_partitions, in_W, in_C, ke
687695
688696
The HLS compiler produces suboptimal designs for a im2col algorithm implementation, so a trick we use is
689697
to generate a resulting a result of im2col transformation explicitly, instead of relying on loops. Since
690-
the result depends on the paraleters of the convolution layer (the input size, the kernel size, stride etc),
698+
the result depends on the parameters of the convolution layer (the input size, the kernel size, stride etc),
691699
we need to do this for every convolution layer.
692700
693701
Args:
@@ -784,7 +792,7 @@ def generate_conv2d_line_buffer_fn(
784792
785793
The HLS compiler produces suboptimal designs for a im2col algorithm implementation, so a trick we use is
786794
to generate a resulting a result of im2col transformation explicitly, instead of relying on loops. Since
787-
the result depends on the paraleters of the convolution layer (the input size, the kernel size, stride etc),
795+
the result depends on the parameters of the convolution layer (the input size, the kernel size, stride etc),
788796
we need to do this for every convolution layer.
789797
790798
Args:

hls4ml/backends/fpga/passes/codegen.py

Lines changed: 73 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,34 @@
1-
from hls4ml.model.layers import Conv1D, Conv2D
1+
from hls4ml.model.layers import Conv1D, Conv2D, SeparableConv1D, SeparableConv2D
22
from hls4ml.model.optimizer import OptimizerPass
33
from hls4ml.model.types import Source
44

55

66
class GenerateConvIm2col(OptimizerPass):
77
'''Generates tcode for im2col step of 1D/2d convolution'''
88

9+
# Note, DepthwizeConv1D/2D also matches because it inherits from Conv1D/2D
910
def match(self, node):
10-
return isinstance(node, (Conv1D, Conv2D)) and node.model.config.get_config_value('IOType') == 'io_parallel'
11+
return (
12+
isinstance(node, (Conv1D, Conv2D, SeparableConv1D, SeparableConv2D))
13+
and node.model.config.get_config_value('IOType') == 'io_parallel'
14+
)
1115

1216
def transform(self, model, node):
13-
node_class = node.__class__.__name__
14-
if '1D' in node_class:
15-
self._generate_im2col_1d(node)
16-
elif '2D' in node_class:
17-
self._generate_im2col_2d(node)
17+
node_class = node.class_name
18+
if 'Separable' in node_class:
19+
if '1D' in node_class:
20+
self._generate_separable_im2col_1d(node)
21+
elif '2D' in node_class:
22+
self._generate_separable_im2col_2d(node)
23+
else:
24+
raise Exception(f'Cannot generate instructions for node {node.name} ({node_class})')
1825
else:
19-
raise Exception(f'Cannot generate instructions for node {node.name} ({node_class})')
26+
if '1D' in node_class:
27+
self._generate_im2col_1d(node)
28+
elif '2D' in node_class:
29+
self._generate_im2col_2d(node)
30+
else:
31+
raise Exception(f'Cannot generate instructions for node {node.name} ({node_class})')
2032

2133
def _generate_im2col_1d(self, node):
2234
code_str = node.model.config.backend.generate_conv1d_line_buffer_fn(
@@ -49,3 +61,56 @@ def _generate_im2col_2d(self, node):
4961
)
5062

5163
node.set_attr('line_buffer_codegen', Source(code_str))
64+
65+
def _generate_separable_im2col_1d(self, node):
66+
dw_code_str = node.model.config.backend.generate_conv1d_line_buffer_fn(
67+
str(node.get_attr('index')) + '_dw',
68+
node.get_attr('n_partitions'),
69+
node.get_input_variable().shape[0],
70+
node.get_input_variable().shape[1],
71+
kernel=node.get_attr('filt_width'),
72+
stride=node.get_attr('stride_width'),
73+
pad=(node.get_attr('pad_left'), node.get_attr('pad_right')),
74+
)
75+
76+
node.set_attr('dw_line_buffer_codegen', Source(dw_code_str))
77+
78+
pw_code_str = node.model.config.backend.generate_conv1d_line_buffer_fn(
79+
str(node.get_attr('index')) + '_pw',
80+
node.get_attr('n_partitions'),
81+
node.get_output_variable().shape[0],
82+
node.get_input_variable().shape[1],
83+
kernel=1,
84+
)
85+
86+
node.set_attr('pw_line_buffer_codegen', Source(pw_code_str))
87+
88+
def _generate_separable_im2col_2d(self, node):
89+
dw_code_str = node.model.config.backend.generate_conv2d_line_buffer_fn(
90+
str(node.get_attr('index')) + '_dw',
91+
node.get_attr('n_partitions'),
92+
node.get_input_variable().shape[0],
93+
node.get_input_variable().shape[1],
94+
node.get_input_variable().shape[2],
95+
kernel=(node.get_attr('filt_height'), node.get_attr('filt_width')),
96+
stride=(node.get_attr('stride_height'), node.get_attr('stride_width')),
97+
pad=(
98+
node.get_attr('pad_top'),
99+
node.get_attr('pad_bottom'),
100+
node.get_attr('pad_left'),
101+
node.get_attr('pad_right'),
102+
),
103+
)
104+
105+
node.set_attr('dw_line_buffer_codegen', Source(dw_code_str))
106+
107+
pw_code_str = node.model.config.backend.generate_conv2d_line_buffer_fn(
108+
str(node.get_attr('index')) + '_pw',
109+
node.get_attr('n_partitions'),
110+
node.get_output_variable().shape[0],
111+
node.get_output_variable().shape[1],
112+
node.get_input_variable().shape[2],
113+
kernel=(1, 1),
114+
)
115+
116+
node.set_attr('pw_line_buffer_codegen', Source(pw_code_str))
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import numpy as np
2+
3+
from hls4ml.backends import Backend
4+
from hls4ml.backends.template import FunctionCallTemplate
5+
from hls4ml.model.layers import Layer
6+
from hls4ml.model.optimizer import OptimizerPass
7+
from hls4ml.model.optimizer.passes.hgq_proxy_model import FixedPointQuantizer, UnaryLUT
8+
from hls4ml.model.types import Source
9+
10+
11+
def to_apfixed(k, b, i, RND, SAT):
12+
u = 'u' if k == 0 else ''
13+
return f'ap_{u}fixed<{b},{i},AP_{RND},AP_{SAT}>'
14+
15+
16+
def to_acfixed(k, b, i, RND, SAT):
17+
k = 'false' if k == 0 else 'true'
18+
return f'ac_fixed<{b},{i},{k},AC_{RND},AC_{SAT}>'
19+
20+
21+
def generate_mask_fn(
22+
name: str, shape: tuple[int, ...], k: np.ndarray, b: np.ndarray, i: np.ndarray, RND: str, SAT: str, backend: str
23+
) -> str:
24+
"""Generate heterogenous quantization mask function, ONLY works for IOType=io_parallel"""
25+
assert k.shape[0] == b.shape[0] == i.shape[0] == 1
26+
assert backend.lower() in ('quartus', 'vivado', 'vitis'), f'Backend {backend} not tested'
27+
Ks, Bs, Is = k[0], b[0], i[0]
28+
Ks, Bs, Is = np.broadcast_to(Ks, shape), np.broadcast_to(Bs, shape), np.broadcast_to(Is, shape)
29+
Ks, Bs, Is = Ks.ravel(), Bs.ravel(), Is.ravel()
30+
masks = []
31+
to_fixed = to_acfixed if backend.lower() == 'quartus' else to_apfixed
32+
for idx, (k, b, i) in enumerate(zip(Ks, Bs, Is)):
33+
if b == 0:
34+
fn = f'out[{idx}] = 0;'
35+
else:
36+
fn = f'out[{idx}] = {to_fixed(k, b, i, RND, SAT)}(inp[{idx}]);'
37+
masks.append(f' {fn}')
38+
body = "\n".join(masks)
39+
mask_fn = f'''
40+
template<typename input_t, typename output_t>
41+
void {name}(input_t *inp, output_t *out) {{
42+
#pragma HLS INLINE
43+
44+
{body}
45+
}}
46+
'''
47+
return mask_fn
48+
49+
50+
class ProcessFixedPointQuantizerLayer(OptimizerPass):
51+
def match(self, node: Layer):
52+
return isinstance(node, FixedPointQuantizer)
53+
54+
def transform(self, model, node: FixedPointQuantizer):
55+
if node.fusible:
56+
model.remove_node(node, rewire=True)
57+
return True
58+
59+
if model.config.config['IOType'] != 'io_parallel':
60+
raise NotImplementedError('Heterogenous quantization for activations is only supported with IOType=io_parallel')
61+
62+
backend = model.config.config['Backend']
63+
64+
name = node.name
65+
66+
assert node.mask_kbi is not None
67+
k, b, i = node.mask_kbi
68+
RND = node.RND
69+
SAT = node.SAT
70+
mask_fn: str = generate_mask_fn(name, node.get_input_variable().shape, k, b, i, RND, SAT, backend)
71+
72+
node.set_attr('mask_fn_codegen', Source(mask_fn))
73+
74+
75+
class ProcessFixedPointQuantizerCall(FunctionCallTemplate):
76+
def __init__(self):
77+
super().__init__(FixedPointQuantizer, include_header=[])
78+
self.template = 'nnet::{name}<{input_t}, {output_t}>({input}, {output});'
79+
80+
def format(self, node):
81+
params = self._default_function_params(node)
82+
83+
return self.template.format(**params)
84+
85+
86+
class ProcessUnaryLUTCall(FunctionCallTemplate):
87+
def __init__(self):
88+
super().__init__(UnaryLUT, include_header=[])
89+
self.template = 'nnet::unary_lut<{input_t}, {output_t}, {config}>({input}, {output}, {table});'
90+
self.include_header = [
91+
'nnet_utils/nnet_activation.h',
92+
'nnet_utils/nnet_activation_stream.h',
93+
]
94+
95+
def format(self, node):
96+
params = self._default_function_params(node)
97+
node.attributes['result_t'].precision = node.attributes['table_t'].precision
98+
params['config'] = f'unary_lut_config{node.index}'
99+
params['table'] = node.get_weights('table').name
100+
101+
return self.template.format(**params)
102+
103+
104+
def register_hgq_proxy_model(backend: Backend):
105+
backend.register_pass('process_fixed_point_quantizer_layer', ProcessFixedPointQuantizerLayer)
106+
backend.register_template(ProcessFixedPointQuantizerCall)
107+
backend.register_template(ProcessUnaryLUTCall)

hls4ml/backends/quartus/passes/core_templates.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from hls4ml.backends.backend import get_backend
22
from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate
33
from hls4ml.model.layers import Activation, BatchNormalization, Dense, HardActivation, ParametrizedActivation, PReLU, Softmax
4+
from hls4ml.model.optimizer.passes.hgq_proxy_model import UnaryLUT
45

56
# Dense templates
67

@@ -152,7 +153,7 @@ def format(self, node):
152153

153154
class ActivationConfigTemplate(LayerConfigTemplate):
154155
def __init__(self):
155-
super().__init__((Activation, ParametrizedActivation, PReLU))
156+
super().__init__((Activation, ParametrizedActivation, PReLU, UnaryLUT))
156157
self.template = activ_config_template
157158

158159
def format(self, node):

hls4ml/backends/quartus/passes/pointwise.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from copy import copy
22

3-
import numpy as np
4-
53
from hls4ml.backends.fpga.fpga_layers import PointwiseConv1D, PointwiseConv2D
64
from hls4ml.backends.quartus.passes.convolution_templates import (
75
Conv1DConfigTemplate,
@@ -86,9 +84,6 @@ def transform(self, model, node):
8684
pw_node = model.make_node(
8785
'PointwiseConv' + dim, node.name, copy(node.attributes), node.inputs.copy(), outputs=node.outputs.copy()
8886
)
89-
if len(node.weights['weight'].data.shape) == 2: # This can happen if we assign weights of Dense layer to 1x1 Conv2D
90-
expand_axis = tuple(range(int(dim[0])))
91-
pw_node.weights['weight'].data = np.expand_dims(node.weights['weight'].data, axis=expand_axis)
9287
pw_node.weights['bias'].data = node.weights['bias'].data
9388
model.replace_node(node, pw_node)
9489

hls4ml/backends/quartus/passes/recurrent_templates.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
using activation_recr = nnet::activation::{recurrent_activation}<x_T, y_T, config_T>;
6767
6868
static const unsigned reuse_factor = {reuse};
69+
static const unsigned pytorch_order = {pytorch};
6970
static const bool store_weights_in_bram = false;
7071
}};\n'''
7172

@@ -92,6 +93,7 @@ def format(self, node):
9293
params['config_mult_h'] = f'config{node.index}_h_mult'
9394
params['act_t'] = '{}_config{}'.format(node.get_attr('activation'), str(node.index) + '_act')
9495
params['act_recurrent_t'] = '{}_config{}'.format(node.get_attr('recurrent_activation'), str(node.index) + '_rec_act')
96+
params['pytorch'] = 'true' if node.get_attr('pytorch', False) else 'false'
9597
gru_config = self.gru_template.format(**params)
9698

9799
# Activation is on candidate hidden state, dimensionality (1, n_units)
@@ -256,6 +258,9 @@ def format(self, node):
256258
}};\n"""
257259

258260
simple_rnn_function_template = 'nnet::simple_rnn<{input_t}, {output_t}, {config}>({input}, {output}, {weights});'
261+
simple_rnn_pytorch_function_template = (
262+
'nnet::simple_rnn_pytorch<{input_t}, {output_t}, {config}>({input}, {output}, {weights});'
263+
)
259264

260265

261266
class SimpleRNNConfigTemplate(LayerConfigTemplate):
@@ -301,5 +306,9 @@ def __init__(self):
301306

302307
def format(self, node):
303308
params = self._default_function_params(node)
304-
params['weights'] = 'w{0}, wr{0}, b{0}'.format(str(node.index))
309+
if node.get_attr('pytorch', False):
310+
self.template = simple_rnn_pytorch_function_template
311+
params['weights'] = 'w{0}, wr{0}, b{0}, br{0}'.format(str(node.index))
312+
else:
313+
params['weights'] = 'w{0}, wr{0}, b{0}'.format(str(node.index))
305314
return self.template.format(**params)

0 commit comments

Comments
 (0)