Skip to content

Commit 95279e4

Browse files
committed
[Transform] introduce ExtractQuantScaleZeroPt and simple test
1 parent cadd6b2 commit 95279e4

File tree

2 files changed

+235
-0
lines changed

2 files changed

+235
-0
lines changed
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# Copyright (c) 2023 Advanced Micro Devices, Inc.
2+
# All rights reserved.
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions are met:
6+
#
7+
# * Redistributions of source code must retain the above copyright notice, this
8+
# list of conditions and the following disclaimer.
9+
#
10+
# * Redistributions in binary form must reproduce the above copyright notice,
11+
# this list of conditions and the following disclaimer in the documentation
12+
# and/or other materials provided with the distribution.
13+
#
14+
# * Neither the name of qonnx nor the names of its
15+
# contributors may be used to endorse or promote products derived from
16+
# this software without specific prior written permission.
17+
#
18+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28+
29+
import numpy as np
30+
from onnx import TensorProto, helper
31+
32+
from qonnx.core.modelwrapper import ModelWrapper
33+
from qonnx.transformation.base import Transformation
34+
from qonnx.transformation.general import GiveUniqueParameterTensors, SortGraph
35+
from qonnx.transformation.remove import RemoveIdentityOps
36+
37+
38+
class ExtractQuantScaleZeroPt(Transformation):
39+
"""Extract any non-identity scale and zero-point Quant inputs as
40+
separate Div/Mul (for scale) and Add/Sub (for zeropoint" nodes,
41+
preceding and following the Quant node."""
42+
43+
def apply(self, model: ModelWrapper):
44+
graph = model.graph
45+
for node in graph.node:
46+
if node.op_type == "Quant":
47+
quant_node = node
48+
input_nm, scale_nm, zeropt_nm, _ = node.input
49+
scale_t = model.get_initializer(scale_nm)
50+
zeropt_t = model.get_initializer(zeropt_nm)
51+
ishp = model.get_tensor_shape(input_nm)
52+
extract_scale = False
53+
extract_zeropt = False
54+
if scale_t is not None and (scale_t != 1).any():
55+
extract_scale = True
56+
if zeropt_t is not None and (zeropt_t != 0).any():
57+
extract_zeropt = True
58+
if (not extract_scale) and (not extract_zeropt):
59+
continue
60+
running_input = input_nm
61+
if extract_scale:
62+
# create new Div node that divides the input
63+
# by the scale
64+
inp_scaled_nm = model.make_new_valueinfo_name()
65+
inp_scaled = helper.make_tensor_value_info(
66+
inp_scaled_nm,
67+
TensorProto.FLOAT,
68+
ishp,
69+
)
70+
graph.value_info.append(inp_scaled)
71+
inp_scale_node = helper.make_node("Div", [running_input, scale_nm], [inp_scaled_nm])
72+
graph.node.append(inp_scale_node)
73+
# create new Mul node
74+
# remove scale from Quant node
75+
new_scale_nm = model.make_new_valueinfo_name()
76+
model.set_initializer(new_scale_nm, np.asarray(1.0, dtype=np.float32))
77+
quant_node.input[1] = new_scale_nm
78+
running_input = inp_scaled_nm
79+
if extract_zeropt:
80+
# create new Add node that adds the zeropoint to
81+
# the scaled input
82+
inp_zeropt_nm = model.make_new_valueinfo_name()
83+
inp_zeropt = helper.make_tensor_value_info(
84+
inp_zeropt_nm,
85+
TensorProto.FLOAT,
86+
ishp,
87+
)
88+
graph.value_info.append(inp_zeropt)
89+
inp_zeropt_node = helper.make_node("Add", [running_input, zeropt_nm], [inp_zeropt_nm])
90+
graph.node.append(inp_zeropt_node)
91+
# remove zeropt from Quant node
92+
new_zeropt_nm = model.make_new_valueinfo_name()
93+
model.set_initializer(new_zeropt_nm, np.asarray(0.0, dtype=np.float32))
94+
quant_node.input[2] = new_zeropt_nm
95+
running_input = inp_zeropt_nm
96+
# rewire node input to any newly created Div/Add nodes
97+
quant_node.input[0] = running_input
98+
last_node = quant_node
99+
final_output = quant_node.output[0]
100+
if extract_zeropt:
101+
# create new Sub node that subtracts the zeropoint from
102+
# the output
103+
out_zeropt_nm = model.make_new_valueinfo_name()
104+
out_zeropt = helper.make_tensor_value_info(
105+
out_zeropt_nm,
106+
TensorProto.FLOAT,
107+
ishp,
108+
)
109+
graph.value_info.append(out_zeropt)
110+
out_zeropt_node = helper.make_node("Sub", [out_zeropt_nm, zeropt_nm], [final_output])
111+
last_node.output[0] = out_zeropt_nm
112+
graph.node.append(out_zeropt_node)
113+
# important: when tracking a pointer to newly added nodes,
114+
# ensure the item from the container is used, and not the
115+
# make_node result -- those are different objects
116+
# e.g. if we use last_node = out_zeropt_node below,
117+
# this will point to the wrong object and cause bugs later
118+
last_node = graph.node[-1]
119+
if extract_scale:
120+
# create new Mul node that applies the output scale
121+
out_scale_nm = model.make_new_valueinfo_name()
122+
out_scale = helper.make_tensor_value_info(
123+
out_scale_nm,
124+
TensorProto.FLOAT,
125+
ishp,
126+
)
127+
last_node.output[0] = out_scale_nm
128+
graph.value_info.append(out_scale)
129+
out_scale_node = helper.make_node("Mul", [out_scale_nm, scale_nm], [final_output])
130+
graph.node.append(out_scale_node)
131+
132+
if extract_scale or extract_zeropt:
133+
# since we used append() for new nodes, need to call
134+
# SortGraph to ensure correct (topological) order
135+
model = model.transform(SortGraph())
136+
# Remove potential unity multiplications from alpha and beta attributes
137+
model = model.transform(RemoveIdentityOps())
138+
# Ensure unique parameter tensors
139+
model = model.transform(GiveUniqueParameterTensors())
140+
return model, True
141+
142+
return model, False
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright (c) 2023 Advanced Micro Devices, Inc.
2+
# All rights reserved.
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions are met:
6+
#
7+
# * Redistributions of source code must retain the above copyright notice, this
8+
# list of conditions and the following disclaimer.
9+
#
10+
# * Redistributions in binary form must reproduce the above copyright notice,
11+
# this list of conditions and the following disclaimer in the documentation
12+
# and/or other materials provided with the distribution.
13+
#
14+
# * Neither the name of qonnx nor the names of its
15+
# contributors may be used to endorse or promote products derived from
16+
# this software without specific prior written permission.
17+
#
18+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28+
29+
import numpy as np
30+
import onnx.parser as oprs
31+
32+
from qonnx.core.modelwrapper import ModelWrapper
33+
from qonnx.core.onnx_exec import execute_onnx
34+
from qonnx.transformation.extract_quant_scale_zeropt import ExtractQuantScaleZeroPt
35+
36+
37+
def make_test_model():
38+
ishp = (1, 10)
39+
ishp_str = str(list(ishp))
40+
channelwise = True
41+
bitwidth = np.asarray(4.0, dtype=np.float32)
42+
if channelwise:
43+
q_attr_shp = ishp
44+
else:
45+
q_attr_shp = 1
46+
attrshp_str = str(list(q_attr_shp))
47+
np.random.seed(0)
48+
scale = np.random.rand(*q_attr_shp).astype(np.float32)
49+
zeropt = np.random.rand(*q_attr_shp).astype(np.float32)
50+
signed = 1
51+
narrow = 1
52+
rounding_mode = "ROUND"
53+
54+
input = f"""
55+
<
56+
ir_version: 7,
57+
opset_import: ["" : 9]
58+
>
59+
agraph (float{ishp_str} in0) => (float{ishp_str} out0)
60+
<
61+
float{attrshp_str} scale_param,
62+
float{attrshp_str} zeropt_param,
63+
float bitwidth_param
64+
>
65+
{{
66+
out0 = qonnx.custom_op.general.Quant<
67+
signed={str(signed)},
68+
narrow={str(narrow)},
69+
rounding_mode="{rounding_mode}"
70+
>(in0, scale_param, zeropt_param, bitwidth_param)
71+
}}
72+
"""
73+
model = oprs.parse_model(input)
74+
model = ModelWrapper(model)
75+
model.set_initializer("scale_param", scale)
76+
model.set_initializer("zeropt_param", zeropt)
77+
model.set_initializer("bitwidth_param", bitwidth)
78+
return model
79+
80+
81+
def test_extract_quant_scale_zeropt():
82+
model = make_test_model()
83+
ishp = model.get_tensor_shape("in0")
84+
inp = np.random.rand(*ishp).astype(np.float32)
85+
y_golden = execute_onnx(model, {"in0": inp})["out0"]
86+
model_new = model.transform(ExtractQuantScaleZeroPt())
87+
y_ret = execute_onnx(model_new, {"in0": inp})["out0"]
88+
assert np.allclose(y_golden, y_ret)
89+
qnt_node = model_new.get_nodes_by_op_type("Quant")[0]
90+
new_scale = model_new.get_initializer(qnt_node.input[1])
91+
assert new_scale == 1
92+
new_zeropt = model_new.get_initializer(qnt_node.input[2])
93+
assert new_zeropt == 0

0 commit comments

Comments
 (0)