Skip to content

Commit 1afec04

Browse files
committed
[ET-VK] Introduce generic export pass for fusing Q/DQ nodes
Pull Request resolved: #10525 ## Context When quantizing models with the PT2E quantization flow, quantize/dequantize nodes will be inserted into the graph. However, these quantize/dequantize nodes must be fused with operators such as `aten.linear.default` to produce nodes corresponding to quantized operators (e.g. `weight_int8pack_mm`) in order for quantized operator implementations to be called at runtime. Currently, the op fusion is done by the `fuse_dequant_linear.py` pass, however, this only handles one specific fusion pattern to generate a `weight_int8pack_mm` operator. As more quantized operators are to be supported in ET-VK via the PT2E quantization flow, a more generic fusion pass is needed that can handle a variety of fusion patterns. ## Changes Introduce the `FuseQuantizedOpsTransform()` pass. I elected to introduce a new pass under the `backends/vulkan/_passes` directory, as opposed to modifying the existing pass because I anticipate the majority of the fusion patterns to be specific to ET-VK. Remove the existing `FuseDequantLinearPass()` Switch to using the `FuseQuantizedOpsTransform` pass instead of the old `FuseDequantLinear` pass. Add `test_vulkan_passes` Python test to test export passes. Added some refactors to `test_vulkan_delegate` Python test to improve code organization. Introduce the `linear_qcsnw` nomenclature: * q - quantized * c - per-channel / channelswise * s - symmetric * n - number of bits (qcs4w for 4-bit quant, qcs8w for 8-bit quant) * w - weight quantized Added custom op for `linear_qcs4w` for 4-bit weight quantized linear and add the ability for the quantized op fusion pass to produce this op. Slight renaming/refactoring of quantization config retrieval functions in the `VulkanQuantizer` to improve clarity and API flexibility. ghstack-source-id: 281448174 @exported-using-ghexport Differential Revision: [D73794042](https://our.internmc.facebook.com/intern/diff/D73794042/)
1 parent 280db15 commit 1afec04

14 files changed

+712
-202
lines changed

backends/transforms/fuse_dequant_linear.py

-77
This file was deleted.

backends/transforms/targets.bzl

-15
Original file line numberDiff line numberDiff line change
@@ -77,21 +77,6 @@ def define_common_targets():
7777
],
7878
)
7979

80-
runtime.python_library(
81-
name = "fuse_dequant_linear",
82-
srcs = ["fuse_dequant_linear.py"],
83-
visibility = [
84-
"//executorch/backends/...",
85-
],
86-
deps = [
87-
":utils",
88-
"//caffe2:torch",
89-
"//executorch/exir:pass_base",
90-
"//executorch/exir:sym_util",
91-
"//executorch/exir/dialects:lib",
92-
],
93-
)
94-
9580
runtime.python_library(
9681
name = "view_copy_to_squeeze_unsqueeze",
9782
srcs = ["view_copy_to_squeeze_unsqueeze.py"],

backends/vulkan/_passes/TARGETS

+19
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,23 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
33

44
oncall("executorch")
55

6+
runtime.python_library(
7+
name = "fuse_quantized_ops",
8+
srcs = ["fuse_quantized_ops.py"],
9+
visibility = [
10+
"//executorch/backends/...",
11+
],
12+
deps = [
13+
"//caffe2:torch",
14+
"//executorch/backends/transforms:utils",
15+
"//executorch/backends/vulkan:custom_ops_lib",
16+
"//executorch/backends/vulkan:utils_lib",
17+
"//executorch/exir:pass_base",
18+
"//executorch/exir:sym_util",
19+
"//executorch/exir/dialects:lib",
20+
],
21+
)
22+
623
runtime.python_library(
724
name = "insert_prepack_nodes",
825
srcs = ["insert_prepack_nodes.py"],
@@ -13,6 +30,7 @@ runtime.python_library(
1330
"//caffe2:torch",
1431
"//executorch/exir:pass_base",
1532
"//executorch/backends/vulkan:utils_lib",
33+
"//executorch/backends/vulkan:op_registry",
1634
],
1735
)
1836

@@ -110,6 +128,7 @@ runtime.python_library(
110128
"//executorch/examples/...",
111129
],
112130
deps = [
131+
":fuse_quantized_ops",
113132
":insert_prepack_nodes",
114133
":int4_weight_only_quantizer",
115134
":remove_asserts",

backends/vulkan/_passes/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66

77
# pyre-strict
88

9+
from executorch.backends.vulkan._passes.fuse_quantized_ops import (
10+
FuseQuantizedOpsTransform,
11+
)
912
from executorch.backends.vulkan._passes.insert_prepack_nodes import insert_prepack_nodes
1013
from executorch.backends.vulkan._passes.int4_weight_only_quantizer import (
1114
VkInt4WeightOnlyQuantizer,
@@ -26,6 +29,7 @@
2629
from executorch.backends.vulkan._passes.tag_memory_meta_pass import TagMemoryMetaPass
2730

2831
__all__ = [
32+
"FuseQuantizedOpsTransform",
2933
"insert_prepack_nodes",
3034
"VkInt4WeightOnlyQuantizer",
3135
"remove_asserts",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
from typing import Optional, Tuple
10+
11+
import executorch.backends.vulkan.utils as utils
12+
import torch
13+
14+
import torch.nn.functional as F
15+
16+
from executorch.backends.transforms.utils import get_param_tensor, is_param_node
17+
from executorch.exir import ExportedProgram
18+
from executorch.exir.dialects._ops import ops as exir_ops
19+
from executorch.exir.pass_base import ExportPass, PassResult
20+
21+
#################
22+
## linear_qcnw ##
23+
#################
24+
25+
26+
def matches_linear_qcnw_pattern( # noqa: C901
27+
program: ExportedProgram, node: torch.fx.Node
28+
) -> Optional[Tuple[torch.qscheme, int]]:
29+
"""
30+
Checks if the nodes surrounding a linear node matches the pattern for weight only
31+
quantized linear, where the weight is quantized channelswise to n bits.
32+
33+
If the graph pattern matches, then return a tuple of (quantization_method, nbits)
34+
describing the type of quantization used for the weights. Otherwise, return None.
35+
"""
36+
if not utils.is_linear_node(node):
37+
return None
38+
39+
input_node = node.args[0]
40+
weight_node = node.args[1]
41+
42+
# Type checking
43+
if not isinstance(weight_node, torch.fx.Node):
44+
return None
45+
if not isinstance(input_node, torch.fx.Node):
46+
return None
47+
48+
# The input arg should not be a dequant node; if it is, then it is indicative that
49+
# dynamically quantized linear should be used instead
50+
if utils.is_dequant_node(input_node):
51+
return None
52+
53+
# The weight arg should be a dequant node dequantizing the quantized weight
54+
# Furthermore, the op expects per channel quantization of the weight
55+
if not utils.is_dequant_per_channel_node(weight_node):
56+
return None
57+
58+
orig_weight = weight_node.args[0]
59+
zeros = weight_node.args[2]
60+
61+
# Type checking
62+
if not isinstance(orig_weight, torch.fx.Node):
63+
return None
64+
if not is_param_node(program, orig_weight):
65+
return None
66+
if not isinstance(zeros, torch.fx.Node):
67+
return None
68+
if not is_param_node(program, zeros):
69+
return None
70+
71+
zeros_tensor = get_param_tensor(program, zeros)
72+
if not isinstance(zeros_tensor, torch.Tensor):
73+
return None
74+
75+
quant_method = torch.per_channel_affine
76+
# Check for symmetric quantization, where the zeros used for dequantization will
77+
# actually be all zeros.
78+
if torch.all(zeros_tensor == 0):
79+
quant_method = torch.per_channel_symmetric
80+
81+
orig_weight_tensor = get_param_tensor(program, orig_weight)
82+
if not isinstance(orig_weight_tensor, torch.Tensor):
83+
return None
84+
# Sanity check the dtype of the quantized weight
85+
if orig_weight_tensor.dtype != torch.int8:
86+
return None
87+
88+
quant_min = orig_weight_tensor.min().item()
89+
quant_max = orig_weight_tensor.max().item()
90+
# Determine the number of bits the weight has been quantized to
91+
if quant_min >= -8 and quant_max <= 7:
92+
return quant_method, 4
93+
elif quant_min >= -128 and quant_max <= 127:
94+
return quant_method, 8
95+
96+
return None
97+
98+
99+
def pack_4bit_weight_tensor(inp: torch.Tensor) -> torch.Tensor:
100+
"""
101+
Given a 8-bit weight tensor containing values quantized to 4 bits, create a packed
102+
weight tensor by packing 2 4-bit values in one unsigned 8-bit value.
103+
104+
An input weight tensor of shape (M, K) will produce a packed weight tensor of shape
105+
(M, K / 2).
106+
"""
107+
108+
# Assert we got a properly quantized tensor.
109+
min, max = inp.min().item(), inp.max().item()
110+
assert (
111+
max <= 7 and min >= -8
112+
), f"convert_to_qc4w: [min,max] out of [-8, 7] range, got [{min}, {max}]"
113+
114+
# Assuming we have a 2d tensor
115+
if inp.ndim != 2:
116+
inp = inp.squeeze()
117+
assert (
118+
inp.ndim == 2
119+
), f"convert_to_qc4w: expecting input tensor to be 2d, got {inp.ndim}"
120+
121+
# pad ic
122+
if inp.shape[-1] % 2 != 0:
123+
inp = F.pad(input=inp, pad=(0, 1, 0, 0), mode="constant", value=0)
124+
125+
# Shape after padding
126+
oc, ic = inp.shape
127+
assert ic % 2 == 0, "convert_to_qc4w: expecting ic to be even"
128+
129+
# Adjust inp tensor for zp
130+
inp = inp.to(dtype=torch.uint8) + 8
131+
132+
# Prepare the Result tensor
133+
inp = inp.contiguous().view(-1)
134+
return (inp[::2] << 4 | inp[1::2]).view(oc, int(ic / 2))
135+
136+
137+
def fuse_into_linear_qcnw_node(
138+
program: ExportedProgram,
139+
graph_module: torch.fx.GraphModule,
140+
linear_node: torch.fx.Node,
141+
quant_method: torch.qscheme,
142+
nbits: int,
143+
) -> None:
144+
"""
145+
The weight_int8pack_mm operator represents a weight only quantized linear operator,
146+
where the weight tensor has been quantized channelswise to nbits bits.
147+
148+
After the PT2E quantization flow, the expected graph pattern is
149+
150+
dq_weight = dequantize(weight, scales)
151+
out = linear(activation, dq_weight, bias?)
152+
153+
The goal of this function is to condense that sequence into
154+
155+
out = quantized_linear(activation, dq_weight, scales)
156+
out = out + bias
157+
"""
158+
activation = linear_node.args[0]
159+
dq_weight_node = linear_node.args[1]
160+
assert isinstance(activation, torch.fx.Node)
161+
assert isinstance(dq_weight_node, torch.fx.Node)
162+
163+
bias = None
164+
if len(linear_node.args) > 2:
165+
bias = linear_node.args[2]
166+
assert isinstance(bias, torch.fx.Node)
167+
168+
orig_weight = dq_weight_node.args[0]
169+
scale = dq_weight_node.args[1]
170+
171+
# For 4 bit quantization, pack the weight tensor
172+
if nbits == 4:
173+
assert isinstance(orig_weight, torch.fx.Node)
174+
orig_weight_tensor = get_param_tensor(program, orig_weight)
175+
assert isinstance(orig_weight_tensor, torch.Tensor)
176+
packed_weight_tensor = pack_4bit_weight_tensor(orig_weight_tensor)
177+
utils.update_program_state_dict(
178+
program,
179+
orig_weight.name,
180+
packed_weight_tensor,
181+
)
182+
orig_weight.meta["val"] = orig_weight.meta["val"][:, ::2].to(torch.uint8)
183+
184+
if nbits == 8 and quant_method == torch.per_channel_symmetric:
185+
op_target = exir_ops.edge.aten._weight_int8pack_mm.default
186+
elif nbits == 4 and quant_method == torch.per_channel_symmetric:
187+
op_target = exir_ops.edge.et_vk.linear_qcs4w.default
188+
else:
189+
raise NotImplementedError(
190+
"only 4 and 8 bits per channel symmetric quant supported for linear_qcnw"
191+
)
192+
193+
with graph_module.graph.inserting_before(linear_node):
194+
weight_int8pack_mm_node = graph_module.graph.create_node(
195+
"call_function",
196+
op_target,
197+
(activation, orig_weight, scale),
198+
)
199+
if bias:
200+
add_node = graph_module.graph.create_node(
201+
"call_function",
202+
exir_ops.edge.aten.add.Tensor,
203+
(weight_int8pack_mm_node, bias),
204+
)
205+
linear_node.replace_all_uses_with(add_node)
206+
else:
207+
linear_node.replace_all_uses_with(weight_int8pack_mm_node)
208+
graph_module.graph.erase_node(linear_node)
209+
graph_module.graph.erase_node(dq_weight_node)
210+
211+
212+
class FuseQuantizedOpsTransform(ExportPass):
213+
def __init__(self, exported_program: ExportedProgram) -> None:
214+
super().__init__()
215+
self.program = exported_program
216+
217+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
218+
for node in graph_module.graph.nodes:
219+
qcnw_details = matches_linear_qcnw_pattern(self.program, node)
220+
if qcnw_details is not None:
221+
qcnw_method, qcnw_nbits = qcnw_details
222+
fuse_into_linear_qcnw_node(
223+
self.program, graph_module, node, qcnw_method, qcnw_nbits
224+
)
225+
226+
graph_module.recompile()
227+
graph_module = super().call(graph_module).graph_module
228+
229+
return PassResult(graph_module, True)

0 commit comments

Comments
 (0)