Skip to content

Commit 616d637

Browse files
committed
[mlir][llvm][OpenMP] Support translation for linear clause in omp.wsloop
1 parent e4f3cb2 commit 616d637

File tree

5 files changed

+289
-15
lines changed

5 files changed

+289
-15
lines changed

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3580,6 +3580,9 @@ class CanonicalLoopInfo {
35803580
BasicBlock *Latch = nullptr;
35813581
BasicBlock *Exit = nullptr;
35823582

3583+
// Hold the MLIR value for the `lastiter` of the canonical loop.
3584+
Value *LastIter = nullptr;
3585+
35833586
/// Add the control blocks of this loop to \p BBs.
35843587
///
35853588
/// This does not include any block from the body, including the one returned
@@ -3612,6 +3615,18 @@ class CanonicalLoopInfo {
36123615
void mapIndVar(llvm::function_ref<Value *(Instruction *)> Updater);
36133616

36143617
public:
3618+
/// Sets the last iteration variable for this loop.
3619+
void setLastIter(Value *IterVar) { LastIter = std::move(IterVar); }
3620+
3621+
/// Returns the last iteration variable for this loop.
3622+
/// Certain use-cases (like translation of linear clause) may access
3623+
/// this variable even after a loop transformation. Hence, do not guard
3624+
/// this getter function by `isValid`. It is the responsibility of the
3625+
/// callee to ensure this functionality is not invoked by a non-outlined
3626+
/// CanonicalLoopInfo object (in which case, `setLastIter` will never be
3627+
/// invoked and `LastIter` will be by default `nullptr`).
3628+
Value *getLastIter() { return LastIter; }
3629+
36153630
/// Returns whether this object currently represents the IR of a loop. If
36163631
/// returning false, it may have been consumed by a loop transformation or not
36173632
/// been intialized. Do not use in this case;

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4254,6 +4254,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyStaticWorkshareLoop(
42544254
Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
42554255
Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
42564256
Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
4257+
CLI->setLastIter(PLastIter);
42574258

42584259
// At the end of the preheader, prepare for calling the "init" function by
42594260
// storing the current loop bounds into the allocated space. A canonical loop
@@ -4361,6 +4362,7 @@ OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(DebugLoc DL,
43614362
Value *PUpperBound =
43624363
Builder.CreateAlloca(InternalIVTy, nullptr, "p.upperbound");
43634364
Value *PStride = Builder.CreateAlloca(InternalIVTy, nullptr, "p.stride");
4365+
CLI->setLastIter(PLastIter);
43644366

43654367
// Set up the source location value for the OpenMP runtime.
43664368
Builder.restoreIP(CLI->getPreheaderIP());
@@ -4844,6 +4846,7 @@ OpenMPIRBuilder::applyDynamicWorkshareLoop(DebugLoc DL, CanonicalLoopInfo *CLI,
48444846
Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
48454847
Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
48464848
Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
4849+
CLI->setLastIter(PLastIter);
48474850

48484851
// At the end of the preheader, prepare for calling the "init" function by
48494852
// storing the current loop bounds into the allocated space. A canonical loop

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 183 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,146 @@ class PreviouslyReportedError
124124

125125
char PreviouslyReportedError::ID = 0;
126126

127+
/*
128+
* Custom class for processing linear clause for omp.wsloop
129+
* and omp.simd. Linear clause translation requires setup,
130+
* initialization, update, and finalization at varying
131+
* basic blocks in the IR. This class helps maintain
132+
* internal state to allow consistent translation in
133+
* each of these stages.
134+
*/
135+
136+
class LinearClauseProcessor {
137+
138+
private:
139+
SmallVector<llvm::Value *> linearPreconditionVars;
140+
SmallVector<llvm::Value *> linearLoopBodyTemps;
141+
SmallVector<llvm::AllocaInst *> linearOrigVars;
142+
SmallVector<llvm::Value *> linearOrigVal;
143+
SmallVector<llvm::Value *> linearSteps;
144+
llvm::BasicBlock *linearFinalizationBB;
145+
llvm::BasicBlock *linearExitBB;
146+
llvm::BasicBlock *linearLastIterExitBB;
147+
148+
public:
149+
// Allocate space for linear variabes
150+
void createLinearVar(llvm::IRBuilderBase &builder,
151+
LLVM::ModuleTranslation &moduleTranslation,
152+
mlir::Value &linearVar) {
153+
if (llvm::AllocaInst *linearVarAlloca = dyn_cast<llvm::AllocaInst>(
154+
moduleTranslation.lookupValue(linearVar))) {
155+
linearPreconditionVars.push_back(builder.CreateAlloca(
156+
linearVarAlloca->getAllocatedType(), nullptr, ".linear_var"));
157+
llvm::Value *linearLoopBodyTemp = builder.CreateAlloca(
158+
linearVarAlloca->getAllocatedType(), nullptr, ".linear_result");
159+
linearOrigVal.push_back(moduleTranslation.lookupValue(linearVar));
160+
linearLoopBodyTemps.push_back(linearLoopBodyTemp);
161+
linearOrigVars.push_back(linearVarAlloca);
162+
}
163+
}
164+
165+
// Initialize linear step
166+
inline void initLinearStep(LLVM::ModuleTranslation &moduleTranslation,
167+
mlir::Value &linearStep) {
168+
linearSteps.push_back(moduleTranslation.lookupValue(linearStep));
169+
}
170+
171+
// Emit IR for initialization of linear variables
172+
llvm::OpenMPIRBuilder::InsertPointOrErrorTy
173+
initLinearVar(llvm::IRBuilderBase &builder,
174+
LLVM::ModuleTranslation &moduleTranslation,
175+
llvm::BasicBlock *loopPreHeader) {
176+
builder.SetInsertPoint(loopPreHeader->getTerminator());
177+
for (size_t index = 0; index < linearOrigVars.size(); index++) {
178+
llvm::LoadInst *linearVarLoad = builder.CreateLoad(
179+
linearOrigVars[index]->getAllocatedType(), linearOrigVars[index]);
180+
builder.CreateStore(linearVarLoad, linearPreconditionVars[index]);
181+
}
182+
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
183+
moduleTranslation.getOpenMPBuilder()->createBarrier(
184+
builder.saveIP(), llvm::omp::OMPD_barrier);
185+
return afterBarrierIP;
186+
}
187+
188+
// Emit IR for updating Linear variables
189+
void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
190+
llvm::Value *loopInductionVar) {
191+
builder.SetInsertPoint(loopBody->getTerminator());
192+
for (size_t index = 0; index < linearPreconditionVars.size(); index++) {
193+
// Emit increments for linear vars
194+
llvm::LoadInst *linearVarStart =
195+
builder.CreateLoad(linearOrigVars[index]->getAllocatedType(),
196+
197+
linearPreconditionVars[index]);
198+
auto mulInst = builder.CreateMul(loopInductionVar, linearSteps[index]);
199+
auto addInst = builder.CreateAdd(linearVarStart, mulInst);
200+
builder.CreateStore(addInst, linearLoopBodyTemps[index]);
201+
}
202+
}
203+
204+
// Linear variable finalization is conditional on the last logical iteration.
205+
// Create BB splits to manage the same.
206+
void outlineLinearFinalizationBB(llvm::IRBuilderBase &builder,
207+
llvm::BasicBlock *loopExit) {
208+
linearFinalizationBB = loopExit->splitBasicBlock(
209+
loopExit->getTerminator(), "omp_loop.linear_finalization");
210+
linearExitBB = linearFinalizationBB->splitBasicBlock(
211+
linearFinalizationBB->getTerminator(), "omp_loop.linear_exit");
212+
linearLastIterExitBB = linearFinalizationBB->splitBasicBlock(
213+
linearFinalizationBB->getTerminator(), "omp_loop.linear_lastiter_exit");
214+
}
215+
216+
// Finalize the linear vars
217+
llvm::OpenMPIRBuilder::InsertPointOrErrorTy
218+
finalizeLinearVar(llvm::IRBuilderBase &builder,
219+
LLVM::ModuleTranslation &moduleTranslation,
220+
llvm::Value *lastIter) {
221+
// Emit condition to check whether last logical iteration is being executed
222+
builder.SetInsertPoint(linearFinalizationBB->getTerminator());
223+
llvm::Value *loopLastIterLoad = builder.CreateLoad(
224+
llvm::Type::getInt32Ty(builder.getContext()), lastIter);
225+
llvm::Value *isLast =
226+
builder.CreateCmp(llvm::CmpInst::ICMP_NE, loopLastIterLoad,
227+
llvm::ConstantInt::get(
228+
llvm::Type::getInt32Ty(builder.getContext()), 0));
229+
// Store the linear variable values to original variables.
230+
builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
231+
for (size_t index = 0; index < linearOrigVars.size(); index++) {
232+
llvm::LoadInst *linearVarTemp =
233+
builder.CreateLoad(linearOrigVars[index]->getAllocatedType(),
234+
linearLoopBodyTemps[index]);
235+
builder.CreateStore(linearVarTemp, linearOrigVars[index]);
236+
}
237+
238+
// Create conditional branch such that the linear variable
239+
// values are stored to original variables only at the
240+
// last logical iteration
241+
builder.SetInsertPoint(linearFinalizationBB->getTerminator());
242+
builder.CreateCondBr(isLast, linearLastIterExitBB, linearExitBB);
243+
linearFinalizationBB->getTerminator()->eraseFromParent();
244+
// Emit barrier
245+
builder.SetInsertPoint(linearExitBB->getTerminator());
246+
return moduleTranslation.getOpenMPBuilder()->createBarrier(
247+
builder.saveIP(), llvm::omp::OMPD_barrier);
248+
}
249+
250+
// Rewrite all uses of the original variable in `BBName`
251+
// with the linear variable in-place
252+
void rewriteInPlace(llvm::IRBuilderBase &builder, std::string BBName,
253+
size_t varIndex) {
254+
llvm::SmallVector<llvm::User *> users;
255+
for (llvm::User *user : linearOrigVal[varIndex]->users())
256+
users.push_back(user);
257+
for (auto *user : users) {
258+
if (auto *userInst = dyn_cast<llvm::Instruction>(user)) {
259+
if (userInst->getParent()->getName().str() == BBName)
260+
user->replaceUsesOfWith(linearOrigVal[varIndex],
261+
linearLoopBodyTemps[varIndex]);
262+
}
263+
}
264+
}
265+
};
266+
127267
} // namespace
128268

129269
/// Looks up from the operation from and returns the PrivateClauseOp with
@@ -292,7 +432,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
292432
})
293433
.Case([&](omp::WsloopOp op) {
294434
checkAllocate(op, result);
295-
checkLinear(op, result);
296435
checkOrder(op, result);
297436
checkReduction(op, result);
298437
})
@@ -2423,15 +2562,40 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
24232562
llvm::omp::Directive::OMPD_for);
24242563

24252564
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2565+
2566+
// Initialize linear variables and linear step
2567+
LinearClauseProcessor linearClauseProcessor;
2568+
if (wsloopOp.getLinearVars().size()) {
2569+
for (mlir::Value linearVar : wsloopOp.getLinearVars())
2570+
linearClauseProcessor.createLinearVar(builder, moduleTranslation,
2571+
linearVar);
2572+
for (mlir::Value linearStep : wsloopOp.getLinearStepVars())
2573+
linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
2574+
}
2575+
24262576
llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
24272577
wsloopOp.getRegion(), "omp.wsloop.region", builder, moduleTranslation);
24282578

24292579
if (failed(handleError(regionBlock, opInst)))
24302580
return failure();
24312581

2432-
builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
24332582
llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
24342583

2584+
// Emit Initialization and Update IR for linear variables
2585+
if (wsloopOp.getLinearVars().size()) {
2586+
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
2587+
linearClauseProcessor.initLinearVar(builder, moduleTranslation,
2588+
loopInfo->getPreheader());
2589+
if (failed(handleError(afterBarrierIP, *loopOp)))
2590+
return failure();
2591+
builder.restoreIP(*afterBarrierIP);
2592+
linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
2593+
loopInfo->getIndVar());
2594+
linearClauseProcessor.outlineLinearFinalizationBB(builder,
2595+
loopInfo->getExit());
2596+
}
2597+
2598+
builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
24352599
llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
24362600
ompBuilder->applyWorkshareLoop(
24372601
ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
@@ -2443,6 +2607,23 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
24432607
if (failed(handleError(wsloopIP, opInst)))
24442608
return failure();
24452609

2610+
// Emit finalization and in-place rewrites for linear vars.
2611+
if (wsloopOp.getLinearVars().size()) {
2612+
llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
2613+
assert(loopInfo->getLastIter() &&
2614+
"`lastiter` in CanonicalLoopInfo is nullptr");
2615+
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
2616+
linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
2617+
loopInfo->getLastIter());
2618+
if (failed(handleError(afterBarrierIP, *loopOp)))
2619+
return failure();
2620+
builder.restoreIP(*afterBarrierIP);
2621+
for (size_t index = 0; index < wsloopOp.getLinearVars().size(); index++)
2622+
linearClauseProcessor.rewriteInPlace(builder, "omp.loop_nest.region",
2623+
index);
2624+
builder.restoreIP(oldIP);
2625+
}
2626+
24462627
// Set the correct branch target for task cancellation
24472628
popCancelFinalizationCB(cancelTerminators, *ompBuilder, wsloopIP.get());
24482629

mlir/test/Target/LLVMIR/openmp-llvm.mlir

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,94 @@ llvm.func @wsloop_simple(%arg0: !llvm.ptr) {
358358

359359
// -----
360360

361+
// CHECK-LABEL: wsloop_linear
362+
363+
// CHECK: {{.*}} = alloca i32, i64 1, align 4
364+
// CHECK: %[[Y:.*]] = alloca i32, i64 1, align 4
365+
// CHECK: %[[X:.*]] = alloca i32, i64 1, align 4
366+
367+
// CHECK: entry:
368+
// CHECK: %[[LINEAR_VAR:.*]] = alloca i32, align 4
369+
// CHECK: %[[LINEAR_RESULT:.*]] = alloca i32, align 4
370+
// CHECK: br label %omp_loop.preheader
371+
372+
// CHECK: omp_loop.preheader:
373+
// CHECK: %[[LOAD:.*]] = load i32, ptr %[[X]], align 4
374+
// CHECK: store i32 %[[LOAD]], ptr %[[LINEAR_VAR]], align 4
375+
// CHECK: %omp_global_thread_num = call i32 @__kmpc_global_thread_num(ptr @2)
376+
// CHECK: call void @__kmpc_barrier(ptr @1, i32 %omp_global_thread_num)
377+
378+
// CHECK: omp_loop.body:
379+
// CHECK: %[[LOOP_IV:.*]] = add i32 %omp_loop.iv, {{.*}}
380+
// CHECK: %[[LINEAR_LOAD:.*]] = load i32, ptr %[[LINEAR_VAR]], align 4
381+
// CHECK: %[[MUL:.*]] = mul i32 %[[LOOP_IV]], 1
382+
// CHECK: %[[ADD:.*]] = add i32 %[[LINEAR_LOAD]], %[[MUL]]
383+
// CHECK: store i32 %[[ADD]], ptr %[[LINEAR_RESULT]], align 4
384+
// CHECK: br label %omp.loop_nest.region
385+
386+
// CHECK: omp.loop_nest.region:
387+
// CHECK: %[[LINEAR_LOAD:.*]] = load i32, ptr %[[LINEAR_RESULT]], align 4
388+
// CHECK: %[[ADD:.*]] = add i32 %[[LINEAR_LOAD]], 2
389+
// CHECK: store i32 %[[ADD]], ptr %[[Y]], align 4
390+
391+
// CHECK: omp_loop.exit:
392+
// CHECK: call void @__kmpc_for_static_fini(ptr @2, i32 %omp_global_thread_num4)
393+
// CHECK: %omp_global_thread_num5 = call i32 @__kmpc_global_thread_num(ptr @2)
394+
// CHECK: call void @__kmpc_barrier(ptr @3, i32 %omp_global_thread_num5)
395+
// CHECK: br label %omp_loop.linear_finalization
396+
397+
// CHECK: omp_loop.linear_finalization:
398+
// CHECK: %[[LAST_ITER:.*]] = load i32, ptr %p.lastiter, align 4
399+
// CHECK: %[[CMP:.*]] = icmp ne i32 %[[LAST_ITER]], 0
400+
// CHECK: br i1 %[[CMP]], label %omp_loop.linear_lastiter_exit, label %omp_loop.linear_exit
401+
402+
// CHECK: omp_loop.linear_lastiter_exit:
403+
// CHECK: %[[LINEAR_RESULT_LOAD:.*]] = load i32, ptr %[[LINEAR_RESULT]], align 4
404+
// CHECK: store i32 %[[LINEAR_RESULT_LOAD]], ptr %[[X]], align 4
405+
// CHECK: br label %omp_loop.linear_exit
406+
407+
// CHECK: omp_loop.linear_exit:
408+
// CHECK: %omp_global_thread_num6 = call i32 @__kmpc_global_thread_num(ptr @2)
409+
// CHECK: call void @__kmpc_barrier(ptr @1, i32 %omp_global_thread_num6)
410+
// CHECK: br label %omp_loop.after
411+
412+
llvm.func @wsloop_linear() {
413+
%0 = llvm.mlir.constant(1 : i64) : i64
414+
%1 = llvm.alloca %0 x i32 {bindc_name = "i", pinned} : (i64) -> !llvm.ptr
415+
%2 = llvm.mlir.constant(1 : i64) : i64
416+
%3 = llvm.alloca %2 x i32 {bindc_name = "y"} : (i64) -> !llvm.ptr
417+
%4 = llvm.mlir.constant(1 : i64) : i64
418+
%5 = llvm.alloca %4 x i32 {bindc_name = "x"} : (i64) -> !llvm.ptr
419+
%6 = llvm.mlir.constant(1 : i64) : i64
420+
%7 = llvm.alloca %6 x i32 {bindc_name = "i"} : (i64) -> !llvm.ptr
421+
%8 = llvm.mlir.constant(2 : i32) : i32
422+
%9 = llvm.mlir.constant(10 : i32) : i32
423+
%10 = llvm.mlir.constant(1 : i32) : i32
424+
%11 = llvm.mlir.constant(1 : i64) : i64
425+
%12 = llvm.mlir.constant(1 : i64) : i64
426+
%13 = llvm.mlir.constant(1 : i64) : i64
427+
%14 = llvm.mlir.constant(1 : i64) : i64
428+
omp.wsloop linear(%5 = %10 : !llvm.ptr) {
429+
omp.loop_nest (%arg0) : i32 = (%10) to (%9) inclusive step (%10) {
430+
llvm.store %arg0, %1 : i32, !llvm.ptr
431+
%15 = llvm.load %5 : !llvm.ptr -> i32
432+
%16 = llvm.add %15, %8 : i32
433+
llvm.store %16, %3 : i32, !llvm.ptr
434+
%17 = llvm.add %arg0, %10 : i32
435+
%18 = llvm.icmp "sgt" %17, %9 : i32
436+
llvm.cond_br %18, ^bb1, ^bb2
437+
^bb1: // pred: ^bb0
438+
llvm.store %17, %1 : i32, !llvm.ptr
439+
llvm.br ^bb2
440+
^bb2: // 2 preds: ^bb0, ^bb1
441+
omp.yield
442+
}
443+
}
444+
llvm.return
445+
}
446+
447+
// -----
448+
361449
// CHECK-LABEL: @wsloop_inclusive_1
362450
llvm.func @wsloop_inclusive_1(%arg0: !llvm.ptr) {
363451
%0 = llvm.mlir.constant(42 : index) : i64

mlir/test/Target/LLVMIR/openmp-todo.mlir

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -511,19 +511,6 @@ llvm.func @wsloop_allocate(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
511511

512512
// -----
513513

514-
llvm.func @wsloop_linear(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
515-
// expected-error@below {{not yet implemented: Unhandled clause linear in omp.wsloop operation}}
516-
// expected-error@below {{LLVM Translation failed for operation: omp.wsloop}}
517-
omp.wsloop linear(%x = %step : !llvm.ptr) {
518-
omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
519-
omp.yield
520-
}
521-
}
522-
llvm.return
523-
}
524-
525-
// -----
526-
527514
llvm.func @wsloop_order(%lb : i32, %ub : i32, %step : i32) {
528515
// expected-error@below {{not yet implemented: Unhandled clause order in omp.wsloop operation}}
529516
// expected-error@below {{LLVM Translation failed for operation: omp.wsloop}}

0 commit comments

Comments
 (0)