Skip to content

Commit d3348e3

Browse files
authored
Merge pull request #194 from fastmachinelearning/feature/quant2qcdq_improvements
Minor QuantToQCDQ improvements
2 parents cae4dc7 + 21328ae commit d3348e3

File tree

2 files changed

+44
-93
lines changed

2 files changed

+44
-93
lines changed

src/qonnx/transformation/qonnx_to_qcdq.py

Lines changed: 44 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -28,22 +28,39 @@
2828

2929
import numpy as np
3030
from onnx import TensorProto, helper
31-
from warnings import warn
32-
33-
from onnxscript.rewriter import pattern, rewrite
3431
from onnxscript import ir
32+
from onnxscript.rewriter import pattern, rewrite
33+
from warnings import warn
3534

3635
from qonnx.core.modelwrapper import ModelWrapper
3736
from qonnx.transformation.base import Transformation
3837
from qonnx.transformation.general import MovePadAttributeToTensor, RemoveUnusedTensors
3938

39+
# TODO: operate on IntQuant when Quant -> IntQuant is added to cleanup steps
40+
# TODO: add transformation for Trunc to standard ops
41+
# TODO: add transformation for BipolarQuant to standard ops
42+
# TODO: add transformation for FloatQuant to standard ops
43+
4044

4145
# Target patterns
4246
def quant_pattern_brevitas(qonnx_op, x, scale, zero_point, bitwidth, signed, narrow, rounding_mode):
43-
return qonnx_op.Quant(x, scale, zero_point, bitwidth, signed=signed, narrow=narrow, _allow_other_attributes=True, _domain="onnx.brevitas")
47+
return qonnx_op.Quant(
48+
x, scale, zero_point, bitwidth, signed=signed, narrow=narrow, _allow_other_attributes=True, _domain="onnx.brevitas"
49+
)
50+
4451

4552
def quant_pattern_qonnx(qonnx_op, x, scale, zero_point, bitwidth, signed, narrow, rounding_mode):
46-
return qonnx_op.Quant(x, scale, zero_point, bitwidth, signed=signed, narrow=narrow, _allow_other_attributes=True, _domain="qonnx.custom_op.general")
53+
return qonnx_op.Quant(
54+
x,
55+
scale,
56+
zero_point,
57+
bitwidth,
58+
signed=signed,
59+
narrow=narrow,
60+
_allow_other_attributes=True,
61+
_domain="qonnx.custom_op.general",
62+
)
63+
4764

4865
# Replacement pattern
4966
def qcdq_pattern(op, x, scale, zero_point, bitwidth, signed, narrow, rounding_mode):
@@ -54,10 +71,15 @@ def qcdq_pattern(op, x, scale, zero_point, bitwidth, signed, narrow, rounding_mo
5471
# Create the QuantizeLinear node
5572
scale_np = scale.const_value.numpy()
5673
scale_np_new = scale_np.squeeze()
74+
# TODO add support for non-zero zero_point, taking different definitions into account:
75+
# DequantizeLinear(QuantizeLinear(x)) uses scale * ((saturate((x / scale) + zero_point) - zero_point))
76+
# Quant(x) uses scale * (round(clip(x / scale + zero_point)) - zero_point)
5777
if scale_np_new.ndim == 1:
5878
qnt_axis = scale_np.shape.index(scale_np_new.shape[0])
5979
c_scale = helper.make_tensor("scale", scale.dtype, scale_np_new.shape, scale_np_new)
60-
c_zero_point = helper.make_tensor("zero_point", new_dtype, scale_np_new.shape, np.zeros(scale_np_new.shape, dtype=np_dtype))
80+
c_zero_point = helper.make_tensor(
81+
"zero_point", new_dtype, scale_np_new.shape, np.zeros(scale_np_new.shape, dtype=np_dtype)
82+
)
6183
else:
6284
qnt_axis = None
6385
c_scale = helper.make_tensor("scale", scale.dtype, (), [scale_np_new.item()])
@@ -71,17 +93,17 @@ def qcdq_pattern(op, x, scale, zero_point, bitwidth, signed, narrow, rounding_mo
7193
if (signed.value and narrow.value) or (bw_val < 8):
7294
# Compute clipping values
7395
if signed.value:
74-
max_val = 2 ** (bw_val-1) - 1
96+
max_val = 2 ** (bw_val - 1) - 1
7597
if narrow.value:
76-
min_val = -2 ** (bw_val-1) + 1
98+
min_val = -(2 ** (bw_val - 1)) + 1
7799
else:
78-
min_val = -2 ** (bw_val-1)
100+
min_val = -(2 ** (bw_val - 1))
79101
else:
80102
min_val = 0
81103
if narrow.value:
82-
max_val = 2 ** bw_val - 2
104+
max_val = 2**bw_val - 2
83105
else:
84-
max_val = 2 ** bw_val - 1
106+
max_val = 2**bw_val - 1
85107

86108
if isinstance(min_val, np.ndarray):
87109
min_val = min_val.astype(np_dtype)
@@ -105,6 +127,7 @@ def qcdq_pattern(op, x, scale, zero_point, bitwidth, signed, narrow, rounding_mo
105127
# Create the DequantizeLinear node
106128
return op.DequantizeLinear(qnt, scale, zero_point, axis=qnt_axis)
107129

130+
108131
def is_valid_qcdq_transformation(context, x, scale, zero_point, bitwidth, signed, narrow, rounding_mode, **_) -> bool:
109132
"""Condition to check if the Quant node can be replaced.
110133
The following conditions must be satisfied:
@@ -135,12 +158,13 @@ def is_valid_qcdq_transformation(context, x, scale, zero_point, bitwidth, signed
135158
return False
136159

137160
# Check rounding mode
138-
if rounding_mode is None: # No rounding_mode specified, assume default to be `ROUND`
161+
if rounding_mode is None: # No rounding_mode specified, assume default to be `ROUND`
139162
return True
140163
if rounding_mode.value != "ROUND":
141164
return False
142165
return True
143166

167+
144168
class QuantToQCDQ(Transformation):
145169
"""Replace QONNX Quant-style quantization nodes with QuantizeLinear
146170
-> Clip -> DequantizeLinear (QCDQ)-style quantization nodes. The following
@@ -153,22 +177,12 @@ class QuantToQCDQ(Transformation):
153177
- the rounding_mode attribute must be ROUND
154178
BipolarQuant is not (yet) supported.
155179
"""
180+
156181
def __init__(self):
157182
super().__init__()
158-
rewrite_rule_qcdq_brevitas = pattern.RewriteRule(
159-
quant_pattern_brevitas,
160-
qcdq_pattern,
161-
is_valid_qcdq_transformation
162-
)
163-
rewrite_rule_qcdq_qonnx = pattern.RewriteRule(
164-
quant_pattern_qonnx,
165-
qcdq_pattern,
166-
is_valid_qcdq_transformation
167-
)
168-
self._rewrite_rule_set = pattern.RewriteRuleSet([
169-
rewrite_rule_qcdq_brevitas,
170-
rewrite_rule_qcdq_qonnx
171-
], commute=True)
183+
rewrite_rule_qcdq_brevitas = pattern.RewriteRule(quant_pattern_brevitas, qcdq_pattern, is_valid_qcdq_transformation)
184+
rewrite_rule_qcdq_qonnx = pattern.RewriteRule(quant_pattern_qonnx, qcdq_pattern, is_valid_qcdq_transformation)
185+
self._rewrite_rule_set = pattern.RewriteRuleSet([rewrite_rule_qcdq_brevitas, rewrite_rule_qcdq_qonnx], commute=True)
172186

173187
self._preserve_qnt_optypes = ["Quant", "BipolarQuant", "QuantizeLinear", "DequantizeLinear"]
174188

@@ -197,7 +211,10 @@ def apply(self, model: ModelWrapper):
197211

198212
qdq_min_opset = 10
199213
if model.model.opset_import[0].version < qdq_min_opset:
200-
warn(f"QCDQ QuantizeLinear requires ONNX opset >= {qdq_min_opset} but found {model.model.opset_import[0].version}")
214+
warn(
215+
f"QCDQ QuantizeLinear requires ONNX opset >= {qdq_min_opset} but found"
216+
" {model.model.opset_import[0].version}"
217+
)
201218
warn(f"Forcing opset version {qdq_min_opset} upgrade to ensure valid ONNX")
202219
model.model.opset_import[0].version = qdq_min_opset
203220
# Ensure new Pad node requirements are respected

tests/transformation/test_qonnx_to_qcdq.py

Lines changed: 0 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,6 @@
3939
from qonnx.util.test import download_model, get_golden_in_and_output, test_model_details
4040

4141
qonnxtoqcdq_details = {
42-
"Conv_bias_example": {
43-
# The model does not have any Quant node
44-
"nonconvertible_quant": 0,
45-
"exp_qdq_nodes": 0,
46-
"exp_clip_nodes": 0,
47-
},
4842
"FINN-CNV_W1A1": {
4943
"nonconvertible_quant": 0,
5044
"exp_qdq_nodes": 1,
@@ -63,12 +57,6 @@
6357
# input quantizer doesn't need Clip so 1 less
6458
"exp_clip_nodes": 17,
6559
},
66-
"FINN-TFC_W1A1": {
67-
# The model does not have any Quant node
68-
"nonconvertible_quant": 0,
69-
"exp_qdq_nodes": 0,
70-
"exp_clip_nodes": 0,
71-
},
7260
"FINN-TFC_W1A2": {
7361
# all Quant nodes convertible to QCDQ
7462
"nonconvertible_quant": 0,
@@ -100,60 +88,6 @@
10088
"exp_qdq_nodes": 49,
10189
"exp_clip_nodes": 49,
10290
},
103-
"rn18_w4a4_a2q_13b": {
104-
# 25 bit bias quant not convertible to QCDQ
105-
"nonconvertible_quant": 1,
106-
"exp_qdq_nodes": 49,
107-
"exp_clip_nodes": 49,
108-
},
109-
"rn18_w4a4_a2q_14b": {
110-
# 25 bit bias quant not convertible to QCDQ
111-
"nonconvertible_quant": 1,
112-
"exp_qdq_nodes": 49,
113-
"exp_clip_nodes": 49,
114-
},
115-
"rn18_w4a4_a2q_15b": {
116-
# 25 bit bias quant not convertible to QCDQ
117-
"nonconvertible_quant": 1,
118-
"exp_qdq_nodes": 49,
119-
"exp_clip_nodes": 49,
120-
},
121-
"rn18_w4a4_a2q_16b": {
122-
# 25 bit bias quant not convertible to QCDQ
123-
"nonconvertible_quant": 1,
124-
"exp_qdq_nodes": 49,
125-
"exp_clip_nodes": 49,
126-
},
127-
"rn18_w4a4_a2q_plus_12b": {
128-
# 25 bit bias quant not convertible to QCDQ
129-
"nonconvertible_quant": 1,
130-
"exp_qdq_nodes": 49,
131-
"exp_clip_nodes": 49,
132-
},
133-
"rn18_w4a4_a2q_plus_13b": {
134-
# 25 bit bias quant not convertible to QCDQ
135-
"nonconvertible_quant": 1,
136-
"exp_qdq_nodes": 49,
137-
"exp_clip_nodes": 49,
138-
},
139-
"rn18_w4a4_a2q_plus_14b": {
140-
# 25 bit bias quant not convertible to QCDQ
141-
"nonconvertible_quant": 1,
142-
"exp_qdq_nodes": 49,
143-
"exp_clip_nodes": 49,
144-
},
145-
"rn18_w4a4_a2q_plus_15b": {
146-
# 25 bit bias quant not convertible to QCDQ
147-
"nonconvertible_quant": 1,
148-
"exp_qdq_nodes": 49,
149-
"exp_clip_nodes": 49,
150-
},
151-
"rn18_w4a4_a2q_plus_16b": {
152-
# 25 bit bias quant not convertible to QCDQ
153-
"nonconvertible_quant": 1,
154-
"exp_qdq_nodes": 49,
155-
"exp_clip_nodes": 49,
156-
},
15791
}
15892

15993
# inherit basics for matching testcases from test util

0 commit comments

Comments
 (0)