|
32 | 32 | is_quantized_tensor,
|
33 | 33 | quantize_tensor_multiplier,
|
34 | 34 | )
|
35 |
| -from executorch.backends.cadence.aot.fuse_ops import FuseCascadedViewOps |
| 35 | +from executorch.backends.cadence.aot.fuse_ops import ( |
| 36 | + FuseCascadedTransposeOrPermuteOps, |
| 37 | + FuseCascadedViewOps, |
| 38 | +) |
36 | 39 | from executorch.backends.cadence.aot.pass_utils import (
|
37 | 40 | CadencePassAttribute,
|
38 | 41 | register_cadence_pass,
|
@@ -2290,6 +2293,101 @@ def call_operator(
|
2290 | 2293 | )
|
2291 | 2294 |
|
2292 | 2295 |
|
| 2296 | +@register_cadence_pass(CadencePassAttribute(opt_level=0)) |
| 2297 | +class ReplaceMatmulWithTransposedMatmulPass(ExportPass): |
| 2298 | + """ |
| 2299 | + For certain backends, we have efficient kernels for transposed matmul. We |
| 2300 | + replace AxB with AxB' for such backends. |
| 2301 | + """ |
| 2302 | + |
| 2303 | + def call_operator(self, op, args, kwargs, meta): |
| 2304 | + if op != exir_ops.edge.cadence.quantized_matmul.default or args[-1] is True: |
| 2305 | + return super().call_operator(op, args, kwargs, meta) |
| 2306 | + |
| 2307 | + # Get the args |
| 2308 | + if len(args) == 9: |
| 2309 | + ( |
| 2310 | + X_arg, |
| 2311 | + X_zero_point, |
| 2312 | + Y_arg, |
| 2313 | + Y_zero_point, |
| 2314 | + bias, |
| 2315 | + out_multiplier, |
| 2316 | + out_shift, |
| 2317 | + out_zero_point, |
| 2318 | + transposed, |
| 2319 | + ) = args |
| 2320 | + elif len(args) == 8: |
| 2321 | + ( |
| 2322 | + X_arg, |
| 2323 | + X_zero_point, |
| 2324 | + Y_arg, |
| 2325 | + Y_zero_point, |
| 2326 | + bias, |
| 2327 | + out_multiplier, |
| 2328 | + out_shift, |
| 2329 | + out_zero_point, |
| 2330 | + ) = args |
| 2331 | + transposed = False |
| 2332 | + else: |
| 2333 | + raise AssertionError( |
| 2334 | + f"Unexpected number of args for quantized_matmul: {len(args)}" |
| 2335 | + ) |
| 2336 | + |
| 2337 | + # If the matmul is already transposed, bail |
| 2338 | + if transposed: |
| 2339 | + return super().call_operator(op, args, kwargs, meta) |
| 2340 | + |
| 2341 | + # Get the second tensor |
| 2342 | + Y_tensor = Y_arg.to_tensor() if isinstance(Y_arg, ProxyValue) else Y_arg |
| 2343 | + # Concretize the bias |
| 2344 | + zero_bias = super().call_operator( |
| 2345 | + exir_ops.edge.aten.full.default, |
| 2346 | + ([Y_tensor.size(-1)], 0), |
| 2347 | + {"dtype": torch.int32}, |
| 2348 | + meta, |
| 2349 | + ) |
| 2350 | + |
| 2351 | + # If the arg was a ProxyValue, insert a transpose node. Otherwise we |
| 2352 | + # can simply transpose the tensor inplace. |
| 2353 | + if isinstance(Y_arg, ProxyValue): |
| 2354 | + transpose_args = (Y_arg, -1, -2) |
| 2355 | + transpose_node = super().call_operator( |
| 2356 | + exir_ops.edge.aten.transpose_copy.int, |
| 2357 | + transpose_args, |
| 2358 | + {}, |
| 2359 | + meta, |
| 2360 | + ) |
| 2361 | + Y_arg_t = transpose_node |
| 2362 | + else: |
| 2363 | + Y_arg_t = Y_tensor.transpose(-1, -2) |
| 2364 | + |
| 2365 | + # Construct the new args, and return the transposed matmult op |
| 2366 | + new_args = ( |
| 2367 | + X_arg, |
| 2368 | + X_zero_point, |
| 2369 | + Y_arg_t, |
| 2370 | + Y_zero_point, |
| 2371 | + zero_bias, |
| 2372 | + out_multiplier, |
| 2373 | + out_shift, |
| 2374 | + out_zero_point, |
| 2375 | + True, |
| 2376 | + ) |
| 2377 | + return super().call_operator(op, new_args, kwargs, meta) |
| 2378 | + |
| 2379 | + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: |
| 2380 | + result = super().call(graph_module) |
| 2381 | + # Fuse any inserted transpose node with transpose/permute nodes |
| 2382 | + # surrounding it. |
| 2383 | + result = FuseCascadedTransposeOrPermuteOps()(result.graph_module) |
| 2384 | + assert result is not None |
| 2385 | + # Replace permute with transpose. |
| 2386 | + result = ReplacePermuteWithTransposePass()(result.graph_module) |
| 2387 | + assert result is not None |
| 2388 | + return result |
| 2389 | + |
| 2390 | + |
2293 | 2391 | # This class encapsulates all the functions that replace/switch one op in the
|
2294 | 2392 | # graph with another.
|
2295 | 2393 | class CadenceReplaceOpsInGraph:
|
@@ -2317,6 +2415,7 @@ class CadenceReplaceOpsInGraph:
|
2317 | 2415 | # This pass should be after passes that replace conv -> im2row + linear.
|
2318 | 2416 | ReplaceIm2RowWithViewPass,
|
2319 | 2417 | MakeSliceAndCatDimOutermostPass,
|
| 2418 | + ReplaceMatmulWithTransposedMatmulPass, |
2320 | 2419 | ReplaceNopTransposeOrPermuteWithViewPass,
|
2321 | 2420 | ReplaceLinearWithFullyConnectedOpPass,
|
2322 | 2421 | ReplaceScalarTensorWithFullPass,
|
|
0 commit comments