Skip to content

Commit 40615b3

Browse files
authored
[NNPA] Call ScrubDisposablePass before ZHighConstPropagation (#2583) (#2588)
Cherry pick of #2583 to previous release
1 parent ca5af9a commit 40615b3

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

src/Accelerators/NNPA/Compiler/NNPACompilerUtils.cpp

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ void addONNXToZHighPasses(
7979
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
8080
// There are more opportunities for const propagation once all zhigh ops were
8181
// generated.
82+
pm.addNestedPass<func::FuncOp>(onnx_mlir::createDecomposeONNXToONNXPass());
8283
pm.addNestedPass<func::FuncOp>(onnx_mlir::createConstPropONNXToONNXPass());
8384
pm.addPass(mlir::createCanonicalizerPass());
8485
// Layout propagation at ZHighIR.
@@ -103,13 +104,6 @@ void addONNXToZHighPasses(
103104
if (nnpaEnableZHighToOnnx)
104105
pm.addNestedPass<func::FuncOp>(onnx_mlir::createZHighToONNXPass());
105106

106-
// Constant propagation at ZHighIR: constant stickify.
107-
// Only support BE machines.
108-
bool isBE = llvm::support::endian::system_endianness() ==
109-
llvm::support::endianness::big;
110-
if (isBE)
111-
pm.addNestedPass<func::FuncOp>(
112-
onnx_mlir::zhigh::createZHighConstPropagationPass());
113107
// One more call to ONNX shape inference/canonicalization/... to update shape
114108
// if possible.
115109
if (enableONNXHybridPass) {
@@ -121,6 +115,19 @@ void addONNXToZHighPasses(
121115
pm.addPass(mlir::createCanonicalizerPass());
122116
pm.addNestedPass<func::FuncOp>(onnx_mlir::createShapeInferencePass());
123117
}
118+
119+
// Replace every DisposableElementsAttr with DenseElementsAttr.
120+
// ZHighConstPropagation currently assumes that DenseElementsAttr is used.
121+
pm.addPass(createScrubDisposablePass());
122+
123+
// Constant propagation at ZHighIR: constant stickify.
124+
// Only support BE machines.
125+
bool isBE = llvm::support::endian::system_endianness() ==
126+
llvm::support::endianness::big;
127+
if (isBE)
128+
pm.addNestedPass<func::FuncOp>(
129+
onnx_mlir::zhigh::createZHighConstPropagationPass());
130+
124131
// Remove common sub-expressions.
125132
pm.addPass(mlir::createCSEPass());
126133

0 commit comments

Comments
 (0)