|
| 1 | +# Copyright (c) 2024 Advanced Micro Devices, Inc. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# Redistribution and use in source and binary forms, with or without |
| 5 | +# modification, are permitted provided that the following conditions are met: |
| 6 | +# |
| 7 | +# * Redistributions of source code must retain the above copyright notice, this |
| 8 | +# list of conditions and the following disclaimer. |
| 9 | +# |
| 10 | +# * Redistributions in binary form must reproduce the above copyright notice, |
| 11 | +# this list of conditions and the following disclaimer in the documentation |
| 12 | +# and/or other materials provided with the distribution. |
| 13 | +# |
| 14 | +# * Neither the name of qonnx nor the names of its |
| 15 | +# contributors may be used to endorse or promote products derived from |
| 16 | +# this software without specific prior written permission. |
| 17 | +# |
| 18 | +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" |
| 19 | +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE |
| 20 | +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE |
| 21 | +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE |
| 22 | +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL |
| 23 | +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR |
| 24 | +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER |
| 25 | +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, |
| 26 | +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
| 27 | +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
| 28 | + |
| 29 | + |
| 30 | +import numpy as np |
| 31 | +import onnx |
| 32 | +from onnx import TensorProto |
| 33 | + |
| 34 | +from qonnx.core.datatype import DataType |
| 35 | +from qonnx.core.modelwrapper import ModelWrapper |
| 36 | +from qonnx.transformation.base import Transformation |
| 37 | +from qonnx.transformation.general import SortGraph |
| 38 | +from qonnx.transformation.infer_shapes import InferShapes |
| 39 | +from qonnx.util.basic import qonnx_make_model |
| 40 | +from qonnx.util.cleanup import cleanup_model |
| 41 | + |
| 42 | + |
| 43 | +class graph_util: |
| 44 | + def get_node_id(self, model): |
| 45 | + node_index = {} |
| 46 | + node_ind = 0 |
| 47 | + for node in model.graph.node: |
| 48 | + node_index[node.name] = node_ind |
| 49 | + node_ind += 1 |
| 50 | + return node_index |
| 51 | + |
| 52 | + def node_from_name(self, model, node_name): |
| 53 | + for node in model.graph.node: |
| 54 | + if node.name == node_name: |
| 55 | + return node |
| 56 | + |
| 57 | + def identify_nodes(self, model, node_type): |
| 58 | + node_list = [] |
| 59 | + for node in model.graph.node: |
| 60 | + if node.op_type == node_type: |
| 61 | + node_list.append(node) |
| 62 | + return node_list |
| 63 | + |
| 64 | + def create_node( |
| 65 | + self, |
| 66 | + model, |
| 67 | + quantnode_input, |
| 68 | + quantnode_output_shape, |
| 69 | + node_count, |
| 70 | + tensor_count, |
| 71 | + scale_value, |
| 72 | + zeropoint_value, |
| 73 | + bitwidth_value, |
| 74 | + narrow, |
| 75 | + signed, |
| 76 | + rounding_mode, |
| 77 | + ): |
| 78 | + quantnode_output_dtype = DataType["UINT8"] |
| 79 | + quant_tensor = onnx.helper.make_tensor_value_info( |
| 80 | + model.make_new_valueinfo_name(), TensorProto.FLOAT, quantnode_output_shape |
| 81 | + ) |
| 82 | + model.graph.value_info.append(quant_tensor) |
| 83 | + model.set_tensor_datatype(quant_tensor.name, quantnode_output_dtype) |
| 84 | + |
| 85 | + stationary_input_dtype = DataType["FLOAT32"] |
| 86 | + scale_tensor = np.array(scale_value).astype(np.float32) |
| 87 | + s_value = onnx.helper.make_tensor_value_info( |
| 88 | + model.make_new_valueinfo_name(), TensorProto.FLOAT, quantnode_output_shape |
| 89 | + ) |
| 90 | + model.graph.value_info.append(s_value) |
| 91 | + model.set_tensor_datatype(s_value.name, stationary_input_dtype) |
| 92 | + model.set_initializer(s_value.name, scale_tensor) |
| 93 | + |
| 94 | + zeropt_tensor = np.array(zeropoint_value).astype(np.float32) |
| 95 | + z_value = onnx.helper.make_tensor_value_info( |
| 96 | + model.make_new_valueinfo_name(), TensorProto.FLOAT, quantnode_output_shape |
| 97 | + ) |
| 98 | + model.graph.value_info.append(z_value) |
| 99 | + model.set_tensor_datatype(z_value.name, stationary_input_dtype) |
| 100 | + model.set_initializer(z_value.name, zeropt_tensor) |
| 101 | + |
| 102 | + bitwidth_tensor = np.array(bitwidth_value).astype(np.float32) |
| 103 | + b_value = onnx.helper.make_tensor_value_info(model.make_new_valueinfo_name(), TensorProto.FLOAT, [1]) |
| 104 | + model.graph.value_info.append(b_value) |
| 105 | + model.set_tensor_datatype(b_value.name, stationary_input_dtype) |
| 106 | + model.set_initializer(b_value.name, bitwidth_tensor) |
| 107 | + |
| 108 | + quant_node = onnx.helper.make_node( |
| 109 | + "Quant", |
| 110 | + inputs=[quantnode_input, s_value.name, z_value.name, b_value.name], |
| 111 | + outputs=[quant_tensor.name], |
| 112 | + name="Quant_node_" + str(node_count) + str(tensor_count), |
| 113 | + narrow=narrow, |
| 114 | + signed=signed, |
| 115 | + rounding_mode=rounding_mode, |
| 116 | + ) |
| 117 | + |
| 118 | + return quant_node, quant_tensor |
| 119 | + |
| 120 | + def adjust_graph(self, model, input_positions, node_in_focus, quantized_nodes, node_count): |
| 121 | + tensor_count = 0 |
| 122 | + for pos in input_positions: |
| 123 | + node_details = (node_in_focus.name, pos[0]) |
| 124 | + if ( |
| 125 | + node_details not in quantized_nodes |
| 126 | + ): # This is to ensure that we don't quantize the same node for the same input/output index. |
| 127 | + if pos[0][0] == "input": |
| 128 | + input_to_quantnode = node_in_focus.input[pos[0][1]] |
| 129 | + consumer_node = node_in_focus |
| 130 | + producer_node = model.find_producer(input_to_quantnode) |
| 131 | + if producer_node is None or producer_node.op_type != "Quant": |
| 132 | + quantization_to_perform = "yes" |
| 133 | + else: |
| 134 | + quantization_to_perform = "no" |
| 135 | + else: |
| 136 | + input_to_quantnode = node_in_focus.output[pos[0][1]] |
| 137 | + consumer_node = model.find_consumer(input_to_quantnode) |
| 138 | + producer_node = model.find_producer(input_to_quantnode) |
| 139 | + if consumer_node is None or consumer_node.op_type != "Quant": |
| 140 | + quantization_to_perform = "yes" |
| 141 | + else: |
| 142 | + quantization_to_perform = "no" |
| 143 | + if quantization_to_perform == "yes": |
| 144 | + node_indx = self.get_node_id(model) # Getting index of each node in the graph. |
| 145 | + quantnode_output_shape = model.get_tensor_shape(input_to_quantnode) # Step: 3 |
| 146 | + |
| 147 | + quant_node, quant_tensor = self.create_node( |
| 148 | + model, |
| 149 | + input_to_quantnode, |
| 150 | + quantnode_output_shape, |
| 151 | + node_count, |
| 152 | + tensor_count, |
| 153 | + scale_value=pos[1][0], |
| 154 | + zeropoint_value=pos[1][1], |
| 155 | + bitwidth_value=pos[1][2], |
| 156 | + narrow=pos[1][3], |
| 157 | + signed=pos[1][4], |
| 158 | + rounding_mode=pos[1][5], |
| 159 | + ) |
| 160 | + |
| 161 | + if consumer_node is not None: |
| 162 | + node_pos = node_indx[consumer_node.name] |
| 163 | + model.graph.node[node_pos].input[pos[0][1]] = quant_tensor.name |
| 164 | + model.graph.node.append(quant_node) |
| 165 | + else: |
| 166 | + model.graph.value_info.remove(quant_tensor) |
| 167 | + model.graph.node.append(quant_node) |
| 168 | + model.graph.output.insert(0, quant_tensor) |
| 169 | + model.graph.output.pop(1) |
| 170 | + |
| 171 | + model = model.transform(SortGraph()) |
| 172 | + tensor_count += 1 |
| 173 | + quantized_nodes.append(node_details) |
| 174 | + else: |
| 175 | + print(f"{pos[0][0]} index {pos[0][1]} of {node_in_focus.name} is already quantized.") |
| 176 | + else: |
| 177 | + print(f"{pos[0][0]} index {pos[0][1]} of {node_in_focus.name} is already quantized.") |
| 178 | + continue |
| 179 | + |
| 180 | + return model |
| 181 | + |
| 182 | + |
| 183 | +class IntroduceQuantnode(Transformation): |
| 184 | + """This transformation can be used to introduce a Quant node for a specific type of node in the graph. |
| 185 | + Users would be able to specify the location of the quant node by providing the input and output indexs |
| 186 | + as the parameters. |
| 187 | +
|
| 188 | + 1) Expectations: |
| 189 | + a) Onnx model in the modelwraper format. |
| 190 | + b) Model must be cleaned using cleanup_model qonnx.util.cleanup.cleanup_model() |
| 191 | + c) Batchsize to be set. |
| 192 | +
|
| 193 | + 2) Steps to transform are |
| 194 | + Step1: Finding the input for the quant node. |
| 195 | + Step2: Finding the consumer of the quant node output. |
| 196 | + Step3: Finding the shape for the output tensor of quant node. |
| 197 | + Note: The output tensor of the quant node must have the same shape as the |
| 198 | + consumer of the input to the quant node. |
| 199 | +
|
| 200 | + 3) Introduction to quantnodes will be done with precedence to "Name" in comparison to "op_type". |
| 201 | +
|
| 202 | + 4) Assert: |
| 203 | + a) The input is a dictionary representing the node names as keys and a list of quant positions |
| 204 | + as values. |
| 205 | + b) The input dictionary must have atleast one mac node (Conv, gemm, matmul) for the transformation. |
| 206 | +
|
| 207 | + 5) Return: |
| 208 | + Returns a cleaned version of the model. |
| 209 | +
|
| 210 | + """ |
| 211 | + |
| 212 | + def __init__(self, quant_node_inputs): |
| 213 | + super().__init__() |
| 214 | + self.quant_node_inputs = quant_node_inputs |
| 215 | + self.graph_util = graph_util() |
| 216 | + |
| 217 | + def apply(self, model): |
| 218 | + model = model.transform(InferShapes()) |
| 219 | + if type(self.quant_node_inputs) == dict: |
| 220 | + selection_type = self.quant_node_inputs.keys() |
| 221 | + if set(selection_type) <= {"name", "op_type"}: |
| 222 | + node_count = 0 |
| 223 | + quantized_nodes = [] |
| 224 | + if "name" in selection_type: |
| 225 | + by_name = self.quant_node_inputs[ |
| 226 | + "name" |
| 227 | + ] # by_name is a dictionary with the unique node names as keys and the list of positions as values. |
| 228 | + node_list_by_name = by_name.keys() # name of all the nodes specified by the user for an quant node. |
| 229 | + for node_name in node_list_by_name: |
| 230 | + node_in_focus = self.graph_util.node_from_name(model, node_name) |
| 231 | + input_positions = by_name[ |
| 232 | + node_name |
| 233 | + ] # input positions specified by the user to introduce quant node. |
| 234 | + model = self.graph_util.adjust_graph( |
| 235 | + model, input_positions, node_in_focus, quantized_nodes, node_count |
| 236 | + ) |
| 237 | + node_count += 1 |
| 238 | + if "op_type" in selection_type: |
| 239 | + by_op_type = self.quant_node_inputs[ |
| 240 | + "op_type" |
| 241 | + ] # by_name is a dictionary with the unique node names as keys and the list of positions as values. |
| 242 | + op_list = by_op_type.keys() |
| 243 | + for op in op_list: |
| 244 | + node_list = self.graph_util.identify_nodes( |
| 245 | + model, op |
| 246 | + ) # List of all nodes with the operation type "op". |
| 247 | + input_positions = by_op_type[op] |
| 248 | + for node_in_focus in node_list: |
| 249 | + model = self.graph_util.adjust_graph( |
| 250 | + model, input_positions, node_in_focus, quantized_nodes, node_count |
| 251 | + ) |
| 252 | + node_count += 1 |
| 253 | + model = qonnx_make_model(model.graph) |
| 254 | + model = ModelWrapper(model) |
| 255 | + model = cleanup_model(model) |
| 256 | + else: |
| 257 | + raise Exception("Unsupported selection type") |
| 258 | + else: |
| 259 | + raise TypeError("Input must be a dictionary.") |
| 260 | + |
| 261 | + graph_modified = False |
| 262 | + |
| 263 | + return (model, graph_modified) |
0 commit comments