@@ -125,6 +125,7 @@ static T getPerfectlyNested(Operation *op) {
125
125
// / E()
126
126
127
127
struct FissionWorkdistribute : public OpRewritePattern <omp::WorkdistributeOp> {
128
+ static bool fissionWorkdistributePatternMatched;
128
129
using OpRewritePattern::OpRewritePattern;
129
130
LogicalResult matchAndRewrite (omp::WorkdistributeOp workdistribute,
130
131
PatternRewriter &rewriter) const override {
@@ -210,9 +211,12 @@ struct FissionWorkdistribute : public OpRewritePattern<omp::WorkdistributeOp> {
210
211
changed = true ;
211
212
}
212
213
}
214
+ if (changed)
215
+ fissionWorkdistributePatternMatched = true ;
213
216
return success (changed);
214
217
}
215
218
};
219
+ bool FissionWorkdistribute::fissionWorkdistributePatternMatched = false ;
216
220
217
221
// / If fir.do_loop is present inside teams workdistribute
218
222
// /
@@ -296,6 +300,7 @@ static void genWsLoopOp(mlir::PatternRewriter &rewriter, fir::DoLoopOp doLoop,
296
300
}
297
301
298
302
struct WorkdistributeDoLower : public OpRewritePattern <omp::WorkdistributeOp> {
303
+ static bool workdistributeDoLowerPatternMatched;
299
304
using OpRewritePattern::OpRewritePattern;
300
305
LogicalResult matchAndRewrite (omp::WorkdistributeOp workdistribute,
301
306
PatternRewriter &rewriter) const override {
@@ -309,12 +314,15 @@ struct WorkdistributeDoLower : public OpRewritePattern<omp::WorkdistributeOp> {
309
314
genLoopNestClauseOps (rewriter, doLoop, loopNestClauseOps);
310
315
genWsLoopOp (rewriter, doLoop, loopNestClauseOps, true );
311
316
rewriter.eraseOp (workdistribute);
317
+ workdistributeDoLowerPatternMatched = true ;
312
318
return success ();
313
319
}
314
320
return failure ();
315
321
}
316
322
};
317
323
324
+ bool WorkdistributeDoLower::workdistributeDoLowerPatternMatched = false ;
325
+
318
326
// / If A() and B () are present inside teams workdistribute
319
327
// /
320
328
// / omp.teams {
@@ -331,6 +339,7 @@ struct WorkdistributeDoLower : public OpRewritePattern<omp::WorkdistributeOp> {
331
339
// /
332
340
333
341
struct TeamsWorkdistributeToSingle : public OpRewritePattern <omp::TeamsOp> {
342
+ static bool teamsWorkdistributeToSinglePatternMatched;
334
343
using OpRewritePattern::OpRewritePattern;
335
344
LogicalResult matchAndRewrite (omp::TeamsOp teamsOp,
336
345
PatternRewriter &rewriter) const override {
@@ -343,9 +352,12 @@ struct TeamsWorkdistributeToSingle : public OpRewritePattern<omp::TeamsOp> {
343
352
rewriter.eraseOp (workdistributeBlock->getTerminator ());
344
353
rewriter.inlineBlockBefore (workdistributeBlock, teamsOp);
345
354
rewriter.eraseOp (workdistributeOp);
355
+ teamsWorkdistributeToSinglePatternMatched = true ;
346
356
return success ();
347
357
}
348
358
};
359
+ bool TeamsWorkdistributeToSingle::teamsWorkdistributeToSinglePatternMatched =
360
+ false ;
349
361
350
362
struct SplitTargetResult {
351
363
omp::TargetOp targetOp;
@@ -517,28 +529,6 @@ static bool usedOutsideSplit(Value v, Operation *split) {
517
529
return false ;
518
530
};
519
531
520
- static bool isOpToBeCached (Operation *op) {
521
- if (auto loadOp = dyn_cast<fir::LoadOp>(op)) {
522
- Value memref = loadOp.getMemref ();
523
- if (auto blockArg = dyn_cast<BlockArgument>(memref)) {
524
- // 'op' is an operation within the targetOp that 'splitBefore' is also in.
525
- Operation *parentOpOfLoadBlock = op->getBlock ()->getParentOp ();
526
- // Ensure the blockArg belongs to the entry block of this parent omp.TargetOp.
527
- // This implies the load is from a variable directly mapped into the target region.
528
- if (isa<omp::TargetOp>(parentOpOfLoadBlock) &&
529
- !parentOpOfLoadBlock->getRegions ().empty ()) {
530
- Block *targetOpEntryBlock = &parentOpOfLoadBlock->getRegions ().front ().front ();
531
- if (blockArg.getOwner () == targetOpEntryBlock) {
532
- // This load is from a direct argument of the target op.
533
- // It's safe to recompute.
534
- return false ;
535
- }
536
- }
537
- }
538
- }
539
- return true ;
540
- }
541
-
542
532
static bool isRecomputableAfterFission (Operation *op, Operation *splitBefore) {
543
533
if (isa<fir::DeclareOp>(op))
544
534
return true ;
@@ -892,23 +882,31 @@ class LowerWorkdistributePass
892
882
config.setRegionSimplificationLevel (GreedySimplifyRegionLevel::Disabled);
893
883
894
884
Operation *op = getOperation ();
885
+ bool anyPatternChanged = false ;
895
886
{
896
887
RewritePatternSet patterns (&context);
897
888
patterns.insert <FissionWorkdistribute, WorkdistributeDoLower>(&context);
898
889
if (failed (applyPatternsGreedily (op, std::move (patterns), config))) {
899
890
emitError (op->getLoc (), DEBUG_TYPE " pass failed\n " );
900
891
signalPassFailure ();
901
892
}
893
+ anyPatternChanged |=
894
+ FissionWorkdistribute::fissionWorkdistributePatternMatched;
895
+ anyPatternChanged |=
896
+ WorkdistributeDoLower::workdistributeDoLowerPatternMatched;
902
897
}
903
898
{
904
899
RewritePatternSet patterns (&context);
905
- patterns.insert <TeamsWorkdistributeToSingle>(&context);
900
+ patterns.insert <WorkdistributeDoLower, TeamsWorkdistributeToSingle>(
901
+ &context);
906
902
if (failed (applyPatternsGreedily (op, std::move (patterns), config))) {
907
903
emitError (op->getLoc (), DEBUG_TYPE " pass failed\n " );
908
904
signalPassFailure ();
909
905
}
906
+ anyPatternChanged |= TeamsWorkdistributeToSingle::
907
+ teamsWorkdistributeToSinglePatternMatched;
910
908
}
911
- {
909
+ if (anyPatternChanged) {
912
910
SmallVector<omp::TargetOp> targetOps;
913
911
op->walk ([&](omp::TargetOp targetOp) { targetOps.push_back (targetOp); });
914
912
IRRewriter rewriter (&context);
@@ -917,7 +915,6 @@ class LowerWorkdistributePass
917
915
if (res) fissionTarget (res->targetOp , rewriter);
918
916
}
919
917
}
920
-
921
918
}
922
919
};
923
920
} // namespace
0 commit comments