Skip to content

Commit 538a935

Browse files
author
Harish
committed
Transformation pass to introduce quantnodes
1 parent cadd6b2 commit 538a935

File tree

2 files changed

+410
-0
lines changed

2 files changed

+410
-0
lines changed
Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
# Copyright (c) 2024 Advanced Micro Devices, Inc.
2+
# All rights reserved.
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions are met:
6+
#
7+
# * Redistributions of source code must retain the above copyright notice, this
8+
# list of conditions and the following disclaimer.
9+
#
10+
# * Redistributions in binary form must reproduce the above copyright notice,
11+
# this list of conditions and the following disclaimer in the documentation
12+
# and/or other materials provided with the distribution.
13+
#
14+
# * Neither the name of qonnx nor the names of its
15+
# contributors may be used to endorse or promote products derived from
16+
# this software without specific prior written permission.
17+
#
18+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28+
29+
30+
import numpy as np
31+
import onnx
32+
from onnx import TensorProto
33+
34+
from qonnx.core.datatype import DataType
35+
from qonnx.core.modelwrapper import ModelWrapper
36+
from qonnx.transformation.base import Transformation
37+
from qonnx.transformation.general import SortGraph
38+
from qonnx.transformation.infer_shapes import InferShapes
39+
from qonnx.util.basic import qonnx_make_model
40+
from qonnx.util.cleanup import cleanup_model
41+
42+
43+
class graph_util:
44+
def get_node_id(self, model):
45+
node_index = {}
46+
node_ind = 0
47+
for node in model.graph.node:
48+
node_index[node.name] = node_ind
49+
node_ind += 1
50+
return node_index
51+
52+
def node_from_name(self, model, node_name):
53+
for node in model.graph.node:
54+
if node.name == node_name:
55+
return node
56+
57+
def identify_nodes(self, model, node_type):
58+
node_list = []
59+
for node in model.graph.node:
60+
if node.op_type == node_type:
61+
node_list.append(node)
62+
return node_list
63+
64+
def create_node(
65+
self,
66+
model,
67+
quantnode_input,
68+
quantnode_output_shape,
69+
node_count,
70+
tensor_count,
71+
scale_value,
72+
zeropoint_value,
73+
bitwidth_value,
74+
narrow,
75+
signed,
76+
rounding_mode,
77+
):
78+
quantnode_output_dtype = DataType["UINT8"]
79+
quant_tensor = onnx.helper.make_tensor_value_info(
80+
model.make_new_valueinfo_name(), TensorProto.FLOAT, quantnode_output_shape
81+
)
82+
model.graph.value_info.append(quant_tensor)
83+
model.set_tensor_datatype(quant_tensor.name, quantnode_output_dtype)
84+
85+
stationary_input_dtype = DataType["FLOAT32"]
86+
scale_tensor = np.array(scale_value).astype(np.float32)
87+
s_value = onnx.helper.make_tensor_value_info(
88+
model.make_new_valueinfo_name(), TensorProto.FLOAT, quantnode_output_shape
89+
)
90+
model.graph.value_info.append(s_value)
91+
model.set_tensor_datatype(s_value.name, stationary_input_dtype)
92+
model.set_initializer(s_value.name, scale_tensor)
93+
94+
zeropt_tensor = np.array(zeropoint_value).astype(np.float32)
95+
z_value = onnx.helper.make_tensor_value_info(
96+
model.make_new_valueinfo_name(), TensorProto.FLOAT, quantnode_output_shape
97+
)
98+
model.graph.value_info.append(z_value)
99+
model.set_tensor_datatype(z_value.name, stationary_input_dtype)
100+
model.set_initializer(z_value.name, zeropt_tensor)
101+
102+
bitwidth_tensor = np.array(bitwidth_value).astype(np.float32)
103+
b_value = onnx.helper.make_tensor_value_info(model.make_new_valueinfo_name(), TensorProto.FLOAT, [1])
104+
model.graph.value_info.append(b_value)
105+
model.set_tensor_datatype(b_value.name, stationary_input_dtype)
106+
model.set_initializer(b_value.name, bitwidth_tensor)
107+
108+
quant_node = onnx.helper.make_node(
109+
"Quant",
110+
inputs=[quantnode_input, s_value.name, z_value.name, b_value.name],
111+
outputs=[quant_tensor.name],
112+
name="Quant_node_" + str(node_count) + str(tensor_count),
113+
narrow=narrow,
114+
signed=signed,
115+
rounding_mode=rounding_mode,
116+
)
117+
118+
return quant_node, quant_tensor
119+
120+
def adjust_graph(self, model, input_positions, node_in_focus, quantized_nodes, node_count):
121+
tensor_count = 0
122+
for pos in input_positions:
123+
node_details = (node_in_focus.name, pos[0])
124+
if (
125+
node_details not in quantized_nodes
126+
): # This is to ensure that we don't quantize the same node for the same input/output index.
127+
if pos[0][0] == "input":
128+
input_to_quantnode = node_in_focus.input[pos[0][1]]
129+
consumer_node = node_in_focus
130+
producer_node = model.find_producer(input_to_quantnode)
131+
if producer_node is None or producer_node.op_type != "Quant":
132+
quantization_to_perform = "yes"
133+
else:
134+
quantization_to_perform = "no"
135+
else:
136+
input_to_quantnode = node_in_focus.output[pos[0][1]]
137+
consumer_node = model.find_consumer(input_to_quantnode)
138+
producer_node = model.find_producer(input_to_quantnode)
139+
if consumer_node is None or consumer_node.op_type != "Quant":
140+
quantization_to_perform = "yes"
141+
else:
142+
quantization_to_perform = "no"
143+
if quantization_to_perform == "yes":
144+
node_indx = self.get_node_id(model) # Getting index of each node in the graph.
145+
quantnode_output_shape = model.get_tensor_shape(input_to_quantnode) # Step: 3
146+
147+
quant_node, quant_tensor = self.create_node(
148+
model,
149+
input_to_quantnode,
150+
quantnode_output_shape,
151+
node_count,
152+
tensor_count,
153+
scale_value=pos[1][0],
154+
zeropoint_value=pos[1][1],
155+
bitwidth_value=pos[1][2],
156+
narrow=pos[1][3],
157+
signed=pos[1][4],
158+
rounding_mode=pos[1][5],
159+
)
160+
161+
if consumer_node is not None:
162+
node_pos = node_indx[consumer_node.name]
163+
model.graph.node[node_pos].input[pos[0][1]] = quant_tensor.name
164+
model.graph.node.append(quant_node)
165+
else:
166+
model.graph.value_info.remove(quant_tensor)
167+
model.graph.node.append(quant_node)
168+
model.graph.output.insert(0, quant_tensor)
169+
model.graph.output.pop(1)
170+
171+
model = model.transform(SortGraph())
172+
tensor_count += 1
173+
quantized_nodes.append(node_details)
174+
else:
175+
print(f"{pos[0][0]} index {pos[0][1]} of {node_in_focus.name} is already quantized.")
176+
else:
177+
print(f"{pos[0][0]} index {pos[0][1]} of {node_in_focus.name} is already quantized.")
178+
continue
179+
180+
return model
181+
182+
183+
class IntroduceQuantnode(Transformation):
184+
"""This transformation can be used to introduce a Quant node for a specific type of node in the graph.
185+
Users would be able to specify the location of the quant node by providing the input and output indexs
186+
as the parameters.
187+
188+
1) Expectations:
189+
a) Onnx model in the modelwraper format.
190+
b) Model must be cleaned using cleanup_model qonnx.util.cleanup.cleanup_model()
191+
c) Batchsize to be set.
192+
193+
2) Steps to transform are
194+
Step1: Finding the input for the quant node.
195+
Step2: Finding the consumer of the quant node output.
196+
Step3: Finding the shape for the output tensor of quant node.
197+
Note: The output tensor of the quant node must have the same shape as the
198+
consumer of the input to the quant node.
199+
200+
3) Introduction to quantnodes will be done with precedence to "Name" in comparison to "op_type".
201+
202+
4) Assert:
203+
a) The input is a dictionary representing the node names as keys and a list of quant positions
204+
as values.
205+
b) The input dictionary must have atleast one mac node (Conv, gemm, matmul) for the transformation.
206+
207+
5) Return:
208+
Returns a cleaned version of the model.
209+
210+
"""
211+
212+
def __init__(self, quant_node_inputs):
213+
super().__init__()
214+
self.quant_node_inputs = quant_node_inputs
215+
self.graph_util = graph_util()
216+
217+
def apply(self, model):
218+
model = model.transform(InferShapes())
219+
if type(self.quant_node_inputs) == dict:
220+
selection_type = self.quant_node_inputs.keys()
221+
if set(selection_type) <= {"name", "op_type"}:
222+
node_count = 0
223+
quantized_nodes = []
224+
if "name" in selection_type:
225+
by_name = self.quant_node_inputs[
226+
"name"
227+
] # by_name is a dictionary with the unique node names as keys and the list of positions as values.
228+
node_list_by_name = by_name.keys() # name of all the nodes specified by the user for an quant node.
229+
for node_name in node_list_by_name:
230+
node_in_focus = self.graph_util.node_from_name(model, node_name)
231+
input_positions = by_name[
232+
node_name
233+
] # input positions specified by the user to introduce quant node.
234+
model = self.graph_util.adjust_graph(
235+
model, input_positions, node_in_focus, quantized_nodes, node_count
236+
)
237+
node_count += 1
238+
if "op_type" in selection_type:
239+
by_op_type = self.quant_node_inputs[
240+
"op_type"
241+
] # by_name is a dictionary with the unique node names as keys and the list of positions as values.
242+
op_list = by_op_type.keys()
243+
for op in op_list:
244+
node_list = self.graph_util.identify_nodes(
245+
model, op
246+
) # List of all nodes with the operation type "op".
247+
input_positions = by_op_type[op]
248+
for node_in_focus in node_list:
249+
model = self.graph_util.adjust_graph(
250+
model, input_positions, node_in_focus, quantized_nodes, node_count
251+
)
252+
node_count += 1
253+
model = qonnx_make_model(model.graph)
254+
model = ModelWrapper(model)
255+
model = cleanup_model(model)
256+
else:
257+
raise Exception("Unsupported selection type")
258+
else:
259+
raise TypeError("Input must be a dictionary.")
260+
261+
graph_modified = False
262+
263+
return (model, graph_modified)

0 commit comments

Comments
 (0)