Skip to content

Commit 60351b6

Browse files
committed
update to workdistribute lowering
1 parent df65bd5 commit 60351b6

File tree

3 files changed

+139
-105
lines changed

3 files changed

+139
-105
lines changed

flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp

Lines changed: 109 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -48,25 +48,21 @@ using namespace mlir;
4848

4949
namespace {
5050

51-
template <typename T>
52-
static T getPerfectlyNested(Operation *op) {
53-
if (op->getNumRegions() != 1)
54-
return nullptr;
55-
auto &region = op->getRegion(0);
56-
if (region.getBlocks().size() != 1)
57-
return nullptr;
58-
auto *block = &region.front();
59-
auto *firstOp = &block->front();
60-
if (auto nested = dyn_cast<T>(firstOp))
61-
if (firstOp->getNextNode() == block->getTerminator())
62-
return nested;
63-
return nullptr;
51+
static bool isRuntimeCall(Operation *op) {
52+
if (auto callOp = dyn_cast<fir::CallOp>(op)) {
53+
auto callee = callOp.getCallee();
54+
if (!callee)
55+
return false;
56+
auto *func = op->getParentOfType<ModuleOp>().lookupSymbol(*callee);
57+
if (func->getAttr(fir::FIROpsDialect::getFirRuntimeAttrName()))
58+
return true;
59+
}
60+
return false;
6461
}
6562

6663
/// This is the single source of truth about whether we should parallelize an
67-
/// operation nested in an omp.workdistribute region.
64+
/// operation nested in an omp.execute region.
6865
static bool shouldParallelize(Operation *op) {
69-
// Currently we cannot parallelize operations with results that have uses
7066
if (llvm::any_of(op->getResults(),
7167
[](OpResult v) -> bool { return !v.use_empty(); }))
7268
return false;
@@ -77,21 +73,28 @@ static bool shouldParallelize(Operation *op) {
7773
return false;
7874
return *unordered;
7975
}
80-
if (auto callOp = dyn_cast<fir::CallOp>(op)) {
81-
auto callee = callOp.getCallee();
82-
if (!callee)
83-
return false;
84-
auto *func = op->getParentOfType<ModuleOp>().lookupSymbol(*callee);
85-
// TODO need to insert a check here whether it is a call we can actually
86-
// parallelize currently
87-
if (func->getAttr(fir::FIROpsDialect::getFirRuntimeAttrName()))
88-
return true;
89-
return false;
76+
if (isRuntimeCall(op)) {
77+
return true;
9078
}
9179
// We cannot parallise anything else
9280
return false;
9381
}
9482

83+
template <typename T>
84+
static T getPerfectlyNested(Operation *op) {
85+
if (op->getNumRegions() != 1)
86+
return nullptr;
87+
auto &region = op->getRegion(0);
88+
if (region.getBlocks().size() != 1)
89+
return nullptr;
90+
auto *block = &region.front();
91+
auto *firstOp = &block->front();
92+
if (auto nested = dyn_cast<T>(firstOp))
93+
if (firstOp->getNextNode() == block->getTerminator())
94+
return nested;
95+
return nullptr;
96+
}
97+
9598
/// If B() and D() are parallelizable,
9699
///
97100
/// omp.teams {
@@ -138,17 +141,33 @@ struct FissionWorkdistribute : public OpRewritePattern<omp::WorkdistributeOp> {
138141
emitError(loc, "teams with multiple blocks\n");
139142
return failure();
140143
}
141-
if (teams.getRegion().getBlocks().front().getOperations().size() != 2) {
142-
emitError(loc, "teams with multiple nested ops\n");
143-
return failure();
144-
}
145144

146145
auto *teamsBlock = &teams.getRegion().front();
146+
bool changed = false;
147+
// Move the ops inside teams and before workdistribute outside.
148+
IRMapping irMapping;
149+
llvm::SmallVector<Operation *> teamsHoisted;
150+
for (auto &op : teams.getOps()) {
151+
if (&op == workdistribute) {
152+
break;
153+
}
154+
if (shouldParallelize(&op)) {
155+
emitError(loc,
156+
"teams has parallelize ops before first workdistribute\n");
157+
return failure();
158+
} else {
159+
rewriter.setInsertionPoint(teams);
160+
rewriter.clone(op, irMapping);
161+
teamsHoisted.push_back(&op);
162+
changed = true;
163+
}
164+
}
165+
for (auto *op : teamsHoisted)
166+
rewriter.replaceOp(op, irMapping.lookup(op));
147167

148168
// While we have unhandled operations in the original workdistribute
149169
auto *workdistributeBlock = &workdistribute.getRegion().front();
150170
auto *terminator = workdistributeBlock->getTerminator();
151-
bool changed = false;
152171
while (&workdistributeBlock->front() != terminator) {
153172
rewriter.setInsertionPoint(teams);
154173
IRMapping mapping;
@@ -194,9 +213,51 @@ struct FissionWorkdistribute : public OpRewritePattern<omp::WorkdistributeOp> {
194213
}
195214
};
196215

216+
/// If fir.do_loop is present inside teams workdistribute
217+
///
218+
/// omp.teams {
219+
/// omp.workdistribute {
220+
/// fir.do_loop unoredered {
221+
/// ...
222+
/// }
223+
/// }
224+
/// }
225+
///
226+
/// Then, its lowered to
227+
///
228+
/// omp.teams {
229+
/// omp.parallel {
230+
/// omp.distribute {
231+
/// omp.wsloop {
232+
/// omp.loop_nest
233+
/// ...
234+
/// }
235+
/// }
236+
/// }
237+
/// }
238+
239+
static void genParallelOp(Location loc, PatternRewriter &rewriter,
240+
bool composite) {
241+
auto parallelOp = rewriter.create<mlir::omp::ParallelOp>(loc);
242+
parallelOp.setComposite(composite);
243+
rewriter.createBlock(&parallelOp.getRegion());
244+
rewriter.setInsertionPoint(rewriter.create<mlir::omp::TerminatorOp>(loc));
245+
return;
246+
}
247+
248+
static void genDistributeOp(Location loc, PatternRewriter &rewriter,
249+
bool composite) {
250+
mlir::omp::DistributeOperands distributeClauseOps;
251+
auto distributeOp =
252+
rewriter.create<mlir::omp::DistributeOp>(loc, distributeClauseOps);
253+
distributeOp.setComposite(composite);
254+
auto distributeBlock = rewriter.createBlock(&distributeOp.getRegion());
255+
rewriter.setInsertionPointToStart(distributeBlock);
256+
return;
257+
}
258+
197259
static void
198-
genLoopNestClauseOps(mlir::Location loc, mlir::PatternRewriter &rewriter,
199-
fir::DoLoopOp loop,
260+
genLoopNestClauseOps(mlir::PatternRewriter &rewriter, fir::DoLoopOp loop,
200261
mlir::omp::LoopNestOperands &loopNestClauseOps) {
201262
assert(loopNestClauseOps.loopLowerBounds.empty() &&
202263
"Loop nest bounds were already emitted!");
@@ -207,9 +268,11 @@ genLoopNestClauseOps(mlir::Location loc, mlir::PatternRewriter &rewriter,
207268
}
208269

209270
static void genWsLoopOp(mlir::PatternRewriter &rewriter, fir::DoLoopOp doLoop,
210-
const mlir::omp::LoopNestOperands &clauseOps) {
271+
const mlir::omp::LoopNestOperands &clauseOps,
272+
bool composite) {
211273

212274
auto wsloopOp = rewriter.create<mlir::omp::WsloopOp>(doLoop.getLoc());
275+
wsloopOp.setComposite(composite);
213276
rewriter.createBlock(&wsloopOp.getRegion());
214277

215278
auto loopNestOp =
@@ -231,57 +294,20 @@ static void genWsLoopOp(mlir::PatternRewriter &rewriter, fir::DoLoopOp doLoop,
231294
return;
232295
}
233296

234-
/// If fir.do_loop is present inside teams workdistribute
235-
///
236-
/// omp.teams {
237-
/// omp.workdistribute {
238-
/// fir.do_loop unoredered {
239-
/// ...
240-
/// }
241-
/// }
242-
/// }
243-
///
244-
/// Then, its lowered to
245-
///
246-
/// omp.teams {
247-
/// omp.workdistribute {
248-
/// omp.parallel {
249-
/// omp.wsloop {
250-
/// omp.loop_nest
251-
/// ...
252-
/// }
253-
/// }
254-
/// }
255-
/// }
256-
/// }
257-
258-
struct TeamsWorkdistributeLowering : public OpRewritePattern<omp::TeamsOp> {
297+
struct WorkdistributeDoLower : public OpRewritePattern<omp::WorkdistributeOp> {
259298
using OpRewritePattern::OpRewritePattern;
260-
LogicalResult matchAndRewrite(omp::TeamsOp teamsOp,
299+
LogicalResult matchAndRewrite(omp::WorkdistributeOp workdistribute,
261300
PatternRewriter &rewriter) const override {
262-
auto teamsLoc = teamsOp->getLoc();
263-
auto workdistributeOp = getPerfectlyNested<omp::WorkdistributeOp>(teamsOp);
264-
if (!workdistributeOp) {
265-
LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << " No workdistribute nested\n");
266-
return failure();
267-
}
268-
assert(teamsOp.getReductionVars().empty());
269-
270-
auto doLoop = getPerfectlyNested<fir::DoLoopOp>(workdistributeOp);
301+
auto doLoop = getPerfectlyNested<fir::DoLoopOp>(workdistribute);
302+
auto wdLoc = workdistribute->getLoc();
271303
if (doLoop && shouldParallelize(doLoop)) {
272-
273-
auto parallelOp = rewriter.create<mlir::omp::ParallelOp>(teamsLoc);
274-
rewriter.createBlock(&parallelOp.getRegion());
275-
rewriter.setInsertionPoint(
276-
rewriter.create<mlir::omp::TerminatorOp>(doLoop.getLoc()));
277-
304+
assert(doLoop.getReduceOperands().empty());
305+
genParallelOp(wdLoc, rewriter, true);
306+
genDistributeOp(wdLoc, rewriter, true);
278307
mlir::omp::LoopNestOperands loopNestClauseOps;
279-
genLoopNestClauseOps(doLoop.getLoc(), rewriter, doLoop,
280-
loopNestClauseOps);
281-
282-
genWsLoopOp(rewriter, doLoop, loopNestClauseOps);
283-
rewriter.setInsertionPoint(doLoop);
284-
rewriter.eraseOp(doLoop);
308+
genLoopNestClauseOps(rewriter, doLoop, loopNestClauseOps);
309+
genWsLoopOp(rewriter, doLoop, loopNestClauseOps, true);
310+
rewriter.eraseOp(workdistribute);
285311
return success();
286312
}
287313
return failure();
@@ -315,7 +341,7 @@ struct TeamsWorkdistributeToSingle : public OpRewritePattern<omp::TeamsOp> {
315341
Block *workdistributeBlock = &workdistributeOp.getRegion().front();
316342
rewriter.eraseOp(workdistributeBlock->getTerminator());
317343
rewriter.inlineBlockBefore(workdistributeBlock, teamsOp);
318-
rewriter.eraseOp(teamsOp);
344+
rewriter.eraseOp(workdistributeOp);
319345
return success();
320346
}
321347
};
@@ -332,17 +358,15 @@ class LowerWorkdistributePass
332358
Operation *op = getOperation();
333359
{
334360
RewritePatternSet patterns(&context);
335-
patterns.insert<FissionWorkdistribute, TeamsWorkdistributeLowering>(
336-
&context);
361+
patterns.insert<FissionWorkdistribute, WorkdistributeDoLower>(&context);
337362
if (failed(applyPatternsGreedily(op, std::move(patterns), config))) {
338363
emitError(op->getLoc(), DEBUG_TYPE " pass failed\n");
339364
signalPassFailure();
340365
}
341366
}
342367
{
343368
RewritePatternSet patterns(&context);
344-
patterns.insert<TeamsWorkdistributeLowering, TeamsWorkdistributeToSingle>(
345-
&context);
369+
patterns.insert<TeamsWorkdistributeToSingle>(&context);
346370
if (failed(applyPatternsGreedily(op, std::move(patterns), config))) {
347371
emitError(op->getLoc(), DEBUG_TYPE " pass failed\n");
348372
signalPassFailure();

flang/test/Transforms/OpenMP/lower-workdistribute-doloop.mlir

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,18 @@
22

33
// CHECK-LABEL: func.func @x({{.*}})
44
// CHECK: %[[VAL_0:.*]] = arith.constant 0 : index
5-
// CHECK: omp.parallel {
6-
// CHECK: omp.wsloop {
7-
// CHECK: omp.loop_nest (%[[VAL_1:.*]]) : index = (%[[ARG0:.*]]) to (%[[ARG1:.*]]) inclusive step (%[[ARG2:.*]]) {
8-
// CHECK: fir.store %[[VAL_0]] to %[[ARG4:.*]] : !fir.ref<index>
9-
// CHECK: omp.yield
10-
// CHECK: }
11-
// CHECK: }
5+
// CHECK: omp.teams {
6+
// CHECK: omp.parallel {
7+
// CHECK: omp.distribute {
8+
// CHECK: omp.wsloop {
9+
// CHECK: omp.loop_nest (%[[VAL_1:.*]]) : index = (%[[ARG0:.*]]) to (%[[ARG1:.*]]) inclusive step (%[[ARG2:.*]]) {
10+
// CHECK: fir.store %[[VAL_0]] to %[[ARG4:.*]] : !fir.ref<index>
11+
// CHECK: omp.yield
12+
// CHECK: }
13+
// CHECK: } {omp.composite}
14+
// CHECK: } {omp.composite}
15+
// CHECK: omp.terminator
16+
// CHECK: } {omp.composite}
1217
// CHECK: omp.terminator
1318
// CHECK: }
1419
// CHECK: return

flang/test/Transforms/OpenMP/lower-workdistribute-fission.mlir

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,26 @@
11
// RUN: fir-opt --lower-workdistribute %s | FileCheck %s
22

3-
// CHECK-LABEL: func.func @test_fission_workdistribute({{.*}}) {
3+
// CHECK-LABEL: func.func @test_fission_workdistribute(
44
// CHECK: %[[VAL_0:.*]] = arith.constant 0 : index
55
// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index
66
// CHECK: %[[VAL_2:.*]] = arith.constant 9 : index
77
// CHECK: %[[VAL_3:.*]] = arith.constant 5.000000e+00 : f32
88
// CHECK: fir.store %[[VAL_3]] to %[[ARG2:.*]] : !fir.ref<f32>
9-
// CHECK: omp.parallel {
10-
// CHECK: omp.wsloop {
11-
// CHECK: omp.loop_nest (%[[VAL_4:.*]]) : index = (%[[VAL_0]]) to (%[[VAL_2]]) inclusive step (%[[VAL_1]]) {
12-
// CHECK: %[[VAL_5:.*]] = fir.coordinate_of %[[ARG0:.*]], %[[VAL_4]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
13-
// CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_5]] : !fir.ref<f32>
14-
// CHECK: %[[VAL_7:.*]] = fir.coordinate_of %[[ARG1:.*]], %[[VAL_4]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
15-
// CHECK: fir.store %[[VAL_6]] to %[[VAL_7]] : !fir.ref<f32>
16-
// CHECK: omp.yield
17-
// CHECK: }
18-
// CHECK: }
9+
// CHECK: omp.teams {
10+
// CHECK: omp.parallel {
11+
// CHECK: omp.distribute {
12+
// CHECK: omp.wsloop {
13+
// CHECK: omp.loop_nest (%[[VAL_4:.*]]) : index = (%[[VAL_0]]) to (%[[VAL_2]]) inclusive step (%[[VAL_1]]) {
14+
// CHECK: %[[VAL_5:.*]] = fir.coordinate_of %[[ARG0:.*]], %[[VAL_4]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
15+
// CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_5]] : !fir.ref<f32>
16+
// CHECK: %[[VAL_7:.*]] = fir.coordinate_of %[[ARG1:.*]], %[[VAL_4]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
17+
// CHECK: fir.store %[[VAL_6]] to %[[VAL_7]] : !fir.ref<f32>
18+
// CHECK: omp.yield
19+
// CHECK: }
20+
// CHECK: } {omp.composite}
21+
// CHECK: } {omp.composite}
22+
// CHECK: omp.terminator
23+
// CHECK: } {omp.composite}
1924
// CHECK: omp.terminator
2025
// CHECK: }
2126
// CHECK: fir.call @regular_side_effect_func(%[[ARG2:.*]]) : (!fir.ref<f32>) -> ()
@@ -24,8 +29,8 @@
2429
// CHECK: %[[VAL_9:.*]] = fir.coordinate_of %[[ARG0]], %[[VAL_8]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
2530
// CHECK: fir.store %[[VAL_3]] to %[[VAL_9]] : !fir.ref<f32>
2631
// CHECK: }
27-
// CHECK: %[[VAL_10:.*]] = fir.load %[[ARG2]] : !fir.ref<f32>
28-
// CHECK: fir.store %[[VAL_10]] to %[[ARG3]] : !fir.ref<f32>
32+
// CHECK: %[[VAL_10:.*]] = fir.load %[[ARG2:.*]] : !fir.ref<f32>
33+
// CHECK: fir.store %[[VAL_10]] to %[[ARG3:.*]] : !fir.ref<f32>
2934
// CHECK: return
3035
// CHECK: }
3136
module {

0 commit comments

Comments
 (0)