Skip to content

Commit 2d138e2

Browse files
committed
[Flang][OpenMP] Add FissionWorkdistribute lowering.
Fission logic inspired from ivanradanov implementation : c97eca4
1 parent 92c4480 commit 2d138e2

File tree

3 files changed

+243
-52
lines changed

3 files changed

+243
-52
lines changed

flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp

Lines changed: 182 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -10,31 +10,26 @@
1010
//
1111
//===----------------------------------------------------------------------===//
1212

13-
#include <flang/Optimizer/Builder/FIRBuilder.h>
14-
#include <flang/Optimizer/Dialect/FIROps.h>
15-
#include <flang/Optimizer/Dialect/FIRType.h>
16-
#include <flang/Optimizer/HLFIR/HLFIROps.h>
17-
#include <flang/Optimizer/OpenMP/Passes.h>
18-
#include <llvm/ADT/BreadthFirstIterator.h>
19-
#include <llvm/ADT/STLExtras.h>
20-
#include <llvm/ADT/SmallVectorExtras.h>
21-
#include <llvm/ADT/iterator_range.h>
22-
#include <llvm/Support/ErrorHandling.h>
13+
#include "flang/Optimizer/Dialect/FIRDialect.h"
14+
#include "flang/Optimizer/Dialect/FIROps.h"
15+
#include "flang/Optimizer/Dialect/FIRType.h"
16+
#include "flang/Optimizer/Transforms/Passes.h"
17+
#include "flang/Optimizer/HLFIR/Passes.h"
18+
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
19+
#include "mlir/IR/Builders.h"
20+
#include "mlir/IR/Value.h"
21+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2322
#include <mlir/Dialect/Arith/IR/Arith.h>
2423
#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
25-
#include <mlir/Dialect/OpenMP/OpenMPClauseOperands.h>
26-
#include <mlir/Dialect/OpenMP/OpenMPDialect.h>
27-
#include <mlir/Dialect/SCF/IR/SCF.h>
24+
#include <mlir/Dialect/Utils/IndexingUtils.h>
25+
#include <mlir/IR/BlockSupport.h>
2826
#include <mlir/IR/BuiltinOps.h>
27+
#include <mlir/IR/Diagnostics.h>
2928
#include <mlir/IR/IRMapping.h>
30-
#include <mlir/IR/OpDefinition.h>
3129
#include <mlir/IR/PatternMatch.h>
32-
#include <mlir/IR/Value.h>
33-
#include <mlir/IR/Visitors.h>
3430
#include <mlir/Interfaces/SideEffectInterfaces.h>
3531
#include <mlir/Support/LLVM.h>
36-
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
37-
32+
#include <optional>
3833
#include <variant>
3934

4035
namespace flangomp {
@@ -48,52 +43,188 @@ using namespace mlir;
4843

4944
namespace {
5045

51-
struct WorkdistributeToSingle : public mlir::OpRewritePattern<mlir::omp::WorkdistributeOp> {
52-
using OpRewritePattern::OpRewritePattern;
53-
mlir::LogicalResult
54-
matchAndRewrite(mlir::omp::WorkdistributeOp workdistribute,
55-
mlir::PatternRewriter &rewriter) const override {
56-
auto loc = workdistribute->getLoc();
57-
auto teams = llvm::dyn_cast<mlir::omp::TeamsOp>(workdistribute->getParentOp());
58-
if (!teams) {
59-
mlir::emitError(loc, "workdistribute not nested in teams\n");
60-
return mlir::failure();
61-
}
62-
if (workdistribute.getRegion().getBlocks().size() != 1) {
63-
mlir::emitError(loc, "workdistribute with multiple blocks\n");
64-
return mlir::failure();
46+
template <typename T>
47+
static T getPerfectlyNested(Operation *op) {
48+
if (op->getNumRegions() != 1)
49+
return nullptr;
50+
auto &region = op->getRegion(0);
51+
if (region.getBlocks().size() != 1)
52+
return nullptr;
53+
auto *block = &region.front();
54+
auto *firstOp = &block->front();
55+
if (auto nested = dyn_cast<T>(firstOp))
56+
if (firstOp->getNextNode() == block->getTerminator())
57+
return nested;
58+
return nullptr;
59+
}
60+
61+
/// This is the single source of truth about whether we should parallelize an
62+
/// operation nested in an omp.workdistribute region.
63+
static bool shouldParallelize(Operation *op) {
64+
// Currently we cannot parallelize operations with results that have uses
65+
if (llvm::any_of(op->getResults(),
66+
[](OpResult v) -> bool { return !v.use_empty(); }))
67+
return false;
68+
// We will parallelize unordered loops - these come from array syntax
69+
if (auto loop = dyn_cast<fir::DoLoopOp>(op)) {
70+
auto unordered = loop.getUnordered();
71+
if (!unordered)
72+
return false;
73+
return *unordered;
74+
}
75+
if (auto callOp = dyn_cast<fir::CallOp>(op)) {
76+
auto callee = callOp.getCallee();
77+
if (!callee)
78+
return false;
79+
auto *func = op->getParentOfType<ModuleOp>().lookupSymbol(*callee);
80+
// TODO need to insert a check here whether it is a call we can actually
81+
// parallelize currently
82+
if (func->getAttr(fir::FIROpsDialect::getFirRuntimeAttrName()))
83+
return true;
84+
return false;
85+
}
86+
// We cannot parallise anything else
87+
return false;
88+
}
89+
90+
struct WorkdistributeToSingle : public OpRewritePattern<omp::TeamsOp> {
91+
using OpRewritePattern::OpRewritePattern;
92+
LogicalResult matchAndRewrite(omp::TeamsOp teamsOp,
93+
PatternRewriter &rewriter) const override {
94+
auto workdistributeOp = getPerfectlyNested<omp::WorkdistributeOp>(teamsOp);
95+
if (!workdistributeOp) {
96+
LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << " No workdistribute nested\n");
97+
return failure();
6598
}
66-
if (teams.getRegion().getBlocks().size() != 1) {
67-
mlir::emitError(loc, "teams with multiple blocks\n");
68-
return mlir::failure();
99+
100+
Block *workdistributeBlock = &workdistributeOp.getRegion().front();
101+
rewriter.eraseOp(workdistributeBlock->getTerminator());
102+
rewriter.inlineBlockBefore(workdistributeBlock, teamsOp);
103+
rewriter.eraseOp(teamsOp);
104+
workdistributeOp.emitWarning("unable to parallelize coexecute");
105+
return success();
106+
}
107+
};
108+
109+
/// If B() and D() are parallelizable,
110+
///
111+
/// omp.teams {
112+
/// omp.workdistribute {
113+
/// A()
114+
/// B()
115+
/// C()
116+
/// D()
117+
/// E()
118+
/// }
119+
/// }
120+
///
121+
/// becomes
122+
///
123+
/// A()
124+
/// omp.teams {
125+
/// omp.workdistribute {
126+
/// B()
127+
/// }
128+
/// }
129+
/// C()
130+
/// omp.teams {
131+
/// omp.workdistribute {
132+
/// D()
133+
/// }
134+
/// }
135+
/// E()
136+
137+
struct FissionWorkdistribute
138+
: public OpRewritePattern<omp::WorkdistributeOp> {
139+
using OpRewritePattern::OpRewritePattern;
140+
LogicalResult
141+
matchAndRewrite(omp::WorkdistributeOp workdistribute,
142+
PatternRewriter &rewriter) const override {
143+
auto loc = workdistribute->getLoc();
144+
auto teams = dyn_cast<omp::TeamsOp>(workdistribute->getParentOp());
145+
if (!teams) {
146+
emitError(loc, "workdistribute not nested in teams\n");
147+
return failure();
148+
}
149+
if (workdistribute.getRegion().getBlocks().size() != 1) {
150+
emitError(loc, "workdistribute with multiple blocks\n");
151+
return failure();
152+
}
153+
if (teams.getRegion().getBlocks().size() != 1) {
154+
emitError(loc, "teams with multiple blocks\n");
155+
return failure();
156+
}
157+
if (teams.getRegion().getBlocks().front().getOperations().size() != 2) {
158+
emitError(loc, "teams with multiple nested ops\n");
159+
return failure();
160+
}
161+
162+
auto *teamsBlock = &teams.getRegion().front();
163+
164+
// While we have unhandled operations in the original workdistribute
165+
auto *workdistributeBlock = &workdistribute.getRegion().front();
166+
auto *terminator = workdistributeBlock->getTerminator();
167+
bool changed = false;
168+
while (&workdistributeBlock->front() != terminator) {
169+
rewriter.setInsertionPoint(teams);
170+
IRMapping mapping;
171+
llvm::SmallVector<Operation *> hoisted;
172+
Operation *parallelize = nullptr;
173+
for (auto &op : workdistribute.getOps()) {
174+
if (&op == terminator) {
175+
break;
69176
}
70-
if (teams.getRegion().getBlocks().front().getOperations().size() != 2) {
71-
mlir::emitError(loc, "teams with multiple nested ops\n");
72-
return mlir::failure();
177+
if (shouldParallelize(&op)) {
178+
parallelize = &op;
179+
break;
180+
} else {
181+
rewriter.clone(op, mapping);
182+
hoisted.push_back(&op);
183+
changed = true;
73184
}
74-
mlir::Block *workdistributeBlock = &workdistribute.getRegion().front();
75-
rewriter.eraseOp(workdistributeBlock->getTerminator());
76-
rewriter.inlineBlockBefore(workdistributeBlock, teams);
77-
rewriter.eraseOp(teams);
78-
return mlir::success();
185+
}
186+
187+
for (auto *op : hoisted)
188+
rewriter.replaceOp(op, mapping.lookup(op));
189+
190+
if (parallelize && hoisted.empty() &&
191+
parallelize->getNextNode() == terminator)
192+
break;
193+
if (parallelize) {
194+
auto newTeams = rewriter.cloneWithoutRegions(teams);
195+
auto *newTeamsBlock = rewriter.createBlock(
196+
&newTeams.getRegion(), newTeams.getRegion().begin(), {}, {});
197+
for (auto arg : teamsBlock->getArguments())
198+
newTeamsBlock->addArgument(arg.getType(), arg.getLoc());
199+
auto newWorkdistribute = rewriter.create<omp::WorkdistributeOp>(loc);
200+
rewriter.create<omp::TerminatorOp>(loc);
201+
rewriter.createBlock(&newWorkdistribute.getRegion(),
202+
newWorkdistribute.getRegion().begin(), {}, {});
203+
auto *cloned = rewriter.clone(*parallelize);
204+
rewriter.replaceOp(parallelize, cloned);
205+
rewriter.create<omp::TerminatorOp>(loc);
206+
changed = true;
207+
}
79208
}
209+
return success(changed);
210+
}
80211
};
81212

82213
class LowerWorkdistributePass
83214
: public flangomp::impl::LowerWorkdistributeBase<LowerWorkdistributePass> {
84215
public:
85216
void runOnOperation() override {
86-
mlir::MLIRContext &context = getContext();
87-
mlir::RewritePatternSet patterns(&context);
88-
mlir::GreedyRewriteConfig config;
217+
MLIRContext &context = getContext();
218+
RewritePatternSet patterns(&context);
219+
GreedyRewriteConfig config;
89220
// prevent the pattern driver form merging blocks
90221
config.setRegionSimplificationLevel(
91-
mlir::GreedySimplifyRegionLevel::Disabled);
222+
GreedySimplifyRegionLevel::Disabled);
92223

93-
patterns.insert<WorkdistributeToSingle>(&context);
94-
mlir::Operation *op = getOperation();
95-
if (mlir::failed(mlir::applyPatternsGreedily(op, std::move(patterns), config))) {
96-
mlir::emitError(op->getLoc(), DEBUG_TYPE " pass failed\n");
224+
patterns.insert<FissionWorkdistribute, WorkdistributeToSingle>(&context);
225+
Operation *op = getOperation();
226+
if (failed(applyPatternsGreedily(op, std::move(patterns), config))) {
227+
emitError(op->getLoc(), DEBUG_TYPE " pass failed\n");
97228
signalPassFailure();
98229
}
99230
}
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
// RUN: fir-opt --lower-workdistribute %s | FileCheck %s
2+
3+
// CHECK-LABEL: func.func @test_fission_workdistribute({{.*}}) {
4+
// CHECK: %[[VAL_0:.*]] = arith.constant 0 : index
5+
// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index
6+
// CHECK: %[[VAL_2:.*]] = arith.constant 9 : index
7+
// CHECK: %[[VAL_3:.*]] = arith.constant 5.000000e+00 : f32
8+
// CHECK: fir.store %[[VAL_3]] to %[[ARG2:.*]] : !fir.ref<f32>
9+
// CHECK: fir.do_loop %[[VAL_4:.*]] = %[[VAL_0]] to %[[VAL_2]] step %[[VAL_1]] unordered {
10+
// CHECK: %[[VAL_5:.*]] = fir.coordinate_of %[[ARG0:.*]], %[[VAL_4]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
11+
// CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_5]] : !fir.ref<f32>
12+
// CHECK: %[[VAL_7:.*]] = fir.coordinate_of %[[ARG1:.*]], %[[VAL_4]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
13+
// CHECK: fir.store %[[VAL_6]] to %[[VAL_7]] : !fir.ref<f32>
14+
// CHECK: }
15+
// CHECK: fir.call @regular_side_effect_func(%[[ARG2:.*]]) : (!fir.ref<f32>) -> ()
16+
// CHECK: fir.call @my_fir_parallel_runtime_func(%[[ARG3:.*]]) : (!fir.ref<f32>) -> ()
17+
// CHECK: fir.do_loop %[[VAL_8:.*]] = %[[VAL_0]] to %[[VAL_2]] step %[[VAL_1]] {
18+
// CHECK: %[[VAL_9:.*]] = fir.coordinate_of %[[ARG0:.*]], %[[VAL_8]] : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
19+
// CHECK: fir.store %[[VAL_3]] to %[[VAL_9]] : !fir.ref<f32>
20+
// CHECK: }
21+
// CHECK: %[[VAL_10:.*]] = fir.load %[[ARG2:.*]] : !fir.ref<f32>
22+
// CHECK: fir.store %[[VAL_10]] to %[[ARG3:.*]] : !fir.ref<f32>
23+
// CHECK: return
24+
// CHECK: }
25+
module {
26+
func.func @regular_side_effect_func(%arg0: !fir.ref<f32>) {
27+
return
28+
}
29+
func.func @my_fir_parallel_runtime_func(%arg0: !fir.ref<f32>) attributes {fir.runtime} {
30+
return
31+
}
32+
func.func @test_fission_workdistribute(%arr1: !fir.ref<!fir.array<10xf32>>, %arr2: !fir.ref<!fir.array<10xf32>>, %scalar_ref1: !fir.ref<f32>, %scalar_ref2: !fir.ref<f32>) {
33+
%c0_idx = arith.constant 0 : index
34+
%c1_idx = arith.constant 1 : index
35+
%c9_idx = arith.constant 9 : index
36+
%float_val = arith.constant 5.0 : f32
37+
omp.teams {
38+
omp.workdistribute {
39+
fir.store %float_val to %scalar_ref1 : !fir.ref<f32>
40+
fir.do_loop %iv = %c0_idx to %c9_idx step %c1_idx unordered {
41+
%elem_ptr_arr1 = fir.coordinate_of %arr1, %iv : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
42+
%loaded_val_loop1 = fir.load %elem_ptr_arr1 : !fir.ref<f32>
43+
%elem_ptr_arr2 = fir.coordinate_of %arr2, %iv : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
44+
fir.store %loaded_val_loop1 to %elem_ptr_arr2 : !fir.ref<f32>
45+
}
46+
fir.call @regular_side_effect_func(%scalar_ref1) : (!fir.ref<f32>) -> ()
47+
fir.call @my_fir_parallel_runtime_func(%scalar_ref2) : (!fir.ref<f32>) -> ()
48+
fir.do_loop %jv = %c0_idx to %c9_idx step %c1_idx {
49+
%elem_ptr_ordered_loop = fir.coordinate_of %arr1, %jv : (!fir.ref<!fir.array<10xf32>>, index) -> !fir.ref<f32>
50+
fir.store %float_val to %elem_ptr_ordered_loop : !fir.ref<f32>
51+
}
52+
%loaded_for_hoist = fir.load %scalar_ref1 : !fir.ref<f32>
53+
fir.store %loaded_for_hoist to %scalar_ref2 : !fir.ref<f32>
54+
omp.terminator
55+
}
56+
omp.terminator
57+
}
58+
return
59+
}
60+
}

flang/test/Transforms/OpenMP/lower-workdistribute.mlir renamed to flang/test/Transforms/OpenMP/lower-workdistribute-to-single.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,4 @@ func.func @_QPtarget_simple() {
4949
omp.terminator
5050
}
5151
return
52-
}
52+
}

0 commit comments

Comments
 (0)