@@ -124,6 +124,146 @@ class PreviouslyReportedError
124
124
125
125
char PreviouslyReportedError::ID = 0 ;
126
126
127
+ /*
128
+ * Custom class for processing linear clause for omp.wsloop
129
+ * and omp.simd. Linear clause translation requires setup,
130
+ * initialization, update, and finalization at varying
131
+ * basic blocks in the IR. This class helps maintain
132
+ * internal state to allow consistent translation in
133
+ * each of these stages.
134
+ */
135
+
136
+ class LinearClauseProcessor {
137
+
138
+ private:
139
+ SmallVector<llvm::Value *> linearPreconditionVars;
140
+ SmallVector<llvm::Value *> linearLoopBodyTemps;
141
+ SmallVector<llvm::AllocaInst *> linearOrigVars;
142
+ SmallVector<llvm::Value *> linearOrigVal;
143
+ SmallVector<llvm::Value *> linearSteps;
144
+ llvm::BasicBlock *linearFinalizationBB;
145
+ llvm::BasicBlock *linearExitBB;
146
+ llvm::BasicBlock *linearLastIterExitBB;
147
+
148
+ public:
149
+ // Allocate space for linear variabes
150
+ void createLinearVar (llvm::IRBuilderBase &builder,
151
+ LLVM::ModuleTranslation &moduleTranslation,
152
+ mlir::Value &linearVar) {
153
+ if (llvm::AllocaInst *linearVarAlloca = dyn_cast<llvm::AllocaInst>(
154
+ moduleTranslation.lookupValue (linearVar))) {
155
+ linearPreconditionVars.push_back (builder.CreateAlloca (
156
+ linearVarAlloca->getAllocatedType (), nullptr , " .linear_var" ));
157
+ llvm::Value *linearLoopBodyTemp = builder.CreateAlloca (
158
+ linearVarAlloca->getAllocatedType (), nullptr , " .linear_result" );
159
+ linearOrigVal.push_back (moduleTranslation.lookupValue (linearVar));
160
+ linearLoopBodyTemps.push_back (linearLoopBodyTemp);
161
+ linearOrigVars.push_back (linearVarAlloca);
162
+ }
163
+ }
164
+
165
+ // Initialize linear step
166
+ inline void initLinearStep (LLVM::ModuleTranslation &moduleTranslation,
167
+ mlir::Value &linearStep) {
168
+ linearSteps.push_back (moduleTranslation.lookupValue (linearStep));
169
+ }
170
+
171
+ // Emit IR for initialization of linear variables
172
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy
173
+ initLinearVar (llvm::IRBuilderBase &builder,
174
+ LLVM::ModuleTranslation &moduleTranslation,
175
+ llvm::BasicBlock *loopPreHeader) {
176
+ builder.SetInsertPoint (loopPreHeader->getTerminator ());
177
+ for (size_t index = 0 ; index < linearOrigVars.size (); index++) {
178
+ llvm::LoadInst *linearVarLoad = builder.CreateLoad (
179
+ linearOrigVars[index]->getAllocatedType (), linearOrigVars[index]);
180
+ builder.CreateStore (linearVarLoad, linearPreconditionVars[index]);
181
+ }
182
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
183
+ moduleTranslation.getOpenMPBuilder ()->createBarrier (
184
+ builder.saveIP (), llvm::omp::OMPD_barrier);
185
+ return afterBarrierIP;
186
+ }
187
+
188
+ // Emit IR for updating Linear variables
189
+ void updateLinearVar (llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
190
+ llvm::Value *loopInductionVar) {
191
+ builder.SetInsertPoint (loopBody->getTerminator ());
192
+ for (size_t index = 0 ; index < linearPreconditionVars.size (); index++) {
193
+ // Emit increments for linear vars
194
+ llvm::LoadInst *linearVarStart =
195
+ builder.CreateLoad (linearOrigVars[index]->getAllocatedType (),
196
+
197
+ linearPreconditionVars[index]);
198
+ auto mulInst = builder.CreateMul (loopInductionVar, linearSteps[index]);
199
+ auto addInst = builder.CreateAdd (linearVarStart, mulInst);
200
+ builder.CreateStore (addInst, linearLoopBodyTemps[index]);
201
+ }
202
+ }
203
+
204
+ // Linear variable finalization is conditional on the last logical iteration.
205
+ // Create BB splits to manage the same.
206
+ void outlineLinearFinalizationBB (llvm::IRBuilderBase &builder,
207
+ llvm::BasicBlock *loopExit) {
208
+ linearFinalizationBB = loopExit->splitBasicBlock (
209
+ loopExit->getTerminator (), " omp_loop.linear_finalization" );
210
+ linearExitBB = linearFinalizationBB->splitBasicBlock (
211
+ linearFinalizationBB->getTerminator (), " omp_loop.linear_exit" );
212
+ linearLastIterExitBB = linearFinalizationBB->splitBasicBlock (
213
+ linearFinalizationBB->getTerminator (), " omp_loop.linear_lastiter_exit" );
214
+ }
215
+
216
+ // Finalize the linear vars
217
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy
218
+ finalizeLinearVar (llvm::IRBuilderBase &builder,
219
+ LLVM::ModuleTranslation &moduleTranslation,
220
+ llvm::Value *lastIter) {
221
+ // Emit condition to check whether last logical iteration is being executed
222
+ builder.SetInsertPoint (linearFinalizationBB->getTerminator ());
223
+ llvm::Value *loopLastIterLoad = builder.CreateLoad (
224
+ llvm::Type::getInt32Ty (builder.getContext ()), lastIter);
225
+ llvm::Value *isLast =
226
+ builder.CreateCmp (llvm::CmpInst::ICMP_NE, loopLastIterLoad,
227
+ llvm::ConstantInt::get (
228
+ llvm::Type::getInt32Ty (builder.getContext ()), 0 ));
229
+ // Store the linear variable values to original variables.
230
+ builder.SetInsertPoint (linearLastIterExitBB->getTerminator ());
231
+ for (size_t index = 0 ; index < linearOrigVars.size (); index++) {
232
+ llvm::LoadInst *linearVarTemp =
233
+ builder.CreateLoad (linearOrigVars[index]->getAllocatedType (),
234
+ linearLoopBodyTemps[index]);
235
+ builder.CreateStore (linearVarTemp, linearOrigVars[index]);
236
+ }
237
+
238
+ // Create conditional branch such that the linear variable
239
+ // values are stored to original variables only at the
240
+ // last logical iteration
241
+ builder.SetInsertPoint (linearFinalizationBB->getTerminator ());
242
+ builder.CreateCondBr (isLast, linearLastIterExitBB, linearExitBB);
243
+ linearFinalizationBB->getTerminator ()->eraseFromParent ();
244
+ // Emit barrier
245
+ builder.SetInsertPoint (linearExitBB->getTerminator ());
246
+ return moduleTranslation.getOpenMPBuilder ()->createBarrier (
247
+ builder.saveIP (), llvm::omp::OMPD_barrier);
248
+ }
249
+
250
+ // Rewrite all uses of the original variable in `BBName`
251
+ // with the linear variable in-place
252
+ void rewriteInPlace (llvm::IRBuilderBase &builder, std::string BBName,
253
+ size_t varIndex) {
254
+ llvm::SmallVector<llvm::User *> users;
255
+ for (llvm::User *user : linearOrigVal[varIndex]->users ())
256
+ users.push_back (user);
257
+ for (auto *user : users) {
258
+ if (auto *userInst = dyn_cast<llvm::Instruction>(user)) {
259
+ if (userInst->getParent ()->getName ().str () == BBName)
260
+ user->replaceUsesOfWith (linearOrigVal[varIndex],
261
+ linearLoopBodyTemps[varIndex]);
262
+ }
263
+ }
264
+ }
265
+ };
266
+
127
267
} // namespace
128
268
129
269
// / Looks up from the operation from and returns the PrivateClauseOp with
@@ -292,7 +432,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
292
432
})
293
433
.Case ([&](omp::WsloopOp op) {
294
434
checkAllocate (op, result);
295
- checkLinear (op, result);
296
435
checkOrder (op, result);
297
436
checkReduction (op, result);
298
437
})
@@ -2423,15 +2562,40 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
2423
2562
llvm::omp::Directive::OMPD_for);
2424
2563
2425
2564
llvm::OpenMPIRBuilder::LocationDescription ompLoc (builder);
2565
+
2566
+ // Initialize linear variables and linear step
2567
+ LinearClauseProcessor linearClauseProcessor;
2568
+ if (wsloopOp.getLinearVars ().size ()) {
2569
+ for (mlir::Value linearVar : wsloopOp.getLinearVars ())
2570
+ linearClauseProcessor.createLinearVar (builder, moduleTranslation,
2571
+ linearVar);
2572
+ for (mlir::Value linearStep : wsloopOp.getLinearStepVars ())
2573
+ linearClauseProcessor.initLinearStep (moduleTranslation, linearStep);
2574
+ }
2575
+
2426
2576
llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions (
2427
2577
wsloopOp.getRegion (), " omp.wsloop.region" , builder, moduleTranslation);
2428
2578
2429
2579
if (failed (handleError (regionBlock, opInst)))
2430
2580
return failure ();
2431
2581
2432
- builder.SetInsertPoint (*regionBlock, (*regionBlock)->begin ());
2433
2582
llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo (moduleTranslation);
2434
2583
2584
+ // Emit Initialization and Update IR for linear variables
2585
+ if (wsloopOp.getLinearVars ().size ()) {
2586
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
2587
+ linearClauseProcessor.initLinearVar (builder, moduleTranslation,
2588
+ loopInfo->getPreheader ());
2589
+ if (failed (handleError (afterBarrierIP, *loopOp)))
2590
+ return failure ();
2591
+ builder.restoreIP (*afterBarrierIP);
2592
+ linearClauseProcessor.updateLinearVar (builder, loopInfo->getBody (),
2593
+ loopInfo->getIndVar ());
2594
+ linearClauseProcessor.outlineLinearFinalizationBB (builder,
2595
+ loopInfo->getExit ());
2596
+ }
2597
+
2598
+ builder.SetInsertPoint (*regionBlock, (*regionBlock)->begin ());
2435
2599
llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
2436
2600
ompBuilder->applyWorkshareLoop (
2437
2601
ompLoc.DL , loopInfo, allocaIP, loopNeedsBarrier,
@@ -2443,6 +2607,23 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
2443
2607
if (failed (handleError (wsloopIP, opInst)))
2444
2608
return failure ();
2445
2609
2610
+ // Emit finalization and in-place rewrites for linear vars.
2611
+ if (wsloopOp.getLinearVars ().size ()) {
2612
+ llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP ();
2613
+ assert (loopInfo->getLastIter () &&
2614
+ " `lastiter` in CanonicalLoopInfo is nullptr" );
2615
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
2616
+ linearClauseProcessor.finalizeLinearVar (builder, moduleTranslation,
2617
+ loopInfo->getLastIter ());
2618
+ if (failed (handleError (afterBarrierIP, *loopOp)))
2619
+ return failure ();
2620
+ builder.restoreIP (*afterBarrierIP);
2621
+ for (size_t index = 0 ; index < wsloopOp.getLinearVars ().size (); index++)
2622
+ linearClauseProcessor.rewriteInPlace (builder, " omp.loop_nest.region" ,
2623
+ index);
2624
+ builder.restoreIP (oldIP);
2625
+ }
2626
+
2446
2627
// Set the correct branch target for task cancellation
2447
2628
popCancelFinalizationCB (cancelTerminators, *ompBuilder, wsloopIP.get ());
2448
2629
0 commit comments