@@ -147,9 +147,9 @@ class LinearClauseProcessor {
147
147
148
148
public:
149
149
// Allocate space for linear variabes
150
- void createLinearVar (llvm::IRBuilderBase &builder,
151
- LLVM::ModuleTranslation &moduleTranslation,
152
- mlir::Value &linearVar) {
150
+ LogicalResult createLinearVar (llvm::IRBuilderBase &builder,
151
+ LLVM::ModuleTranslation &moduleTranslation,
152
+ mlir::Value &linearVar, Operation &op ) {
153
153
if (llvm::AllocaInst *linearVarAlloca = dyn_cast<llvm::AllocaInst>(
154
154
moduleTranslation.lookupValue (linearVar))) {
155
155
linearPreconditionVars.push_back (builder.CreateAlloca (
@@ -159,7 +159,12 @@ class LinearClauseProcessor {
159
159
linearOrigVal.push_back (moduleTranslation.lookupValue (linearVar));
160
160
linearLoopBodyTemps.push_back (linearLoopBodyTemp);
161
161
linearOrigVars.push_back (linearVarAlloca);
162
+ return success ();
162
163
}
164
+
165
+ else
166
+ return op.emitError () << " not yet implemented: linear clause support"
167
+ << " for non alloca linear variables" ;
163
168
}
164
169
165
170
// Initialize linear step
@@ -169,20 +174,15 @@ class LinearClauseProcessor {
169
174
}
170
175
171
176
// Emit IR for initialization of linear variables
172
- llvm::OpenMPIRBuilder::InsertPointOrErrorTy
173
- initLinearVar (llvm::IRBuilderBase &builder,
174
- LLVM::ModuleTranslation &moduleTranslation,
175
- llvm::BasicBlock *loopPreHeader) {
177
+ void initLinearVar (llvm::IRBuilderBase &builder,
178
+ LLVM::ModuleTranslation &moduleTranslation,
179
+ llvm::BasicBlock *loopPreHeader) {
176
180
builder.SetInsertPoint (loopPreHeader->getTerminator ());
177
181
for (size_t index = 0 ; index < linearOrigVars.size (); index++) {
178
182
llvm::LoadInst *linearVarLoad = builder.CreateLoad (
179
183
linearOrigVars[index]->getAllocatedType (), linearOrigVars[index]);
180
184
builder.CreateStore (linearVarLoad, linearPreconditionVars[index]);
181
185
}
182
- llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
183
- moduleTranslation.getOpenMPBuilder ()->createBarrier (
184
- builder.saveIP (), llvm::omp::OMPD_barrier);
185
- return afterBarrierIP;
186
186
}
187
187
188
188
// Emit IR for updating Linear variables
@@ -193,18 +193,27 @@ class LinearClauseProcessor {
193
193
// Emit increments for linear vars
194
194
llvm::LoadInst *linearVarStart =
195
195
builder.CreateLoad (linearOrigVars[index]->getAllocatedType (),
196
-
197
196
linearPreconditionVars[index]);
197
+
198
198
auto mulInst = builder.CreateMul (loopInductionVar, linearSteps[index]);
199
- auto addInst = builder.CreateAdd (linearVarStart, mulInst);
200
- builder.CreateStore (addInst, linearLoopBodyTemps[index]);
199
+ if (linearOrigVars[index]->getAllocatedType ()->isIntegerTy ()) {
200
+ auto addInst = builder.CreateAdd (linearVarStart, mulInst);
201
+ builder.CreateStore (addInst, linearLoopBodyTemps[index]);
202
+ } else if (linearOrigVars[index]
203
+ ->getAllocatedType ()
204
+ ->isFloatingPointTy ()) {
205
+ auto cvt = builder.CreateSIToFP (
206
+ mulInst, linearOrigVars[index]->getAllocatedType ());
207
+ auto addInst = builder.CreateFAdd (linearVarStart, cvt);
208
+ builder.CreateStore (addInst, linearLoopBodyTemps[index]);
209
+ }
201
210
}
202
211
}
203
212
204
213
// Linear variable finalization is conditional on the last logical iteration.
205
214
// Create BB splits to manage the same.
206
- void outlineLinearFinalizationBB (llvm::IRBuilderBase &builder,
207
- llvm::BasicBlock *loopExit) {
215
+ void splitLinearFiniBB (llvm::IRBuilderBase &builder,
216
+ llvm::BasicBlock *loopExit) {
208
217
linearFinalizationBB = loopExit->splitBasicBlock (
209
218
loopExit->getTerminator (), " omp_loop.linear_finalization" );
210
219
linearExitBB = linearFinalizationBB->splitBasicBlock (
@@ -339,10 +348,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
339
348
if (!op.getIsDevicePtrVars ().empty ())
340
349
result = todo (" is_device_ptr" );
341
350
};
342
- auto checkLinear = [&todo](auto op, LogicalResult &result) {
343
- if (!op.getLinearVars ().empty () || !op.getLinearStepVars ().empty ())
344
- result = todo (" linear" );
345
- };
346
351
auto checkNowait = [&todo](auto op, LogicalResult &result) {
347
352
if (op.getNowait ())
348
353
result = todo (" nowait" );
@@ -432,18 +437,14 @@ static LogicalResult checkImplementationStatus(Operation &op) {
432
437
})
433
438
.Case ([&](omp::WsloopOp op) {
434
439
checkAllocate (op, result);
435
- checkLinear (op, result);
436
440
checkOrder (op, result);
437
441
checkReduction (op, result);
438
442
})
439
443
.Case ([&](omp::ParallelOp op) {
440
444
checkAllocate (op, result);
441
445
checkReduction (op, result);
442
446
})
443
- .Case ([&](omp::SimdOp op) {
444
- checkLinear (op, result);
445
- checkReduction (op, result);
446
- })
447
+ .Case ([&](omp::SimdOp op) { checkReduction (op, result); })
447
448
.Case <omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
448
449
omp::AtomicCaptureOp>([&](auto op) { checkHint (op, result); })
449
450
.Case <omp::TargetEnterDataOp, omp::TargetExitDataOp, omp::TargetUpdateOp>(
@@ -2587,13 +2588,13 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
2587
2588
2588
2589
// Initialize linear variables and linear step
2589
2590
LinearClauseProcessor linearClauseProcessor;
2590
- if (wsloopOp.getLinearVars ().size ()) {
2591
- for (mlir::Value linearVar : wsloopOp.getLinearVars ())
2592
- linearClauseProcessor.createLinearVar (builder, moduleTranslation,
2593
- linearVar);
2594
- for (mlir::Value linearStep : wsloopOp.getLinearStepVars ())
2595
- linearClauseProcessor.initLinearStep (moduleTranslation, linearStep);
2591
+ for (mlir::Value linearVar : wsloopOp.getLinearVars ()) {
2592
+ if (failed (linearClauseProcessor.createLinearVar (builder, moduleTranslation,
2593
+ linearVar, opInst)))
2594
+ return failure ();
2596
2595
}
2596
+ for (mlir::Value linearStep : wsloopOp.getLinearStepVars ())
2597
+ linearClauseProcessor.initLinearStep (moduleTranslation, linearStep);
2597
2598
2598
2599
llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions (
2599
2600
wsloopOp.getRegion (), " omp.wsloop.region" , builder, moduleTranslation);
@@ -2605,16 +2606,17 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
2605
2606
2606
2607
// Emit Initialization and Update IR for linear variables
2607
2608
if (wsloopOp.getLinearVars ().size ()) {
2609
+ linearClauseProcessor.initLinearVar (builder, moduleTranslation,
2610
+ loopInfo->getPreheader ());
2608
2611
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
2609
- linearClauseProcessor. initLinearVar (builder, moduleTranslation,
2610
- loopInfo-> getPreheader () );
2612
+ moduleTranslation. getOpenMPBuilder ()-> createBarrier (
2613
+ builder. saveIP (), llvm::omp::OMPD_barrier );
2611
2614
if (failed (handleError (afterBarrierIP, *loopOp)))
2612
2615
return failure ();
2613
2616
builder.restoreIP (*afterBarrierIP);
2614
2617
linearClauseProcessor.updateLinearVar (builder, loopInfo->getBody (),
2615
2618
loopInfo->getIndVar ());
2616
- linearClauseProcessor.outlineLinearFinalizationBB (builder,
2617
- loopInfo->getExit ());
2619
+ linearClauseProcessor.splitLinearFiniBB (builder, loopInfo->getExit ());
2618
2620
}
2619
2621
2620
2622
builder.SetInsertPoint (*regionBlock, (*regionBlock)->begin ());
@@ -2882,6 +2884,17 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
2882
2884
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2883
2885
findAllocaInsertPoint (builder, moduleTranslation);
2884
2886
2887
+ // Create linear variables and initialize linear step
2888
+ LinearClauseProcessor linearClauseProcessor;
2889
+
2890
+ for (mlir::Value linearVar : simdOp.getLinearVars ()) {
2891
+ if (failed (linearClauseProcessor.createLinearVar (builder, moduleTranslation,
2892
+ linearVar, opInst)))
2893
+ return failure ();
2894
+ }
2895
+ for (mlir::Value linearStep : simdOp.getLinearStepVars ())
2896
+ linearClauseProcessor.initLinearStep (moduleTranslation, linearStep);
2897
+
2885
2898
llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars (
2886
2899
builder, moduleTranslation, privateVarsInfo, allocaIP);
2887
2900
if (handleError (afterAllocas, opInst).failed ())
@@ -2945,14 +2958,27 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
2945
2958
if (failed (handleError (regionBlock, opInst)))
2946
2959
return failure ();
2947
2960
2948
- builder.SetInsertPoint (*regionBlock, (*regionBlock)->begin ());
2949
2961
llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo (moduleTranslation);
2962
+
2963
+ // Emit Initialization for linear variables
2964
+ if (simdOp.getLinearVars ().size ()) {
2965
+ linearClauseProcessor.initLinearVar (builder, moduleTranslation,
2966
+ loopInfo->getPreheader ());
2967
+ linearClauseProcessor.updateLinearVar (builder, loopInfo->getBody (),
2968
+ loopInfo->getIndVar ());
2969
+ }
2970
+ builder.SetInsertPoint (*regionBlock, (*regionBlock)->begin ());
2971
+
2950
2972
ompBuilder->applySimd (loopInfo, alignedVars,
2951
2973
simdOp.getIfExpr ()
2952
2974
? moduleTranslation.lookupValue (simdOp.getIfExpr ())
2953
2975
: nullptr ,
2954
2976
order, simdlen, safelen);
2955
2977
2978
+ for (size_t index = 0 ; index < simdOp.getLinearVars ().size (); index++)
2979
+ linearClauseProcessor.rewriteInPlace (builder, " omp.loop_nest.region" ,
2980
+ index);
2981
+
2956
2982
// We now need to reduce the per-simd-lane reduction variable into the
2957
2983
// original variable. This works a bit differently to other reductions (e.g.
2958
2984
// wsloop) because we don't need to call into the OpenMP runtime to handle
0 commit comments