Skip to content

Commit 7568f71

Browse files
author
Mahesh Ravishankar
committed
Revert "[mlir][Linalg] NFC: Combine elementwise fusion test passes."
This reverts commit d730336.
1 parent 157bbe6 commit 7568f71

File tree

5 files changed

+79
-65
lines changed

5 files changed

+79
-65
lines changed

mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-generic-ops -split-input-file | FileCheck %s
1+
// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns -split-input-file | FileCheck %s
22

33
#map0 = affine_map<(d0, d1) -> (d0, d1)>
44
#binary2Dpointwise = {

mlir/test/Dialect/Linalg/fusion-push-reshape.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=push-expanding-reshape -split-input-file | FileCheck %s
1+
// RUN: mlir-opt %s -test-linalg-push-reshape -split-input-file | FileCheck %s
22

33
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
44
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d1)>

mlir/test/Dialect/Linalg/reshape_control_fusion.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt -test-linalg-elementwise-fusion-patterns=control-fusion-by-expansion %s -split-input-file | FileCheck %s
1+
// RUN: mlir-opt -test-linalg-control-fusion-by-expansion %s -split-input-file | FileCheck %s
22

33
func @control_producer_reshape_fusion(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?xf32>) -> tensor<?x?xf32> {
44
%c0 = arith.constant 0 : index

mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp

Lines changed: 72 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -58,77 +58,87 @@ struct TestLinalgElementwiseFusion
5858
return "Test Linalg element wise operation fusion patterns";
5959
}
6060

61-
Option<bool>
62-
fuseGenericOps(*this, "fuse-generic-ops",
63-
llvm::cl::desc("Test fusion of generic operations."),
64-
llvm::cl::init(false));
65-
66-
Option<bool> controlFuseByExpansion(
67-
*this, "control-fusion-by-expansion",
68-
llvm::cl::desc(
69-
"Test controlling fusion of reshape with generic op by expansion"),
70-
llvm::cl::init(false));
71-
72-
Option<bool>
73-
pushExpandingReshape(*this, "push-expanding-reshape",
74-
llvm::cl::desc("Test linalg expand_shape -> generic "
75-
"to generic -> expand_shape pattern"),
76-
llvm::cl::init(false));
77-
7861
void runOnOperation() override {
7962
MLIRContext *context = &this->getContext();
8063
FuncOp funcOp = this->getOperation();
64+
RewritePatternSet fusionPatterns(context);
65+
66+
linalg::populateElementwiseOpsFusionPatterns(
67+
fusionPatterns,
68+
linalg::LinalgElementwiseFusionOptions()
69+
.setControlElementwiseOpsFusionFn(setFusedOpOperandLimit<4>));
8170

82-
if (fuseGenericOps) {
83-
RewritePatternSet fusionPatterns(context);
84-
linalg::populateElementwiseOpsFusionPatterns(
85-
fusionPatterns,
86-
linalg::LinalgElementwiseFusionOptions()
87-
.setControlElementwiseOpsFusionFn(setFusedOpOperandLimit<4>));
88-
89-
(void)applyPatternsAndFoldGreedily(funcOp.getBody(),
90-
std::move(fusionPatterns));
91-
return;
92-
}
93-
94-
if (controlFuseByExpansion) {
95-
RewritePatternSet fusionPatterns(context);
96-
97-
linalg::ControlElementwiseOpsFusionFn controlReshapeFusionFn =
98-
[](const OpResult &producer, OpOperand &consumer) {
99-
if (auto collapseOp =
100-
producer.getDefiningOp<tensor::CollapseShapeOp>()) {
101-
if (!collapseOp.src().getDefiningOp<linalg::LinalgOp>()) {
102-
return false;
103-
}
71+
(void)applyPatternsAndFoldGreedily(funcOp.getBody(),
72+
std::move(fusionPatterns));
73+
}
74+
};
75+
76+
struct TestLinalgControlFuseByExpansion
77+
: public PassWrapper<TestLinalgControlFuseByExpansion,
78+
OperationPass<FuncOp>> {
79+
void getDependentDialects(DialectRegistry &registry) const override {
80+
registry
81+
.insert<AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>();
82+
}
83+
StringRef getArgument() const final {
84+
return "test-linalg-control-fusion-by-expansion";
85+
}
86+
StringRef getDescription() const final {
87+
return "Test controlling of fusion of elementwise ops with reshape by "
88+
"expansion";
89+
}
90+
91+
void runOnOperation() override {
92+
MLIRContext *context = &this->getContext();
93+
FuncOp funcOp = this->getOperation();
94+
RewritePatternSet fusionPatterns(context);
95+
96+
linalg::ControlElementwiseOpsFusionFn controlReshapeFusionFn =
97+
[](const OpResult &producer, OpOperand &consumer) {
98+
if (auto collapseOp =
99+
producer.getDefiningOp<tensor::CollapseShapeOp>()) {
100+
if (!collapseOp.src().getDefiningOp<linalg::LinalgOp>()) {
101+
return false;
104102
}
105-
if (auto expandOp =
106-
dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) {
107-
if (expandOp->hasOneUse()) {
108-
OpOperand &use = *expandOp->getUses().begin();
109-
auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner());
110-
if (linalgOp && linalgOp.isOutputTensor(&use))
111-
return true;
112-
}
103+
}
104+
if (auto expandOp =
105+
dyn_cast<tensor::ExpandShapeOp>(consumer.getOwner())) {
106+
if (expandOp->hasOneUse()) {
107+
OpOperand &use = *expandOp->getUses().begin();
108+
auto linalgOp = dyn_cast<linalg::LinalgOp>(use.getOwner());
109+
if (linalgOp && linalgOp.isOutputTensor(&use))
110+
return true;
113111
}
114-
return linalg::skipUnitDimReshape(producer, consumer);
115-
};
116-
117-
linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns,
118-
controlReshapeFusionFn);
119-
(void)applyPatternsAndFoldGreedily(funcOp.getBody(),
120-
std::move(fusionPatterns));
121-
return;
122-
}
123-
124-
if (pushExpandingReshape) {
125-
RewritePatternSet patterns(context);
126-
linalg::populatePushReshapeOpsPatterns(patterns);
127-
(void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
128-
}
112+
}
113+
return linalg::skipUnitDimReshape(producer, consumer);
114+
};
115+
116+
linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns,
117+
controlReshapeFusionFn);
118+
(void)applyPatternsAndFoldGreedily(funcOp.getBody(),
119+
std::move(fusionPatterns));
129120
}
130121
};
131122

123+
struct TestPushExpandingReshape
124+
: public PassWrapper<TestPushExpandingReshape, OperationPass<FuncOp>> {
125+
void getDependentDialects(DialectRegistry &registry) const override {
126+
registry
127+
.insert<AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>();
128+
}
129+
StringRef getArgument() const final { return "test-linalg-push-reshape"; }
130+
StringRef getDescription() const final {
131+
return "Test Linalg reshape push patterns";
132+
}
133+
134+
void runOnOperation() override {
135+
MLIRContext *context = &this->getContext();
136+
FuncOp funcOp = this->getOperation();
137+
RewritePatternSet patterns(context);
138+
linalg::populatePushReshapeOpsPatterns(patterns);
139+
(void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
140+
}
141+
};
132142
} // namespace
133143

134144
namespace test {

mlir/tools/mlir-opt/mlir-opt.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,10 @@ void registerTestGenericIRVisitorsPass();
8181
void registerTestGenericIRVisitorsInterruptPass();
8282
void registerTestInterfaces();
8383
void registerTestLinalgCodegenStrategy();
84+
void registerTestLinalgControlFuseByExpansion();
8485
void registerTestLinalgDistribution();
8586
void registerTestLinalgElementwiseFusion();
87+
void registerTestPushExpandingReshape();
8688
void registerTestLinalgFusionTransforms();
8789
void registerTestLinalgTensorFusionTransforms();
8890
void registerTestLinalgTiledLoopFusionTransforms();
@@ -170,8 +172,10 @@ void registerTestPasses() {
170172
mlir::test::registerTestGenericIRVisitorsPass();
171173
mlir::test::registerTestInterfaces();
172174
mlir::test::registerTestLinalgCodegenStrategy();
175+
mlir::test::registerTestLinalgControlFuseByExpansion();
173176
mlir::test::registerTestLinalgDistribution();
174177
mlir::test::registerTestLinalgElementwiseFusion();
178+
mlir::test::registerTestPushExpandingReshape();
175179
mlir::test::registerTestLinalgFusionTransforms();
176180
mlir::test::registerTestLinalgTensorFusionTransforms();
177181
mlir::test::registerTestLinalgTiledLoopFusionTransforms();

0 commit comments

Comments
 (0)