@@ -58,87 +58,77 @@ struct TestLinalgElementwiseFusion
58
58
return " Test Linalg element wise operation fusion patterns" ;
59
59
}
60
60
61
- void runOnOperation () override {
62
- MLIRContext *context = &this ->getContext ();
63
- FuncOp funcOp = this ->getOperation ();
64
- RewritePatternSet fusionPatterns (context);
65
-
66
- linalg::populateElementwiseOpsFusionPatterns (
67
- fusionPatterns,
68
- linalg::LinalgElementwiseFusionOptions ()
69
- .setControlElementwiseOpsFusionFn (setFusedOpOperandLimit<4 >));
70
-
71
- (void )applyPatternsAndFoldGreedily (funcOp.getBody (),
72
- std::move (fusionPatterns));
73
- }
74
- };
75
-
76
- struct TestLinalgControlFuseByExpansion
77
- : public PassWrapper<TestLinalgControlFuseByExpansion,
78
- OperationPass<FuncOp>> {
79
- void getDependentDialects (DialectRegistry ®istry) const override {
80
- registry
81
- .insert <AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>();
82
- }
83
- StringRef getArgument () const final {
84
- return " test-linalg-control-fusion-by-expansion" ;
85
- }
86
- StringRef getDescription () const final {
87
- return " Test controlling of fusion of elementwise ops with reshape by "
88
- " expansion" ;
89
- }
61
+ Option<bool >
62
+ fuseGenericOps (*this , " fuse-generic-ops" ,
63
+ llvm::cl::desc (" Test fusion of generic operations." ),
64
+ llvm::cl::init(false ));
65
+
66
+ Option<bool > controlFuseByExpansion (
67
+ *this , " control-fusion-by-expansion" ,
68
+ llvm::cl::desc (
69
+ " Test controlling fusion of reshape with generic op by expansion" ),
70
+ llvm::cl::init(false ));
71
+
72
+ Option<bool >
73
+ pushExpandingReshape (*this , " push-expanding-reshape" ,
74
+ llvm::cl::desc (" Test linalg expand_shape -> generic "
75
+ " to generic -> expand_shape pattern" ),
76
+ llvm::cl::init(false ));
90
77
91
78
void runOnOperation () override {
92
79
MLIRContext *context = &this ->getContext ();
93
80
FuncOp funcOp = this ->getOperation ();
94
- RewritePatternSet fusionPatterns (context);
95
-
96
- linalg::ControlElementwiseOpsFusionFn controlReshapeFusionFn =
97
- [](const OpResult &producer, OpOperand &consumer) {
98
- if (auto collapseOp =
99
- producer.getDefiningOp <tensor::CollapseShapeOp>()) {
100
- if (!collapseOp.src ().getDefiningOp <linalg::LinalgOp>()) {
101
- return false ;
81
+
82
+ if (fuseGenericOps) {
83
+ RewritePatternSet fusionPatterns (context);
84
+ linalg::populateElementwiseOpsFusionPatterns (
85
+ fusionPatterns,
86
+ linalg::LinalgElementwiseFusionOptions ()
87
+ .setControlElementwiseOpsFusionFn (setFusedOpOperandLimit<4 >));
88
+
89
+ (void )applyPatternsAndFoldGreedily (funcOp.getBody (),
90
+ std::move (fusionPatterns));
91
+ return ;
92
+ }
93
+
94
+ if (controlFuseByExpansion) {
95
+ RewritePatternSet fusionPatterns (context);
96
+
97
+ linalg::ControlElementwiseOpsFusionFn controlReshapeFusionFn =
98
+ [](const OpResult &producer, OpOperand &consumer) {
99
+ if (auto collapseOp =
100
+ producer.getDefiningOp <tensor::CollapseShapeOp>()) {
101
+ if (!collapseOp.src ().getDefiningOp <linalg::LinalgOp>()) {
102
+ return false ;
103
+ }
102
104
}
103
- }
104
- if ( auto expandOp =
105
- dyn_cast<tensor::ExpandShapeOp>(consumer. getOwner () )) {
106
- if ( expandOp->hasOneUse ()) {
107
- OpOperand &use = *expandOp-> getUses (). begin ( );
108
- auto linalgOp = dyn_cast<linalg::LinalgOp>(use. getOwner ());
109
- if (linalgOp && linalgOp. isOutputTensor (&use))
110
- return true ;
105
+ if ( auto expandOp =
106
+ dyn_cast<tensor::ExpandShapeOp>(consumer. getOwner ())) {
107
+ if (expandOp-> hasOneUse ( )) {
108
+ OpOperand &use = * expandOp->getUses (). begin ();
109
+ auto linalgOp = dyn_cast<linalg::LinalgOp>(use. getOwner () );
110
+ if (linalgOp && linalgOp. isOutputTensor (&use))
111
+ return true ;
112
+ }
111
113
}
112
- }
113
- return linalg::skipUnitDimReshape (producer, consumer);
114
- };
115
-
116
- linalg::populateFoldReshapeOpsByExpansionPatterns (fusionPatterns,
117
- controlReshapeFusionFn);
118
- (void )applyPatternsAndFoldGreedily (funcOp.getBody (),
119
- std::move (fusionPatterns));
114
+ return linalg::skipUnitDimReshape (producer, consumer);
115
+ };
116
+
117
+ linalg::populateFoldReshapeOpsByExpansionPatterns (fusionPatterns,
118
+ controlReshapeFusionFn);
119
+ (void )applyPatternsAndFoldGreedily (funcOp.getBody (),
120
+ std::move (fusionPatterns));
121
+ return ;
122
+ }
123
+
124
+ if (pushExpandingReshape) {
125
+ RewritePatternSet patterns (context);
126
+ linalg::populatePushReshapeOpsPatterns (patterns);
127
+ (void )applyPatternsAndFoldGreedily (funcOp.getBody (), std::move (patterns));
128
+ }
120
129
}
121
130
};
122
131
123
- struct TestPushExpandingReshape
124
- : public PassWrapper<TestPushExpandingReshape, OperationPass<FuncOp>> {
125
- void getDependentDialects (DialectRegistry ®istry) const override {
126
- registry
127
- .insert <AffineDialect, linalg::LinalgDialect, tensor::TensorDialect>();
128
- }
129
- StringRef getArgument () const final { return " test-linalg-push-reshape" ; }
130
- StringRef getDescription () const final {
131
- return " Test Linalg reshape push patterns" ;
132
- }
133
-
134
- void runOnOperation () override {
135
- MLIRContext *context = &this ->getContext ();
136
- FuncOp funcOp = this ->getOperation ();
137
- RewritePatternSet patterns (context);
138
- linalg::populatePushReshapeOpsPatterns (patterns);
139
- (void )applyPatternsAndFoldGreedily (funcOp.getBody (), std::move (patterns));
140
- }
141
- };
142
132
} // namespace
143
133
144
134
namespace test {
0 commit comments