Skip to content

Commit 26c83e4

Browse files
committed
[Flang] Bail out if lower-workdistribute didn't patternmatch.
1 parent 7a06703 commit 26c83e4

File tree

3 files changed

+24
-28
lines changed

3 files changed

+24
-28
lines changed

flang-rt/lib/runtime/assign_omp.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ RT_EXT_API_GROUP_BEGIN
6868
void RTDEF(Assign_omp)(Descriptor &to, const Descriptor &from,
6969
const char *sourceFile, int sourceLine, omp::OMPDeviceTy omp_device) {
7070
Terminator terminator{sourceFile, sourceLine};
71-
omp::Assign(to, from, terminator,
71+
Fortran::runtime::omp::Assign(to, from, terminator,
7272
MaybeReallocate | NeedFinalization | ComponentCanBeDefinedAssignment,
7373
omp_device);
7474
}

flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ static T getPerfectlyNested(Operation *op) {
125125
/// E()
126126

127127
struct FissionWorkdistribute : public OpRewritePattern<omp::WorkdistributeOp> {
128+
static bool fissionWorkdistributePatternMatched;
128129
using OpRewritePattern::OpRewritePattern;
129130
LogicalResult matchAndRewrite(omp::WorkdistributeOp workdistribute,
130131
PatternRewriter &rewriter) const override {
@@ -210,9 +211,12 @@ struct FissionWorkdistribute : public OpRewritePattern<omp::WorkdistributeOp> {
210211
changed = true;
211212
}
212213
}
214+
if (changed)
215+
fissionWorkdistributePatternMatched = true;
213216
return success(changed);
214217
}
215218
};
219+
bool FissionWorkdistribute::fissionWorkdistributePatternMatched = false;
216220

217221
/// If fir.do_loop is present inside teams workdistribute
218222
///
@@ -296,6 +300,7 @@ static void genWsLoopOp(mlir::PatternRewriter &rewriter, fir::DoLoopOp doLoop,
296300
}
297301

298302
struct WorkdistributeDoLower : public OpRewritePattern<omp::WorkdistributeOp> {
303+
static bool workdistributeDoLowerPatternMatched;
299304
using OpRewritePattern::OpRewritePattern;
300305
LogicalResult matchAndRewrite(omp::WorkdistributeOp workdistribute,
301306
PatternRewriter &rewriter) const override {
@@ -309,12 +314,15 @@ struct WorkdistributeDoLower : public OpRewritePattern<omp::WorkdistributeOp> {
309314
genLoopNestClauseOps(rewriter, doLoop, loopNestClauseOps);
310315
genWsLoopOp(rewriter, doLoop, loopNestClauseOps, true);
311316
rewriter.eraseOp(workdistribute);
317+
workdistributeDoLowerPatternMatched = true;
312318
return success();
313319
}
314320
return failure();
315321
}
316322
};
317323

324+
bool WorkdistributeDoLower::workdistributeDoLowerPatternMatched = false;
325+
318326
/// If A() and B () are present inside teams workdistribute
319327
///
320328
/// omp.teams {
@@ -331,6 +339,7 @@ struct WorkdistributeDoLower : public OpRewritePattern<omp::WorkdistributeOp> {
331339
///
332340

333341
struct TeamsWorkdistributeToSingle : public OpRewritePattern<omp::TeamsOp> {
342+
static bool teamsWorkdistributeToSinglePatternMatched;
334343
using OpRewritePattern::OpRewritePattern;
335344
LogicalResult matchAndRewrite(omp::TeamsOp teamsOp,
336345
PatternRewriter &rewriter) const override {
@@ -343,9 +352,12 @@ struct TeamsWorkdistributeToSingle : public OpRewritePattern<omp::TeamsOp> {
343352
rewriter.eraseOp(workdistributeBlock->getTerminator());
344353
rewriter.inlineBlockBefore(workdistributeBlock, teamsOp);
345354
rewriter.eraseOp(workdistributeOp);
355+
teamsWorkdistributeToSinglePatternMatched = true;
346356
return success();
347357
}
348358
};
359+
bool TeamsWorkdistributeToSingle::teamsWorkdistributeToSinglePatternMatched =
360+
false;
349361

350362
struct SplitTargetResult {
351363
omp::TargetOp targetOp;
@@ -517,28 +529,6 @@ static bool usedOutsideSplit(Value v, Operation *split) {
517529
return false;
518530
};
519531

520-
static bool isOpToBeCached(Operation *op) {
521-
if (auto loadOp = dyn_cast<fir::LoadOp>(op)) {
522-
Value memref = loadOp.getMemref();
523-
if (auto blockArg = dyn_cast<BlockArgument>(memref)) {
524-
// 'op' is an operation within the targetOp that 'splitBefore' is also in.
525-
Operation *parentOpOfLoadBlock = op->getBlock()->getParentOp();
526-
// Ensure the blockArg belongs to the entry block of this parent omp.TargetOp.
527-
// This implies the load is from a variable directly mapped into the target region.
528-
if (isa<omp::TargetOp>(parentOpOfLoadBlock) &&
529-
!parentOpOfLoadBlock->getRegions().empty()) {
530-
Block *targetOpEntryBlock = &parentOpOfLoadBlock->getRegions().front().front();
531-
if (blockArg.getOwner() == targetOpEntryBlock) {
532-
// This load is from a direct argument of the target op.
533-
// It's safe to recompute.
534-
return false;
535-
}
536-
}
537-
}
538-
}
539-
return true;
540-
}
541-
542532
static bool isRecomputableAfterFission(Operation *op, Operation *splitBefore) {
543533
if (isa<fir::DeclareOp>(op))
544534
return true;
@@ -892,23 +882,31 @@ class LowerWorkdistributePass
892882
config.setRegionSimplificationLevel(GreedySimplifyRegionLevel::Disabled);
893883

894884
Operation *op = getOperation();
885+
bool anyPatternChanged = false;
895886
{
896887
RewritePatternSet patterns(&context);
897888
patterns.insert<FissionWorkdistribute, WorkdistributeDoLower>(&context);
898889
if (failed(applyPatternsGreedily(op, std::move(patterns), config))) {
899890
emitError(op->getLoc(), DEBUG_TYPE " pass failed\n");
900891
signalPassFailure();
901892
}
893+
anyPatternChanged |=
894+
FissionWorkdistribute::fissionWorkdistributePatternMatched;
895+
anyPatternChanged |=
896+
WorkdistributeDoLower::workdistributeDoLowerPatternMatched;
902897
}
903898
{
904899
RewritePatternSet patterns(&context);
905-
patterns.insert<TeamsWorkdistributeToSingle>(&context);
900+
patterns.insert<WorkdistributeDoLower, TeamsWorkdistributeToSingle>(
901+
&context);
906902
if (failed(applyPatternsGreedily(op, std::move(patterns), config))) {
907903
emitError(op->getLoc(), DEBUG_TYPE " pass failed\n");
908904
signalPassFailure();
909905
}
906+
anyPatternChanged |= TeamsWorkdistributeToSingle::
907+
teamsWorkdistributeToSinglePatternMatched;
910908
}
911-
{
909+
if (anyPatternChanged) {
912910
SmallVector<omp::TargetOp> targetOps;
913911
op->walk([&](omp::TargetOp targetOp) { targetOps.push_back(targetOp); });
914912
IRRewriter rewriter(&context);
@@ -917,7 +915,6 @@ class LowerWorkdistributePass
917915
if (res) fissionTarget(res->targetOp, rewriter);
918916
}
919917
}
920-
921918
}
922919
};
923920
} // namespace

flang/test/Transforms/OpenMP/lower-workdistribute-target.mlir

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
// CHECK: %[[VAL_7:.*]] = fir.coordinate_of %[[VAL_6]], r : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>>) -> !fir.ref<f32>
1111
// CHECK: %[[VAL_8:.*]] = omp.map.info var_ptr(%[[VAL_7]] : !fir.ref<f32>, f32) map_clauses(tofrom) capture(ByRef) -> !fir.ref<f32> {name = "sa%[[VAL_4]]%[[VAL_9:.*]]"}
1212
// CHECK: %[[VAL_10:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>, !fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>) map_clauses(tofrom) capture(ByRef) members(%[[VAL_3]], %[[VAL_8]] : [1, 0], [1, 1] : !fir.ref<i32>, !fir.ref<f32>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>> {name = "sa", partial_map = true}
13-
// CHECK: %[[VAL_11:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>, !fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) members(%[[VAL_3]], %[[VAL_8]] : [1, 0], [1, 1] : !fir.ref<i32>, !fir.ref<f32>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>> {name = "sa", partial_map = true}
14-
// CHECK: omp.target_data map_entries(%[[VAL_10]] : !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) {
13+
// CHECK: omp.target map_entries(%[[VAL_10]] -> %[[VAL_11:.*]] : !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) {
1514
// CHECK: omp.terminator
1615
// CHECK: }
1716
// CHECK: return

0 commit comments

Comments
 (0)