Skip to content

[llvm][mlir][OpenMP] Support translation for linear clause in omp.wsloop and omp.simd #139386

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

NimishMishra
Copy link
Contributor

This patch adds support for LLVM translation of linear clause on omp.wsloop (except for linear modifiers).

@llvmbot
Copy link
Member

llvmbot commented May 10, 2025

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-openmp

@llvm/pr-subscribers-flang-openmp

Author: None (NimishMishra)

Changes

This patch adds support for LLVM translation of linear clause on omp.wsloop (except for linear modifiers).


Patch is 25.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/139386.diff

10 Files Affected:

  • (modified) flang/lib/Lower/OpenMP/ClauseProcessor.cpp (+34)
  • (modified) flang/lib/Lower/OpenMP/ClauseProcessor.h (+1)
  • (modified) flang/lib/Lower/OpenMP/DataSharingProcessor.cpp (+3-2)
  • (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+2-2)
  • (added) flang/test/Lower/OpenMP/wsloop-linear.f90 (+57)
  • (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+15)
  • (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+3)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+183-2)
  • (modified) mlir/test/Target/LLVMIR/openmp-llvm.mlir (+88)
  • (modified) mlir/test/Target/LLVMIR/openmp-todo.mlir (-13)
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 79b5087e4da68..8ba2f604df80a 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -1060,6 +1060,40 @@ bool ClauseProcessor::processIsDevicePtr(
       });
 }
 
+bool ClauseProcessor::processLinear(mlir::omp::LinearClauseOps &result) const {
+  lower::StatementContext stmtCtx;
+  return findRepeatableClause<
+      omp::clause::Linear>([&](const omp::clause::Linear &clause,
+                               const parser::CharBlock &) {
+    auto &objects = std::get<omp::ObjectList>(clause.t);
+    for (const omp::Object &object : objects) {
+      semantics::Symbol *sym = object.sym();
+      const mlir::Value variable = converter.getSymbolAddress(*sym);
+      result.linearVars.push_back(variable);
+    }
+    if (objects.size()) {
+      if (auto &mod =
+              std::get<std::optional<omp::clause::Linear::StepComplexModifier>>(
+                  clause.t)) {
+        mlir::Value operand =
+            fir::getBase(converter.genExprValue(toEvExpr(*mod), stmtCtx));
+        result.linearStepVars.append(objects.size(), operand);
+      } else if (std::get<std::optional<omp::clause::Linear::LinearModifier>>(
+                     clause.t)) {
+        mlir::Location currentLocation = converter.getCurrentLocation();
+        TODO(currentLocation, "Linear modifiers not yet implemented");
+      } else {
+        // If nothing is present, add the default step of 1.
+        fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+        mlir::Location currentLocation = converter.getCurrentLocation();
+        mlir::Value operand = firOpBuilder.createIntegerConstant(
+            currentLocation, firOpBuilder.getI32Type(), 1);
+        result.linearStepVars.append(objects.size(), operand);
+      }
+    }
+  });
+}
+
 bool ClauseProcessor::processLink(
     llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
   return findRepeatableClause<omp::clause::Link>(
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 7857ba3fd0845..0ec41bdd33256 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -122,6 +122,7 @@ class ClauseProcessor {
   bool processIsDevicePtr(
       mlir::omp::IsDevicePtrClauseOps &result,
       llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
+  bool processLinear(mlir::omp::LinearClauseOps &result) const;
   bool
   processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
 
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
index 7eec598645eac..2a1c94407e1c8 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
@@ -213,14 +213,15 @@ void DataSharingProcessor::collectSymbolsForPrivatization() {
   // so, we won't need to explicitely handle block objects (or forget to do
   // so).
   for (auto *sym : explicitlyPrivatizedSymbols)
-    allPrivatizedSymbols.insert(sym);
+    if (!sym->test(Fortran::semantics::Symbol::Flag::OmpLinear))
+      allPrivatizedSymbols.insert(sym);
 }
 
 bool DataSharingProcessor::needBarrier() {
   // Emit implicit barrier to synchronize threads and avoid data races on
   // initialization of firstprivate variables and post-update of lastprivate
   // variables.
-  // Emit implicit barrier for linear clause. Maybe on somewhere else.
+  // Emit implicit barrier for linear clause in the OpenMPIRBuilder.
   for (const semantics::Symbol *sym : allPrivatizedSymbols) {
     if (sym->test(semantics::Symbol::Flag::OmpLastPrivate) &&
         (sym->test(semantics::Symbol::Flag::OmpFirstPrivate) ||
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 54560729eb4af..6fa915b4364f9 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1841,13 +1841,13 @@ static void genWsloopClauses(
     llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) {
   ClauseProcessor cp(converter, semaCtx, clauses);
   cp.processNowait(clauseOps);
+  cp.processLinear(clauseOps);
   cp.processOrder(clauseOps);
   cp.processOrdered(clauseOps);
   cp.processReduction(loc, clauseOps, reductionSyms);
   cp.processSchedule(stmtCtx, clauseOps);
 
-  cp.processTODO<clause::Allocate, clause::Linear>(
-      loc, llvm::omp::Directive::OMPD_do);
+  cp.processTODO<clause::Allocate>(loc, llvm::omp::Directive::OMPD_do);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/flang/test/Lower/OpenMP/wsloop-linear.f90 b/flang/test/Lower/OpenMP/wsloop-linear.f90
new file mode 100644
index 0000000000000..b99677108be2f
--- /dev/null
+++ b/flang/test/Lower/OpenMP/wsloop-linear.f90
@@ -0,0 +1,57 @@
+! This test checks lowering of OpenMP DO Directive (Worksharing)
+! with linear clause
+
+! RUN: %flang_fc1 -fopenmp -emit-hlfir %s -o - 2>&1 | FileCheck %s
+
+!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFsimple_linearEx"}
+!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFsimple_linearEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[const:.*]] = arith.constant 1 : i32
+subroutine simple_linear
+    implicit none
+    integer :: x, y, i
+    !CHECK: omp.wsloop linear(%[[X]]#0 = %[[const]] : !fir.ref<i32>) {{.*}}
+    !$omp do linear(x)
+    !CHECK: %[[LOAD:.*]] = fir.load %[[X]]#0 : !fir.ref<i32>
+    !CHECK: %[[const:.*]] = arith.constant 2 : i32
+    !CHECK: %[[RESULT:.*]] = arith.addi %[[LOAD]], %[[const]] : i32
+    do i = 1, 10
+        y = x + 2
+    end do
+    !$omp end do
+end subroutine
+
+
+!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFlinear_stepEx"}
+!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFlinear_stepEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+subroutine linear_step
+    implicit none
+    integer :: x, y, i
+    !CHECK: %[[const:.*]] = arith.constant 4 : i32
+    !CHECK: omp.wsloop linear(%[[X]]#0 = %[[const]] : !fir.ref<i32>) {{.*}}
+    !$omp do linear(x:4)
+    !CHECK: %[[LOAD:.*]] = fir.load %[[X]]#0 : !fir.ref<i32>
+    !CHECK: %[[const:.*]] = arith.constant 2 : i32
+    !CHECK: %[[RESULT:.*]] = arith.addi %[[LOAD]], %[[const]] : i32   
+    do i = 1, 10
+        y = x + 2
+    end do
+    !$omp end do
+end subroutine
+
+!CHECK: %[[A_alloca:.*]] = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFlinear_exprEa"}
+!CHECK: %[[A:.*]]:2 = hlfir.declare %[[A_alloca]] {uniq_name = "_QFlinear_exprEa"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFlinear_exprEx"}
+!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFlinear_exprEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+subroutine linear_expr
+    implicit none
+    integer :: x, y, i, a
+    !CHECK: %[[LOAD_A:.*]] = fir.load %[[A]]#0 : !fir.ref<i32>
+    !CHECK: %[[const:.*]] = arith.constant 4 : i32
+    !CHECK: %[[LINEAR_EXPR:.*]] = arith.addi %[[LOAD_A]], %[[const]] : i32
+    !CHECK: omp.wsloop linear(%[[X]]#0 = %[[LINEAR_EXPR]] : !fir.ref<i32>) {{.*}}
+    !$omp do linear(x:a+4)
+    do i = 1, 10
+        y = x + 2
+    end do
+    !$omp end do
+end subroutine
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index ffc0fd0a0bdac..68f15d5c7d41e 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -3580,6 +3580,9 @@ class CanonicalLoopInfo {
   BasicBlock *Latch = nullptr;
   BasicBlock *Exit = nullptr;
 
+  // Hold the MLIR value for the `lastiter` of the canonical loop.
+  Value *LastIter = nullptr;
+
   /// Add the control blocks of this loop to \p BBs.
   ///
   /// This does not include any block from the body, including the one returned
@@ -3612,6 +3615,18 @@ class CanonicalLoopInfo {
   void mapIndVar(llvm::function_ref<Value *(Instruction *)> Updater);
 
 public:
+  /// Sets the last iteration variable for this loop.
+  void setLastIter(Value *IterVar) { LastIter = std::move(IterVar); }
+
+  /// Returns the last iteration variable for this loop.
+  /// Certain use-cases (like translation of linear clause) may access
+  /// this variable even after a loop transformation. Hence, do not guard
+  /// this getter function by `isValid`. It is the responsibility of the
+  /// callee to ensure this functionality is not invoked by a non-outlined
+  /// CanonicalLoopInfo object (in which case, `setLastIter` will never be
+  /// invoked and `LastIter` will be by default `nullptr`).
+  Value *getLastIter() { return LastIter; }
+
   /// Returns whether this object currently represents the IR of a loop. If
   /// returning false, it may have been consumed by a loop transformation or not
   /// been intialized. Do not use in this case;
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index a1268ca76b2d5..991cdb7b6b416 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -4254,6 +4254,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyStaticWorkshareLoop(
   Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
   Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
   Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
+  CLI->setLastIter(PLastIter);
 
   // At the end of the preheader, prepare for calling the "init" function by
   // storing the current loop bounds into the allocated space. A canonical loop
@@ -4361,6 +4362,7 @@ OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(DebugLoc DL,
   Value *PUpperBound =
       Builder.CreateAlloca(InternalIVTy, nullptr, "p.upperbound");
   Value *PStride = Builder.CreateAlloca(InternalIVTy, nullptr, "p.stride");
+  CLI->setLastIter(PLastIter);
 
   // Set up the source location value for the OpenMP runtime.
   Builder.restoreIP(CLI->getPreheaderIP());
@@ -4844,6 +4846,7 @@ OpenMPIRBuilder::applyDynamicWorkshareLoop(DebugLoc DL, CanonicalLoopInfo *CLI,
   Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
   Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
   Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
+  CLI->setLastIter(PLastIter);
 
   // At the end of the preheader, prepare for calling the "init" function by
   // storing the current loop bounds into the allocated space. A canonical loop
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 9f7b5605556e6..571505ab9b9aa 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -124,6 +124,146 @@ class PreviouslyReportedError
 
 char PreviouslyReportedError::ID = 0;
 
+/*
+ * Custom class for processing linear clause for omp.wsloop
+ * and omp.simd. Linear clause translation requires setup,
+ * initialization, update, and finalization at varying
+ * basic blocks in the IR. This class helps maintain
+ * internal state to allow consistent translation in
+ * each of these stages.
+ */
+
+class LinearClauseProcessor {
+
+private:
+  SmallVector<llvm::Value *> linearPreconditionVars;
+  SmallVector<llvm::Value *> linearLoopBodyTemps;
+  SmallVector<llvm::AllocaInst *> linearOrigVars;
+  SmallVector<llvm::Value *> linearOrigVal;
+  SmallVector<llvm::Value *> linearSteps;
+  llvm::BasicBlock *linearFinalizationBB;
+  llvm::BasicBlock *linearExitBB;
+  llvm::BasicBlock *linearLastIterExitBB;
+
+public:
+  // Allocate space for linear variabes
+  void createLinearVar(llvm::IRBuilderBase &builder,
+                       LLVM::ModuleTranslation &moduleTranslation,
+                       mlir::Value &linearVar) {
+    if (llvm::AllocaInst *linearVarAlloca = dyn_cast<llvm::AllocaInst>(
+            moduleTranslation.lookupValue(linearVar))) {
+      linearPreconditionVars.push_back(builder.CreateAlloca(
+          linearVarAlloca->getAllocatedType(), nullptr, ".linear_var"));
+      llvm::Value *linearLoopBodyTemp = builder.CreateAlloca(
+          linearVarAlloca->getAllocatedType(), nullptr, ".linear_result");
+      linearOrigVal.push_back(moduleTranslation.lookupValue(linearVar));
+      linearLoopBodyTemps.push_back(linearLoopBodyTemp);
+      linearOrigVars.push_back(linearVarAlloca);
+    }
+  }
+
+  // Initialize linear step
+  inline void initLinearStep(LLVM::ModuleTranslation &moduleTranslation,
+                             mlir::Value &linearStep) {
+    linearSteps.push_back(moduleTranslation.lookupValue(linearStep));
+  }
+
+  // Emit IR for initialization of linear variables
+  llvm::OpenMPIRBuilder::InsertPointOrErrorTy
+  initLinearVar(llvm::IRBuilderBase &builder,
+                LLVM::ModuleTranslation &moduleTranslation,
+                llvm::BasicBlock *loopPreHeader) {
+    builder.SetInsertPoint(loopPreHeader->getTerminator());
+    for (size_t index = 0; index < linearOrigVars.size(); index++) {
+      llvm::LoadInst *linearVarLoad = builder.CreateLoad(
+          linearOrigVars[index]->getAllocatedType(), linearOrigVars[index]);
+      builder.CreateStore(linearVarLoad, linearPreconditionVars[index]);
+    }
+    llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
+        moduleTranslation.getOpenMPBuilder()->createBarrier(
+            builder.saveIP(), llvm::omp::OMPD_barrier);
+    return afterBarrierIP;
+  }
+
+  // Emit IR for updating Linear variables
+  void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
+                       llvm::Value *loopInductionVar) {
+    builder.SetInsertPoint(loopBody->getTerminator());
+    for (size_t index = 0; index < linearPreconditionVars.size(); index++) {
+      // Emit increments for linear vars
+      llvm::LoadInst *linearVarStart =
+          builder.CreateLoad(linearOrigVars[index]->getAllocatedType(),
+
+                             linearPreconditionVars[index]);
+      auto mulInst = builder.CreateMul(loopInductionVar, linearSteps[index]);
+      auto addInst = builder.CreateAdd(linearVarStart, mulInst);
+      builder.CreateStore(addInst, linearLoopBodyTemps[index]);
+    }
+  }
+
+  // Linear variable finalization is conditional on the last logical iteration.
+  // Create BB splits to manage the same.
+  void outlineLinearFinalizationBB(llvm::IRBuilderBase &builder,
+                                   llvm::BasicBlock *loopExit) {
+    linearFinalizationBB = loopExit->splitBasicBlock(
+        loopExit->getTerminator(), "omp_loop.linear_finalization");
+    linearExitBB = linearFinalizationBB->splitBasicBlock(
+        linearFinalizationBB->getTerminator(), "omp_loop.linear_exit");
+    linearLastIterExitBB = linearFinalizationBB->splitBasicBlock(
+        linearFinalizationBB->getTerminator(), "omp_loop.linear_lastiter_exit");
+  }
+
+  // Finalize the linear vars
+  llvm::OpenMPIRBuilder::InsertPointOrErrorTy
+  finalizeLinearVar(llvm::IRBuilderBase &builder,
+                    LLVM::ModuleTranslation &moduleTranslation,
+                    llvm::Value *lastIter) {
+    // Emit condition to check whether last logical iteration is being executed
+    builder.SetInsertPoint(linearFinalizationBB->getTerminator());
+    llvm::Value *loopLastIterLoad = builder.CreateLoad(
+        llvm::Type::getInt32Ty(builder.getContext()), lastIter);
+    llvm::Value *isLast =
+        builder.CreateCmp(llvm::CmpInst::ICMP_NE, loopLastIterLoad,
+                          llvm::ConstantInt::get(
+                              llvm::Type::getInt32Ty(builder.getContext()), 0));
+    // Store the linear variable values to original variables.
+    builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
+    for (size_t index = 0; index < linearOrigVars.size(); index++) {
+      llvm::LoadInst *linearVarTemp =
+          builder.CreateLoad(linearOrigVars[index]->getAllocatedType(),
+                             linearLoopBodyTemps[index]);
+      builder.CreateStore(linearVarTemp, linearOrigVars[index]);
+    }
+
+    // Create conditional branch such that the linear variable
+    // values are stored to original variables only at the
+    // last logical iteration
+    builder.SetInsertPoint(linearFinalizationBB->getTerminator());
+    builder.CreateCondBr(isLast, linearLastIterExitBB, linearExitBB);
+    linearFinalizationBB->getTerminator()->eraseFromParent();
+    // Emit barrier
+    builder.SetInsertPoint(linearExitBB->getTerminator());
+    return moduleTranslation.getOpenMPBuilder()->createBarrier(
+        builder.saveIP(), llvm::omp::OMPD_barrier);
+  }
+
+  // Rewrite all uses of the original variable in `BBName`
+  //  with the linear variable in-place
+  void rewriteInPlace(llvm::IRBuilderBase &builder, std::string BBName,
+                      size_t varIndex) {
+    llvm::SmallVector<llvm::User *> users;
+    for (llvm::User *user : linearOrigVal[varIndex]->users())
+      users.push_back(user);
+    for (auto *user : users) {
+      if (auto *userInst = dyn_cast<llvm::Instruction>(user)) {
+        if (userInst->getParent()->getName().str() == BBName)
+          user->replaceUsesOfWith(linearOrigVal[varIndex],
+                                  linearLoopBodyTemps[varIndex]);
+      }
+    }
+  }
+};
+
 } // namespace
 
 /// Looks up from the operation from and returns the PrivateClauseOp with
@@ -292,7 +432,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
       })
       .Case([&](omp::WsloopOp op) {
         checkAllocate(op, result);
-        checkLinear(op, result);
         checkOrder(op, result);
         checkReduction(op, result);
       })
@@ -2423,15 +2562,40 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
                            llvm::omp::Directive::OMPD_for);
 
   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+
+  // Initialize linear variables and linear step
+  LinearClauseProcessor linearClauseProcessor;
+  if (wsloopOp.getLinearVars().size()) {
+    for (mlir::Value linearVar : wsloopOp.getLinearVars())
+      linearClauseProcessor.createLinearVar(builder, moduleTranslation,
+                                            linearVar);
+    for (mlir::Value linearStep : wsloopOp.getLinearStepVars())
+      linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
+  }
+
   llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
       wsloopOp.getRegion(), "omp.wsloop.region", builder, moduleTranslation);
 
   if (failed(handleError(regionBlock, opInst)))
     return failure();
 
-  builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
   llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
 
+  // Emit Initialization and Update IR for linear variables
+  if (wsloopOp.getLinearVars().size()) {
+    llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
+        linearClauseProcessor.initLinearVar(builder, moduleTranslation,
+                                            loopInfo->getPreheader());
+    if (failed(handleError(afterBarrierIP, *loopOp)))
+      return failure();
+    builder.restoreIP(*afterBarrierIP);
+    linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
+                                          loopInfo->getIndVar());
+    linearClauseProcessor.outlineLinearFinalizationBB(builder,
+                                                      loopInfo->getExit());
+  }
+
+  builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
   llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
       ompBuilder->applyWorkshareLoop(
           ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
@@ -2443,6 +2607,23 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
   if (failed(handleError(wsloopIP, opInst)))
     return failure();
 
+  // Emit finalization and in-place rewrites for linear vars.
+  if (wsloopOp.getLinearVars().size()) {
+    llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
+    assert(loopInfo->getLastIter() &&
+           "`lastiter` in CanonicalLoopInfo is nullptr"...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 10, 2025

@llvm/pr-subscribers-mlir-llvm

Author: None (NimishMishra)

Changes

This patch adds support for LLVM translation of linear clause on omp.wsloop (except for linear modifiers).


Patch is 25.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/139386.diff

10 Files Affected:

  • (modified) flang/lib/Lower/OpenMP/ClauseProcessor.cpp (+34)
  • (modified) flang/lib/Lower/OpenMP/ClauseProcessor.h (+1)
  • (modified) flang/lib/Lower/OpenMP/DataSharingProcessor.cpp (+3-2)
  • (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+2-2)
  • (added) flang/test/Lower/OpenMP/wsloop-linear.f90 (+57)
  • (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+15)
  • (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+3)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+183-2)
  • (modified) mlir/test/Target/LLVMIR/openmp-llvm.mlir (+88)
  • (modified) mlir/test/Target/LLVMIR/openmp-todo.mlir (-13)
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 79b5087e4da68..8ba2f604df80a 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -1060,6 +1060,40 @@ bool ClauseProcessor::processIsDevicePtr(
       });
 }
 
+bool ClauseProcessor::processLinear(mlir::omp::LinearClauseOps &result) const {
+  lower::StatementContext stmtCtx;
+  return findRepeatableClause<
+      omp::clause::Linear>([&](const omp::clause::Linear &clause,
+                               const parser::CharBlock &) {
+    auto &objects = std::get<omp::ObjectList>(clause.t);
+    for (const omp::Object &object : objects) {
+      semantics::Symbol *sym = object.sym();
+      const mlir::Value variable = converter.getSymbolAddress(*sym);
+      result.linearVars.push_back(variable);
+    }
+    if (objects.size()) {
+      if (auto &mod =
+              std::get<std::optional<omp::clause::Linear::StepComplexModifier>>(
+                  clause.t)) {
+        mlir::Value operand =
+            fir::getBase(converter.genExprValue(toEvExpr(*mod), stmtCtx));
+        result.linearStepVars.append(objects.size(), operand);
+      } else if (std::get<std::optional<omp::clause::Linear::LinearModifier>>(
+                     clause.t)) {
+        mlir::Location currentLocation = converter.getCurrentLocation();
+        TODO(currentLocation, "Linear modifiers not yet implemented");
+      } else {
+        // If nothing is present, add the default step of 1.
+        fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+        mlir::Location currentLocation = converter.getCurrentLocation();
+        mlir::Value operand = firOpBuilder.createIntegerConstant(
+            currentLocation, firOpBuilder.getI32Type(), 1);
+        result.linearStepVars.append(objects.size(), operand);
+      }
+    }
+  });
+}
+
 bool ClauseProcessor::processLink(
     llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
   return findRepeatableClause<omp::clause::Link>(
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 7857ba3fd0845..0ec41bdd33256 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -122,6 +122,7 @@ class ClauseProcessor {
   bool processIsDevicePtr(
       mlir::omp::IsDevicePtrClauseOps &result,
       llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
+  bool processLinear(mlir::omp::LinearClauseOps &result) const;
   bool
   processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
 
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
index 7eec598645eac..2a1c94407e1c8 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
@@ -213,14 +213,15 @@ void DataSharingProcessor::collectSymbolsForPrivatization() {
   // so, we won't need to explicitely handle block objects (or forget to do
   // so).
   for (auto *sym : explicitlyPrivatizedSymbols)
-    allPrivatizedSymbols.insert(sym);
+    if (!sym->test(Fortran::semantics::Symbol::Flag::OmpLinear))
+      allPrivatizedSymbols.insert(sym);
 }
 
 bool DataSharingProcessor::needBarrier() {
   // Emit implicit barrier to synchronize threads and avoid data races on
   // initialization of firstprivate variables and post-update of lastprivate
   // variables.
-  // Emit implicit barrier for linear clause. Maybe on somewhere else.
+  // Emit implicit barrier for linear clause in the OpenMPIRBuilder.
   for (const semantics::Symbol *sym : allPrivatizedSymbols) {
     if (sym->test(semantics::Symbol::Flag::OmpLastPrivate) &&
         (sym->test(semantics::Symbol::Flag::OmpFirstPrivate) ||
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 54560729eb4af..6fa915b4364f9 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1841,13 +1841,13 @@ static void genWsloopClauses(
     llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) {
   ClauseProcessor cp(converter, semaCtx, clauses);
   cp.processNowait(clauseOps);
+  cp.processLinear(clauseOps);
   cp.processOrder(clauseOps);
   cp.processOrdered(clauseOps);
   cp.processReduction(loc, clauseOps, reductionSyms);
   cp.processSchedule(stmtCtx, clauseOps);
 
-  cp.processTODO<clause::Allocate, clause::Linear>(
-      loc, llvm::omp::Directive::OMPD_do);
+  cp.processTODO<clause::Allocate>(loc, llvm::omp::Directive::OMPD_do);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/flang/test/Lower/OpenMP/wsloop-linear.f90 b/flang/test/Lower/OpenMP/wsloop-linear.f90
new file mode 100644
index 0000000000000..b99677108be2f
--- /dev/null
+++ b/flang/test/Lower/OpenMP/wsloop-linear.f90
@@ -0,0 +1,57 @@
+! This test checks lowering of OpenMP DO Directive (Worksharing)
+! with linear clause
+
+! RUN: %flang_fc1 -fopenmp -emit-hlfir %s -o - 2>&1 | FileCheck %s
+
+!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFsimple_linearEx"}
+!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFsimple_linearEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[const:.*]] = arith.constant 1 : i32
+subroutine simple_linear
+    implicit none
+    integer :: x, y, i
+    !CHECK: omp.wsloop linear(%[[X]]#0 = %[[const]] : !fir.ref<i32>) {{.*}}
+    !$omp do linear(x)
+    !CHECK: %[[LOAD:.*]] = fir.load %[[X]]#0 : !fir.ref<i32>
+    !CHECK: %[[const:.*]] = arith.constant 2 : i32
+    !CHECK: %[[RESULT:.*]] = arith.addi %[[LOAD]], %[[const]] : i32
+    do i = 1, 10
+        y = x + 2
+    end do
+    !$omp end do
+end subroutine
+
+
+!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFlinear_stepEx"}
+!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFlinear_stepEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+subroutine linear_step
+    implicit none
+    integer :: x, y, i
+    !CHECK: %[[const:.*]] = arith.constant 4 : i32
+    !CHECK: omp.wsloop linear(%[[X]]#0 = %[[const]] : !fir.ref<i32>) {{.*}}
+    !$omp do linear(x:4)
+    !CHECK: %[[LOAD:.*]] = fir.load %[[X]]#0 : !fir.ref<i32>
+    !CHECK: %[[const:.*]] = arith.constant 2 : i32
+    !CHECK: %[[RESULT:.*]] = arith.addi %[[LOAD]], %[[const]] : i32   
+    do i = 1, 10
+        y = x + 2
+    end do
+    !$omp end do
+end subroutine
+
+!CHECK: %[[A_alloca:.*]] = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFlinear_exprEa"}
+!CHECK: %[[A:.*]]:2 = hlfir.declare %[[A_alloca]] {uniq_name = "_QFlinear_exprEa"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFlinear_exprEx"}
+!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFlinear_exprEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+subroutine linear_expr
+    implicit none
+    integer :: x, y, i, a
+    !CHECK: %[[LOAD_A:.*]] = fir.load %[[A]]#0 : !fir.ref<i32>
+    !CHECK: %[[const:.*]] = arith.constant 4 : i32
+    !CHECK: %[[LINEAR_EXPR:.*]] = arith.addi %[[LOAD_A]], %[[const]] : i32
+    !CHECK: omp.wsloop linear(%[[X]]#0 = %[[LINEAR_EXPR]] : !fir.ref<i32>) {{.*}}
+    !$omp do linear(x:a+4)
+    do i = 1, 10
+        y = x + 2
+    end do
+    !$omp end do
+end subroutine
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index ffc0fd0a0bdac..68f15d5c7d41e 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -3580,6 +3580,9 @@ class CanonicalLoopInfo {
   BasicBlock *Latch = nullptr;
   BasicBlock *Exit = nullptr;
 
+  // Hold the MLIR value for the `lastiter` of the canonical loop.
+  Value *LastIter = nullptr;
+
   /// Add the control blocks of this loop to \p BBs.
   ///
   /// This does not include any block from the body, including the one returned
@@ -3612,6 +3615,18 @@ class CanonicalLoopInfo {
   void mapIndVar(llvm::function_ref<Value *(Instruction *)> Updater);
 
 public:
+  /// Sets the last iteration variable for this loop.
+  void setLastIter(Value *IterVar) { LastIter = std::move(IterVar); }
+
+  /// Returns the last iteration variable for this loop.
+  /// Certain use-cases (like translation of linear clause) may access
+  /// this variable even after a loop transformation. Hence, do not guard
+  /// this getter function by `isValid`. It is the responsibility of the
+  /// callee to ensure this functionality is not invoked by a non-outlined
+  /// CanonicalLoopInfo object (in which case, `setLastIter` will never be
+  /// invoked and `LastIter` will be by default `nullptr`).
+  Value *getLastIter() { return LastIter; }
+
   /// Returns whether this object currently represents the IR of a loop. If
   /// returning false, it may have been consumed by a loop transformation or not
   /// been intialized. Do not use in this case;
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index a1268ca76b2d5..991cdb7b6b416 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -4254,6 +4254,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyStaticWorkshareLoop(
   Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
   Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
   Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
+  CLI->setLastIter(PLastIter);
 
   // At the end of the preheader, prepare for calling the "init" function by
   // storing the current loop bounds into the allocated space. A canonical loop
@@ -4361,6 +4362,7 @@ OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(DebugLoc DL,
   Value *PUpperBound =
       Builder.CreateAlloca(InternalIVTy, nullptr, "p.upperbound");
   Value *PStride = Builder.CreateAlloca(InternalIVTy, nullptr, "p.stride");
+  CLI->setLastIter(PLastIter);
 
   // Set up the source location value for the OpenMP runtime.
   Builder.restoreIP(CLI->getPreheaderIP());
@@ -4844,6 +4846,7 @@ OpenMPIRBuilder::applyDynamicWorkshareLoop(DebugLoc DL, CanonicalLoopInfo *CLI,
   Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
   Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
   Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
+  CLI->setLastIter(PLastIter);
 
   // At the end of the preheader, prepare for calling the "init" function by
   // storing the current loop bounds into the allocated space. A canonical loop
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 9f7b5605556e6..571505ab9b9aa 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -124,6 +124,146 @@ class PreviouslyReportedError
 
 char PreviouslyReportedError::ID = 0;
 
+/*
+ * Custom class for processing linear clause for omp.wsloop
+ * and omp.simd. Linear clause translation requires setup,
+ * initialization, update, and finalization at varying
+ * basic blocks in the IR. This class helps maintain
+ * internal state to allow consistent translation in
+ * each of these stages.
+ */
+
+class LinearClauseProcessor {
+
+private:
+  SmallVector<llvm::Value *> linearPreconditionVars;
+  SmallVector<llvm::Value *> linearLoopBodyTemps;
+  SmallVector<llvm::AllocaInst *> linearOrigVars;
+  SmallVector<llvm::Value *> linearOrigVal;
+  SmallVector<llvm::Value *> linearSteps;
+  llvm::BasicBlock *linearFinalizationBB;
+  llvm::BasicBlock *linearExitBB;
+  llvm::BasicBlock *linearLastIterExitBB;
+
+public:
+  // Allocate space for linear variabes
+  void createLinearVar(llvm::IRBuilderBase &builder,
+                       LLVM::ModuleTranslation &moduleTranslation,
+                       mlir::Value &linearVar) {
+    if (llvm::AllocaInst *linearVarAlloca = dyn_cast<llvm::AllocaInst>(
+            moduleTranslation.lookupValue(linearVar))) {
+      linearPreconditionVars.push_back(builder.CreateAlloca(
+          linearVarAlloca->getAllocatedType(), nullptr, ".linear_var"));
+      llvm::Value *linearLoopBodyTemp = builder.CreateAlloca(
+          linearVarAlloca->getAllocatedType(), nullptr, ".linear_result");
+      linearOrigVal.push_back(moduleTranslation.lookupValue(linearVar));
+      linearLoopBodyTemps.push_back(linearLoopBodyTemp);
+      linearOrigVars.push_back(linearVarAlloca);
+    }
+  }
+
+  // Initialize linear step
+  inline void initLinearStep(LLVM::ModuleTranslation &moduleTranslation,
+                             mlir::Value &linearStep) {
+    linearSteps.push_back(moduleTranslation.lookupValue(linearStep));
+  }
+
+  // Emit IR for initialization of linear variables
+  llvm::OpenMPIRBuilder::InsertPointOrErrorTy
+  initLinearVar(llvm::IRBuilderBase &builder,
+                LLVM::ModuleTranslation &moduleTranslation,
+                llvm::BasicBlock *loopPreHeader) {
+    builder.SetInsertPoint(loopPreHeader->getTerminator());
+    for (size_t index = 0; index < linearOrigVars.size(); index++) {
+      llvm::LoadInst *linearVarLoad = builder.CreateLoad(
+          linearOrigVars[index]->getAllocatedType(), linearOrigVars[index]);
+      builder.CreateStore(linearVarLoad, linearPreconditionVars[index]);
+    }
+    llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
+        moduleTranslation.getOpenMPBuilder()->createBarrier(
+            builder.saveIP(), llvm::omp::OMPD_barrier);
+    return afterBarrierIP;
+  }
+
+  // Emit IR for updating Linear variables
+  void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
+                       llvm::Value *loopInductionVar) {
+    builder.SetInsertPoint(loopBody->getTerminator());
+    for (size_t index = 0; index < linearPreconditionVars.size(); index++) {
+      // Emit increments for linear vars
+      llvm::LoadInst *linearVarStart =
+          builder.CreateLoad(linearOrigVars[index]->getAllocatedType(),
+
+                             linearPreconditionVars[index]);
+      auto mulInst = builder.CreateMul(loopInductionVar, linearSteps[index]);
+      auto addInst = builder.CreateAdd(linearVarStart, mulInst);
+      builder.CreateStore(addInst, linearLoopBodyTemps[index]);
+    }
+  }
+
+  // Linear variable finalization is conditional on the last logical iteration.
+  // Create BB splits to manage the same.
+  void outlineLinearFinalizationBB(llvm::IRBuilderBase &builder,
+                                   llvm::BasicBlock *loopExit) {
+    linearFinalizationBB = loopExit->splitBasicBlock(
+        loopExit->getTerminator(), "omp_loop.linear_finalization");
+    linearExitBB = linearFinalizationBB->splitBasicBlock(
+        linearFinalizationBB->getTerminator(), "omp_loop.linear_exit");
+    linearLastIterExitBB = linearFinalizationBB->splitBasicBlock(
+        linearFinalizationBB->getTerminator(), "omp_loop.linear_lastiter_exit");
+  }
+
+  // Finalize the linear vars
+  llvm::OpenMPIRBuilder::InsertPointOrErrorTy
+  finalizeLinearVar(llvm::IRBuilderBase &builder,
+                    LLVM::ModuleTranslation &moduleTranslation,
+                    llvm::Value *lastIter) {
+    // Emit condition to check whether last logical iteration is being executed
+    builder.SetInsertPoint(linearFinalizationBB->getTerminator());
+    llvm::Value *loopLastIterLoad = builder.CreateLoad(
+        llvm::Type::getInt32Ty(builder.getContext()), lastIter);
+    llvm::Value *isLast =
+        builder.CreateCmp(llvm::CmpInst::ICMP_NE, loopLastIterLoad,
+                          llvm::ConstantInt::get(
+                              llvm::Type::getInt32Ty(builder.getContext()), 0));
+    // Store the linear variable values to original variables.
+    builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
+    for (size_t index = 0; index < linearOrigVars.size(); index++) {
+      llvm::LoadInst *linearVarTemp =
+          builder.CreateLoad(linearOrigVars[index]->getAllocatedType(),
+                             linearLoopBodyTemps[index]);
+      builder.CreateStore(linearVarTemp, linearOrigVars[index]);
+    }
+
+    // Create conditional branch such that the linear variable
+    // values are stored to original variables only at the
+    // last logical iteration
+    builder.SetInsertPoint(linearFinalizationBB->getTerminator());
+    builder.CreateCondBr(isLast, linearLastIterExitBB, linearExitBB);
+    linearFinalizationBB->getTerminator()->eraseFromParent();
+    // Emit barrier
+    builder.SetInsertPoint(linearExitBB->getTerminator());
+    return moduleTranslation.getOpenMPBuilder()->createBarrier(
+        builder.saveIP(), llvm::omp::OMPD_barrier);
+  }
+
+  // Rewrite all uses of the original variable in `BBName`
+  //  with the linear variable in-place
+  void rewriteInPlace(llvm::IRBuilderBase &builder, std::string BBName,
+                      size_t varIndex) {
+    llvm::SmallVector<llvm::User *> users;
+    for (llvm::User *user : linearOrigVal[varIndex]->users())
+      users.push_back(user);
+    for (auto *user : users) {
+      if (auto *userInst = dyn_cast<llvm::Instruction>(user)) {
+        if (userInst->getParent()->getName().str() == BBName)
+          user->replaceUsesOfWith(linearOrigVal[varIndex],
+                                  linearLoopBodyTemps[varIndex]);
+      }
+    }
+  }
+};
+
 } // namespace
 
 /// Looks up from the operation from and returns the PrivateClauseOp with
@@ -292,7 +432,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
       })
       .Case([&](omp::WsloopOp op) {
         checkAllocate(op, result);
-        checkLinear(op, result);
         checkOrder(op, result);
         checkReduction(op, result);
       })
@@ -2423,15 +2562,40 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
                            llvm::omp::Directive::OMPD_for);
 
   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+
+  // Initialize linear variables and linear step
+  LinearClauseProcessor linearClauseProcessor;
+  if (wsloopOp.getLinearVars().size()) {
+    for (mlir::Value linearVar : wsloopOp.getLinearVars())
+      linearClauseProcessor.createLinearVar(builder, moduleTranslation,
+                                            linearVar);
+    for (mlir::Value linearStep : wsloopOp.getLinearStepVars())
+      linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
+  }
+
   llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
       wsloopOp.getRegion(), "omp.wsloop.region", builder, moduleTranslation);
 
   if (failed(handleError(regionBlock, opInst)))
     return failure();
 
-  builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
   llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
 
+  // Emit Initialization and Update IR for linear variables
+  if (wsloopOp.getLinearVars().size()) {
+    llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
+        linearClauseProcessor.initLinearVar(builder, moduleTranslation,
+                                            loopInfo->getPreheader());
+    if (failed(handleError(afterBarrierIP, *loopOp)))
+      return failure();
+    builder.restoreIP(*afterBarrierIP);
+    linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
+                                          loopInfo->getIndVar());
+    linearClauseProcessor.outlineLinearFinalizationBB(builder,
+                                                      loopInfo->getExit());
+  }
+
+  builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
   llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
       ompBuilder->applyWorkshareLoop(
           ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
@@ -2443,6 +2607,23 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
   if (failed(handleError(wsloopIP, opInst)))
     return failure();
 
+  // Emit finalization and in-place rewrites for linear vars.
+  if (wsloopOp.getLinearVars().size()) {
+    llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
+    assert(loopInfo->getLastIter() &&
+           "`lastiter` in CanonicalLoopInfo is nullptr"...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 10, 2025

@llvm/pr-subscribers-flang-fir-hlfir

Author: None (NimishMishra)

Changes

This patch adds support for LLVM translation of linear clause on omp.wsloop (except for linear modifiers).


Patch is 25.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/139386.diff

10 Files Affected:

  • (modified) flang/lib/Lower/OpenMP/ClauseProcessor.cpp (+34)
  • (modified) flang/lib/Lower/OpenMP/ClauseProcessor.h (+1)
  • (modified) flang/lib/Lower/OpenMP/DataSharingProcessor.cpp (+3-2)
  • (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+2-2)
  • (added) flang/test/Lower/OpenMP/wsloop-linear.f90 (+57)
  • (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+15)
  • (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+3)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+183-2)
  • (modified) mlir/test/Target/LLVMIR/openmp-llvm.mlir (+88)
  • (modified) mlir/test/Target/LLVMIR/openmp-todo.mlir (-13)
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 79b5087e4da68..8ba2f604df80a 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -1060,6 +1060,40 @@ bool ClauseProcessor::processIsDevicePtr(
       });
 }
 
+bool ClauseProcessor::processLinear(mlir::omp::LinearClauseOps &result) const {
+  lower::StatementContext stmtCtx;
+  return findRepeatableClause<
+      omp::clause::Linear>([&](const omp::clause::Linear &clause,
+                               const parser::CharBlock &) {
+    auto &objects = std::get<omp::ObjectList>(clause.t);
+    for (const omp::Object &object : objects) {
+      semantics::Symbol *sym = object.sym();
+      const mlir::Value variable = converter.getSymbolAddress(*sym);
+      result.linearVars.push_back(variable);
+    }
+    if (objects.size()) {
+      if (auto &mod =
+              std::get<std::optional<omp::clause::Linear::StepComplexModifier>>(
+                  clause.t)) {
+        mlir::Value operand =
+            fir::getBase(converter.genExprValue(toEvExpr(*mod), stmtCtx));
+        result.linearStepVars.append(objects.size(), operand);
+      } else if (std::get<std::optional<omp::clause::Linear::LinearModifier>>(
+                     clause.t)) {
+        mlir::Location currentLocation = converter.getCurrentLocation();
+        TODO(currentLocation, "Linear modifiers not yet implemented");
+      } else {
+        // If nothing is present, add the default step of 1.
+        fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+        mlir::Location currentLocation = converter.getCurrentLocation();
+        mlir::Value operand = firOpBuilder.createIntegerConstant(
+            currentLocation, firOpBuilder.getI32Type(), 1);
+        result.linearStepVars.append(objects.size(), operand);
+      }
+    }
+  });
+}
+
 bool ClauseProcessor::processLink(
     llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
   return findRepeatableClause<omp::clause::Link>(
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 7857ba3fd0845..0ec41bdd33256 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -122,6 +122,7 @@ class ClauseProcessor {
   bool processIsDevicePtr(
       mlir::omp::IsDevicePtrClauseOps &result,
       llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
+  bool processLinear(mlir::omp::LinearClauseOps &result) const;
   bool
   processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
 
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
index 7eec598645eac..2a1c94407e1c8 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
@@ -213,14 +213,15 @@ void DataSharingProcessor::collectSymbolsForPrivatization() {
   // so, we won't need to explicitely handle block objects (or forget to do
   // so).
   for (auto *sym : explicitlyPrivatizedSymbols)
-    allPrivatizedSymbols.insert(sym);
+    if (!sym->test(Fortran::semantics::Symbol::Flag::OmpLinear))
+      allPrivatizedSymbols.insert(sym);
 }
 
 bool DataSharingProcessor::needBarrier() {
   // Emit implicit barrier to synchronize threads and avoid data races on
   // initialization of firstprivate variables and post-update of lastprivate
   // variables.
-  // Emit implicit barrier for linear clause. Maybe on somewhere else.
+  // Emit implicit barrier for linear clause in the OpenMPIRBuilder.
   for (const semantics::Symbol *sym : allPrivatizedSymbols) {
     if (sym->test(semantics::Symbol::Flag::OmpLastPrivate) &&
         (sym->test(semantics::Symbol::Flag::OmpFirstPrivate) ||
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 54560729eb4af..6fa915b4364f9 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1841,13 +1841,13 @@ static void genWsloopClauses(
     llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) {
   ClauseProcessor cp(converter, semaCtx, clauses);
   cp.processNowait(clauseOps);
+  cp.processLinear(clauseOps);
   cp.processOrder(clauseOps);
   cp.processOrdered(clauseOps);
   cp.processReduction(loc, clauseOps, reductionSyms);
   cp.processSchedule(stmtCtx, clauseOps);
 
-  cp.processTODO<clause::Allocate, clause::Linear>(
-      loc, llvm::omp::Directive::OMPD_do);
+  cp.processTODO<clause::Allocate>(loc, llvm::omp::Directive::OMPD_do);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/flang/test/Lower/OpenMP/wsloop-linear.f90 b/flang/test/Lower/OpenMP/wsloop-linear.f90
new file mode 100644
index 0000000000000..b99677108be2f
--- /dev/null
+++ b/flang/test/Lower/OpenMP/wsloop-linear.f90
@@ -0,0 +1,57 @@
+! This test checks lowering of OpenMP DO Directive (Worksharing)
+! with linear clause
+
+! RUN: %flang_fc1 -fopenmp -emit-hlfir %s -o - 2>&1 | FileCheck %s
+
+!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFsimple_linearEx"}
+!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFsimple_linearEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[const:.*]] = arith.constant 1 : i32
+subroutine simple_linear
+    implicit none
+    integer :: x, y, i
+    !CHECK: omp.wsloop linear(%[[X]]#0 = %[[const]] : !fir.ref<i32>) {{.*}}
+    !$omp do linear(x)
+    !CHECK: %[[LOAD:.*]] = fir.load %[[X]]#0 : !fir.ref<i32>
+    !CHECK: %[[const:.*]] = arith.constant 2 : i32
+    !CHECK: %[[RESULT:.*]] = arith.addi %[[LOAD]], %[[const]] : i32
+    do i = 1, 10
+        y = x + 2
+    end do
+    !$omp end do
+end subroutine
+
+
+!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFlinear_stepEx"}
+!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFlinear_stepEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+subroutine linear_step
+    implicit none
+    integer :: x, y, i
+    !CHECK: %[[const:.*]] = arith.constant 4 : i32
+    !CHECK: omp.wsloop linear(%[[X]]#0 = %[[const]] : !fir.ref<i32>) {{.*}}
+    !$omp do linear(x:4)
+    !CHECK: %[[LOAD:.*]] = fir.load %[[X]]#0 : !fir.ref<i32>
+    !CHECK: %[[const:.*]] = arith.constant 2 : i32
+    !CHECK: %[[RESULT:.*]] = arith.addi %[[LOAD]], %[[const]] : i32   
+    do i = 1, 10
+        y = x + 2
+    end do
+    !$omp end do
+end subroutine
+
+!CHECK: %[[A_alloca:.*]] = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFlinear_exprEa"}
+!CHECK: %[[A:.*]]:2 = hlfir.declare %[[A_alloca]] {uniq_name = "_QFlinear_exprEa"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFlinear_exprEx"}
+!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFlinear_exprEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+subroutine linear_expr
+    implicit none
+    integer :: x, y, i, a
+    !CHECK: %[[LOAD_A:.*]] = fir.load %[[A]]#0 : !fir.ref<i32>
+    !CHECK: %[[const:.*]] = arith.constant 4 : i32
+    !CHECK: %[[LINEAR_EXPR:.*]] = arith.addi %[[LOAD_A]], %[[const]] : i32
+    !CHECK: omp.wsloop linear(%[[X]]#0 = %[[LINEAR_EXPR]] : !fir.ref<i32>) {{.*}}
+    !$omp do linear(x:a+4)
+    do i = 1, 10
+        y = x + 2
+    end do
+    !$omp end do
+end subroutine
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index ffc0fd0a0bdac..68f15d5c7d41e 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -3580,6 +3580,9 @@ class CanonicalLoopInfo {
   BasicBlock *Latch = nullptr;
   BasicBlock *Exit = nullptr;
 
+  // Hold the MLIR value for the `lastiter` of the canonical loop.
+  Value *LastIter = nullptr;
+
   /// Add the control blocks of this loop to \p BBs.
   ///
   /// This does not include any block from the body, including the one returned
@@ -3612,6 +3615,18 @@ class CanonicalLoopInfo {
   void mapIndVar(llvm::function_ref<Value *(Instruction *)> Updater);
 
 public:
+  /// Sets the last iteration variable for this loop.
+  void setLastIter(Value *IterVar) { LastIter = std::move(IterVar); }
+
+  /// Returns the last iteration variable for this loop.
+  /// Certain use-cases (like translation of linear clause) may access
+  /// this variable even after a loop transformation. Hence, do not guard
+  /// this getter function by `isValid`. It is the responsibility of the
+  /// callee to ensure this functionality is not invoked by a non-outlined
+  /// CanonicalLoopInfo object (in which case, `setLastIter` will never be
+  /// invoked and `LastIter` will be by default `nullptr`).
+  Value *getLastIter() { return LastIter; }
+
   /// Returns whether this object currently represents the IR of a loop. If
   /// returning false, it may have been consumed by a loop transformation or not
   /// been intialized. Do not use in this case;
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index a1268ca76b2d5..991cdb7b6b416 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -4254,6 +4254,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyStaticWorkshareLoop(
   Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
   Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
   Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
+  CLI->setLastIter(PLastIter);
 
   // At the end of the preheader, prepare for calling the "init" function by
   // storing the current loop bounds into the allocated space. A canonical loop
@@ -4361,6 +4362,7 @@ OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(DebugLoc DL,
   Value *PUpperBound =
       Builder.CreateAlloca(InternalIVTy, nullptr, "p.upperbound");
   Value *PStride = Builder.CreateAlloca(InternalIVTy, nullptr, "p.stride");
+  CLI->setLastIter(PLastIter);
 
   // Set up the source location value for the OpenMP runtime.
   Builder.restoreIP(CLI->getPreheaderIP());
@@ -4844,6 +4846,7 @@ OpenMPIRBuilder::applyDynamicWorkshareLoop(DebugLoc DL, CanonicalLoopInfo *CLI,
   Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
   Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
   Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
+  CLI->setLastIter(PLastIter);
 
   // At the end of the preheader, prepare for calling the "init" function by
   // storing the current loop bounds into the allocated space. A canonical loop
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 9f7b5605556e6..571505ab9b9aa 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -124,6 +124,146 @@ class PreviouslyReportedError
 
 char PreviouslyReportedError::ID = 0;
 
+/*
+ * Custom class for processing linear clause for omp.wsloop
+ * and omp.simd. Linear clause translation requires setup,
+ * initialization, update, and finalization at varying
+ * basic blocks in the IR. This class helps maintain
+ * internal state to allow consistent translation in
+ * each of these stages.
+ */
+
+class LinearClauseProcessor {
+
+private:
+  SmallVector<llvm::Value *> linearPreconditionVars;
+  SmallVector<llvm::Value *> linearLoopBodyTemps;
+  SmallVector<llvm::AllocaInst *> linearOrigVars;
+  SmallVector<llvm::Value *> linearOrigVal;
+  SmallVector<llvm::Value *> linearSteps;
+  llvm::BasicBlock *linearFinalizationBB;
+  llvm::BasicBlock *linearExitBB;
+  llvm::BasicBlock *linearLastIterExitBB;
+
+public:
+  // Allocate space for linear variabes
+  void createLinearVar(llvm::IRBuilderBase &builder,
+                       LLVM::ModuleTranslation &moduleTranslation,
+                       mlir::Value &linearVar) {
+    if (llvm::AllocaInst *linearVarAlloca = dyn_cast<llvm::AllocaInst>(
+            moduleTranslation.lookupValue(linearVar))) {
+      linearPreconditionVars.push_back(builder.CreateAlloca(
+          linearVarAlloca->getAllocatedType(), nullptr, ".linear_var"));
+      llvm::Value *linearLoopBodyTemp = builder.CreateAlloca(
+          linearVarAlloca->getAllocatedType(), nullptr, ".linear_result");
+      linearOrigVal.push_back(moduleTranslation.lookupValue(linearVar));
+      linearLoopBodyTemps.push_back(linearLoopBodyTemp);
+      linearOrigVars.push_back(linearVarAlloca);
+    }
+  }
+
+  // Initialize linear step
+  inline void initLinearStep(LLVM::ModuleTranslation &moduleTranslation,
+                             mlir::Value &linearStep) {
+    linearSteps.push_back(moduleTranslation.lookupValue(linearStep));
+  }
+
+  // Emit IR for initialization of linear variables
+  llvm::OpenMPIRBuilder::InsertPointOrErrorTy
+  initLinearVar(llvm::IRBuilderBase &builder,
+                LLVM::ModuleTranslation &moduleTranslation,
+                llvm::BasicBlock *loopPreHeader) {
+    builder.SetInsertPoint(loopPreHeader->getTerminator());
+    for (size_t index = 0; index < linearOrigVars.size(); index++) {
+      llvm::LoadInst *linearVarLoad = builder.CreateLoad(
+          linearOrigVars[index]->getAllocatedType(), linearOrigVars[index]);
+      builder.CreateStore(linearVarLoad, linearPreconditionVars[index]);
+    }
+    llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
+        moduleTranslation.getOpenMPBuilder()->createBarrier(
+            builder.saveIP(), llvm::omp::OMPD_barrier);
+    return afterBarrierIP;
+  }
+
+  // Emit IR for updating Linear variables
+  void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
+                       llvm::Value *loopInductionVar) {
+    builder.SetInsertPoint(loopBody->getTerminator());
+    for (size_t index = 0; index < linearPreconditionVars.size(); index++) {
+      // Emit increments for linear vars
+      llvm::LoadInst *linearVarStart =
+          builder.CreateLoad(linearOrigVars[index]->getAllocatedType(),
+
+                             linearPreconditionVars[index]);
+      auto mulInst = builder.CreateMul(loopInductionVar, linearSteps[index]);
+      auto addInst = builder.CreateAdd(linearVarStart, mulInst);
+      builder.CreateStore(addInst, linearLoopBodyTemps[index]);
+    }
+  }
+
+  // Linear variable finalization is conditional on the last logical iteration.
+  // Create BB splits to manage the same.
+  void outlineLinearFinalizationBB(llvm::IRBuilderBase &builder,
+                                   llvm::BasicBlock *loopExit) {
+    linearFinalizationBB = loopExit->splitBasicBlock(
+        loopExit->getTerminator(), "omp_loop.linear_finalization");
+    linearExitBB = linearFinalizationBB->splitBasicBlock(
+        linearFinalizationBB->getTerminator(), "omp_loop.linear_exit");
+    linearLastIterExitBB = linearFinalizationBB->splitBasicBlock(
+        linearFinalizationBB->getTerminator(), "omp_loop.linear_lastiter_exit");
+  }
+
+  // Finalize the linear vars
+  llvm::OpenMPIRBuilder::InsertPointOrErrorTy
+  finalizeLinearVar(llvm::IRBuilderBase &builder,
+                    LLVM::ModuleTranslation &moduleTranslation,
+                    llvm::Value *lastIter) {
+    // Emit condition to check whether last logical iteration is being executed
+    builder.SetInsertPoint(linearFinalizationBB->getTerminator());
+    llvm::Value *loopLastIterLoad = builder.CreateLoad(
+        llvm::Type::getInt32Ty(builder.getContext()), lastIter);
+    llvm::Value *isLast =
+        builder.CreateCmp(llvm::CmpInst::ICMP_NE, loopLastIterLoad,
+                          llvm::ConstantInt::get(
+                              llvm::Type::getInt32Ty(builder.getContext()), 0));
+    // Store the linear variable values to original variables.
+    builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
+    for (size_t index = 0; index < linearOrigVars.size(); index++) {
+      llvm::LoadInst *linearVarTemp =
+          builder.CreateLoad(linearOrigVars[index]->getAllocatedType(),
+                             linearLoopBodyTemps[index]);
+      builder.CreateStore(linearVarTemp, linearOrigVars[index]);
+    }
+
+    // Create conditional branch such that the linear variable
+    // values are stored to original variables only at the
+    // last logical iteration
+    builder.SetInsertPoint(linearFinalizationBB->getTerminator());
+    builder.CreateCondBr(isLast, linearLastIterExitBB, linearExitBB);
+    linearFinalizationBB->getTerminator()->eraseFromParent();
+    // Emit barrier
+    builder.SetInsertPoint(linearExitBB->getTerminator());
+    return moduleTranslation.getOpenMPBuilder()->createBarrier(
+        builder.saveIP(), llvm::omp::OMPD_barrier);
+  }
+
+  // Rewrite all uses of the original variable in `BBName`
+  //  with the linear variable in-place
+  void rewriteInPlace(llvm::IRBuilderBase &builder, std::string BBName,
+                      size_t varIndex) {
+    llvm::SmallVector<llvm::User *> users;
+    for (llvm::User *user : linearOrigVal[varIndex]->users())
+      users.push_back(user);
+    for (auto *user : users) {
+      if (auto *userInst = dyn_cast<llvm::Instruction>(user)) {
+        if (userInst->getParent()->getName().str() == BBName)
+          user->replaceUsesOfWith(linearOrigVal[varIndex],
+                                  linearLoopBodyTemps[varIndex]);
+      }
+    }
+  }
+};
+
 } // namespace
 
 /// Looks up from the operation from and returns the PrivateClauseOp with
@@ -292,7 +432,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
       })
       .Case([&](omp::WsloopOp op) {
         checkAllocate(op, result);
-        checkLinear(op, result);
         checkOrder(op, result);
         checkReduction(op, result);
       })
@@ -2423,15 +2562,40 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
                            llvm::omp::Directive::OMPD_for);
 
   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+
+  // Initialize linear variables and linear step
+  LinearClauseProcessor linearClauseProcessor;
+  if (wsloopOp.getLinearVars().size()) {
+    for (mlir::Value linearVar : wsloopOp.getLinearVars())
+      linearClauseProcessor.createLinearVar(builder, moduleTranslation,
+                                            linearVar);
+    for (mlir::Value linearStep : wsloopOp.getLinearStepVars())
+      linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
+  }
+
   llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
       wsloopOp.getRegion(), "omp.wsloop.region", builder, moduleTranslation);
 
   if (failed(handleError(regionBlock, opInst)))
     return failure();
 
-  builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
   llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
 
+  // Emit Initialization and Update IR for linear variables
+  if (wsloopOp.getLinearVars().size()) {
+    llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
+        linearClauseProcessor.initLinearVar(builder, moduleTranslation,
+                                            loopInfo->getPreheader());
+    if (failed(handleError(afterBarrierIP, *loopOp)))
+      return failure();
+    builder.restoreIP(*afterBarrierIP);
+    linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
+                                          loopInfo->getIndVar());
+    linearClauseProcessor.outlineLinearFinalizationBB(builder,
+                                                      loopInfo->getExit());
+  }
+
+  builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
   llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
       ompBuilder->applyWorkshareLoop(
           ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
@@ -2443,6 +2607,23 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
   if (failed(handleError(wsloopIP, opInst)))
     return failure();
 
+  // Emit finalization and in-place rewrites for linear vars.
+  if (wsloopOp.getLinearVars().size()) {
+    llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
+    assert(loopInfo->getLastIter() &&
+           "`lastiter` in CanonicalLoopInfo is nullptr"...
[truncated]

@NimishMishra
Copy link
Contributor Author

The following test case from OpenMP examples document compiles successfully with a combination of PR #139385 and this PR:

program linear_loop
 use omp_lib
 implicit none
 integer, parameter :: N = 100
 real :: a(N), b(N/2)
 integer :: i, j

 do i = 1, N
 a(i) = i
 end do

 j = 0
 !$omp parallel
 !$omp do linear(j:1)
 do i = 1, N, 2
 j = j + 1
 b(j) = a(i) * 2.0
 end do
 !$omp end parallel

 print *, j, b(1), b(j)
 ! print out: 50 2.0 198.0

 end program

Flang output: 50 2. 198. as expected.

@NimishMishra
Copy link
Contributor Author

This PR is stacked over #139385 to allow for easy testing. I intend to merge this PR only after #139385 is merged.

Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the progress so far

// Emit condition to check whether last logical iteration is being executed
builder.SetInsertPoint(linearFinalizationBB->getTerminator());
llvm::Value *loopLastIterLoad = builder.CreateLoad(
llvm::Type::getInt32Ty(builder.getContext()), lastIter);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this always i32?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am thinking of this from the perspective of checking whether the current iteration is the last iteration of the loop or not. Now that you mention it, I am not sure whether p.lastiter in canonical loop bodygen is a bool (i.e. a flag denoting whether this is the last iteration or not) or an integer (i.e. holding end - 1, where end is the loop end bound). It is better to have this load match the datatype of its counterpart in canonical loop bodygen

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked the workshare loop allocas, %p.lastiter is alloca i32, align 4. So we can use a load i32 too. Would that be ok?


// Initialize linear variables and linear step
LinearClauseProcessor linearClauseProcessor;
if (wsloopOp.getLinearVars().size()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This if is unnecessary. The for loops will not execute if there are no linear vars.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


// Linear variable finalization is conditional on the last logical iteration.
// Create BB splits to manage the same.
void outlineLinearFinalizationBB(llvm::IRBuilderBase &builder,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For me at least, this isn't what I expect from the term "outline". For me, "outline" would mean moving blocks into a different function. Perhaps a better name would be splitLinearFiniBlock.

Feel free to ignore if others disagree.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

split is fine; I will change. Thanks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor Author

@NimishMishra NimishMishra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @tblah for the review and apologies for taking a while to circle back on this.

I will wait for your response on the floating point support of linear steps; rest changes look fine to me, I'll address them.


// Linear variable finalization is conditional on the last logical iteration.
// Create BB splits to manage the same.
void outlineLinearFinalizationBB(llvm::IRBuilderBase &builder,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

split is fine; I will change. Thanks

// Emit condition to check whether last logical iteration is being executed
builder.SetInsertPoint(linearFinalizationBB->getTerminator());
llvm::Value *loopLastIterLoad = builder.CreateLoad(
llvm::Type::getInt32Ty(builder.getContext()), lastIter);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am thinking of this from the perspective of checking whether the current iteration is the last iteration of the loop or not. Now that you mention it, I am not sure whether p.lastiter in canonical loop bodygen is a bool (i.e. a flag denoting whether this is the last iteration or not) or an integer (i.e. holding end - 1, where end is the loop end bound). It is better to have this load match the datatype of its counterpart in canonical loop bodygen

@NimishMishra NimishMishra force-pushed the linear_clause_translation branch from b783aa2 to 39ca840 Compare July 9, 2025 12:49
Copy link
Contributor Author

@NimishMishra NimishMishra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @tblah for the comments. I have changed the implementation to address your comments as well as support linear clause on simd.

Please note that I am planning to raise another PR for implicit linearization in simd. So I am not planning to merge this PR unless the that other PR gets reviewed and accepted.


// Linear variable finalization is conditional on the last logical iteration.
// Create BB splits to manage the same.
void outlineLinearFinalizationBB(llvm::IRBuilderBase &builder,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

// Emit condition to check whether last logical iteration is being executed
builder.SetInsertPoint(linearFinalizationBB->getTerminator());
llvm::Value *loopLastIterLoad = builder.CreateLoad(
llvm::Type::getInt32Ty(builder.getContext()), lastIter);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked the workshare loop allocas, %p.lastiter is alloca i32, align 4. So we can use a load i32 too. Would that be ok?


// Initialize linear variables and linear step
LinearClauseProcessor linearClauseProcessor;
if (wsloopOp.getLinearVars().size()) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@NimishMishra NimishMishra changed the title [llvm][mlir][OpenMP] Support translation for linear clause in omp.wsloop [llvm][mlir][OpenMP] Support translation for linear clause in omp.wsloop and omp.simd Jul 9, 2025
@@ -2608,7 +2608,7 @@ void SimdOp::build(OpBuilder &builder, OperationState &state,
// TODO Store clauses in op: linearVars, linearStepVars
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// TODO Store clauses in op: linearVars, linearStepVars

@@ -1816,8 +1816,7 @@ static void genSimdClauses(
cp.processReduction(loc, clauseOps, reductionSyms);
cp.processSafelen(clauseOps);
cp.processSimdlen(clauseOps);

cp.processTODO<clause::Linear>(loc, llvm::omp::Directive::OMPD_simd);
cp.processLinear(clauseOps);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current implementation of processLinear does not seem to guarantee that the variable comes from an alloca. For example what if it is a function argument?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran some tests, and the semantic checks were able to catch those cases. For instance, if the linear variable is anything other than an integer type, we get a semantic error noting that a linear variable with the REF modifier needs to be of integer type.

Do you think we should nevertheless add a check during FIR gen? I added a check during the translation but skipped adding it during FIR gen because of this reason.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

subroutine test(lin, lb, ub, step)
  integer :: lin, lb, ub, step, i

  !$omp do linear(lin)
  do i = lb,ub,step
    l = l + 2
  enddo
endsubroutine

Comment on lines +424 to +430
%8 = llvm.mlir.constant(2 : i32) : i32
%9 = llvm.mlir.constant(10 : i32) : i32
%10 = llvm.mlir.constant(1 : i32) : i32
%11 = llvm.mlir.constant(1 : i64) : i64
%12 = llvm.mlir.constant(1 : i64) : i64
%13 = llvm.mlir.constant(1 : i64) : i64
%14 = llvm.mlir.constant(1 : i64) : i64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: these could be trivially cleaned up with something like mlir-opt --canonicalize test_file.mlir

// CHECK: omp_loop.exit:
// CHECK: call void @__kmpc_for_static_fini(ptr @2, i32 %omp_global_thread_num4)
// CHECK: %omp_global_thread_num5 = call i32 @__kmpc_global_thread_num(ptr @2)
// CHECK: call void @__kmpc_barrier(ptr @3, i32 %omp_global_thread_num5)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need both this barrier and the one after processing the linear variables.

// CHECK: br label %omp_loop.header

// CHECK: omp_loop.body:
// CHECK: %[[LINEAR_LOAD:.*]] = load i32, ptr %[[LINEAR_VAR]], align 4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please could you test that these loads and stores do get the llvm access group metadata

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:openmp OpenMP related changes to Clang flang:fir-hlfir flang:openmp flang Flang issues not falling into any other category mlir:llvm mlir:openmp mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants