Skip to content

Commit 770dc44

Browse files
authored
Merge branch 'main' into selu-lambda-clean
2 parents 60df034 + 6cdf842 commit 770dc44

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+3325
-130
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ repos:
1010
'--skip-string-normalization']
1111

1212
- repo: https://github.com/tox-dev/pyproject-fmt
13-
rev: v2.5.1
13+
rev: v2.6.0
1414
hooks:
1515
- id: pyproject-fmt
1616

@@ -35,7 +35,7 @@ repos:
3535
- id: isort
3636

3737
- repo: https://github.com/asottile/pyupgrade
38-
rev: v3.19.1
38+
rev: v3.20.0
3939
hooks:
4040
- id: pyupgrade
4141
args: ["--py310-plus"]

docs/frontend/keras.rst

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,19 @@
1-
================
2-
Keras and QKeras
3-
================
1+
================================
2+
Keras and its quantized variants
3+
================================
44

5-
Keras and the quantization library QKeras are well supported in ``hls4ml``. Currently, the Keras v2 (``tf.keras``) is the preferred version, and the future versions of ``hls4ml`` will expand support for Keras v3. The frontend is based on the parsing the serialized json representation of the model.
5+
Keras and the quantization library QKeras are well supported in ``hls4ml``. Both Keras v2 (``tf.keras``) and the new Keras v3 are supported. While the Keras v2 support is based on parsing the serialized json representation of the model, the Keras v3 support uses direct model inspection.
66

7-
Currently, ``hls4ml`` can parse most Keras layers, including core layers, convolutional layers, pooling layers, recurrent layers, merging/reshaping layers and activation layers, implemented either via sequential or functional API. Notably missing are the attention and normalization layers. The equivalent QKeras API and quantizers are also supported. The ``Lambda`` layers don't save their state in the serialized format and are thus impossible to parse. In this case, the ``Lambda`` layers can be implemented as custom layers and parsed via the :ref:`Extension API`.
7+
Currently, ``hls4ml`` can parse most Keras layers, including core layers, convolutional layers, pooling layers, recurrent layers, merging/reshaping layers and activation layers, implemented either via sequential or functional API. Notably missing are the attention and normalization layers. The ``Lambda`` layers don't save their state in the serialized format and are thus impossible to parse. In this case, the ``Lambda`` layers can be implemented as custom layers and parsed via the :ref:`Extension API`.
88

99
The ``data_format='channels_first'`` parameter of Keras layers is supported, but not extensively tested. All HLS implementations in ``hls4ml`` are based on ``channels_last`` data format and need to be converted to that format before the HLS code can be emitted. We encourage users of ``channels_first`` to report their experiences to developers on GitHub.
1010

11+
12+
* `QKeras <https://github.com/fastmachinelearning/qkeras>`_
13+
The equivalent QKeras API and its quantizers are also supported by ``hls4ml``. QKeras is not compatible with Keras v3. Currently, only HGQ2 is compatible with Keras v3 (see below).
14+
* `HGQ <https://github.com/calad0i/HGQ>`_
15+
The equivalent HGQ API is also supported. HGQ is not compatible with Keras v3. See `advanced/HGQ <../advanced/hgq.html>`__ for more information.
16+
* `HGQ2 <https://github.com/calad0i/HGQ2>`_
17+
HGQ2 is based on Keras v3. Its support in hls4ml is currently under development.
18+
1119
The development team of ``hls4ml`` is currently exploring options for QKeras alternative and will provide a drop-in replacement API compatible with Keras v3.

docs/intro/setup.rst

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,26 @@ version can be installed directly from ``git``:
3737
Dependencies
3838
============
3939

40-
The ``hls4ml`` library requires python 3.10 or later, and depends on a number of Python packages and external tools for synthesis and simulation. Python dependencies are automatically managed
41-
by ``pip`` or ``conda``.
40+
.. note::
41+
As of version 1.1.0+, all conversion frontend specific packages are optional. Only install the packages you need.
4242

43-
* `TensorFlow <https://pypi.org/project/tensorflow/>`_ (version 2.8 to 2.14) and `QKeras <https://pypi.org/project/qkeras/>`_ are required by the Keras converter. One may want to install newer versions of QKeras from GitHub. Newer versions of TensorFlow can be used, but QKeras and hl4ml do not currently support Keras v3.
43+
The ``hls4ml`` library requires python 3.10 or later, and depends on a number of Python packages and external tools for synthesis and simulation. Python dependencies are automatically managed by ``pip`` or ``conda``.
44+
45+
The following Python packages are all optional and are only required if you intend to use the corresponding converter.
46+
47+
* `Keras <https://pypi.org/project/keras/>`_ is required by the Keras converter.
48+
* `TensorFlow <https://pypi.org/project/tensorflow/>`_ (version 2.8 to 2.14) is required by the Keras v2 converter (keras v2 is included in TensorFlow).
49+
* `Keras <https://pypi.org/project/keras/>` 3.0 or above is required by the Keras v3 converter. Keras v3 supports multiple backends for training and inference, and the conversion is not tied any specific backend. Notice that Keras v3 may **not** coexist with Keras v2 in the same Python environment.
4450

4551
* `ONNX <https://pypi.org/project/onnx/>`_ (version 1.4.0 and newer) is required by the ONNX converter.
4652

47-
* `PyTorch <https://pytorch.org/get-started>`_ package is optional. If not installed, the PyTorch converter will not be available.
53+
* `PyTorch <https://pytorch.org/get-started>`_ is required by the PyTorch converter.
54+
55+
* Quantization support
56+
* `QKeras <https://github.com/fastmachinelearning/qkeras>`_: based on Keras v2. See `frontend/keras <../frontend/keras.html>`_ for more details
57+
* `HGQ <https://github.com/calad0i/HGQ>`_: Based on Keras v2. See `advanced/HGQ <../advanced/hgq.html>`_ for more details.
58+
* `Brevitas <https://xilinx.github.io/brevitas/>`_: Based on PyTorch. See `frontend/pytorch <../frontend/pytorch.html>`_ for more details.
59+
* `QONNX <https://github.com/fastmachinelearning/qonnx>`_: Based on ONNX. See `frontend/onnx <../frontend/onnx.html>`_ for more details.
4860

4961
Running C simulation from Python requires a C++11-compatible compiler. On Linux, a GCC C++ compiler ``g++`` is required. Any version from a recent
5062
Linux should work. On MacOS, the *clang*-based ``g++`` is enough. For the oneAPI backend, one must have oneAPI installed, along with the FPGA compiler,

hls4ml/backends/fpga/fpga_backend.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -917,33 +917,6 @@ def generate_conv2d_line_buffer_fn(
917917

918918
return generated_code
919919

920-
@staticmethod
921-
def permute_config_gen(name: str, shape: tuple[int, ...], perm: tuple[int, ...]):
922-
"""
923-
Generate new shape and perm_strides for a permute operation. Operates by mapping the output index
924-
to input input index by:
925-
- unravel the output index
926-
- map each dimension to the corresponding stride in the input tensor, sum
927-
The operation can be expressed as:
928-
929-
new_shape = tuple(shape[i] for i in perm)
930-
strides = np.cumprod((shapes[1:] + (1,))[::-1])[::-1]
931-
perm_strides = [strides[i] for i in perm]
932-
out[index] = inp[np.dot(np.unravel_index(index, new_shape), perm_strides)]
933-
934-
Args:
935-
name (str): The name of the configuration.
936-
shape (tuple[int, ...]): The shape of the input tensor.
937-
perm (tuple[int, ...]): The permutation of the dimensions.
938-
939-
Returns:
940-
(new_shape, perm_strides) (tuple, tuple): the output shape and permutation strides.
941-
"""
942-
new_shape = tuple(shape[i] for i in perm)
943-
strides = np.cumprod((shape[1:] + (1,))[::-1])[::-1]
944-
perm_strides = tuple(int(strides[i]) for i in perm)
945-
return (new_shape, perm_strides)
946-
947920
@model_optimizer()
948921
def write_hls(self, model):
949922
self.writer.write_hls(model)

hls4ml/backends/fpga/passes/clone.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,9 @@ def transform(self, model, node):
7979
n_outputs = len(output_map[output]) + in_output
8080
if n_outputs == 1:
8181
continue
82-
if n_outputs > 3:
82+
if n_outputs > 7:
8383
msg = f'ERROR: Cloning output {output} of {node.class_name}\
84-
({node.name}) more than 3 times not currently supported'
84+
({node.name}) more than 7 times not currently supported'
8585
raise ValueError(msg)
8686

8787
out_var = node.get_output_variable(output)

hls4ml/backends/oneapi/oneapi_backend.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -130,12 +130,15 @@ def get_default_flow(self):
130130
def get_writer_flow(self):
131131
return self._writer_flow
132132

133-
def create_initial_config(self, part='Arria10', clock_period=5, io_type='io_parallel', write_tar=False, **_):
133+
def create_initial_config(
134+
self, part='Agilex7', clock_period=5, hyperopt_handshake=False, io_type='io_parallel', write_tar=False, **_
135+
):
134136
"""Create initial configuration of the oneAPI backend.
135137
136138
Args:
137-
part (str, optional): The FPGA part to be used. Defaults to 'Arria10'.
138-
clock_period (int, optional): The clock period. Defaults to 5.
139+
part (str, optional): The FPGA part to be used. Defaults to 'Agilex7'.
140+
clock_period (int, optional): The clock period in ns. Defaults to 5.
141+
hyperopt_handshake (bool, optional): Should hyper-optimized handshaking be used? Defaults to False
139142
io_type (str, optional): Type of implementation used. One of
140143
'io_parallel' or 'io_stream'. Defaults to 'io_parallel'.
141144
write_tar (bool, optional): If True, compresses the output directory into a .tar.gz file. Defaults to False.
@@ -146,8 +149,9 @@ def create_initial_config(self, part='Arria10', clock_period=5, io_type='io_para
146149

147150
config = {}
148151

149-
config['Part'] = part if part is not None else 'Arria10'
152+
config['Part'] = part if part is not None else 'Agilex7'
150153
config['ClockPeriod'] = clock_period
154+
config['HyperoptHandshake'] = hyperopt_handshake
151155
config['IOType'] = io_type
152156
config['HLSConfig'] = {}
153157
config['WriterConfig'] = {
@@ -167,7 +171,7 @@ def compile(self, model):
167171
Exception: If the project failed to compile
168172
169173
Returns:
170-
string: Returns the name of the compiled library.
174+
Path: Returns the name of the compiled library.
171175
"""
172176
outdir = Path(Path.cwd(), model.config.get_output_dir())
173177
builddir = outdir / 'build'

hls4ml/backends/oneapi/passes/reshaping_templates.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from hls4ml.backends.oneapi.oneapi_template import StreamFunctionCallTemplate, TaskSequenceTemplate
44
from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate
55
from hls4ml.model.layers import Reshape, Resize, Transpose, ZeroPadding1D, ZeroPadding2D
6+
from hls4ml.utils.transpose_utils import transpose_config_gen
67

78
# ZeroPadding templates
89

@@ -185,16 +186,8 @@ def format(self, node):
185186
perm = tuple(node.get_attr('perm'))
186187
name = f'config{node.index}'
187188

188-
new_shape, perm_strides = node.model.config.backend.permute_config_gen(name, shape, perm)
189-
return transpose_config_template.format(
190-
dims=len(shape),
191-
N=int(np.prod(shape)),
192-
from_shape=', '.join(str(x) for x in shape),
193-
perm=', '.join(str(x) for x in perm),
194-
perm_strides=', '.join(str(x) for x in perm_strides),
195-
to_shape=', '.join(str(x) for x in new_shape),
196-
config_name=name,
197-
)
189+
conf = transpose_config_gen(name, shape, perm)
190+
return transpose_config_template.format(**conf)
198191

199192

200193
class TransposeFunctionTemplate(FunctionCallTemplate):
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
from math import ceil
2+
3+
from hls4ml.backends.backend import get_backend
4+
from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate
5+
from hls4ml.model.layers import Einsum
6+
from hls4ml.utils.transpose_utils import transpose_config_gen
7+
8+
from .reshaping_templates import transpose_config_template
9+
10+
# Shared Dense template
11+
# Einsum template
12+
13+
einsum_config_template = '''
14+
struct config{index} {{
15+
typedef config{index}_tpose_inp0 tpose_inp0_config;
16+
typedef config{index}_tpose_inp1 tpose_inp1_config;
17+
typedef config{index}_tpose_out tpose_out_conf;
18+
19+
typedef {accum_t.name} accum_t;
20+
21+
// Layer Sizes
22+
static const unsigned n_free0 = {n_free0};
23+
static const unsigned n_free1 = {n_free1};
24+
static const unsigned n_contract = {n_contract};
25+
static const unsigned n_inplace = {n_inplace};
26+
27+
// Resource reuse info
28+
static const unsigned io_type = nnet::{iotype};
29+
static const unsigned strategy = nnet::{strategy};
30+
static const unsigned reuse_factor = {reuse_factor};
31+
static const unsigned multiplier_limit = {multiplier_limit};
32+
static const bool store_weights_in_bram = false; // NOT USED
33+
34+
template <class x_T, class y_T>
35+
using product = nnet::product::{product_type}<x_T, y_T>;
36+
}};
37+
'''
38+
39+
einsum_function_template = 'nnet::einsum<{input0_t}, {input1_t}, {output_t}, {config}>({input0}, {input1}, {output});'
40+
41+
einsum_include_list = ['nnet_utils/nnet_einsum.h']
42+
43+
44+
class EinsumConfigTemplate(LayerConfigTemplate):
45+
def __init__(self):
46+
super().__init__(Einsum)
47+
self.template = einsum_config_template
48+
49+
def format(self, node: Einsum):
50+
default_params = self._default_config_params(node)
51+
52+
strategy = node.attributes['strategy']
53+
io_type = node.model.config.get_config_value('IOType')
54+
55+
assert io_type == 'io_parallel', 'EinsumDense layer only supports io_parallel for now'
56+
assert strategy.lower() == 'latency', 'EinsumDense layer only supports Latency strategy for now'
57+
58+
# EinsumDense config
59+
params = default_params.copy()
60+
params['strategy'] = strategy
61+
params['n_free0'] = node.attributes['n_free0']
62+
params['n_free1'] = node.attributes['n_free1']
63+
params['n_contract'] = node.attributes['n_contract']
64+
params['n_inplace'] = node.attributes['n_inplace']
65+
inp0_t = node.get_input_variable(node.inputs[0]).type.precision
66+
inp1_t = node.get_input_variable(node.inputs[1]).type.precision
67+
params['product_type'] = get_backend('vivado').product_type(inp0_t, inp1_t)
68+
69+
total_mults = params['n_free0'] * params['n_free1'] * params['n_contract'] * params['n_inplace']
70+
params['multiplier_limit'] = ceil(total_mults / params['reuse_factor'])
71+
72+
einsum_conf = self.template.format(**params)
73+
74+
# inp/out transpose config
75+
inp0_shape = node.attributes['inp0_shape']
76+
inp1_shape = node.attributes['inp1_shape']
77+
out_interpert_shape = node.attributes['out_interpert_shape']
78+
inp0_tpose_idxs = node.attributes['inp0_tpose_idxs']
79+
inp1_tpose_idxs = node.attributes['inp1_tpose_idxs']
80+
out_tpose_idxs = node.attributes['out_tpose_idxs']
81+
tpose_inp0_config_name = f'config{node.index}_tpose_inp0'
82+
tpose_inp1_config_name = f'config{node.index}_tpose_inp1'
83+
tpose_out_conf_name = f'config{node.index}_tpose_out'
84+
85+
conf = transpose_config_gen(tpose_inp0_config_name, inp0_shape, inp0_tpose_idxs)
86+
inp0_tpose_conf = transpose_config_template.format(**conf)
87+
conf = transpose_config_gen(tpose_inp1_config_name, inp1_shape, inp1_tpose_idxs)
88+
inp1_tpose_conf = transpose_config_template.format(**conf)
89+
conf = transpose_config_gen(tpose_out_conf_name, out_interpert_shape, out_tpose_idxs)
90+
out_tpose_conf = transpose_config_template.format(**conf)
91+
92+
return '\n\n'.join((inp0_tpose_conf, inp1_tpose_conf, out_tpose_conf, einsum_conf))
93+
94+
95+
class EinsumFunctionTemplate(FunctionCallTemplate):
96+
def __init__(self):
97+
super().__init__(Einsum, include_header=einsum_include_list)
98+
self.template = einsum_function_template
99+
100+
def format(self, node: Einsum):
101+
params = {}
102+
params['config'] = f'config{node.index}'
103+
params['input0_t'] = node.get_input_variable(node.inputs[0]).type.name
104+
params['input1_t'] = node.get_input_variable(node.inputs[1]).type.name
105+
params['output_t'] = node.get_output_variable().type.name
106+
params['input0'] = node.get_input_variable(node.inputs[0]).name
107+
params['input1'] = node.get_input_variable(node.inputs[1]).name
108+
params['output'] = node.get_output_variable().name
109+
return self.template.format(**params)

0 commit comments

Comments
 (0)