@@ -48,25 +48,21 @@ using namespace mlir;
48
48
49
49
namespace {
50
50
51
- template <typename T>
52
- static T getPerfectlyNested (Operation *op) {
53
- if (op->getNumRegions () != 1 )
54
- return nullptr ;
55
- auto ®ion = op->getRegion (0 );
56
- if (region.getBlocks ().size () != 1 )
57
- return nullptr ;
58
- auto *block = ®ion.front ();
59
- auto *firstOp = &block->front ();
60
- if (auto nested = dyn_cast<T>(firstOp))
61
- if (firstOp->getNextNode () == block->getTerminator ())
62
- return nested;
63
- return nullptr ;
51
+ static bool isRuntimeCall (Operation *op) {
52
+ if (auto callOp = dyn_cast<fir::CallOp>(op)) {
53
+ auto callee = callOp.getCallee ();
54
+ if (!callee)
55
+ return false ;
56
+ auto *func = op->getParentOfType <ModuleOp>().lookupSymbol (*callee);
57
+ if (func->getAttr (fir::FIROpsDialect::getFirRuntimeAttrName ()))
58
+ return true ;
59
+ }
60
+ return false ;
64
61
}
65
62
66
63
// / This is the single source of truth about whether we should parallelize an
67
- // / operation nested in an omp.workdistribute region.
64
+ // / operation nested in an omp.execute region.
68
65
static bool shouldParallelize (Operation *op) {
69
- // Currently we cannot parallelize operations with results that have uses
70
66
if (llvm::any_of (op->getResults (),
71
67
[](OpResult v) -> bool { return !v.use_empty (); }))
72
68
return false ;
@@ -77,21 +73,28 @@ static bool shouldParallelize(Operation *op) {
77
73
return false ;
78
74
return *unordered;
79
75
}
80
- if (auto callOp = dyn_cast<fir::CallOp>(op)) {
81
- auto callee = callOp.getCallee ();
82
- if (!callee)
83
- return false ;
84
- auto *func = op->getParentOfType <ModuleOp>().lookupSymbol (*callee);
85
- // TODO need to insert a check here whether it is a call we can actually
86
- // parallelize currently
87
- if (func->getAttr (fir::FIROpsDialect::getFirRuntimeAttrName ()))
88
- return true ;
89
- return false ;
76
+ if (isRuntimeCall (op)) {
77
+ return true ;
90
78
}
91
79
// We cannot parallise anything else
92
80
return false ;
93
81
}
94
82
83
+ template <typename T>
84
+ static T getPerfectlyNested (Operation *op) {
85
+ if (op->getNumRegions () != 1 )
86
+ return nullptr ;
87
+ auto ®ion = op->getRegion (0 );
88
+ if (region.getBlocks ().size () != 1 )
89
+ return nullptr ;
90
+ auto *block = ®ion.front ();
91
+ auto *firstOp = &block->front ();
92
+ if (auto nested = dyn_cast<T>(firstOp))
93
+ if (firstOp->getNextNode () == block->getTerminator ())
94
+ return nested;
95
+ return nullptr ;
96
+ }
97
+
95
98
// / If B() and D() are parallelizable,
96
99
// /
97
100
// / omp.teams {
@@ -138,17 +141,33 @@ struct FissionWorkdistribute : public OpRewritePattern<omp::WorkdistributeOp> {
138
141
emitError (loc, " teams with multiple blocks\n " );
139
142
return failure ();
140
143
}
141
- if (teams.getRegion ().getBlocks ().front ().getOperations ().size () != 2 ) {
142
- emitError (loc, " teams with multiple nested ops\n " );
143
- return failure ();
144
- }
145
144
146
145
auto *teamsBlock = &teams.getRegion ().front ();
146
+ bool changed = false ;
147
+ // Move the ops inside teams and before workdistribute outside.
148
+ IRMapping irMapping;
149
+ llvm::SmallVector<Operation *> teamsHoisted;
150
+ for (auto &op : teams.getOps ()) {
151
+ if (&op == workdistribute) {
152
+ break ;
153
+ }
154
+ if (shouldParallelize (&op)) {
155
+ emitError (loc,
156
+ " teams has parallelize ops before first workdistribute\n " );
157
+ return failure ();
158
+ } else {
159
+ rewriter.setInsertionPoint (teams);
160
+ rewriter.clone (op, irMapping);
161
+ teamsHoisted.push_back (&op);
162
+ changed = true ;
163
+ }
164
+ }
165
+ for (auto *op : teamsHoisted)
166
+ rewriter.replaceOp (op, irMapping.lookup (op));
147
167
148
168
// While we have unhandled operations in the original workdistribute
149
169
auto *workdistributeBlock = &workdistribute.getRegion ().front ();
150
170
auto *terminator = workdistributeBlock->getTerminator ();
151
- bool changed = false ;
152
171
while (&workdistributeBlock->front () != terminator) {
153
172
rewriter.setInsertionPoint (teams);
154
173
IRMapping mapping;
@@ -194,9 +213,51 @@ struct FissionWorkdistribute : public OpRewritePattern<omp::WorkdistributeOp> {
194
213
}
195
214
};
196
215
216
+ // / If fir.do_loop is present inside teams workdistribute
217
+ // /
218
+ // / omp.teams {
219
+ // / omp.workdistribute {
220
+ // / fir.do_loop unoredered {
221
+ // / ...
222
+ // / }
223
+ // / }
224
+ // / }
225
+ // /
226
+ // / Then, its lowered to
227
+ // /
228
+ // / omp.teams {
229
+ // / omp.parallel {
230
+ // / omp.distribute {
231
+ // / omp.wsloop {
232
+ // / omp.loop_nest
233
+ // / ...
234
+ // / }
235
+ // / }
236
+ // / }
237
+ // / }
238
+
239
+ static void genParallelOp (Location loc, PatternRewriter &rewriter,
240
+ bool composite) {
241
+ auto parallelOp = rewriter.create <mlir::omp::ParallelOp>(loc);
242
+ parallelOp.setComposite (composite);
243
+ rewriter.createBlock (¶llelOp.getRegion ());
244
+ rewriter.setInsertionPoint (rewriter.create <mlir::omp::TerminatorOp>(loc));
245
+ return ;
246
+ }
247
+
248
+ static void genDistributeOp (Location loc, PatternRewriter &rewriter,
249
+ bool composite) {
250
+ mlir::omp::DistributeOperands distributeClauseOps;
251
+ auto distributeOp =
252
+ rewriter.create <mlir::omp::DistributeOp>(loc, distributeClauseOps);
253
+ distributeOp.setComposite (composite);
254
+ auto distributeBlock = rewriter.createBlock (&distributeOp.getRegion ());
255
+ rewriter.setInsertionPointToStart (distributeBlock);
256
+ return ;
257
+ }
258
+
197
259
static void
198
- genLoopNestClauseOps (mlir::Location loc, mlir::PatternRewriter &rewriter,
199
- fir::DoLoopOp loop,
260
+ genLoopNestClauseOps (mlir::PatternRewriter &rewriter, fir::DoLoopOp loop,
200
261
mlir::omp::LoopNestOperands &loopNestClauseOps) {
201
262
assert (loopNestClauseOps.loopLowerBounds .empty () &&
202
263
" Loop nest bounds were already emitted!" );
@@ -207,9 +268,11 @@ genLoopNestClauseOps(mlir::Location loc, mlir::PatternRewriter &rewriter,
207
268
}
208
269
209
270
static void genWsLoopOp (mlir::PatternRewriter &rewriter, fir::DoLoopOp doLoop,
210
- const mlir::omp::LoopNestOperands &clauseOps) {
271
+ const mlir::omp::LoopNestOperands &clauseOps,
272
+ bool composite) {
211
273
212
274
auto wsloopOp = rewriter.create <mlir::omp::WsloopOp>(doLoop.getLoc ());
275
+ wsloopOp.setComposite (composite);
213
276
rewriter.createBlock (&wsloopOp.getRegion ());
214
277
215
278
auto loopNestOp =
@@ -231,57 +294,20 @@ static void genWsLoopOp(mlir::PatternRewriter &rewriter, fir::DoLoopOp doLoop,
231
294
return ;
232
295
}
233
296
234
- // / If fir.do_loop is present inside teams workdistribute
235
- // /
236
- // / omp.teams {
237
- // / omp.workdistribute {
238
- // / fir.do_loop unoredered {
239
- // / ...
240
- // / }
241
- // / }
242
- // / }
243
- // /
244
- // / Then, its lowered to
245
- // /
246
- // / omp.teams {
247
- // / omp.workdistribute {
248
- // / omp.parallel {
249
- // / omp.wsloop {
250
- // / omp.loop_nest
251
- // / ...
252
- // / }
253
- // / }
254
- // / }
255
- // / }
256
- // / }
257
-
258
- struct TeamsWorkdistributeLowering : public OpRewritePattern <omp::TeamsOp> {
297
+ struct WorkdistributeDoLower : public OpRewritePattern <omp::WorkdistributeOp> {
259
298
using OpRewritePattern::OpRewritePattern;
260
- LogicalResult matchAndRewrite (omp::TeamsOp teamsOp ,
299
+ LogicalResult matchAndRewrite (omp::WorkdistributeOp workdistribute ,
261
300
PatternRewriter &rewriter) const override {
262
- auto teamsLoc = teamsOp->getLoc ();
263
- auto workdistributeOp = getPerfectlyNested<omp::WorkdistributeOp>(teamsOp);
264
- if (!workdistributeOp) {
265
- LLVM_DEBUG (llvm::dbgs () << DEBUG_TYPE << " No workdistribute nested\n " );
266
- return failure ();
267
- }
268
- assert (teamsOp.getReductionVars ().empty ());
269
-
270
- auto doLoop = getPerfectlyNested<fir::DoLoopOp>(workdistributeOp);
301
+ auto doLoop = getPerfectlyNested<fir::DoLoopOp>(workdistribute);
302
+ auto wdLoc = workdistribute->getLoc ();
271
303
if (doLoop && shouldParallelize (doLoop)) {
272
-
273
- auto parallelOp = rewriter.create <mlir::omp::ParallelOp>(teamsLoc);
274
- rewriter.createBlock (¶llelOp.getRegion ());
275
- rewriter.setInsertionPoint (
276
- rewriter.create <mlir::omp::TerminatorOp>(doLoop.getLoc ()));
277
-
304
+ assert (doLoop.getReduceOperands ().empty ());
305
+ genParallelOp (wdLoc, rewriter, true );
306
+ genDistributeOp (wdLoc, rewriter, true );
278
307
mlir::omp::LoopNestOperands loopNestClauseOps;
279
- genLoopNestClauseOps (doLoop.getLoc (), rewriter, doLoop,
280
- loopNestClauseOps);
281
-
282
- genWsLoopOp (rewriter, doLoop, loopNestClauseOps);
283
- rewriter.setInsertionPoint (doLoop);
284
- rewriter.eraseOp (doLoop);
308
+ genLoopNestClauseOps (rewriter, doLoop, loopNestClauseOps);
309
+ genWsLoopOp (rewriter, doLoop, loopNestClauseOps, true );
310
+ rewriter.eraseOp (workdistribute);
285
311
return success ();
286
312
}
287
313
return failure ();
@@ -315,7 +341,7 @@ struct TeamsWorkdistributeToSingle : public OpRewritePattern<omp::TeamsOp> {
315
341
Block *workdistributeBlock = &workdistributeOp.getRegion ().front ();
316
342
rewriter.eraseOp (workdistributeBlock->getTerminator ());
317
343
rewriter.inlineBlockBefore (workdistributeBlock, teamsOp);
318
- rewriter.eraseOp (teamsOp );
344
+ rewriter.eraseOp (workdistributeOp );
319
345
return success ();
320
346
}
321
347
};
@@ -332,17 +358,15 @@ class LowerWorkdistributePass
332
358
Operation *op = getOperation ();
333
359
{
334
360
RewritePatternSet patterns (&context);
335
- patterns.insert <FissionWorkdistribute, TeamsWorkdistributeLowering>(
336
- &context);
361
+ patterns.insert <FissionWorkdistribute, WorkdistributeDoLower>(&context);
337
362
if (failed (applyPatternsGreedily (op, std::move (patterns), config))) {
338
363
emitError (op->getLoc (), DEBUG_TYPE " pass failed\n " );
339
364
signalPassFailure ();
340
365
}
341
366
}
342
367
{
343
368
RewritePatternSet patterns (&context);
344
- patterns.insert <TeamsWorkdistributeLowering, TeamsWorkdistributeToSingle>(
345
- &context);
369
+ patterns.insert <TeamsWorkdistributeToSingle>(&context);
346
370
if (failed (applyPatternsGreedily (op, std::move (patterns), config))) {
347
371
emitError (op->getLoc (), DEBUG_TYPE " pass failed\n " );
348
372
signalPassFailure ();
0 commit comments