Skip to content

Commit 12079fe

Browse files
authored
Move the transpose matmul pass to OSS and run it earlier in the flow
Differential Revision: D73600069 Pull Request resolved: #10433
1 parent f1ceb6c commit 12079fe

File tree

2 files changed

+146
-1
lines changed

2 files changed

+146
-1
lines changed

backends/cadence/aot/replace_ops.py

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@
3232
is_quantized_tensor,
3333
quantize_tensor_multiplier,
3434
)
35-
from executorch.backends.cadence.aot.fuse_ops import FuseCascadedViewOps
35+
from executorch.backends.cadence.aot.fuse_ops import (
36+
FuseCascadedTransposeOrPermuteOps,
37+
FuseCascadedViewOps,
38+
)
3639
from executorch.backends.cadence.aot.pass_utils import (
3740
CadencePassAttribute,
3841
register_cadence_pass,
@@ -2290,6 +2293,101 @@ def call_operator(
22902293
)
22912294

22922295

2296+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
2297+
class ReplaceMatmulWithTransposedMatmulPass(ExportPass):
2298+
"""
2299+
For certain backends, we have efficient kernels for transposed matmul. We
2300+
replace AxB with AxB' for such backends.
2301+
"""
2302+
2303+
def call_operator(self, op, args, kwargs, meta):
2304+
if op != exir_ops.edge.cadence.quantized_matmul.default or args[-1] is True:
2305+
return super().call_operator(op, args, kwargs, meta)
2306+
2307+
# Get the args
2308+
if len(args) == 9:
2309+
(
2310+
X_arg,
2311+
X_zero_point,
2312+
Y_arg,
2313+
Y_zero_point,
2314+
bias,
2315+
out_multiplier,
2316+
out_shift,
2317+
out_zero_point,
2318+
transposed,
2319+
) = args
2320+
elif len(args) == 8:
2321+
(
2322+
X_arg,
2323+
X_zero_point,
2324+
Y_arg,
2325+
Y_zero_point,
2326+
bias,
2327+
out_multiplier,
2328+
out_shift,
2329+
out_zero_point,
2330+
) = args
2331+
transposed = False
2332+
else:
2333+
raise AssertionError(
2334+
f"Unexpected number of args for quantized_matmul: {len(args)}"
2335+
)
2336+
2337+
# If the matmul is already transposed, bail
2338+
if transposed:
2339+
return super().call_operator(op, args, kwargs, meta)
2340+
2341+
# Get the second tensor
2342+
Y_tensor = Y_arg.to_tensor() if isinstance(Y_arg, ProxyValue) else Y_arg
2343+
# Concretize the bias
2344+
zero_bias = super().call_operator(
2345+
exir_ops.edge.aten.full.default,
2346+
([Y_tensor.size(-1)], 0),
2347+
{"dtype": torch.int32},
2348+
meta,
2349+
)
2350+
2351+
# If the arg was a ProxyValue, insert a transpose node. Otherwise we
2352+
# can simply transpose the tensor inplace.
2353+
if isinstance(Y_arg, ProxyValue):
2354+
transpose_args = (Y_arg, -1, -2)
2355+
transpose_node = super().call_operator(
2356+
exir_ops.edge.aten.transpose_copy.int,
2357+
transpose_args,
2358+
{},
2359+
meta,
2360+
)
2361+
Y_arg_t = transpose_node
2362+
else:
2363+
Y_arg_t = Y_tensor.transpose(-1, -2)
2364+
2365+
# Construct the new args, and return the transposed matmult op
2366+
new_args = (
2367+
X_arg,
2368+
X_zero_point,
2369+
Y_arg_t,
2370+
Y_zero_point,
2371+
zero_bias,
2372+
out_multiplier,
2373+
out_shift,
2374+
out_zero_point,
2375+
True,
2376+
)
2377+
return super().call_operator(op, new_args, kwargs, meta)
2378+
2379+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
2380+
result = super().call(graph_module)
2381+
# Fuse any inserted transpose node with transpose/permute nodes
2382+
# surrounding it.
2383+
result = FuseCascadedTransposeOrPermuteOps()(result.graph_module)
2384+
assert result is not None
2385+
# Replace permute with transpose.
2386+
result = ReplacePermuteWithTransposePass()(result.graph_module)
2387+
assert result is not None
2388+
return result
2389+
2390+
22932391
# This class encapsulates all the functions that replace/switch one op in the
22942392
# graph with another.
22952393
class CadenceReplaceOpsInGraph:
@@ -2317,6 +2415,7 @@ class CadenceReplaceOpsInGraph:
23172415
# This pass should be after passes that replace conv -> im2row + linear.
23182416
ReplaceIm2RowWithViewPass,
23192417
MakeSliceAndCatDimOutermostPass,
2418+
ReplaceMatmulWithTransposedMatmulPass,
23202419
ReplaceNopTransposeOrPermuteWithViewPass,
23212420
ReplaceLinearWithFullyConnectedOpPass,
23222421
ReplaceScalarTensorWithFullPass,

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from executorch.backends.cadence.aot.compiler import (
1717
export_to_edge,
1818
quantize_and_export_to_edge,
19+
quantize_pt2,
1920
)
2021
from executorch.backends.cadence.aot.graph_builder import (
2122
GraphBuilder,
@@ -35,6 +36,7 @@
3536
ReplaceGeluWithApproximateGeluPass,
3637
ReplaceIm2RowWithViewPass,
3738
ReplaceLinearWithFullyConnectedOpPass,
39+
ReplaceMatmulWithTransposedMatmulPass,
3840
ReplaceMMWithAddMMPass,
3941
ReplaceNopTransposeOrPermuteWithViewPass,
4042
ReplacePadWithCatPass,
@@ -85,6 +87,50 @@ def assertTargetCountsEqual(
8587
for target, expected_count in targets_and_counts:
8688
self.assertTargetCountEqual(graph_module, target, expected_count)
8789

90+
@parameterized.expand(
91+
[
92+
# Regular MM
93+
[(64, 33), (33, 128)],
94+
# Batched MM
95+
[(2, 48, 48), (2, 48, 48)],
96+
]
97+
)
98+
@torch.no_grad()
99+
def test_replace_matmul_with_transposed_matmul(
100+
self,
101+
x_shape: Tuple[int],
102+
y_shape: Tuple[int],
103+
) -> None:
104+
class MatMul(torch.nn.Module):
105+
def __init__(self) -> None:
106+
super(MatMul, self).__init__()
107+
108+
def forward(self, x, y):
109+
return torch.matmul(x, y)
110+
111+
model = MatMul()
112+
X = torch.randn(x_shape)
113+
Y = torch.randn(y_shape)
114+
p = ReplaceMatmulWithTransposedMatmulPass()
115+
inputs = (X, Y)
116+
quantized_model = quantize_pt2(model, inputs)
117+
graph_module = (
118+
export_to_edge(quantized_model, inputs).exported_program().graph_module
119+
)
120+
# pyre-fixme[16]: Optional type has no attribute `graph_module`
121+
graph_after_passes = p(graph_module).graph_module
122+
123+
self.assertEqual(
124+
count_node(graph_after_passes, exir_ops.edge.aten.transpose_copy.int),
125+
1,
126+
)
127+
self.assertEqual(
128+
count_node(
129+
graph_after_passes, exir_ops.edge.cadence.quantized_matmul.default
130+
),
131+
1,
132+
)
133+
88134
@parameterized.expand(
89135
[
90136
[(3, 5), (0, 0)],

0 commit comments

Comments
 (0)