Skip to content

Commit fd61cfe

Browse files
authored
Merge pull request #94 from fastmachinelearning/transformation_introducing_quantnode
Transformation pass to introduce quant nodes
2 parents fe4aa37 + 5e2d0b8 commit fd61cfe

File tree

3 files changed

+395
-1
lines changed

3 files changed

+395
-1
lines changed

src/qonnx/core/modelwrapper.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,7 @@ def get_non_finn_nodes(self):
535535
return list(filter(lambda x: not util.is_finn_op(x.domain), self.graph.node))
536536

537537
def get_node_index(self, node):
538-
"""Returns current index of given node."""
538+
"""Returns current index of given node, or None if not found."""
539539
n_ind = 0
540540
try:
541541
for n in self.graph.node:
@@ -544,6 +544,17 @@ def get_node_index(self, node):
544544
n_ind += 1
545545
except ValueError:
546546
return None
547+
return None
548+
549+
def get_node_from_name(self, node_name):
550+
"""Returns the node with the specified name, or None if not found."""
551+
try:
552+
for node in self.graph.node:
553+
if node.name == node_name:
554+
return node
555+
except ValueError:
556+
return None
557+
return None
547558

548559
def get_tensor_layout(self, tensor_name):
549560
"""Returns the data layout annotation of tensor with given name.
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
# Copyright (c) 2024 Advanced Micro Devices, Inc.
2+
# All rights reserved.
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions are met:
6+
#
7+
# * Redistributions of source code must retain the above copyright notice, this
8+
# list of conditions and the following disclaimer.
9+
#
10+
# * Redistributions in binary form must reproduce the above copyright notice,
11+
# this list of conditions and the following disclaimer in the documentation
12+
# and/or other materials provided with the distribution.
13+
#
14+
# * Neither the name of qonnx nor the names of its
15+
# contributors may be used to endorse or promote products derived from
16+
# this software without specific prior written permission.
17+
#
18+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28+
29+
30+
import numpy as np
31+
import onnx
32+
from onnx import TensorProto
33+
34+
from qonnx.transformation.base import Transformation
35+
from qonnx.transformation.general import SortGraph
36+
from qonnx.transformation.infer_shapes import InferShapes
37+
from qonnx.util.cleanup import cleanup_model
38+
39+
40+
def create_quantnode(
41+
model,
42+
quantnode_input,
43+
quantnode_output_shape,
44+
scale_value,
45+
zeropoint_value,
46+
bitwidth_value,
47+
narrow,
48+
signed,
49+
rounding_mode,
50+
):
51+
quant_tensor = onnx.helper.make_tensor_value_info(
52+
model.make_new_valueinfo_name(), TensorProto.FLOAT, quantnode_output_shape
53+
)
54+
model.graph.value_info.append(quant_tensor)
55+
56+
scale_tensor = np.array(scale_value).astype(np.float32)
57+
s_value = onnx.helper.make_tensor_value_info(model.make_new_valueinfo_name(), TensorProto.FLOAT, quantnode_output_shape)
58+
model.graph.value_info.append(s_value)
59+
model.set_initializer(s_value.name, scale_tensor)
60+
61+
zeropt_tensor = np.array(zeropoint_value).astype(np.float32)
62+
z_value = onnx.helper.make_tensor_value_info(model.make_new_valueinfo_name(), TensorProto.FLOAT, quantnode_output_shape)
63+
model.graph.value_info.append(z_value)
64+
model.set_initializer(z_value.name, zeropt_tensor)
65+
66+
bitwidth_tensor = np.array(bitwidth_value).astype(np.float32)
67+
b_value = onnx.helper.make_tensor_value_info(model.make_new_valueinfo_name(), TensorProto.FLOAT, [1])
68+
model.graph.value_info.append(b_value)
69+
model.set_initializer(b_value.name, bitwidth_tensor)
70+
71+
quantnode = onnx.helper.make_node(
72+
"Quant",
73+
inputs=[quantnode_input, s_value.name, z_value.name, b_value.name],
74+
outputs=[quant_tensor.name],
75+
name="Quant_" + quantnode_input,
76+
narrow=narrow,
77+
signed=signed,
78+
rounding_mode=rounding_mode,
79+
)
80+
81+
return quantnode, quant_tensor
82+
83+
84+
def adjust_graph(model, input_positions, node_name, quantized_nodes):
85+
for pos in input_positions:
86+
node_details = (node_name, pos[0])
87+
if node_details not in quantized_nodes: # not quantizing for same node_inp/out index.
88+
node_in_focus = model.get_node_from_name(node_name)
89+
90+
if pos[0][0] == "input":
91+
quantnode_input = node_in_focus.input[pos[0][1]]
92+
consumer_node = node_in_focus
93+
producer_node = model.find_producer(quantnode_input)
94+
if producer_node is None or producer_node.op_type != "Quant":
95+
quantization_to_perform = True
96+
else:
97+
quantization_to_perform = False
98+
else:
99+
quantnode_input = node_in_focus.output[pos[0][1]]
100+
consumer_node = model.find_consumer(quantnode_input)
101+
producer_node = model.find_producer(quantnode_input)
102+
if consumer_node is None or consumer_node.op_type != "Quant":
103+
quantization_to_perform = True
104+
else:
105+
quantization_to_perform = False
106+
if quantization_to_perform is True:
107+
quantnode_output_shape = model.get_tensor_shape(quantnode_input) # Step: 3
108+
quantnode, quant_tensor = create_quantnode(
109+
model,
110+
quantnode_input,
111+
quantnode_output_shape,
112+
scale_value=pos[1][0],
113+
zeropoint_value=pos[1][1],
114+
bitwidth_value=pos[1][2],
115+
narrow=pos[1][3],
116+
signed=pos[1][4],
117+
rounding_mode=pos[1][5],
118+
)
119+
120+
if consumer_node is not None:
121+
node_pos = model.get_node_index(consumer_node)
122+
model.graph.node[node_pos].input[pos[0][1]] = quant_tensor.name
123+
model.graph.node.append(quantnode)
124+
else:
125+
model.graph.value_info.remove(quant_tensor)
126+
model.graph.node.append(quantnode)
127+
model.graph.output.insert(0, quant_tensor)
128+
model.graph.output.pop(1)
129+
130+
model = model.transform(SortGraph())
131+
quantized_nodes.append(node_details)
132+
else:
133+
print(f"{pos[0][0]} index {pos[0][1]} of {node_name} is already quantized.")
134+
else:
135+
print(f"{pos[0][0]} index {pos[0][1]} of {node_name} is already quantized.")
136+
continue
137+
138+
return model
139+
140+
141+
class QuantizeGraph(Transformation):
142+
"""This transformation can be used to introduce a Quant node for a specific type of node in the graph.
143+
Users would be able to specify the location of the quant node by providing the input and output index
144+
as the parameters.
145+
146+
1) Expectations:
147+
a) Onnx model in the modelwraper format.
148+
b) Model must be cleaned using qonnx.util.cleanup.cleanup_model()
149+
c) Batchsize to be set.
150+
151+
2) Steps to transform are:
152+
Step1: Finding the input for the quant node.
153+
Step2: Finding the consumer of the quant node output.
154+
Step3: Finding the shape for the output tensor of quant node.
155+
Note: The output tensor of the quant node must have the same shape as the consumer of the input
156+
to the quant node.
157+
158+
3) Input:
159+
A dict "quantnode_map" specifying the criterion, positions, and input parameters like
160+
scale, bitwidth, zeropoint, and others for a specific quantnode.
161+
162+
Criterion:
163+
a) name: This will allow users to add quant nodes for specific node like "Conv_0" and "Gemm_0".
164+
Note: using this users can have quant nodes with different parameters. Ex: quantizing
165+
"Conv_0" and "Conv_1" with bitwidth of 4 and 6, respectively.
166+
b) op_type: This will allow users to add quant nodes for all nodes of a particular op_type such
167+
as, "Conv", "Gemm", and others.
168+
Note: All quant nodes created using op_type criterion will have the same input
169+
parameters (scale, zeropoint, bitwidth, and others.)
170+
c) name and op_type: In this case, quant nodes will be added with precedence to "Name"
171+
in comparison to "op_type".
172+
173+
Positions: ("input", index) or ("output", index)
174+
a) "input": indicates that the user want to quantize the input of the selected node.
175+
b) "output": indicates that the user want to quantize the output of the selected node.
176+
c) index: refers to the input/output index to quantize (a node can have multiple inputs and outputs)
177+
178+
Parameters (to quant node) are provided as (scale, zeropoint, bitwidth, narrow, signed, rounding_mode)
179+
180+
a) Inputs: scale, zeropoint, bitwidth.
181+
b) Attributes: narrow, signed, rounding_mode.
182+
183+
4) Assert:
184+
a) The input is a dictionary representing the node names as keys and a list of quant positions
185+
as values.
186+
b) The input dictionary must have atleast one mac node (Conv, gemm, matmul) for the transformation.
187+
188+
5) Return:
189+
Returns a model with new quant nodes created at the positions specified using the "quantnode_map".
190+
191+
6) Example:
192+
quantnode_map = {"name": {"Conv_0": [(("input", 0), (1, 0, 8, 0, 1, "ROUND")),
193+
(("input", 1), (1, 0, 8, 0, 1, "ROUND")),
194+
(("output", 0), (1, 0, 8, 0, 1, "ROUND"))],
195+
"Conv_1": [(("input", 0), (1, 0, 8, 0, 1, "ROUND"))],
196+
"Conv_2": [(("input", 1), (1, 0, 8, 0, 1, "ROUND")),
197+
(("output", 0), (1, 0, 8, 0, 1, "ROUND"))]},
198+
199+
"op_type": {"Gemm": [(("input", 0), (1, 0, 8, 0, 1, "ROUND")),
200+
(("input", 1), (1, 0, 8, 0, 1, "ROUND")),
201+
(("input", 2), (1, 0, 8, 0, 1, "ROUND")),
202+
(("output", 0), (1, 0, 8, 0, 1, "ROUND"))]}}
203+
"""
204+
205+
def __init__(self, quantnode_map):
206+
super().__init__()
207+
self.quantnode_map = quantnode_map
208+
209+
def apply(self, model):
210+
model = model.transform(InferShapes())
211+
if type(self.quantnode_map) == dict:
212+
selection_type = self.quantnode_map.keys()
213+
if set(selection_type) <= {"name", "op_type"}:
214+
quantized_nodes = []
215+
if "name" in selection_type:
216+
by_name = self.quantnode_map["name"] # dict with unique names and list of positions.
217+
node_list_by_name = by_name.keys() # node names specified by the user for quant nodes.
218+
for node_name in node_list_by_name:
219+
input_positions = by_name[node_name] # input positions to introduce quant nodes.
220+
model = adjust_graph(model, input_positions, node_name, quantized_nodes)
221+
if "op_type" in selection_type:
222+
by_op_type = self.quantnode_map["op_type"] # dict with the unique names and list of positions.
223+
op_list = by_op_type.keys()
224+
for op in op_list:
225+
node_list = model.get_nodes_by_op_type(op) # List of all nodes with the operation type "op".
226+
input_positions = by_op_type[op]
227+
for node in node_list:
228+
node_name = node.name
229+
model = adjust_graph(model, input_positions, node_name, quantized_nodes)
230+
model = cleanup_model(model)
231+
else:
232+
raise Exception("Unsupported selection type")
233+
else:
234+
raise TypeError("Input must be a dictionary.")
235+
236+
graph_modified = False
237+
238+
return (model, graph_modified)

0 commit comments

Comments
 (0)