@@ -58,77 +58,87 @@ struct TestLinalgElementwiseFusion
58
58
return " Test Linalg element wise operation fusion patterns" ;
59
59
}
60
60
61
- Option<bool >
62
- fuseGenericOps (*this , " fuse-generic-ops" ,
63
- llvm::cl::desc (" Test fusion of generic operations." ),
64
- llvm::cl::init(false ));
65
-
66
- Option<bool > controlFuseByExpansion (
67
- *this , " control-fusion-by-expansion" ,
68
- llvm::cl::desc (
69
- " Test controlling fusion of reshape with generic op by expansion" ),
70
- llvm::cl::init(false ));
71
-
72
- Option<bool >
73
- pushExpandingReshape (*this , " push-expanding-reshape" ,
74
- llvm::cl::desc (" Test linalg expand_shape -> generic "
75
- " to generic -> expand_shape pattern" ),
76
- llvm::cl::init(false ));
77
-
78
61
void runOnOperation () override {
79
62
MLIRContext *context = &this ->getContext ();
80
63
FuncOp funcOp = this ->getOperation ();
64
+ RewritePatternSet fusionPatterns (context);
65
+
66
+ linalg::populateElementwiseOpsFusionPatterns (
67
+ fusionPatterns,
68
+ linalg::LinalgElementwiseFusionOptions ()
69
+ .setControlElementwiseOpsFusionFn (setFusedOpOperandLimit<4 >));
81
70
82
- if (fuseGenericOps) {
83
- RewritePatternSet fusionPatterns (context);
84
- linalg::populateElementwiseOpsFusionPatterns (
85
- fusionPatterns,
86
- linalg::LinalgElementwiseFusionOptions ()
87
- .setControlElementwiseOpsFusionFn (setFusedOpOperandLimit<4 >));
88
-
89
- (void )applyPatternsAndFoldGreedily (funcOp.getBody (),
90
- std::move (fusionPatterns));
91
- return ;
92
- }
93
-
94
- if (controlFuseByExpansion) {
95
- RewritePatternSet fusionPatterns (context);
96
-
97
- linalg::ControlElementwiseOpsFusionFn controlReshapeFusionFn =
98
- [](const OpResult &producer, OpOperand &consumer) {
99
- if (auto collapseOp =
100
- producer.getDefiningOp <tensor::CollapseShapeOp>()) {
101
- if (!collapseOp.src ().getDefiningOp <linalg::LinalgOp>()) {
102
- return false ;
103
- }
71
+ (void )applyPatternsAndFoldGreedily (funcOp.getBody (),
72
+ std::move (fusionPatterns));
73
+ }
74
+ };
75
+
76
+ struct TestLinalgControlFuseByExpansion
77
+ : public PassWrapper<TestLinalgControlFuseByExpansion,
78
+ OperationPass<FuncOp>> {
79
+ void getDependentDialects (DialectRegistry ®istry) const override {
80
+ registry
81
+ .insert <AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>();
82
+ }
83
+ StringRef getArgument () const final {
84
+ return " test-linalg-control-fusion-by-expansion" ;
85
+ }
86
+ StringRef getDescription () const final {
87
+ return " Test controlling of fusion of elementwise ops with reshape by "
88
+ " expansion" ;
89
+ }
90
+
91
+ void runOnOperation () override {
92
+ MLIRContext *context = &this ->getContext ();
93
+ FuncOp funcOp = this ->getOperation ();
94
+ RewritePatternSet fusionPatterns (context);
95
+
96
+ linalg::ControlElementwiseOpsFusionFn controlReshapeFusionFn =
97
+ [](const OpResult &producer, OpOperand &consumer) {
98
+ if (auto collapseOp =
99
+ producer.getDefiningOp <tensor::CollapseShapeOp>()) {
100
+ if (!collapseOp.src ().getDefiningOp <linalg::LinalgOp>()) {
101
+ return false ;
104
102
}
105
- if ( auto expandOp =
106
- dyn_cast<tensor::ExpandShapeOp>(consumer. getOwner ())) {
107
- if (expandOp-> hasOneUse ( )) {
108
- OpOperand &use = * expandOp->getUses (). begin ();
109
- auto linalgOp = dyn_cast<linalg::LinalgOp>(use. getOwner () );
110
- if ( linalgOp && linalgOp. isOutputTensor (& use))
111
- return true ;
112
- }
103
+ }
104
+ if ( auto expandOp =
105
+ dyn_cast<tensor::ExpandShapeOp>(consumer. getOwner () )) {
106
+ if ( expandOp->hasOneUse ()) {
107
+ OpOperand &use = *expandOp-> getUses (). begin ( );
108
+ auto linalgOp = dyn_cast<linalg::LinalgOp>( use. getOwner ());
109
+ if (linalgOp && linalgOp. isOutputTensor (&use))
110
+ return true ;
113
111
}
114
- return linalg::skipUnitDimReshape (producer, consumer);
115
- };
116
-
117
- linalg::populateFoldReshapeOpsByExpansionPatterns (fusionPatterns,
118
- controlReshapeFusionFn);
119
- (void )applyPatternsAndFoldGreedily (funcOp.getBody (),
120
- std::move (fusionPatterns));
121
- return ;
122
- }
123
-
124
- if (pushExpandingReshape) {
125
- RewritePatternSet patterns (context);
126
- linalg::populatePushReshapeOpsPatterns (patterns);
127
- (void )applyPatternsAndFoldGreedily (funcOp.getBody (), std::move (patterns));
128
- }
112
+ }
113
+ return linalg::skipUnitDimReshape (producer, consumer);
114
+ };
115
+
116
+ linalg::populateFoldReshapeOpsByExpansionPatterns (fusionPatterns,
117
+ controlReshapeFusionFn);
118
+ (void )applyPatternsAndFoldGreedily (funcOp.getBody (),
119
+ std::move (fusionPatterns));
129
120
}
130
121
};
131
122
123
+ struct TestPushExpandingReshape
124
+ : public PassWrapper<TestPushExpandingReshape, OperationPass<FuncOp>> {
125
+ void getDependentDialects (DialectRegistry ®istry) const override {
126
+ registry
127
+ .insert <AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>();
128
+ }
129
+ StringRef getArgument () const final { return " test-linalg-push-reshape" ; }
130
+ StringRef getDescription () const final {
131
+ return " Test Linalg reshape push patterns" ;
132
+ }
133
+
134
+ void runOnOperation () override {
135
+ MLIRContext *context = &this ->getContext ();
136
+ FuncOp funcOp = this ->getOperation ();
137
+ RewritePatternSet patterns (context);
138
+ linalg::populatePushReshapeOpsPatterns (patterns);
139
+ (void )applyPatternsAndFoldGreedily (funcOp.getBody (), std::move (patterns));
140
+ }
141
+ };
132
142
} // namespace
133
143
134
144
namespace test {
0 commit comments