|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +# pyre-strict |
| 8 | + |
| 9 | +from typing import Optional, Tuple |
| 10 | + |
| 11 | +import executorch.backends.vulkan.utils as utils |
| 12 | +import torch |
| 13 | + |
| 14 | +import torch.nn.functional as F |
| 15 | + |
| 16 | +from executorch.backends.transforms.utils import get_param_tensor, is_param_node |
| 17 | +from executorch.exir import ExportedProgram |
| 18 | +from executorch.exir.dialects._ops import ops as exir_ops |
| 19 | +from executorch.exir.pass_base import ExportPass, PassResult |
| 20 | + |
| 21 | +################# |
| 22 | +## linear_qcnw ## |
| 23 | +################# |
| 24 | + |
| 25 | + |
| 26 | +def matches_linear_qcnw_pattern( # noqa: C901 |
| 27 | + program: ExportedProgram, node: torch.fx.Node |
| 28 | +) -> Optional[Tuple[torch.qscheme, int]]: |
| 29 | + """ |
| 30 | + Checks if the nodes surrounding a linear node matches the pattern for weight only |
| 31 | + quantized linear, where the weight is quantized channelswise to n bits. |
| 32 | +
|
| 33 | + If the graph pattern matches, then return a tuple of (quantization_method, nbits) |
| 34 | + describing the type of quantization used for the weights. Otherwise, return None. |
| 35 | + """ |
| 36 | + if not utils.is_linear_node(node): |
| 37 | + return None |
| 38 | + |
| 39 | + input_node = node.args[0] |
| 40 | + weight_node = node.args[1] |
| 41 | + |
| 42 | + # Type checking |
| 43 | + if not isinstance(weight_node, torch.fx.Node): |
| 44 | + return None |
| 45 | + if not isinstance(input_node, torch.fx.Node): |
| 46 | + return None |
| 47 | + |
| 48 | + # The input arg should not be a dequant node; if it is, then it is indicative that |
| 49 | + # dynamically quantized linear should be used instead |
| 50 | + if utils.is_dequant_node(input_node): |
| 51 | + return None |
| 52 | + |
| 53 | + # The weight arg should be a dequant node dequantizing the quantized weight |
| 54 | + # Furthermore, the op expects per channel quantization of the weight |
| 55 | + if not utils.is_dequant_per_channel_node(weight_node): |
| 56 | + return None |
| 57 | + |
| 58 | + orig_weight = weight_node.args[0] |
| 59 | + zeros = weight_node.args[2] |
| 60 | + |
| 61 | + # Type checking |
| 62 | + if not isinstance(orig_weight, torch.fx.Node): |
| 63 | + return None |
| 64 | + if not is_param_node(program, orig_weight): |
| 65 | + return None |
| 66 | + if not isinstance(zeros, torch.fx.Node): |
| 67 | + return None |
| 68 | + if not is_param_node(program, zeros): |
| 69 | + return None |
| 70 | + |
| 71 | + zeros_tensor = get_param_tensor(program, zeros) |
| 72 | + if not isinstance(zeros_tensor, torch.Tensor): |
| 73 | + return None |
| 74 | + |
| 75 | + quant_method = torch.per_channel_affine |
| 76 | + # Check for symmetric quantization, where the zeros used for dequantization will |
| 77 | + # actually be all zeros. |
| 78 | + if torch.all(zeros_tensor == 0): |
| 79 | + quant_method = torch.per_channel_symmetric |
| 80 | + |
| 81 | + orig_weight_tensor = get_param_tensor(program, orig_weight) |
| 82 | + if not isinstance(orig_weight_tensor, torch.Tensor): |
| 83 | + return None |
| 84 | + # Sanity check the dtype of the quantized weight |
| 85 | + if orig_weight_tensor.dtype != torch.int8: |
| 86 | + return None |
| 87 | + |
| 88 | + quant_min = orig_weight_tensor.min().item() |
| 89 | + quant_max = orig_weight_tensor.max().item() |
| 90 | + # Determine the number of bits the weight has been quantized to |
| 91 | + if quant_min >= -8 and quant_max <= 7: |
| 92 | + return quant_method, 4 |
| 93 | + elif quant_min >= -128 and quant_max <= 127: |
| 94 | + return quant_method, 8 |
| 95 | + |
| 96 | + return None |
| 97 | + |
| 98 | + |
| 99 | +def pack_4bit_weight_tensor(inp: torch.Tensor) -> torch.Tensor: |
| 100 | + """ |
| 101 | + Given a 8-bit weight tensor containing values quantized to 4 bits, create a packed |
| 102 | + weight tensor by packing 2 4-bit values in one unsigned 8-bit value. |
| 103 | +
|
| 104 | + An input weight tensor of shape (M, K) will produce a packed weight tensor of shape |
| 105 | + (M, K / 2). |
| 106 | + """ |
| 107 | + |
| 108 | + # Assert we got a properly quantized tensor. |
| 109 | + min, max = inp.min().item(), inp.max().item() |
| 110 | + assert ( |
| 111 | + max <= 7 and min >= -8 |
| 112 | + ), f"convert_to_qc4w: [min,max] out of [-8, 7] range, got [{min}, {max}]" |
| 113 | + |
| 114 | + # Assuming we have a 2d tensor |
| 115 | + if inp.ndim != 2: |
| 116 | + inp = inp.squeeze() |
| 117 | + assert ( |
| 118 | + inp.ndim == 2 |
| 119 | + ), f"convert_to_qc4w: expecting input tensor to be 2d, got {inp.ndim}" |
| 120 | + |
| 121 | + # pad ic |
| 122 | + if inp.shape[-1] % 2 != 0: |
| 123 | + inp = F.pad(input=inp, pad=(0, 1, 0, 0), mode="constant", value=0) |
| 124 | + |
| 125 | + # Shape after padding |
| 126 | + oc, ic = inp.shape |
| 127 | + assert ic % 2 == 0, "convert_to_qc4w: expecting ic to be even" |
| 128 | + |
| 129 | + # Adjust inp tensor for zp |
| 130 | + inp = inp.to(dtype=torch.uint8) + 8 |
| 131 | + |
| 132 | + # Prepare the Result tensor |
| 133 | + inp = inp.contiguous().view(-1) |
| 134 | + return (inp[::2] << 4 | inp[1::2]).view(oc, int(ic / 2)) |
| 135 | + |
| 136 | + |
| 137 | +def fuse_into_linear_qcnw_node( |
| 138 | + program: ExportedProgram, |
| 139 | + graph_module: torch.fx.GraphModule, |
| 140 | + linear_node: torch.fx.Node, |
| 141 | + quant_method: torch.qscheme, |
| 142 | + nbits: int, |
| 143 | +) -> None: |
| 144 | + """ |
| 145 | + The weight_int8pack_mm operator represents a weight only quantized linear operator, |
| 146 | + where the weight tensor has been quantized channelswise to nbits bits. |
| 147 | +
|
| 148 | + After the PT2E quantization flow, the expected graph pattern is |
| 149 | +
|
| 150 | + dq_weight = dequantize(weight, scales) |
| 151 | + out = linear(activation, dq_weight, bias?) |
| 152 | +
|
| 153 | + The goal of this function is to condense that sequence into |
| 154 | +
|
| 155 | + out = quantized_linear(activation, dq_weight, scales) |
| 156 | + out = out + bias |
| 157 | + """ |
| 158 | + activation = linear_node.args[0] |
| 159 | + dq_weight_node = linear_node.args[1] |
| 160 | + assert isinstance(activation, torch.fx.Node) |
| 161 | + assert isinstance(dq_weight_node, torch.fx.Node) |
| 162 | + |
| 163 | + bias = None |
| 164 | + if len(linear_node.args) > 2: |
| 165 | + bias = linear_node.args[2] |
| 166 | + assert isinstance(bias, torch.fx.Node) |
| 167 | + |
| 168 | + orig_weight = dq_weight_node.args[0] |
| 169 | + scale = dq_weight_node.args[1] |
| 170 | + |
| 171 | + # For 4 bit quantization, pack the weight tensor |
| 172 | + if nbits == 4: |
| 173 | + assert isinstance(orig_weight, torch.fx.Node) |
| 174 | + orig_weight_tensor = get_param_tensor(program, orig_weight) |
| 175 | + assert isinstance(orig_weight_tensor, torch.Tensor) |
| 176 | + packed_weight_tensor = pack_4bit_weight_tensor(orig_weight_tensor) |
| 177 | + utils.update_program_state_dict( |
| 178 | + program, |
| 179 | + orig_weight.name, |
| 180 | + packed_weight_tensor, |
| 181 | + ) |
| 182 | + orig_weight.meta["val"] = orig_weight.meta["val"][:, ::2].to(torch.uint8) |
| 183 | + |
| 184 | + if nbits == 8 and quant_method == torch.per_channel_symmetric: |
| 185 | + op_target = exir_ops.edge.aten._weight_int8pack_mm.default |
| 186 | + elif nbits == 4 and quant_method == torch.per_channel_symmetric: |
| 187 | + op_target = exir_ops.edge.et_vk.linear_qcs4w.default |
| 188 | + else: |
| 189 | + raise NotImplementedError( |
| 190 | + "only 4 and 8 bits per channel symmetric quant supported for linear_qcnw" |
| 191 | + ) |
| 192 | + |
| 193 | + with graph_module.graph.inserting_before(linear_node): |
| 194 | + weight_int8pack_mm_node = graph_module.graph.create_node( |
| 195 | + "call_function", |
| 196 | + op_target, |
| 197 | + (activation, orig_weight, scale), |
| 198 | + ) |
| 199 | + if bias: |
| 200 | + add_node = graph_module.graph.create_node( |
| 201 | + "call_function", |
| 202 | + exir_ops.edge.aten.add.Tensor, |
| 203 | + (weight_int8pack_mm_node, bias), |
| 204 | + ) |
| 205 | + linear_node.replace_all_uses_with(add_node) |
| 206 | + else: |
| 207 | + linear_node.replace_all_uses_with(weight_int8pack_mm_node) |
| 208 | + graph_module.graph.erase_node(linear_node) |
| 209 | + graph_module.graph.erase_node(dq_weight_node) |
| 210 | + |
| 211 | + |
| 212 | +class FuseQuantizedOpsTransform(ExportPass): |
| 213 | + def __init__(self, exported_program: ExportedProgram) -> None: |
| 214 | + super().__init__() |
| 215 | + self.program = exported_program |
| 216 | + |
| 217 | + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: |
| 218 | + for node in graph_module.graph.nodes: |
| 219 | + qcnw_details = matches_linear_qcnw_pattern(self.program, node) |
| 220 | + if qcnw_details is not None: |
| 221 | + qcnw_method, qcnw_nbits = qcnw_details |
| 222 | + fuse_into_linear_qcnw_node( |
| 223 | + self.program, graph_module, node, qcnw_method, qcnw_nbits |
| 224 | + ) |
| 225 | + |
| 226 | + graph_module.recompile() |
| 227 | + graph_module = super().call(graph_module).graph_module |
| 228 | + |
| 229 | + return PassResult(graph_module, True) |
0 commit comments