Skip to content

Commit 811b046

Browse files
[TOSA] Add TosaLayerwiseConstantFoldPass and TosaReduceTransposes passes (llvm#4165)
Add the following passes to TorchBackendToTosaBackendPipeline: - TosaLayerwiseConstantFoldPass: fold full-layer operations on TOSA consts - TosaReduceTransposes: remove unnecessary TOSA transposes to reduce data movements Signed-off-by: Justin Ngo <justin.ngo@arm.com>
1 parent b66526a commit 811b046

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

lib/Dialect/TorchConversion/Transforms/Passes.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,12 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline(
117117
const TorchConversion::TosaBackendPipelineOptions &options) {
118118
pm.addNestedPass<func::FuncOp>(
119119
createConvertTorchToTosaPass(options.requireFullTosaConversion));
120+
// Fold full-layer operations on TOSA constants
121+
pm.addNestedPass<func::FuncOp>(createTosaLayerwiseConstantFoldPass());
122+
123+
// Perform transpose reductions for avoidable data movements
124+
pm.addNestedPass<func::FuncOp>(createTosaReduceTransposes());
125+
120126
// Perform rank broadcasting so TosaToLinalg pass works
121127
pm.addNestedPass<func::FuncOp>(createTosaMakeBroadcastablePass());
122128

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1731,6 +1731,15 @@
17311731
"HBC_basic",
17321732
# 1D inputs cause generated tosa.negate ops to crash downstream
17331733
"NllLossModule_1D_basic",
1734+
# BertModule is not crashing, but is timing out due to TosaLayerwiseConstantFoldPass:
1735+
# Exception ignored on calling ctypes callback function: <function RefBackendInvoker.__init__.<locals>.consume_return_funcs at 0x765783f12c20>
1736+
# Traceback (most recent call last):
1737+
# File "torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py", line 101, in consume_return_funcs
1738+
# def consume_return_funcs(*args):
1739+
# File "torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/framework.py", line 316, in handle_timeout
1740+
# raise TimeoutError(self.error_message)
1741+
# TimeoutError: Timeout
1742+
"BertModule_basic",
17341743
}
17351744

17361745
# Write the TOSA set as a "passing" set as it is very early in development

0 commit comments

Comments
 (0)