Skip to content

Commit 5177cad

Browse files
committed
[flang][OpenMP] Support MLIR lowering of linear clause for omp.wsloop
1 parent 7d92756 commit 5177cad

File tree

3 files changed

+66
-41
lines changed

3 files changed

+66
-41
lines changed

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1816,8 +1816,7 @@ static void genSimdClauses(
18161816
cp.processReduction(loc, clauseOps, reductionSyms);
18171817
cp.processSafelen(clauseOps);
18181818
cp.processSimdlen(clauseOps);
1819-
1820-
cp.processTODO<clause::Linear>(loc, llvm::omp::Directive::OMPD_simd);
1819+
cp.processLinear(clauseOps);
18211820
}
18221821

18231822
static void genSingleClauses(lower::AbstractConverter &converter,
@@ -2007,9 +2006,9 @@ static void genWsloopClauses(
20072006
cp.processOrdered(clauseOps);
20082007
cp.processReduction(loc, clauseOps, reductionSyms);
20092008
cp.processSchedule(stmtCtx, clauseOps);
2009+
cp.processLinear(clauseOps);
20102010

2011-
cp.processTODO<clause::Allocate, clause::Linear>(
2012-
loc, llvm::omp::Directive::OMPD_do);
2011+
cp.processTODO<clause::Allocate>(loc, llvm::omp::Directive::OMPD_do);
20132012
}
20142013

20152014
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2608,7 +2608,7 @@ void SimdOp::build(OpBuilder &builder, OperationState &state,
26082608
// TODO Store clauses in op: linearVars, linearStepVars
26092609
SimdOp::build(builder, state, clauses.alignedVars,
26102610
makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr,
2611-
/*linear_vars=*/{}, /*linear_step_vars=*/{},
2611+
clauses.linearVars, clauses.linearStepVars,
26122612
clauses.nontemporalVars, clauses.order, clauses.orderMod,
26132613
clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
26142614
clauses.privateNeedsBarrier, clauses.reductionMod,

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 62 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,9 @@ class LinearClauseProcessor {
147147

148148
public:
149149
// Allocate space for linear variabes
150-
void createLinearVar(llvm::IRBuilderBase &builder,
151-
LLVM::ModuleTranslation &moduleTranslation,
152-
mlir::Value &linearVar) {
150+
LogicalResult createLinearVar(llvm::IRBuilderBase &builder,
151+
LLVM::ModuleTranslation &moduleTranslation,
152+
mlir::Value &linearVar, Operation &op) {
153153
if (llvm::AllocaInst *linearVarAlloca = dyn_cast<llvm::AllocaInst>(
154154
moduleTranslation.lookupValue(linearVar))) {
155155
linearPreconditionVars.push_back(builder.CreateAlloca(
@@ -159,7 +159,12 @@ class LinearClauseProcessor {
159159
linearOrigVal.push_back(moduleTranslation.lookupValue(linearVar));
160160
linearLoopBodyTemps.push_back(linearLoopBodyTemp);
161161
linearOrigVars.push_back(linearVarAlloca);
162+
return success();
162163
}
164+
165+
else
166+
return op.emitError() << "not yet implemented: linear clause support"
167+
<< " for non alloca linear variables";
163168
}
164169

165170
// Initialize linear step
@@ -169,20 +174,15 @@ class LinearClauseProcessor {
169174
}
170175

171176
// Emit IR for initialization of linear variables
172-
llvm::OpenMPIRBuilder::InsertPointOrErrorTy
173-
initLinearVar(llvm::IRBuilderBase &builder,
174-
LLVM::ModuleTranslation &moduleTranslation,
175-
llvm::BasicBlock *loopPreHeader) {
177+
void initLinearVar(llvm::IRBuilderBase &builder,
178+
LLVM::ModuleTranslation &moduleTranslation,
179+
llvm::BasicBlock *loopPreHeader) {
176180
builder.SetInsertPoint(loopPreHeader->getTerminator());
177181
for (size_t index = 0; index < linearOrigVars.size(); index++) {
178182
llvm::LoadInst *linearVarLoad = builder.CreateLoad(
179183
linearOrigVars[index]->getAllocatedType(), linearOrigVars[index]);
180184
builder.CreateStore(linearVarLoad, linearPreconditionVars[index]);
181185
}
182-
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
183-
moduleTranslation.getOpenMPBuilder()->createBarrier(
184-
builder.saveIP(), llvm::omp::OMPD_barrier);
185-
return afterBarrierIP;
186186
}
187187

188188
// Emit IR for updating Linear variables
@@ -193,18 +193,27 @@ class LinearClauseProcessor {
193193
// Emit increments for linear vars
194194
llvm::LoadInst *linearVarStart =
195195
builder.CreateLoad(linearOrigVars[index]->getAllocatedType(),
196-
197196
linearPreconditionVars[index]);
197+
198198
auto mulInst = builder.CreateMul(loopInductionVar, linearSteps[index]);
199-
auto addInst = builder.CreateAdd(linearVarStart, mulInst);
200-
builder.CreateStore(addInst, linearLoopBodyTemps[index]);
199+
if (linearOrigVars[index]->getAllocatedType()->isIntegerTy()) {
200+
auto addInst = builder.CreateAdd(linearVarStart, mulInst);
201+
builder.CreateStore(addInst, linearLoopBodyTemps[index]);
202+
} else if (linearOrigVars[index]
203+
->getAllocatedType()
204+
->isFloatingPointTy()) {
205+
auto cvt = builder.CreateSIToFP(
206+
mulInst, linearOrigVars[index]->getAllocatedType());
207+
auto addInst = builder.CreateFAdd(linearVarStart, cvt);
208+
builder.CreateStore(addInst, linearLoopBodyTemps[index]);
209+
}
201210
}
202211
}
203212

204213
// Linear variable finalization is conditional on the last logical iteration.
205214
// Create BB splits to manage the same.
206-
void outlineLinearFinalizationBB(llvm::IRBuilderBase &builder,
207-
llvm::BasicBlock *loopExit) {
215+
void splitLinearFiniBB(llvm::IRBuilderBase &builder,
216+
llvm::BasicBlock *loopExit) {
208217
linearFinalizationBB = loopExit->splitBasicBlock(
209218
loopExit->getTerminator(), "omp_loop.linear_finalization");
210219
linearExitBB = linearFinalizationBB->splitBasicBlock(
@@ -339,10 +348,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
339348
if (!op.getIsDevicePtrVars().empty())
340349
result = todo("is_device_ptr");
341350
};
342-
auto checkLinear = [&todo](auto op, LogicalResult &result) {
343-
if (!op.getLinearVars().empty() || !op.getLinearStepVars().empty())
344-
result = todo("linear");
345-
};
346351
auto checkNowait = [&todo](auto op, LogicalResult &result) {
347352
if (op.getNowait())
348353
result = todo("nowait");
@@ -432,18 +437,14 @@ static LogicalResult checkImplementationStatus(Operation &op) {
432437
})
433438
.Case([&](omp::WsloopOp op) {
434439
checkAllocate(op, result);
435-
checkLinear(op, result);
436440
checkOrder(op, result);
437441
checkReduction(op, result);
438442
})
439443
.Case([&](omp::ParallelOp op) {
440444
checkAllocate(op, result);
441445
checkReduction(op, result);
442446
})
443-
.Case([&](omp::SimdOp op) {
444-
checkLinear(op, result);
445-
checkReduction(op, result);
446-
})
447+
.Case([&](omp::SimdOp op) { checkReduction(op, result); })
447448
.Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
448449
omp::AtomicCaptureOp>([&](auto op) { checkHint(op, result); })
449450
.Case<omp::TargetEnterDataOp, omp::TargetExitDataOp, omp::TargetUpdateOp>(
@@ -2587,13 +2588,13 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
25872588

25882589
// Initialize linear variables and linear step
25892590
LinearClauseProcessor linearClauseProcessor;
2590-
if (wsloopOp.getLinearVars().size()) {
2591-
for (mlir::Value linearVar : wsloopOp.getLinearVars())
2592-
linearClauseProcessor.createLinearVar(builder, moduleTranslation,
2593-
linearVar);
2594-
for (mlir::Value linearStep : wsloopOp.getLinearStepVars())
2595-
linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
2591+
for (mlir::Value linearVar : wsloopOp.getLinearVars()) {
2592+
if (failed(linearClauseProcessor.createLinearVar(builder, moduleTranslation,
2593+
linearVar, opInst)))
2594+
return failure();
25962595
}
2596+
for (mlir::Value linearStep : wsloopOp.getLinearStepVars())
2597+
linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
25972598

25982599
llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
25992600
wsloopOp.getRegion(), "omp.wsloop.region", builder, moduleTranslation);
@@ -2605,16 +2606,17 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
26052606

26062607
// Emit Initialization and Update IR for linear variables
26072608
if (wsloopOp.getLinearVars().size()) {
2609+
linearClauseProcessor.initLinearVar(builder, moduleTranslation,
2610+
loopInfo->getPreheader());
26082611
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
2609-
linearClauseProcessor.initLinearVar(builder, moduleTranslation,
2610-
loopInfo->getPreheader());
2612+
moduleTranslation.getOpenMPBuilder()->createBarrier(
2613+
builder.saveIP(), llvm::omp::OMPD_barrier);
26112614
if (failed(handleError(afterBarrierIP, *loopOp)))
26122615
return failure();
26132616
builder.restoreIP(*afterBarrierIP);
26142617
linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
26152618
loopInfo->getIndVar());
2616-
linearClauseProcessor.outlineLinearFinalizationBB(builder,
2617-
loopInfo->getExit());
2619+
linearClauseProcessor.splitLinearFiniBB(builder, loopInfo->getExit());
26182620
}
26192621

26202622
builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
@@ -2882,6 +2884,17 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
28822884
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
28832885
findAllocaInsertPoint(builder, moduleTranslation);
28842886

2887+
// Create linear variables and initialize linear step
2888+
LinearClauseProcessor linearClauseProcessor;
2889+
2890+
for (mlir::Value linearVar : simdOp.getLinearVars()) {
2891+
if (failed(linearClauseProcessor.createLinearVar(builder, moduleTranslation,
2892+
linearVar, opInst)))
2893+
return failure();
2894+
}
2895+
for (mlir::Value linearStep : simdOp.getLinearStepVars())
2896+
linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
2897+
28852898
llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars(
28862899
builder, moduleTranslation, privateVarsInfo, allocaIP);
28872900
if (handleError(afterAllocas, opInst).failed())
@@ -2945,14 +2958,27 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
29452958
if (failed(handleError(regionBlock, opInst)))
29462959
return failure();
29472960

2948-
builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
29492961
llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
2962+
2963+
// Emit Initialization for linear variables
2964+
if (simdOp.getLinearVars().size()) {
2965+
linearClauseProcessor.initLinearVar(builder, moduleTranslation,
2966+
loopInfo->getPreheader());
2967+
linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
2968+
loopInfo->getIndVar());
2969+
}
2970+
builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
2971+
29502972
ompBuilder->applySimd(loopInfo, alignedVars,
29512973
simdOp.getIfExpr()
29522974
? moduleTranslation.lookupValue(simdOp.getIfExpr())
29532975
: nullptr,
29542976
order, simdlen, safelen);
29552977

2978+
for (size_t index = 0; index < simdOp.getLinearVars().size(); index++)
2979+
linearClauseProcessor.rewriteInPlace(builder, "omp.loop_nest.region",
2980+
index);
2981+
29562982
// We now need to reduce the per-simd-lane reduction variable into the
29572983
// original variable. This works a bit differently to other reductions (e.g.
29582984
// wsloop) because we don't need to call into the OpenMP runtime to handle

0 commit comments

Comments
 (0)