10
10
//
11
11
// ===----------------------------------------------------------------------===//
12
12
13
- #include < flang/Optimizer/Builder/FIRBuilder.h>
14
- #include < flang/Optimizer/Dialect/FIROps.h>
15
- #include < flang/Optimizer/Dialect/FIRType.h>
16
- #include < flang/Optimizer/HLFIR/HLFIROps.h>
17
- #include < flang/Optimizer/OpenMP/Passes.h>
18
- #include < llvm/ADT/BreadthFirstIterator.h>
19
- #include < llvm/ADT/STLExtras.h>
20
- #include < llvm/ADT/SmallVectorExtras.h>
21
- #include < llvm/ADT/iterator_range.h>
22
- #include < llvm/Support/ErrorHandling.h>
13
+ #include " flang/Optimizer/Dialect/FIRDialect.h"
14
+ #include " flang/Optimizer/Dialect/FIROps.h"
15
+ #include " flang/Optimizer/Dialect/FIRType.h"
16
+ #include " flang/Optimizer/Transforms/Passes.h"
17
+ #include " flang/Optimizer/HLFIR/Passes.h"
18
+ #include " mlir/Dialect/OpenMP/OpenMPDialect.h"
19
+ #include " mlir/IR/Builders.h"
20
+ #include " mlir/IR/Value.h"
21
+ #include " mlir/Transforms/GreedyPatternRewriteDriver.h"
23
22
#include < mlir/Dialect/Arith/IR/Arith.h>
24
23
#include < mlir/Dialect/LLVMIR/LLVMTypes.h>
25
- #include < mlir/Dialect/OpenMP/OpenMPClauseOperands.h>
26
- #include < mlir/Dialect/OpenMP/OpenMPDialect.h>
27
- #include < mlir/Dialect/SCF/IR/SCF.h>
24
+ #include < mlir/Dialect/Utils/IndexingUtils.h>
25
+ #include < mlir/IR/BlockSupport.h>
28
26
#include < mlir/IR/BuiltinOps.h>
27
+ #include < mlir/IR/Diagnostics.h>
29
28
#include < mlir/IR/IRMapping.h>
30
- #include < mlir/IR/OpDefinition.h>
31
29
#include < mlir/IR/PatternMatch.h>
32
- #include < mlir/IR/Value.h>
33
- #include < mlir/IR/Visitors.h>
34
30
#include < mlir/Interfaces/SideEffectInterfaces.h>
35
31
#include < mlir/Support/LLVM.h>
36
- #include " mlir/Transforms/GreedyPatternRewriteDriver.h"
37
-
32
+ #include < optional>
38
33
#include < variant>
39
34
40
35
namespace flangomp {
@@ -48,52 +43,188 @@ using namespace mlir;
48
43
49
44
namespace {
50
45
51
- struct WorkdistributeToSingle : public mlir ::OpRewritePattern<mlir::omp::WorkdistributeOp> {
52
- using OpRewritePattern::OpRewritePattern;
53
- mlir::LogicalResult
54
- matchAndRewrite (mlir::omp::WorkdistributeOp workdistribute,
55
- mlir::PatternRewriter &rewriter) const override {
56
- auto loc = workdistribute->getLoc ();
57
- auto teams = llvm::dyn_cast<mlir::omp::TeamsOp>(workdistribute->getParentOp ());
58
- if (!teams) {
59
- mlir::emitError (loc, " workdistribute not nested in teams\n " );
60
- return mlir::failure ();
61
- }
62
- if (workdistribute.getRegion ().getBlocks ().size () != 1 ) {
63
- mlir::emitError (loc, " workdistribute with multiple blocks\n " );
64
- return mlir::failure ();
46
+ template <typename T>
47
+ static T getPerfectlyNested (Operation *op) {
48
+ if (op->getNumRegions () != 1 )
49
+ return nullptr ;
50
+ auto ®ion = op->getRegion (0 );
51
+ if (region.getBlocks ().size () != 1 )
52
+ return nullptr ;
53
+ auto *block = ®ion.front ();
54
+ auto *firstOp = &block->front ();
55
+ if (auto nested = dyn_cast<T>(firstOp))
56
+ if (firstOp->getNextNode () == block->getTerminator ())
57
+ return nested;
58
+ return nullptr ;
59
+ }
60
+
61
+ // / This is the single source of truth about whether we should parallelize an
62
+ // / operation nested in an omp.workdistribute region.
63
+ static bool shouldParallelize (Operation *op) {
64
+ // Currently we cannot parallelize operations with results that have uses
65
+ if (llvm::any_of (op->getResults (),
66
+ [](OpResult v) -> bool { return !v.use_empty (); }))
67
+ return false ;
68
+ // We will parallelize unordered loops - these come from array syntax
69
+ if (auto loop = dyn_cast<fir::DoLoopOp>(op)) {
70
+ auto unordered = loop.getUnordered ();
71
+ if (!unordered)
72
+ return false ;
73
+ return *unordered;
74
+ }
75
+ if (auto callOp = dyn_cast<fir::CallOp>(op)) {
76
+ auto callee = callOp.getCallee ();
77
+ if (!callee)
78
+ return false ;
79
+ auto *func = op->getParentOfType <ModuleOp>().lookupSymbol (*callee);
80
+ // TODO need to insert a check here whether it is a call we can actually
81
+ // parallelize currently
82
+ if (func->getAttr (fir::FIROpsDialect::getFirRuntimeAttrName ()))
83
+ return true ;
84
+ return false ;
85
+ }
86
+ // We cannot parallise anything else
87
+ return false ;
88
+ }
89
+
90
+ struct WorkdistributeToSingle : public OpRewritePattern <omp::TeamsOp> {
91
+ using OpRewritePattern::OpRewritePattern;
92
+ LogicalResult matchAndRewrite (omp::TeamsOp teamsOp,
93
+ PatternRewriter &rewriter) const override {
94
+ auto workdistributeOp = getPerfectlyNested<omp::WorkdistributeOp>(teamsOp);
95
+ if (!workdistributeOp) {
96
+ LLVM_DEBUG (llvm::dbgs () << DEBUG_TYPE << " No workdistribute nested\n " );
97
+ return failure ();
65
98
}
66
- if (teams.getRegion ().getBlocks ().size () != 1 ) {
67
- mlir::emitError (loc, " teams with multiple blocks\n " );
68
- return mlir::failure ();
99
+
100
+ Block *workdistributeBlock = &workdistributeOp.getRegion ().front ();
101
+ rewriter.eraseOp (workdistributeBlock->getTerminator ());
102
+ rewriter.inlineBlockBefore (workdistributeBlock, teamsOp);
103
+ rewriter.eraseOp (teamsOp);
104
+ workdistributeOp.emitWarning (" unable to parallelize coexecute" );
105
+ return success ();
106
+ }
107
+ };
108
+
109
+ // / If B() and D() are parallelizable,
110
+ // /
111
+ // / omp.teams {
112
+ // / omp.workdistribute {
113
+ // / A()
114
+ // / B()
115
+ // / C()
116
+ // / D()
117
+ // / E()
118
+ // / }
119
+ // / }
120
+ // /
121
+ // / becomes
122
+ // /
123
+ // / A()
124
+ // / omp.teams {
125
+ // / omp.workdistribute {
126
+ // / B()
127
+ // / }
128
+ // / }
129
+ // / C()
130
+ // / omp.teams {
131
+ // / omp.workdistribute {
132
+ // / D()
133
+ // / }
134
+ // / }
135
+ // / E()
136
+
137
+ struct FissionWorkdistribute
138
+ : public OpRewritePattern<omp::WorkdistributeOp> {
139
+ using OpRewritePattern::OpRewritePattern;
140
+ LogicalResult
141
+ matchAndRewrite (omp::WorkdistributeOp workdistribute,
142
+ PatternRewriter &rewriter) const override {
143
+ auto loc = workdistribute->getLoc ();
144
+ auto teams = dyn_cast<omp::TeamsOp>(workdistribute->getParentOp ());
145
+ if (!teams) {
146
+ emitError (loc, " workdistribute not nested in teams\n " );
147
+ return failure ();
148
+ }
149
+ if (workdistribute.getRegion ().getBlocks ().size () != 1 ) {
150
+ emitError (loc, " workdistribute with multiple blocks\n " );
151
+ return failure ();
152
+ }
153
+ if (teams.getRegion ().getBlocks ().size () != 1 ) {
154
+ emitError (loc, " teams with multiple blocks\n " );
155
+ return failure ();
156
+ }
157
+ if (teams.getRegion ().getBlocks ().front ().getOperations ().size () != 2 ) {
158
+ emitError (loc, " teams with multiple nested ops\n " );
159
+ return failure ();
160
+ }
161
+
162
+ auto *teamsBlock = &teams.getRegion ().front ();
163
+
164
+ // While we have unhandled operations in the original workdistribute
165
+ auto *workdistributeBlock = &workdistribute.getRegion ().front ();
166
+ auto *terminator = workdistributeBlock->getTerminator ();
167
+ bool changed = false ;
168
+ while (&workdistributeBlock->front () != terminator) {
169
+ rewriter.setInsertionPoint (teams);
170
+ IRMapping mapping;
171
+ llvm::SmallVector<Operation *> hoisted;
172
+ Operation *parallelize = nullptr ;
173
+ for (auto &op : workdistribute.getOps ()) {
174
+ if (&op == terminator) {
175
+ break ;
69
176
}
70
- if (teams.getRegion ().getBlocks ().front ().getOperations ().size () != 2 ) {
71
- mlir::emitError (loc, " teams with multiple nested ops\n " );
72
- return mlir::failure ();
177
+ if (shouldParallelize (&op)) {
178
+ parallelize = &op;
179
+ break ;
180
+ } else {
181
+ rewriter.clone (op, mapping);
182
+ hoisted.push_back (&op);
183
+ changed = true ;
73
184
}
74
- mlir::Block *workdistributeBlock = &workdistribute.getRegion ().front ();
75
- rewriter.eraseOp (workdistributeBlock->getTerminator ());
76
- rewriter.inlineBlockBefore (workdistributeBlock, teams);
77
- rewriter.eraseOp (teams);
78
- return mlir::success ();
185
+ }
186
+
187
+ for (auto *op : hoisted)
188
+ rewriter.replaceOp (op, mapping.lookup (op));
189
+
190
+ if (parallelize && hoisted.empty () &&
191
+ parallelize->getNextNode () == terminator)
192
+ break ;
193
+ if (parallelize) {
194
+ auto newTeams = rewriter.cloneWithoutRegions (teams);
195
+ auto *newTeamsBlock = rewriter.createBlock (
196
+ &newTeams.getRegion (), newTeams.getRegion ().begin (), {}, {});
197
+ for (auto arg : teamsBlock->getArguments ())
198
+ newTeamsBlock->addArgument (arg.getType (), arg.getLoc ());
199
+ auto newWorkdistribute = rewriter.create <omp::WorkdistributeOp>(loc);
200
+ rewriter.create <omp::TerminatorOp>(loc);
201
+ rewriter.createBlock (&newWorkdistribute.getRegion (),
202
+ newWorkdistribute.getRegion ().begin (), {}, {});
203
+ auto *cloned = rewriter.clone (*parallelize);
204
+ rewriter.replaceOp (parallelize, cloned);
205
+ rewriter.create <omp::TerminatorOp>(loc);
206
+ changed = true ;
207
+ }
79
208
}
209
+ return success (changed);
210
+ }
80
211
};
81
212
82
213
class LowerWorkdistributePass
83
214
: public flangomp::impl::LowerWorkdistributeBase<LowerWorkdistributePass> {
84
215
public:
85
216
void runOnOperation () override {
86
- mlir:: MLIRContext &context = getContext ();
87
- mlir:: RewritePatternSet patterns (&context);
88
- mlir:: GreedyRewriteConfig config;
217
+ MLIRContext &context = getContext ();
218
+ RewritePatternSet patterns (&context);
219
+ GreedyRewriteConfig config;
89
220
// prevent the pattern driver form merging blocks
90
221
config.setRegionSimplificationLevel (
91
- mlir:: GreedySimplifyRegionLevel::Disabled);
222
+ GreedySimplifyRegionLevel::Disabled);
92
223
93
- patterns.insert <WorkdistributeToSingle>(&context);
94
- mlir:: Operation *op = getOperation ();
95
- if (mlir:: failed (mlir:: applyPatternsGreedily (op, std::move (patterns), config))) {
96
- mlir:: emitError (op->getLoc (), DEBUG_TYPE " pass failed\n " );
224
+ patterns.insert <FissionWorkdistribute, WorkdistributeToSingle>(&context);
225
+ Operation *op = getOperation ();
226
+ if (failed (applyPatternsGreedily (op, std::move (patterns), config))) {
227
+ emitError (op->getLoc (), DEBUG_TYPE " pass failed\n " );
97
228
signalPassFailure ();
98
229
}
99
230
}
0 commit comments