Skip to content

Commit 0282540

Browse files
committed
[ET-VK] Introduce generic export pass for fusing Q/DQ nodes
## Context When quantizing models with the PT2E quantization flow, quantize/dequantize nodes will be inserted into the graph. However, these quantize/dequantize nodes must be fused with operators such as `aten.linear.default` to produce nodes corresponding to quantized operators (e.g. `weight_int8pack_mm`) in order for quantized operator implementations to be called at runtime. Currently, the op fusion is done by the `fuse_dequant_linear.py` pass, however, this only handles one specific fusion pattern to generate a `weight_int8pack_mm` operator. As more quantized operators are to be supported in ET-VK via the PT2E quantization flow, a more generic fusion pass is needed that can handle a variety of fusion patterns. ## Changes Introduce the `FuseQuantizedOpsTransform()` pass. I elected to introduce a new pass under the `backends/vulkan/_passes` directory, as opposed to modifying the existing pass because I anticipate the majority of the fusion patterns to be specific to ET-VK. Remove the existing `FuseDequantLinearPass()` Switch to using the `FuseQuantizedOpsTransform` pass instead of the old `FuseDequantLinear` pass. Add `test_vulkan_passes` Python test to test export passes. Some small refactors to `test_vulkan_delegate` Python test to improve code organizations. Differential Revision: [D73794042](https://our.internmc.facebook.com/intern/diff/D73794042/) ghstack-source-id: 280746102 Pull Request resolved: #10525
1 parent df75088 commit 0282540

11 files changed

+452
-168
lines changed

backends/transforms/fuse_dequant_linear.py

-77
This file was deleted.

backends/transforms/targets.bzl

-15
Original file line numberDiff line numberDiff line change
@@ -77,21 +77,6 @@ def define_common_targets():
7777
],
7878
)
7979

80-
runtime.python_library(
81-
name = "fuse_dequant_linear",
82-
srcs = ["fuse_dequant_linear.py"],
83-
visibility = [
84-
"//executorch/backends/...",
85-
],
86-
deps = [
87-
":utils",
88-
"//caffe2:torch",
89-
"//executorch/exir:pass_base",
90-
"//executorch/exir:sym_util",
91-
"//executorch/exir/dialects:lib",
92-
],
93-
)
94-
9580
runtime.python_library(
9681
name = "view_copy_to_squeeze_unsqueeze",
9782
srcs = ["view_copy_to_squeeze_unsqueeze.py"],

backends/vulkan/_passes/TARGETS

+17
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,21 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
33

44
oncall("executorch")
55

6+
runtime.python_library(
7+
name = "fuse_quantized_ops",
8+
srcs = ["fuse_quantized_ops.py"],
9+
visibility = [
10+
"//executorch/backends/...",
11+
],
12+
deps = [
13+
"//caffe2:torch",
14+
"//executorch/backends/vulkan:utils_lib",
15+
"//executorch/exir:pass_base",
16+
"//executorch/exir:sym_util",
17+
"//executorch/exir/dialects:lib",
18+
],
19+
)
20+
621
runtime.python_library(
722
name = "insert_prepack_nodes",
823
srcs = ["insert_prepack_nodes.py"],
@@ -13,6 +28,7 @@ runtime.python_library(
1328
"//caffe2:torch",
1429
"//executorch/exir:pass_base",
1530
"//executorch/backends/vulkan:utils_lib",
31+
"//executorch/backends/vulkan:op_registry",
1632
],
1733
)
1834

@@ -110,6 +126,7 @@ runtime.python_library(
110126
"//executorch/examples/...",
111127
],
112128
deps = [
129+
":fuse_quantized_ops",
113130
":insert_prepack_nodes",
114131
":int4_weight_only_quantizer",
115132
":remove_asserts",

backends/vulkan/_passes/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66

77
# pyre-strict
88

9+
from executorch.backends.vulkan._passes.fuse_quantized_ops import (
10+
FuseQuantizedOpsTransform,
11+
)
912
from executorch.backends.vulkan._passes.insert_prepack_nodes import insert_prepack_nodes
1013
from executorch.backends.vulkan._passes.int4_weight_only_quantizer import (
1114
VkInt4WeightOnlyQuantizer,
@@ -26,6 +29,7 @@
2629
from executorch.backends.vulkan._passes.tag_memory_meta_pass import TagMemoryMetaPass
2730

2831
__all__ = [
32+
"FuseQuantizedOpsTransform",
2933
"insert_prepack_nodes",
3034
"VkInt4WeightOnlyQuantizer",
3135
"remove_asserts",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import executorch.backends.vulkan.utils as utils
10+
import torch
11+
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
from executorch.exir.pass_base import ExportPass, PassResult
14+
15+
#############################
16+
## aten.weight_int8pack_mm ##
17+
#############################
18+
19+
20+
def matches_int8pack_mm_pattern(node: torch.fx.Node) -> bool:
21+
if not utils.is_linear_node(node):
22+
return False
23+
24+
input_node = node.args[0]
25+
weight_node = node.args[1]
26+
27+
# Type checking
28+
if not isinstance(weight_node, torch.fx.Node):
29+
return False
30+
if not isinstance(input_node, torch.fx.Node):
31+
return False
32+
33+
# The weight arg should be a dequant node dequantizing the quantized weight
34+
# Furthermore, the op expects per channel quantization of the weight
35+
if not utils.is_dequant_per_channel_node(weight_node):
36+
return False
37+
38+
orig_weight = weight_node.args[0]
39+
if not isinstance(orig_weight, torch.fx.Node):
40+
return False
41+
42+
# The quantized weight data should be a int8 tensor
43+
if orig_weight.meta["val"].dtype != torch.int8:
44+
return False
45+
46+
# The input arg should not be a dequant node
47+
if utils.is_dequant_node(input_node):
48+
return False
49+
50+
return True
51+
52+
53+
def fuse_into_weight_int8pack_mm_node(
54+
graph_module: torch.fx.GraphModule,
55+
linear_node: torch.fx.Node,
56+
) -> None:
57+
"""
58+
The weight_int8pack_mm operator represents a weight only quantized linear operator.
59+
After the PT2E quantization flow, the expected graph pattern is
60+
61+
dq_weight = dequantize(weight, scales)
62+
out = linear(activation, dq_weight, bias?)
63+
64+
The goal of this function is to condense that sequence into
65+
66+
out = weight_int8pack_mm(activation, dq_weight, scales)
67+
out = out + bias
68+
"""
69+
activation = linear_node.args[0]
70+
dq_weight_node = linear_node.args[1]
71+
assert isinstance(activation, torch.fx.Node)
72+
assert isinstance(dq_weight_node, torch.fx.Node)
73+
74+
bias = None
75+
if len(linear_node.args) > 2:
76+
bias = linear_node.args[2]
77+
assert isinstance(bias, torch.fx.Node)
78+
79+
orig_weight = dq_weight_node.args[0]
80+
scale = dq_weight_node.args[1]
81+
82+
with graph_module.graph.inserting_before(linear_node):
83+
weight_int8pack_mm_node = graph_module.graph.create_node(
84+
"call_function",
85+
exir_ops.edge.aten._weight_int8pack_mm.default,
86+
(activation, orig_weight, scale),
87+
)
88+
if bias:
89+
add_node = graph_module.graph.create_node(
90+
"call_function",
91+
exir_ops.edge.aten.add.Tensor,
92+
(weight_int8pack_mm_node, bias),
93+
)
94+
linear_node.replace_all_uses_with(add_node)
95+
else:
96+
linear_node.replace_all_uses_with(weight_int8pack_mm_node)
97+
graph_module.graph.erase_node(linear_node)
98+
graph_module.graph.erase_node(dq_weight_node)
99+
100+
101+
class FuseQuantizedOpsTransform(ExportPass):
102+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
103+
for node in graph_module.graph.nodes:
104+
if matches_int8pack_mm_pattern(node):
105+
fuse_into_weight_int8pack_mm_node(graph_module, node)
106+
107+
graph_module.recompile()
108+
graph_module = super().call(graph_module).graph_module
109+
110+
return PassResult(graph_module, True)

backends/vulkan/targets.bzl

+1-1
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ def define_common_targets(is_fbcode = False):
280280
deps = [
281281
"//caffe2:torch",
282282
"//executorch/exir:tensor",
283+
"//executorch/exir/backend/canonical_partitioners:config_partitioner_lib",
283284
"//executorch/backends/vulkan/serialization:lib",
284285
]
285286
)
@@ -332,7 +333,6 @@ def define_common_targets(is_fbcode = False):
332333
"//executorch/backends/transforms:addmm_mm_to_linear",
333334
"//executorch/backends/transforms:fuse_batch_norm_with_conv",
334335
"//executorch/backends/transforms:fuse_conv_with_clamp",
335-
"//executorch/backends/transforms:fuse_dequant_linear",
336336
"//executorch/backends/transforms:fuse_view_copy",
337337
"//executorch/backends/transforms:remove_clone_ops",
338338
"//executorch/backends/transforms:view_copy_to_squeeze_unsqueeze",

backends/vulkan/test/TARGETS

+13
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,19 @@ python_unittest(
2424
],
2525
)
2626

27+
python_unittest(
28+
name = "test_vulkan_passes",
29+
srcs = [
30+
"test_vulkan_passes.py",
31+
],
32+
deps = [
33+
"//caffe2:torch",
34+
"//executorch/backends/vulkan/_passes:vulkan_passes",
35+
"//executorch/backends/vulkan/quantizer:vulkan_quantizer",
36+
"//executorch/backends/vulkan:vulkan_preprocess",
37+
]
38+
)
39+
2740
python_unittest(
2841
name = "test_vulkan_delegate_header",
2942
srcs = [

0 commit comments

Comments
 (0)