|
| 1 | +# Copyright (c) 2024 Advanced Micro Devices, Inc. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# Redistribution and use in source and binary forms, with or without |
| 5 | +# modification, are permitted provided that the following conditions are met: |
| 6 | +# |
| 7 | +# * Redistributions of source code must retain the above copyright notice, this |
| 8 | +# list of conditions and the following disclaimer. |
| 9 | +# |
| 10 | +# * Redistributions in binary form must reproduce the above copyright notice, |
| 11 | +# this list of conditions and the following disclaimer in the documentation |
| 12 | +# and/or other materials provided with the distribution. |
| 13 | +# |
| 14 | +# * Neither the name of qonnx nor the names of its |
| 15 | +# contributors may be used to endorse or promote products derived from |
| 16 | +# this software without specific prior written permission. |
| 17 | +# |
| 18 | +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" |
| 19 | +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE |
| 20 | +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE |
| 21 | +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE |
| 22 | +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL |
| 23 | +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR |
| 24 | +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER |
| 25 | +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, |
| 26 | +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
| 27 | +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
| 28 | + |
| 29 | + |
| 30 | +import numpy as np |
| 31 | +import onnx |
| 32 | +from onnx import TensorProto |
| 33 | + |
| 34 | +from qonnx.transformation.base import Transformation |
| 35 | +from qonnx.transformation.general import SortGraph |
| 36 | +from qonnx.transformation.infer_shapes import InferShapes |
| 37 | +from qonnx.util.cleanup import cleanup_model |
| 38 | + |
| 39 | + |
| 40 | +def create_quantnode( |
| 41 | + model, |
| 42 | + quantnode_input, |
| 43 | + quantnode_output_shape, |
| 44 | + scale_value, |
| 45 | + zeropoint_value, |
| 46 | + bitwidth_value, |
| 47 | + narrow, |
| 48 | + signed, |
| 49 | + rounding_mode, |
| 50 | +): |
| 51 | + quant_tensor = onnx.helper.make_tensor_value_info( |
| 52 | + model.make_new_valueinfo_name(), TensorProto.FLOAT, quantnode_output_shape |
| 53 | + ) |
| 54 | + model.graph.value_info.append(quant_tensor) |
| 55 | + |
| 56 | + scale_tensor = np.array(scale_value).astype(np.float32) |
| 57 | + s_value = onnx.helper.make_tensor_value_info(model.make_new_valueinfo_name(), TensorProto.FLOAT, quantnode_output_shape) |
| 58 | + model.graph.value_info.append(s_value) |
| 59 | + model.set_initializer(s_value.name, scale_tensor) |
| 60 | + |
| 61 | + zeropt_tensor = np.array(zeropoint_value).astype(np.float32) |
| 62 | + z_value = onnx.helper.make_tensor_value_info(model.make_new_valueinfo_name(), TensorProto.FLOAT, quantnode_output_shape) |
| 63 | + model.graph.value_info.append(z_value) |
| 64 | + model.set_initializer(z_value.name, zeropt_tensor) |
| 65 | + |
| 66 | + bitwidth_tensor = np.array(bitwidth_value).astype(np.float32) |
| 67 | + b_value = onnx.helper.make_tensor_value_info(model.make_new_valueinfo_name(), TensorProto.FLOAT, [1]) |
| 68 | + model.graph.value_info.append(b_value) |
| 69 | + model.set_initializer(b_value.name, bitwidth_tensor) |
| 70 | + |
| 71 | + quantnode = onnx.helper.make_node( |
| 72 | + "Quant", |
| 73 | + inputs=[quantnode_input, s_value.name, z_value.name, b_value.name], |
| 74 | + outputs=[quant_tensor.name], |
| 75 | + name="Quant_" + quantnode_input, |
| 76 | + narrow=narrow, |
| 77 | + signed=signed, |
| 78 | + rounding_mode=rounding_mode, |
| 79 | + ) |
| 80 | + |
| 81 | + return quantnode, quant_tensor |
| 82 | + |
| 83 | + |
| 84 | +def adjust_graph(model, input_positions, node_name, quantized_nodes): |
| 85 | + for pos in input_positions: |
| 86 | + node_details = (node_name, pos[0]) |
| 87 | + if node_details not in quantized_nodes: # not quantizing for same node_inp/out index. |
| 88 | + node_in_focus = model.get_node_from_name(node_name) |
| 89 | + |
| 90 | + if pos[0][0] == "input": |
| 91 | + quantnode_input = node_in_focus.input[pos[0][1]] |
| 92 | + consumer_node = node_in_focus |
| 93 | + producer_node = model.find_producer(quantnode_input) |
| 94 | + if producer_node is None or producer_node.op_type != "Quant": |
| 95 | + quantization_to_perform = True |
| 96 | + else: |
| 97 | + quantization_to_perform = False |
| 98 | + else: |
| 99 | + quantnode_input = node_in_focus.output[pos[0][1]] |
| 100 | + consumer_node = model.find_consumer(quantnode_input) |
| 101 | + producer_node = model.find_producer(quantnode_input) |
| 102 | + if consumer_node is None or consumer_node.op_type != "Quant": |
| 103 | + quantization_to_perform = True |
| 104 | + else: |
| 105 | + quantization_to_perform = False |
| 106 | + if quantization_to_perform is True: |
| 107 | + quantnode_output_shape = model.get_tensor_shape(quantnode_input) # Step: 3 |
| 108 | + quantnode, quant_tensor = create_quantnode( |
| 109 | + model, |
| 110 | + quantnode_input, |
| 111 | + quantnode_output_shape, |
| 112 | + scale_value=pos[1][0], |
| 113 | + zeropoint_value=pos[1][1], |
| 114 | + bitwidth_value=pos[1][2], |
| 115 | + narrow=pos[1][3], |
| 116 | + signed=pos[1][4], |
| 117 | + rounding_mode=pos[1][5], |
| 118 | + ) |
| 119 | + |
| 120 | + if consumer_node is not None: |
| 121 | + node_pos = model.get_node_index(consumer_node) |
| 122 | + model.graph.node[node_pos].input[pos[0][1]] = quant_tensor.name |
| 123 | + model.graph.node.append(quantnode) |
| 124 | + else: |
| 125 | + model.graph.value_info.remove(quant_tensor) |
| 126 | + model.graph.node.append(quantnode) |
| 127 | + model.graph.output.insert(0, quant_tensor) |
| 128 | + model.graph.output.pop(1) |
| 129 | + |
| 130 | + model = model.transform(SortGraph()) |
| 131 | + quantized_nodes.append(node_details) |
| 132 | + else: |
| 133 | + print(f"{pos[0][0]} index {pos[0][1]} of {node_name} is already quantized.") |
| 134 | + else: |
| 135 | + print(f"{pos[0][0]} index {pos[0][1]} of {node_name} is already quantized.") |
| 136 | + continue |
| 137 | + |
| 138 | + return model |
| 139 | + |
| 140 | + |
| 141 | +class QuantizeGraph(Transformation): |
| 142 | + """This transformation can be used to introduce a Quant node for a specific type of node in the graph. |
| 143 | + Users would be able to specify the location of the quant node by providing the input and output index |
| 144 | + as the parameters. |
| 145 | +
|
| 146 | + 1) Expectations: |
| 147 | + a) Onnx model in the modelwraper format. |
| 148 | + b) Model must be cleaned using qonnx.util.cleanup.cleanup_model() |
| 149 | + c) Batchsize to be set. |
| 150 | +
|
| 151 | + 2) Steps to transform are: |
| 152 | + Step1: Finding the input for the quant node. |
| 153 | + Step2: Finding the consumer of the quant node output. |
| 154 | + Step3: Finding the shape for the output tensor of quant node. |
| 155 | + Note: The output tensor of the quant node must have the same shape as the consumer of the input |
| 156 | + to the quant node. |
| 157 | +
|
| 158 | + 3) Input: |
| 159 | + A dict "quantnode_map" specifying the criterion, positions, and input parameters like |
| 160 | + scale, bitwidth, zeropoint, and others for a specific quantnode. |
| 161 | +
|
| 162 | + Criterion: |
| 163 | + a) name: This will allow users to add quant nodes for specific node like "Conv_0" and "Gemm_0". |
| 164 | + Note: using this users can have quant nodes with different parameters. Ex: quantizing |
| 165 | + "Conv_0" and "Conv_1" with bitwidth of 4 and 6, respectively. |
| 166 | + b) op_type: This will allow users to add quant nodes for all nodes of a particular op_type such |
| 167 | + as, "Conv", "Gemm", and others. |
| 168 | + Note: All quant nodes created using op_type criterion will have the same input |
| 169 | + parameters (scale, zeropoint, bitwidth, and others.) |
| 170 | + c) name and op_type: In this case, quant nodes will be added with precedence to "Name" |
| 171 | + in comparison to "op_type". |
| 172 | +
|
| 173 | + Positions: ("input", index) or ("output", index) |
| 174 | + a) "input": indicates that the user want to quantize the input of the selected node. |
| 175 | + b) "output": indicates that the user want to quantize the output of the selected node. |
| 176 | + c) index: refers to the input/output index to quantize (a node can have multiple inputs and outputs) |
| 177 | +
|
| 178 | + Parameters (to quant node) are provided as (scale, zeropoint, bitwidth, narrow, signed, rounding_mode) |
| 179 | +
|
| 180 | + a) Inputs: scale, zeropoint, bitwidth. |
| 181 | + b) Attributes: narrow, signed, rounding_mode. |
| 182 | +
|
| 183 | + 4) Assert: |
| 184 | + a) The input is a dictionary representing the node names as keys and a list of quant positions |
| 185 | + as values. |
| 186 | + b) The input dictionary must have atleast one mac node (Conv, gemm, matmul) for the transformation. |
| 187 | +
|
| 188 | + 5) Return: |
| 189 | + Returns a model with new quant nodes created at the positions specified using the "quantnode_map". |
| 190 | +
|
| 191 | + 6) Example: |
| 192 | + quantnode_map = {"name": {"Conv_0": [(("input", 0), (1, 0, 8, 0, 1, "ROUND")), |
| 193 | + (("input", 1), (1, 0, 8, 0, 1, "ROUND")), |
| 194 | + (("output", 0), (1, 0, 8, 0, 1, "ROUND"))], |
| 195 | + "Conv_1": [(("input", 0), (1, 0, 8, 0, 1, "ROUND"))], |
| 196 | + "Conv_2": [(("input", 1), (1, 0, 8, 0, 1, "ROUND")), |
| 197 | + (("output", 0), (1, 0, 8, 0, 1, "ROUND"))]}, |
| 198 | +
|
| 199 | + "op_type": {"Gemm": [(("input", 0), (1, 0, 8, 0, 1, "ROUND")), |
| 200 | + (("input", 1), (1, 0, 8, 0, 1, "ROUND")), |
| 201 | + (("input", 2), (1, 0, 8, 0, 1, "ROUND")), |
| 202 | + (("output", 0), (1, 0, 8, 0, 1, "ROUND"))]}} |
| 203 | + """ |
| 204 | + |
| 205 | + def __init__(self, quantnode_map): |
| 206 | + super().__init__() |
| 207 | + self.quantnode_map = quantnode_map |
| 208 | + |
| 209 | + def apply(self, model): |
| 210 | + model = model.transform(InferShapes()) |
| 211 | + if type(self.quantnode_map) == dict: |
| 212 | + selection_type = self.quantnode_map.keys() |
| 213 | + if set(selection_type) <= {"name", "op_type"}: |
| 214 | + quantized_nodes = [] |
| 215 | + if "name" in selection_type: |
| 216 | + by_name = self.quantnode_map["name"] # dict with unique names and list of positions. |
| 217 | + node_list_by_name = by_name.keys() # node names specified by the user for quant nodes. |
| 218 | + for node_name in node_list_by_name: |
| 219 | + input_positions = by_name[node_name] # input positions to introduce quant nodes. |
| 220 | + model = adjust_graph(model, input_positions, node_name, quantized_nodes) |
| 221 | + if "op_type" in selection_type: |
| 222 | + by_op_type = self.quantnode_map["op_type"] # dict with the unique names and list of positions. |
| 223 | + op_list = by_op_type.keys() |
| 224 | + for op in op_list: |
| 225 | + node_list = model.get_nodes_by_op_type(op) # List of all nodes with the operation type "op". |
| 226 | + input_positions = by_op_type[op] |
| 227 | + for node in node_list: |
| 228 | + node_name = node.name |
| 229 | + model = adjust_graph(model, input_positions, node_name, quantized_nodes) |
| 230 | + model = cleanup_model(model) |
| 231 | + else: |
| 232 | + raise Exception("Unsupported selection type") |
| 233 | + else: |
| 234 | + raise TypeError("Input must be a dictionary.") |
| 235 | + |
| 236 | + graph_modified = False |
| 237 | + |
| 238 | + return (model, graph_modified) |
0 commit comments