28
28
29
29
import numpy as np
30
30
from onnx import TensorProto , helper
31
- from warnings import warn
32
-
33
- from onnxscript .rewriter import pattern , rewrite
34
31
from onnxscript import ir
32
+ from onnxscript .rewriter import pattern , rewrite
33
+ from warnings import warn
35
34
36
35
from qonnx .core .modelwrapper import ModelWrapper
37
36
from qonnx .transformation .base import Transformation
38
37
from qonnx .transformation .general import MovePadAttributeToTensor , RemoveUnusedTensors
39
38
39
+ # TODO: operate on IntQuant when Quant -> IntQuant is added to cleanup steps
40
+ # TODO: add transformation for Trunc to standard ops
41
+ # TODO: add transformation for BipolarQuant to standard ops
42
+ # TODO: add transformation for FloatQuant to standard ops
43
+
40
44
41
45
# Target patterns
42
46
def quant_pattern_brevitas (qonnx_op , x , scale , zero_point , bitwidth , signed , narrow , rounding_mode ):
43
- return qonnx_op .Quant (x , scale , zero_point , bitwidth , signed = signed , narrow = narrow , _allow_other_attributes = True , _domain = "onnx.brevitas" )
47
+ return qonnx_op .Quant (
48
+ x , scale , zero_point , bitwidth , signed = signed , narrow = narrow , _allow_other_attributes = True , _domain = "onnx.brevitas"
49
+ )
50
+
44
51
45
52
def quant_pattern_qonnx (qonnx_op , x , scale , zero_point , bitwidth , signed , narrow , rounding_mode ):
46
- return qonnx_op .Quant (x , scale , zero_point , bitwidth , signed = signed , narrow = narrow , _allow_other_attributes = True , _domain = "qonnx.custom_op.general" )
53
+ return qonnx_op .Quant (
54
+ x ,
55
+ scale ,
56
+ zero_point ,
57
+ bitwidth ,
58
+ signed = signed ,
59
+ narrow = narrow ,
60
+ _allow_other_attributes = True ,
61
+ _domain = "qonnx.custom_op.general" ,
62
+ )
63
+
47
64
48
65
# Replacement pattern
49
66
def qcdq_pattern (op , x , scale , zero_point , bitwidth , signed , narrow , rounding_mode ):
@@ -54,10 +71,15 @@ def qcdq_pattern(op, x, scale, zero_point, bitwidth, signed, narrow, rounding_mo
54
71
# Create the QuantizeLinear node
55
72
scale_np = scale .const_value .numpy ()
56
73
scale_np_new = scale_np .squeeze ()
74
+ # TODO add support for non-zero zero_point, taking different definitions into account:
75
+ # DequantizeLinear(QuantizeLinear(x)) uses scale * ((saturate((x / scale) + zero_point) - zero_point))
76
+ # Quant(x) uses scale * (round(clip(x / scale + zero_point)) - zero_point)
57
77
if scale_np_new .ndim == 1 :
58
78
qnt_axis = scale_np .shape .index (scale_np_new .shape [0 ])
59
79
c_scale = helper .make_tensor ("scale" , scale .dtype , scale_np_new .shape , scale_np_new )
60
- c_zero_point = helper .make_tensor ("zero_point" , new_dtype , scale_np_new .shape , np .zeros (scale_np_new .shape , dtype = np_dtype ))
80
+ c_zero_point = helper .make_tensor (
81
+ "zero_point" , new_dtype , scale_np_new .shape , np .zeros (scale_np_new .shape , dtype = np_dtype )
82
+ )
61
83
else :
62
84
qnt_axis = None
63
85
c_scale = helper .make_tensor ("scale" , scale .dtype , (), [scale_np_new .item ()])
@@ -71,17 +93,17 @@ def qcdq_pattern(op, x, scale, zero_point, bitwidth, signed, narrow, rounding_mo
71
93
if (signed .value and narrow .value ) or (bw_val < 8 ):
72
94
# Compute clipping values
73
95
if signed .value :
74
- max_val = 2 ** (bw_val - 1 ) - 1
96
+ max_val = 2 ** (bw_val - 1 ) - 1
75
97
if narrow .value :
76
- min_val = - 2 ** (bw_val - 1 ) + 1
98
+ min_val = - ( 2 ** (bw_val - 1 ) ) + 1
77
99
else :
78
- min_val = - 2 ** (bw_val - 1 )
100
+ min_val = - ( 2 ** (bw_val - 1 ) )
79
101
else :
80
102
min_val = 0
81
103
if narrow .value :
82
- max_val = 2 ** bw_val - 2
104
+ max_val = 2 ** bw_val - 2
83
105
else :
84
- max_val = 2 ** bw_val - 1
106
+ max_val = 2 ** bw_val - 1
85
107
86
108
if isinstance (min_val , np .ndarray ):
87
109
min_val = min_val .astype (np_dtype )
@@ -105,6 +127,7 @@ def qcdq_pattern(op, x, scale, zero_point, bitwidth, signed, narrow, rounding_mo
105
127
# Create the DequantizeLinear node
106
128
return op .DequantizeLinear (qnt , scale , zero_point , axis = qnt_axis )
107
129
130
+
108
131
def is_valid_qcdq_transformation (context , x , scale , zero_point , bitwidth , signed , narrow , rounding_mode , ** _ ) -> bool :
109
132
"""Condition to check if the Quant node can be replaced.
110
133
The following conditions must be satisfied:
@@ -135,12 +158,13 @@ def is_valid_qcdq_transformation(context, x, scale, zero_point, bitwidth, signed
135
158
return False
136
159
137
160
# Check rounding mode
138
- if rounding_mode is None : # No rounding_mode specified, assume default to be `ROUND`
161
+ if rounding_mode is None : # No rounding_mode specified, assume default to be `ROUND`
139
162
return True
140
163
if rounding_mode .value != "ROUND" :
141
164
return False
142
165
return True
143
166
167
+
144
168
class QuantToQCDQ (Transformation ):
145
169
"""Replace QONNX Quant-style quantization nodes with QuantizeLinear
146
170
-> Clip -> DequantizeLinear (QCDQ)-style quantization nodes. The following
@@ -153,22 +177,12 @@ class QuantToQCDQ(Transformation):
153
177
- the rounding_mode attribute must be ROUND
154
178
BipolarQuant is not (yet) supported.
155
179
"""
180
+
156
181
def __init__ (self ):
157
182
super ().__init__ ()
158
- rewrite_rule_qcdq_brevitas = pattern .RewriteRule (
159
- quant_pattern_brevitas ,
160
- qcdq_pattern ,
161
- is_valid_qcdq_transformation
162
- )
163
- rewrite_rule_qcdq_qonnx = pattern .RewriteRule (
164
- quant_pattern_qonnx ,
165
- qcdq_pattern ,
166
- is_valid_qcdq_transformation
167
- )
168
- self ._rewrite_rule_set = pattern .RewriteRuleSet ([
169
- rewrite_rule_qcdq_brevitas ,
170
- rewrite_rule_qcdq_qonnx
171
- ], commute = True )
183
+ rewrite_rule_qcdq_brevitas = pattern .RewriteRule (quant_pattern_brevitas , qcdq_pattern , is_valid_qcdq_transformation )
184
+ rewrite_rule_qcdq_qonnx = pattern .RewriteRule (quant_pattern_qonnx , qcdq_pattern , is_valid_qcdq_transformation )
185
+ self ._rewrite_rule_set = pattern .RewriteRuleSet ([rewrite_rule_qcdq_brevitas , rewrite_rule_qcdq_qonnx ], commute = True )
172
186
173
187
self ._preserve_qnt_optypes = ["Quant" , "BipolarQuant" , "QuantizeLinear" , "DequantizeLinear" ]
174
188
@@ -197,7 +211,10 @@ def apply(self, model: ModelWrapper):
197
211
198
212
qdq_min_opset = 10
199
213
if model .model .opset_import [0 ].version < qdq_min_opset :
200
- warn (f"QCDQ QuantizeLinear requires ONNX opset >= { qdq_min_opset } but found { model .model .opset_import [0 ].version } " )
214
+ warn (
215
+ f"QCDQ QuantizeLinear requires ONNX opset >= { qdq_min_opset } but found"
216
+ " {model.model.opset_import[0].version}"
217
+ )
201
218
warn (f"Forcing opset version { qdq_min_opset } upgrade to ensure valid ONNX" )
202
219
model .model .opset_import [0 ].version = qdq_min_opset
203
220
# Ensure new Pad node requirements are respected
0 commit comments