Skip to content

Commit 2c17f66

Browse files
authored
Merge pull request #979 from jmitrevs/qonnx-1p0
Update QONNX parsing for 1.0
2 parents 82d059b + fc0417b commit 2c17f66

36 files changed

+2883
-476
lines changed

docs/advanced/qonnx.rst

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
==============
2+
ONNX and QONNX
3+
==============
4+
5+
Parsing of ONNX and QONNX models is made in conjunction with the `qonnx <https://github.com/fastmachinelearning/qonnx>`_ package, even if it no quantization is used. This is a common initial parser shared with the AMD/Xilinx FINN project. The first step is to do constant folding, shape inference, etc., on the ONNX graph, commonly known as `cleaning`. If a model has convolution layers, the model also needs to be converted to a channels-last format, since that is what hls4ml mainly supports. The ``qonnx`` package also provides a number of additional transforms that may need to be used. For example, ``Gemm`` nodes need to converted to ``MatMul`` and ``Add`` nodes.
6+
7+
There are command-line based versions of cleaning and channels-last conversion:
8+
9+
.. code-block:: bash
10+
11+
$ qonnx_clean filename.onnx
12+
$ qonnx_to_channels_last filename_clean.onnx
13+
$ qonnx_clean filename_clean_channels_last.onnx # good to do a clean again as a last step
14+
15+
Things can similarly be done in python. This method is usually easier if you additionally need to call other transforms. An example is given below which also calls the ``GemmToMatMul`` converter:
16+
17+
.. code-block:: python
18+
19+
model = ModelWrapper('filename.onnx')
20+
model = qonnx.util.cleanup.cleanup_model(model)
21+
model = model.transform(ConvertToChannelsLastAndClean())
22+
model = model.transform(GemmToMatMul())
23+
model = qonnx.util.cleanup.cleanup_model(model)
24+
25+
``ModelWrapper`` is defined in ``qonnx.core.modelwrapper``. More information on the ``qonnx`` package can be found at the `QONNX documentation page <https://qonnx.readthedocs.io/en/latest/index.html>`_.
26+
27+
28+
The next steps are very similar to if you are using a Keras model:
29+
30+
.. code-block:: python
31+
32+
config = hls4ml.utils.config.config_from_onnx_model(
33+
model, granularity='name', backend='Vitis', default_precision='fixed<16,6>'
34+
)
35+
# modify the config as desired
36+
hls_model = hls4ml.converters.convert_from_onnx_model(
37+
model,
38+
output_dir='my-hls-test',
39+
io_type='io_stream',
40+
backend='Vitis',
41+
hls_config=config,
42+
)
43+
hls_model.compile()
44+
45+
Note, unlike the Keras version, "name" granularity is the default for ``config_from_onnx_model``, and it must be used for QONNX models. Unquantized ONNX models can use "model" if so desired, but generally there is no benefit.
46+
47+
One can subsequently call the ``predict`` function to check the performance or build the project.
48+
49+
Note that ``execute_onnx`` in ``qonnx.core.onnx_exec`` can be use to run the QONNX graphs directly, and it also provides the values at intermediate layers for validating the model (tracing).
50+
51+
Quant nodes
52+
===========
53+
54+
Documentation for quant nodes is provided in the `qonnx package <https://github.com/fastmachinelearning/qonnx/tree/main/docs/qonnx-custom-ops>`_. Note that currently hls4ml only supports the `Quant operator <https://github.com/fastmachinelearning/qonnx/tree/main/docs/qonnx-custom-ops/quant_op.md>`_. Also, not all legal ``Quant`` configurations are parsable by hls4ml or synthesizable. The ``scale``, ``zeropt``, and ``bitwidth`` values must be constant (though not necessarily scalar for the ``scale`` and ``zeropt``).
55+
56+
Generally if the ``zeropt`` is 0 and the ``scale`` is a scalar power of 2, hls4ml uses ``ap_fixed`` or ``ac_fixed`` types (depending on the backend) to represent the quantizations. In other cases, the ``scale`` and ``zeropt`` need to be explicitly handled by hls4ml, and there is more of a chance of hls4ml not being able to process the input. (Please report any issues that you find.)

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
:hidden:
2323
:caption: Advanced Features
2424

25+
advanced/qonnx
2526
advanced/fifo_depth
2627
advanced/extension
2728
advanced/oneapi

example-models

hls4ml/backends/catapult/passes/pointwise.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from copy import copy
2-
31
from hls4ml.backends.catapult.passes.convolution_templates import (
42
Conv1DConfigTemplate,
53
Conv1DFunctionTemplate,
@@ -75,8 +73,10 @@ def match(self, node):
7573

7674
def transform(self, model, node):
7775
dim = node.__class__.__name__[-2:] # '1D' or '2D'
78-
pw_node = model.make_node('PointwiseConv' + dim, node.name, copy(node.attributes), node.inputs.copy())
79-
pw_node.weights['bias'].data = node.weights['bias'].data
76+
new_attrs = {k: v for k, v in node.attributes.items() if k not in ('trace', 'precision', 'reuse_factor')}
77+
pw_node = model.make_node(
78+
'PointwiseConv' + dim, node.name, new_attrs, node.inputs.copy(), outputs=node.outputs.copy()
79+
)
8080
# Set strategy to ensure lowercase string is passed to the template
8181
if model.config.is_resource_strategy(pw_node):
8282
pw_node.set_attr('strategy', 'resource')

hls4ml/backends/fpga/fpga_backend.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
LSTM,
1414
Activation,
1515
BatchNormalization,
16+
BatchNormOnnx,
17+
Conv,
1618
Conv1D,
1719
Conv2D,
1820
Dense,
@@ -22,8 +24,11 @@
2224
GarNetStack,
2325
GlobalPooling1D,
2426
GlobalPooling2D,
27+
MatMul,
28+
Merge,
2529
Pooling1D,
2630
Pooling2D,
31+
Quant,
2732
SeparableConv1D,
2833
SeparableConv2D,
2934
SimpleRNN,
@@ -63,14 +68,25 @@ def __init__(self, name):
6368
LSTM,
6469
GRU,
6570
Dot,
71+
Conv,
72+
MatMul,
6673
]
6774

6875
for layer in accum_layers:
6976
attrs = self.attribute_map.get(layer, [])
7077
attrs.append(TypeAttribute('accum'))
7178
self.attribute_map[layer] = attrs
7279

73-
rf_layers = accum_layers + [BatchNormalization, Activation, Embedding, GarNet, GarNetStack]
80+
rf_layers = accum_layers + [
81+
BatchNormalization,
82+
Activation,
83+
Embedding,
84+
GarNet,
85+
GarNetStack,
86+
Quant,
87+
BatchNormOnnx,
88+
Merge,
89+
]
7490

7591
for layer in rf_layers:
7692
attrs = self.attribute_map.get(layer, [])

hls4ml/backends/quartus/passes/pointwise.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from copy import copy
2-
31
from hls4ml.backends.fpga.fpga_layers import PointwiseConv1D, PointwiseConv2D
42
from hls4ml.backends.quartus.passes.convolution_templates import (
53
Conv1DConfigTemplate,
@@ -81,10 +79,10 @@ def match(self, node):
8179

8280
def transform(self, model, node):
8381
dim = node.__class__.__name__[-2:] # '1D' or '2D'
82+
new_attrs = {k: v for k, v in node.attributes.items() if k not in ('trace', 'precision', 'reuse_factor')}
8483
pw_node = model.make_node(
85-
'PointwiseConv' + dim, node.name, copy(node.attributes), node.inputs.copy(), outputs=node.outputs.copy()
84+
'PointwiseConv' + dim, node.name, new_attrs, node.inputs.copy(), outputs=node.outputs.copy()
8685
)
87-
pw_node.weights['bias'].data = node.weights['bias'].data
8886
model.replace_node(node, pw_node)
8987

9088
return True

hls4ml/backends/vivado/passes/pointwise.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from copy import copy
2-
31
from hls4ml.backends.fpga.fpga_layers import PointwiseConv1D, PointwiseConv2D
42
from hls4ml.backends.vivado.passes.convolution_templates import (
53
Conv1DConfigTemplate,
@@ -75,8 +73,11 @@ def match(self, node):
7573

7674
def transform(self, model, node):
7775
dim = node.__class__.__name__[-2:] # '1D' or '2D'
78-
pw_node = model.make_node('PointwiseConv' + dim, node.name, copy(node.attributes), node.inputs.copy())
79-
pw_node.weights['bias'].data = node.weights['bias'].data
76+
# to remove warning, since these get set again
77+
new_attrs = {k: v for k, v in node.attributes.items() if k not in ('trace', 'precision', 'reuse_factor')}
78+
pw_node = model.make_node(
79+
'PointwiseConv' + dim, node.name, new_attrs, node.inputs.copy(), outputs=node.outputs.copy()
80+
)
8081
# Set strategy to ensure lowercase string is passed to the template
8182
if model.config.is_resource_strategy(pw_node):
8283
pw_node.set_attr('strategy', 'resource')

hls4ml/converters/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
from hls4ml.converters.keras_to_hls import get_supported_keras_layers # noqa: F401
1111
from hls4ml.converters.keras_to_hls import parse_keras_model # noqa: F401
1212
from hls4ml.converters.keras_to_hls import keras_to_hls, register_keras_layer_handler
13-
14-
# from hls4ml.converters.pytorch_to_hls import parse_pytorch_model # noqa: F401
13+
from hls4ml.converters.onnx_to_hls import parse_onnx_model # noqa: F401
1514
from hls4ml.model import ModelGraph
1615
from hls4ml.utils.config import create_config
1716
from hls4ml.utils.symbolic_utils import LUTFunction

hls4ml/converters/keras/reshape.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ def parse_flatten_layer(keras_layer, input_names, input_shapes, data_reader):
1111
layer = parse_default_keras_layer(keras_layer, input_names)
1212

1313
layer['class_name'] = 'Reshape'
14-
layer['target_shape'] = [input_shapes[0][0], np.prod(input_shapes[0][1:])]
15-
output_shape = layer['target_shape']
14+
layer['target_shape'] = [np.prod(input_shapes[0][1:])] # target shape has no batch dimension
15+
output_shape = input_shapes[0][:1] + layer['target_shape']
1616

1717
return layer, output_shape
1818

hls4ml/converters/onnx/convolution.py

Lines changed: 62 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,85 +1,77 @@
1-
from hls4ml.converters.onnx_to_hls import (
2-
compute_pads_1d,
3-
compute_pads_2d,
4-
get_onnx_attribute,
5-
get_onnx_input_name,
6-
onnx_handler,
7-
)
8-
from hls4ml.converters.utils import compute_padding_1d, compute_padding_2d
1+
import numpy as np
2+
3+
from hls4ml.converters.onnx_to_hls import get_onnx_attribute, onnx_handler
94

105

116
@onnx_handler('Conv')
12-
def parse_conv_layer(reader, node, inputs_map, input_shapes, graph, config):
7+
def parse_conv_layer(node, input_names, input_shapes, graph):
138
layer = {}
149
layer['name'] = node.name
15-
layer['data_format'] = 'channels_first' # ONNX's default is channel first
16-
layer['inputs'] = get_onnx_input_name(node, graph)
17-
reader.add_input(layer['name'], node.input)
10+
if node.domain != 'qonnx.custom_op.channels_last':
11+
raise RuntimeError("Please convert the model to channels-last format with qonnx-to-channels-last")
12+
layer['data_format'] = 'channels_last' # QONNX needs to be channels-last.
13+
layer['inputs'] = input_names
14+
layer['outputs'] = node.output
1815

1916
strides = get_onnx_attribute(node, 'strides')
2017
kernel_shape = get_onnx_attribute(node, 'kernel_shape')
21-
22-
if len(input_shapes[0]) == 3: # Conv1D
23-
layer['class_name'] = 'Conv1D'
24-
25-
layer['in_width'] = input_shapes[0][2]
26-
layer['n_chan'] = input_shapes[0][1]
27-
layer['filt_width'] = kernel_shape[0]
28-
layer['n_filt'] = reader.get_weights_data(layer['name'], 'kernel').shape[2]
29-
layer['stride_width'] = strides[0]
30-
pads = compute_pads_1d(node, layer)
31-
18+
# Note: currently don't have support for auto_pad.
19+
pads = get_onnx_attribute(node, 'pads')
20+
dilations = get_onnx_attribute(node, 'dilations')
21+
if dilations is None:
22+
dilations = [1] * len(layer['kernel_shape'])
23+
24+
layer['in_width'] = input_shapes[0][-2]
25+
layer['n_chan'] = input_shapes[0][-1]
26+
layer['n_filt'] = input_shapes[1][0]
27+
28+
layer['group'] = int(get_onnx_attribute(node, 'group'))
29+
if layer['group'] != 1:
30+
layer['depth_multiplier'] = get_onnx_attribute(node, 'group') / layer['n_chan']
31+
if not layer['depth_multiplier'].is_integer():
32+
raise ValueError('Depth multiplier must be an integer')
33+
else:
34+
layer['depth_multiplier'] = int(layer['depth_multiplier'])
35+
36+
layer['n_dim'] = len(input_shapes[0]) - 2 # 2 comes from channels and batch dimentions
37+
if layer['n_dim'] not in (1, 2):
38+
raise ValueError("Only 1D and 2D convolutions are supported")
39+
layer['class_name'] = 'Conv'
40+
41+
# set some values needed later
42+
if layer['n_dim'] == 1:
43+
# this is 1D convolution
44+
full_width = layer['in_width'] + pads[0] + pads[1]
45+
eff_kernel_width = kernel_shape[0] * dilations[0]
46+
layer['out_width'] = int(np.ceil((full_width - eff_kernel_width + 1) / strides[0]))
47+
# for compatibility interpret some variables
3248
layer['pad_left'] = pads[0]
3349
layer['pad_right'] = pads[1]
34-
35-
if all(x == 0 for x in pads): # No padding, i.e., 'VALID' padding
36-
layer['padding'] = 'valid'
37-
else:
38-
layer['padding'] = 'same'
39-
40-
(layer['out_width'], _, _) = compute_padding_1d(
41-
layer['padding'], layer['in_width'], layer['stride_width'], layer['filt_width']
42-
)
43-
44-
output_shape = [input_shapes[0][0], layer['n_filt'], layer['out_width']]
45-
46-
elif len(input_shapes[0]) == 4: # Conv2D
47-
layer['class_name'] = 'Conv2D'
48-
49-
layer['in_height'] = input_shapes[0][2]
50-
layer['in_width'] = input_shapes[0][3]
51-
layer['n_chan'] = input_shapes[0][1]
52-
50+
layer['filt_width'] = kernel_shape[0]
51+
layer['stride_width'] = strides[0]
52+
layer['dilation_width'] = dilations[0]
53+
else:
54+
# 2d
55+
layer['in_height'] = input_shapes[0][-3]
56+
full_height = layer['in_height'] + pads[0] + pads[2]
57+
eff_kernel_height = kernel_shape[0] * dilations[0]
58+
out_height = int(np.ceil((full_height - eff_kernel_height + 1) / strides[0]))
59+
layer['out_height'] = out_height
60+
61+
full_width = input_shapes[0][-2] + pads[1] + pads[3]
62+
eff_kernel_width = kernel_shape[1] * dilations[1]
63+
out_width = int(np.ceil((full_width - eff_kernel_width + 1) / strides[1]))
64+
layer['out_width'] = out_width
65+
# for compatibility interpret some variables
66+
layer['pad_top'] = pads[0]
67+
layer['pad_left'] = pads[1]
68+
layer['pad_bottom'] = pads[2]
69+
layer['pad_right'] = pads[3]
5370
layer['filt_height'] = kernel_shape[0]
5471
layer['filt_width'] = kernel_shape[1]
55-
56-
layer['n_filt'] = next(
57-
(x.type.tensor_type.shape.dim[1].dim_value for x in graph.value_info if x.name == node.output[0]), None
58-
)
5972
layer['stride_height'] = strides[0]
6073
layer['stride_width'] = strides[1]
61-
pads = compute_pads_2d(node, layer)
62-
63-
layer['pad_top'] = pads[0]
64-
layer['pad_bottom'] = pads[2]
65-
layer['pad_left'] = pads[1]
66-
layer['pad_right'] = pads[3]
67-
68-
if all(x == 0 for x in pads): # No padding, i.e., 'VALID' padding in Keras/Tensorflow
69-
layer['padding'] = 'valid'
70-
else: # Only 'valid' and 'same' padding are available in Keras
71-
layer['padding'] = 'same'
72-
73-
(layer['out_height'], layer['out_width'], _, _, _, _) = compute_padding_2d(
74-
layer['padding'],
75-
layer['in_height'],
76-
layer['in_width'],
77-
layer['stride_height'],
78-
layer['stride_width'],
79-
layer['filt_height'],
80-
layer['filt_width'],
81-
)
82-
83-
output_shape = [input_shapes[0][0], layer['n_filt'], layer['out_height'], layer['out_width']]
74+
layer['dilation_height'] = dilations[0]
75+
layer['dilation_width'] = dilations[1]
8476

85-
return layer, output_shape
77+
return layer

0 commit comments

Comments
 (0)