Skip to content

Commit d730336

Browse files
author
Mahesh Ravishankar
committed
[mlir][Linalg] NFC: Combine elementwise fusion test passes.
There are a few different test passes that check elementwise fusion in Linalg. Consolidate them to a single pass controlled by different pass options (in keeping with how `TestLinalgTransforms` exists).
1 parent bf02586 commit d730336

File tree

5 files changed

+65
-79
lines changed

5 files changed

+65
-79
lines changed

mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns -split-input-file | FileCheck %s
1+
// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-generic-ops -split-input-file | FileCheck %s
22

33
#map0 = affine_map<(d0, d1) -> (d0, d1)>
44
#binary2Dpointwise = {

mlir/test/Dialect/Linalg/fusion-push-reshape.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -test-linalg-push-reshape -split-input-file | FileCheck %s
1+
// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=push-expanding-reshape -split-input-file | FileCheck %s
22

33
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
44
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d1)>

mlir/test/Dialect/Linalg/reshape_control_fusion.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt -test-linalg-control-fusion-by-expansion %s -split-input-file | FileCheck %s
1+
// RUN: mlir-opt -test-linalg-elementwise-fusion-patterns=control-fusion-by-expansion %s -split-input-file | FileCheck %s
22

33
func @control_producer_reshape_fusion(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?xf32>) -> tensor<?x?xf32> {
44
%c0 = arith.constant 0 : index

mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp

Lines changed: 62 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -58,87 +58,77 @@ struct TestLinalgElementwiseFusion
5858
return "Test Linalg element wise operation fusion patterns";
5959
}
6060

61-
void runOnOperation() override {
62-
MLIRContext *context = &this->getContext();
63-
FuncOp funcOp = this->getOperation();
64-
RewritePatternSet fusionPatterns(context);
65-
66-
linalg::populateElementwiseOpsFusionPatterns(
67-
fusionPatterns,
68-
linalg::LinalgElementwiseFusionOptions()
69-
.setControlElementwiseOpsFusionFn(setFusedOpOperandLimit<4>));
70-
71-
(void)applyPatternsAndFoldGreedily(funcOp.getBody(),
72-
std::move(fusionPatterns));
73-
}
74-
};
75-
76-
struct TestLinalgControlFuseByExpansion
77-
: public PassWrapper<TestLinalgControlFuseByExpansion,
78-
OperationPass<FuncOp>> {
79-
void getDependentDialects(DialectRegistry &registry) const override {
80-
registry
81-
.insert<AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>();
82-
}
83-
StringRef getArgument() const final {
84-
return "test-linalg-control-fusion-by-expansion";
85-
}
86-
StringRef getDescription() const final {
87-
return "Test controlling of fusion of elementwise ops with reshape by "
88-
"expansion";
89-
}
61+
Option<bool>
62+
fuseGenericOps(*this, "fuse-generic-ops",
63+
llvm::cl::desc("Test fusion of generic operations."),
64+
llvm::cl::init(false));
65+
66+
Option<bool> controlFuseByExpansion(
67+
*this, "control-fusion-by-expansion",
68+
llvm::cl::desc(
69+
"Test controlling fusion of reshape with generic op by expansion"),
70+
llvm::cl::init(false));
71+
72+
Option<bool>
73+
pushExpandingReshape(*this, "push-expanding-reshape",
74+
llvm::cl::desc("Test linalg expand_shape -> generic "
75+
"to generic -> expand_shape pattern"),
76+
llvm::cl::init(false));
9077

9178
void runOnOperation() override {
9279
MLIRContext *context = &this->getContext();
9380
FuncOp funcOp = this->getOperation();
94-
RewritePatternSet fusionPatterns(context);
95-
96-
linalg::ControlElementwiseOpsFusionFn controlReshapeFusionFn =
97-
[](const OpResult &producer, OpOperand &consumer) {
98-
if (auto collapseOp =
99-
producer.getDefiningOp<tensor::CollapseShapeOp>()) {
100-
if (!collapseOp.src().getDefiningOp<linalg::LinalgOp>()) {
101-
return false;
81+
82+
if (fuseGenericOps) {
83+
RewritePatternSet fusionPatterns(context);
84+
linalg::populateElementwiseOpsFusionPatterns(
85+
fusionPatterns,
86+
linalg::LinalgElementwiseFusionOptions()
87+
.setControlElementwiseOpsFusionFn(setFusedOpOperandLimit<4>));
88+
89+
(void)applyPatternsAndFoldGreedily(funcOp.getBody(),
90+
std::move(fusionPatterns));
91+
return;
92+
}
93+
94+
if (controlFuseByExpansion) {
95+
RewritePatternSet fusionPatterns(context);
96+
97+
linalg::ControlElementwiseOpsFusionFn controlReshapeFusionFn =
98+
[](const OpResult &producer, OpOperand &consumer) {
99+
if (auto collapseOp =
100+
producer.getDefiningOp<tensor::CollapseShapeOp>()) {
101+
if (!collapseOp.src().getDefiningOp<linalg::LinalgOp>()) {
102+
return false;
103+
}
102104
}
103-
}
104-
if (auto expandOp =
105-
dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) {
106-
if (expandOp->hasOneUse()) {
107-
OpOperand &use = *expandOp->getUses().begin();
108-
auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner());
109-
if (linalgOp && linalgOp.isOutputTensor(&use))
110-
return true;
105+
if (auto expandOp =
106+
dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) {
107+
if (expandOp->hasOneUse()) {
108+
OpOperand &use = *expandOp->getUses().begin();
109+
auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner());
110+
if (linalgOp && linalgOp.isOutputTensor(&use))
111+
return true;
112+
}
111113
}
112-
}
113-
return linalg::skipUnitDimReshape(producer, consumer);
114-
};
115-
116-
linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns,
117-
controlReshapeFusionFn);
118-
(void)applyPatternsAndFoldGreedily(funcOp.getBody(),
119-
std::move(fusionPatterns));
114+
return linalg::skipUnitDimReshape(producer, consumer);
115+
};
116+
117+
linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns,
118+
controlReshapeFusionFn);
119+
(void)applyPatternsAndFoldGreedily(funcOp.getBody(),
120+
std::move(fusionPatterns));
121+
return;
122+
}
123+
124+
if (pushExpandingReshape) {
125+
RewritePatternSet patterns(context);
126+
linalg::populatePushReshapeOpsPatterns(patterns);
127+
(void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
128+
}
120129
}
121130
};
122131

123-
struct TestPushExpandingReshape
124-
: public PassWrapper<TestPushExpandingReshape, OperationPass<FuncOp>> {
125-
void getDependentDialects(DialectRegistry &registry) const override {
126-
registry
127-
.insert<AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>();
128-
}
129-
StringRef getArgument() const final { return "test-linalg-push-reshape"; }
130-
StringRef getDescription() const final {
131-
return "Test Linalg reshape push patterns";
132-
}
133-
134-
void runOnOperation() override {
135-
MLIRContext *context = &this->getContext();
136-
FuncOp funcOp = this->getOperation();
137-
RewritePatternSet patterns(context);
138-
linalg::populatePushReshapeOpsPatterns(patterns);
139-
(void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
140-
}
141-
};
142132
} // namespace
143133

144134
namespace test {

mlir/tools/mlir-opt/mlir-opt.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,8 @@ void registerTestGenericIRVisitorsPass();
8181
void registerTestGenericIRVisitorsInterruptPass();
8282
void registerTestInterfaces();
8383
void registerTestLinalgCodegenStrategy();
84-
void registerTestLinalgControlFuseByExpansion();
8584
void registerTestLinalgDistribution();
8685
void registerTestLinalgElementwiseFusion();
87-
void registerTestPushExpandingReshape();
8886
void registerTestLinalgFusionTransforms();
8987
void registerTestLinalgTensorFusionTransforms();
9088
void registerTestLinalgTiledLoopFusionTransforms();
@@ -172,10 +170,8 @@ void registerTestPasses() {
172170
mlir::test::registerTestGenericIRVisitorsPass();
173171
mlir::test::registerTestInterfaces();
174172
mlir::test::registerTestLinalgCodegenStrategy();
175-
mlir::test::registerTestLinalgControlFuseByExpansion();
176173
mlir::test::registerTestLinalgDistribution();
177174
mlir::test::registerTestLinalgElementwiseFusion();
178-
mlir::test::registerTestPushExpandingReshape();
179175
mlir::test::registerTestLinalgFusionTransforms();
180176
mlir::test::registerTestLinalgTensorFusionTransforms();
181177
mlir::test::registerTestLinalgTiledLoopFusionTransforms();

0 commit comments

Comments
 (0)