-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[llvm][mlir][OpenMP] Support translation for linear clause in omp.wsloop and omp.simd #139386
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-flang-openmp Author: None (NimishMishra) ChangesThis patch adds support for LLVM translation of linear clause on omp.wsloop (except for linear modifiers). Patch is 25.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/139386.diff 10 Files Affected:
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 79b5087e4da68..8ba2f604df80a 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -1060,6 +1060,40 @@ bool ClauseProcessor::processIsDevicePtr(
});
}
+bool ClauseProcessor::processLinear(mlir::omp::LinearClauseOps &result) const {
+ lower::StatementContext stmtCtx;
+ return findRepeatableClause<
+ omp::clause::Linear>([&](const omp::clause::Linear &clause,
+ const parser::CharBlock &) {
+ auto &objects = std::get<omp::ObjectList>(clause.t);
+ for (const omp::Object &object : objects) {
+ semantics::Symbol *sym = object.sym();
+ const mlir::Value variable = converter.getSymbolAddress(*sym);
+ result.linearVars.push_back(variable);
+ }
+ if (objects.size()) {
+ if (auto &mod =
+ std::get<std::optional<omp::clause::Linear::StepComplexModifier>>(
+ clause.t)) {
+ mlir::Value operand =
+ fir::getBase(converter.genExprValue(toEvExpr(*mod), stmtCtx));
+ result.linearStepVars.append(objects.size(), operand);
+ } else if (std::get<std::optional<omp::clause::Linear::LinearModifier>>(
+ clause.t)) {
+ mlir::Location currentLocation = converter.getCurrentLocation();
+ TODO(currentLocation, "Linear modifiers not yet implemented");
+ } else {
+ // If nothing is present, add the default step of 1.
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+ mlir::Location currentLocation = converter.getCurrentLocation();
+ mlir::Value operand = firOpBuilder.createIntegerConstant(
+ currentLocation, firOpBuilder.getI32Type(), 1);
+ result.linearStepVars.append(objects.size(), operand);
+ }
+ }
+ });
+}
+
bool ClauseProcessor::processLink(
llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
return findRepeatableClause<omp::clause::Link>(
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 7857ba3fd0845..0ec41bdd33256 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -122,6 +122,7 @@ class ClauseProcessor {
bool processIsDevicePtr(
mlir::omp::IsDevicePtrClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
+ bool processLinear(mlir::omp::LinearClauseOps &result) const;
bool
processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
index 7eec598645eac..2a1c94407e1c8 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
@@ -213,14 +213,15 @@ void DataSharingProcessor::collectSymbolsForPrivatization() {
// so, we won't need to explicitely handle block objects (or forget to do
// so).
for (auto *sym : explicitlyPrivatizedSymbols)
- allPrivatizedSymbols.insert(sym);
+ if (!sym->test(Fortran::semantics::Symbol::Flag::OmpLinear))
+ allPrivatizedSymbols.insert(sym);
}
bool DataSharingProcessor::needBarrier() {
// Emit implicit barrier to synchronize threads and avoid data races on
// initialization of firstprivate variables and post-update of lastprivate
// variables.
- // Emit implicit barrier for linear clause. Maybe on somewhere else.
+ // Emit implicit barrier for linear clause in the OpenMPIRBuilder.
for (const semantics::Symbol *sym : allPrivatizedSymbols) {
if (sym->test(semantics::Symbol::Flag::OmpLastPrivate) &&
(sym->test(semantics::Symbol::Flag::OmpFirstPrivate) ||
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 54560729eb4af..6fa915b4364f9 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1841,13 +1841,13 @@ static void genWsloopClauses(
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processNowait(clauseOps);
+ cp.processLinear(clauseOps);
cp.processOrder(clauseOps);
cp.processOrdered(clauseOps);
cp.processReduction(loc, clauseOps, reductionSyms);
cp.processSchedule(stmtCtx, clauseOps);
- cp.processTODO<clause::Allocate, clause::Linear>(
- loc, llvm::omp::Directive::OMPD_do);
+ cp.processTODO<clause::Allocate>(loc, llvm::omp::Directive::OMPD_do);
}
//===----------------------------------------------------------------------===//
diff --git a/flang/test/Lower/OpenMP/wsloop-linear.f90 b/flang/test/Lower/OpenMP/wsloop-linear.f90
new file mode 100644
index 0000000000000..b99677108be2f
--- /dev/null
+++ b/flang/test/Lower/OpenMP/wsloop-linear.f90
@@ -0,0 +1,57 @@
+! This test checks lowering of OpenMP DO Directive (Worksharing)
+! with linear clause
+
+! RUN: %flang_fc1 -fopenmp -emit-hlfir %s -o - 2>&1 | FileCheck %s
+
+!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFsimple_linearEx"}
+!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFsimple_linearEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[const:.*]] = arith.constant 1 : i32
+subroutine simple_linear
+ implicit none
+ integer :: x, y, i
+ !CHECK: omp.wsloop linear(%[[X]]#0 = %[[const]] : !fir.ref<i32>) {{.*}}
+ !$omp do linear(x)
+ !CHECK: %[[LOAD:.*]] = fir.load %[[X]]#0 : !fir.ref<i32>
+ !CHECK: %[[const:.*]] = arith.constant 2 : i32
+ !CHECK: %[[RESULT:.*]] = arith.addi %[[LOAD]], %[[const]] : i32
+ do i = 1, 10
+ y = x + 2
+ end do
+ !$omp end do
+end subroutine
+
+
+!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFlinear_stepEx"}
+!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFlinear_stepEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+subroutine linear_step
+ implicit none
+ integer :: x, y, i
+ !CHECK: %[[const:.*]] = arith.constant 4 : i32
+ !CHECK: omp.wsloop linear(%[[X]]#0 = %[[const]] : !fir.ref<i32>) {{.*}}
+ !$omp do linear(x:4)
+ !CHECK: %[[LOAD:.*]] = fir.load %[[X]]#0 : !fir.ref<i32>
+ !CHECK: %[[const:.*]] = arith.constant 2 : i32
+ !CHECK: %[[RESULT:.*]] = arith.addi %[[LOAD]], %[[const]] : i32
+ do i = 1, 10
+ y = x + 2
+ end do
+ !$omp end do
+end subroutine
+
+!CHECK: %[[A_alloca:.*]] = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFlinear_exprEa"}
+!CHECK: %[[A:.*]]:2 = hlfir.declare %[[A_alloca]] {uniq_name = "_QFlinear_exprEa"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFlinear_exprEx"}
+!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFlinear_exprEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+subroutine linear_expr
+ implicit none
+ integer :: x, y, i, a
+ !CHECK: %[[LOAD_A:.*]] = fir.load %[[A]]#0 : !fir.ref<i32>
+ !CHECK: %[[const:.*]] = arith.constant 4 : i32
+ !CHECK: %[[LINEAR_EXPR:.*]] = arith.addi %[[LOAD_A]], %[[const]] : i32
+ !CHECK: omp.wsloop linear(%[[X]]#0 = %[[LINEAR_EXPR]] : !fir.ref<i32>) {{.*}}
+ !$omp do linear(x:a+4)
+ do i = 1, 10
+ y = x + 2
+ end do
+ !$omp end do
+end subroutine
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index ffc0fd0a0bdac..68f15d5c7d41e 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -3580,6 +3580,9 @@ class CanonicalLoopInfo {
BasicBlock *Latch = nullptr;
BasicBlock *Exit = nullptr;
+ // Hold the MLIR value for the `lastiter` of the canonical loop.
+ Value *LastIter = nullptr;
+
/// Add the control blocks of this loop to \p BBs.
///
/// This does not include any block from the body, including the one returned
@@ -3612,6 +3615,18 @@ class CanonicalLoopInfo {
void mapIndVar(llvm::function_ref<Value *(Instruction *)> Updater);
public:
+ /// Sets the last iteration variable for this loop.
+ void setLastIter(Value *IterVar) { LastIter = std::move(IterVar); }
+
+ /// Returns the last iteration variable for this loop.
+ /// Certain use-cases (like translation of linear clause) may access
+ /// this variable even after a loop transformation. Hence, do not guard
+ /// this getter function by `isValid`. It is the responsibility of the
+ /// callee to ensure this functionality is not invoked by a non-outlined
+ /// CanonicalLoopInfo object (in which case, `setLastIter` will never be
+ /// invoked and `LastIter` will be by default `nullptr`).
+ Value *getLastIter() { return LastIter; }
+
/// Returns whether this object currently represents the IR of a loop. If
/// returning false, it may have been consumed by a loop transformation or not
/// been intialized. Do not use in this case;
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index a1268ca76b2d5..991cdb7b6b416 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -4254,6 +4254,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyStaticWorkshareLoop(
Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
+ CLI->setLastIter(PLastIter);
// At the end of the preheader, prepare for calling the "init" function by
// storing the current loop bounds into the allocated space. A canonical loop
@@ -4361,6 +4362,7 @@ OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(DebugLoc DL,
Value *PUpperBound =
Builder.CreateAlloca(InternalIVTy, nullptr, "p.upperbound");
Value *PStride = Builder.CreateAlloca(InternalIVTy, nullptr, "p.stride");
+ CLI->setLastIter(PLastIter);
// Set up the source location value for the OpenMP runtime.
Builder.restoreIP(CLI->getPreheaderIP());
@@ -4844,6 +4846,7 @@ OpenMPIRBuilder::applyDynamicWorkshareLoop(DebugLoc DL, CanonicalLoopInfo *CLI,
Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
+ CLI->setLastIter(PLastIter);
// At the end of the preheader, prepare for calling the "init" function by
// storing the current loop bounds into the allocated space. A canonical loop
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 9f7b5605556e6..571505ab9b9aa 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -124,6 +124,146 @@ class PreviouslyReportedError
char PreviouslyReportedError::ID = 0;
+/*
+ * Custom class for processing linear clause for omp.wsloop
+ * and omp.simd. Linear clause translation requires setup,
+ * initialization, update, and finalization at varying
+ * basic blocks in the IR. This class helps maintain
+ * internal state to allow consistent translation in
+ * each of these stages.
+ */
+
+class LinearClauseProcessor {
+
+private:
+ SmallVector<llvm::Value *> linearPreconditionVars;
+ SmallVector<llvm::Value *> linearLoopBodyTemps;
+ SmallVector<llvm::AllocaInst *> linearOrigVars;
+ SmallVector<llvm::Value *> linearOrigVal;
+ SmallVector<llvm::Value *> linearSteps;
+ llvm::BasicBlock *linearFinalizationBB;
+ llvm::BasicBlock *linearExitBB;
+ llvm::BasicBlock *linearLastIterExitBB;
+
+public:
+ // Allocate space for linear variabes
+ void createLinearVar(llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation,
+ mlir::Value &linearVar) {
+ if (llvm::AllocaInst *linearVarAlloca = dyn_cast<llvm::AllocaInst>(
+ moduleTranslation.lookupValue(linearVar))) {
+ linearPreconditionVars.push_back(builder.CreateAlloca(
+ linearVarAlloca->getAllocatedType(), nullptr, ".linear_var"));
+ llvm::Value *linearLoopBodyTemp = builder.CreateAlloca(
+ linearVarAlloca->getAllocatedType(), nullptr, ".linear_result");
+ linearOrigVal.push_back(moduleTranslation.lookupValue(linearVar));
+ linearLoopBodyTemps.push_back(linearLoopBodyTemp);
+ linearOrigVars.push_back(linearVarAlloca);
+ }
+ }
+
+ // Initialize linear step
+ inline void initLinearStep(LLVM::ModuleTranslation &moduleTranslation,
+ mlir::Value &linearStep) {
+ linearSteps.push_back(moduleTranslation.lookupValue(linearStep));
+ }
+
+ // Emit IR for initialization of linear variables
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy
+ initLinearVar(llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation,
+ llvm::BasicBlock *loopPreHeader) {
+ builder.SetInsertPoint(loopPreHeader->getTerminator());
+ for (size_t index = 0; index < linearOrigVars.size(); index++) {
+ llvm::LoadInst *linearVarLoad = builder.CreateLoad(
+ linearOrigVars[index]->getAllocatedType(), linearOrigVars[index]);
+ builder.CreateStore(linearVarLoad, linearPreconditionVars[index]);
+ }
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
+ moduleTranslation.getOpenMPBuilder()->createBarrier(
+ builder.saveIP(), llvm::omp::OMPD_barrier);
+ return afterBarrierIP;
+ }
+
+ // Emit IR for updating Linear variables
+ void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
+ llvm::Value *loopInductionVar) {
+ builder.SetInsertPoint(loopBody->getTerminator());
+ for (size_t index = 0; index < linearPreconditionVars.size(); index++) {
+ // Emit increments for linear vars
+ llvm::LoadInst *linearVarStart =
+ builder.CreateLoad(linearOrigVars[index]->getAllocatedType(),
+
+ linearPreconditionVars[index]);
+ auto mulInst = builder.CreateMul(loopInductionVar, linearSteps[index]);
+ auto addInst = builder.CreateAdd(linearVarStart, mulInst);
+ builder.CreateStore(addInst, linearLoopBodyTemps[index]);
+ }
+ }
+
+ // Linear variable finalization is conditional on the last logical iteration.
+ // Create BB splits to manage the same.
+ void outlineLinearFinalizationBB(llvm::IRBuilderBase &builder,
+ llvm::BasicBlock *loopExit) {
+ linearFinalizationBB = loopExit->splitBasicBlock(
+ loopExit->getTerminator(), "omp_loop.linear_finalization");
+ linearExitBB = linearFinalizationBB->splitBasicBlock(
+ linearFinalizationBB->getTerminator(), "omp_loop.linear_exit");
+ linearLastIterExitBB = linearFinalizationBB->splitBasicBlock(
+ linearFinalizationBB->getTerminator(), "omp_loop.linear_lastiter_exit");
+ }
+
+ // Finalize the linear vars
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy
+ finalizeLinearVar(llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation,
+ llvm::Value *lastIter) {
+ // Emit condition to check whether last logical iteration is being executed
+ builder.SetInsertPoint(linearFinalizationBB->getTerminator());
+ llvm::Value *loopLastIterLoad = builder.CreateLoad(
+ llvm::Type::getInt32Ty(builder.getContext()), lastIter);
+ llvm::Value *isLast =
+ builder.CreateCmp(llvm::CmpInst::ICMP_NE, loopLastIterLoad,
+ llvm::ConstantInt::get(
+ llvm::Type::getInt32Ty(builder.getContext()), 0));
+ // Store the linear variable values to original variables.
+ builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
+ for (size_t index = 0; index < linearOrigVars.size(); index++) {
+ llvm::LoadInst *linearVarTemp =
+ builder.CreateLoad(linearOrigVars[index]->getAllocatedType(),
+ linearLoopBodyTemps[index]);
+ builder.CreateStore(linearVarTemp, linearOrigVars[index]);
+ }
+
+ // Create conditional branch such that the linear variable
+ // values are stored to original variables only at the
+ // last logical iteration
+ builder.SetInsertPoint(linearFinalizationBB->getTerminator());
+ builder.CreateCondBr(isLast, linearLastIterExitBB, linearExitBB);
+ linearFinalizationBB->getTerminator()->eraseFromParent();
+ // Emit barrier
+ builder.SetInsertPoint(linearExitBB->getTerminator());
+ return moduleTranslation.getOpenMPBuilder()->createBarrier(
+ builder.saveIP(), llvm::omp::OMPD_barrier);
+ }
+
+ // Rewrite all uses of the original variable in `BBName`
+ // with the linear variable in-place
+ void rewriteInPlace(llvm::IRBuilderBase &builder, std::string BBName,
+ size_t varIndex) {
+ llvm::SmallVector<llvm::User *> users;
+ for (llvm::User *user : linearOrigVal[varIndex]->users())
+ users.push_back(user);
+ for (auto *user : users) {
+ if (auto *userInst = dyn_cast<llvm::Instruction>(user)) {
+ if (userInst->getParent()->getName().str() == BBName)
+ user->replaceUsesOfWith(linearOrigVal[varIndex],
+ linearLoopBodyTemps[varIndex]);
+ }
+ }
+ }
+};
+
} // namespace
/// Looks up from the operation from and returns the PrivateClauseOp with
@@ -292,7 +432,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
})
.Case([&](omp::WsloopOp op) {
checkAllocate(op, result);
- checkLinear(op, result);
checkOrder(op, result);
checkReduction(op, result);
})
@@ -2423,15 +2562,40 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
llvm::omp::Directive::OMPD_for);
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+
+ // Initialize linear variables and linear step
+ LinearClauseProcessor linearClauseProcessor;
+ if (wsloopOp.getLinearVars().size()) {
+ for (mlir::Value linearVar : wsloopOp.getLinearVars())
+ linearClauseProcessor.createLinearVar(builder, moduleTranslation,
+ linearVar);
+ for (mlir::Value linearStep : wsloopOp.getLinearStepVars())
+ linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
+ }
+
llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
wsloopOp.getRegion(), "omp.wsloop.region", builder, moduleTranslation);
if (failed(handleError(regionBlock, opInst)))
return failure();
- builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
+ // Emit Initialization and Update IR for linear variables
+ if (wsloopOp.getLinearVars().size()) {
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
+ linearClauseProcessor.initLinearVar(builder, moduleTranslation,
+ loopInfo->getPreheader());
+ if (failed(handleError(afterBarrierIP, *loopOp)))
+ return failure();
+ builder.restoreIP(*afterBarrierIP);
+ linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
+ loopInfo->getIndVar());
+ linearClauseProcessor.outlineLinearFinalizationBB(builder,
+ loopInfo->getExit());
+ }
+
+ builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
ompBuilder->applyWorkshareLoop(
ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
@@ -2443,6 +2607,23 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
if (failed(handleError(wsloopIP, opInst)))
return failure();
+ // Emit finalization and in-place rewrites for linear vars.
+ if (wsloopOp.getLinearVars().size()) {
+ llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
+ assert(loopInfo->getLastIter() &&
+ "`lastiter` in CanonicalLoopInfo is nullptr"...
[truncated]
|
@llvm/pr-subscribers-mlir-llvm Author: None (NimishMishra) ChangesThis patch adds support for LLVM translation of linear clause on omp.wsloop (except for linear modifiers). Patch is 25.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/139386.diff 10 Files Affected:
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 79b5087e4da68..8ba2f604df80a 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -1060,6 +1060,40 @@ bool ClauseProcessor::processIsDevicePtr(
});
}
+bool ClauseProcessor::processLinear(mlir::omp::LinearClauseOps &result) const {
+ lower::StatementContext stmtCtx;
+ return findRepeatableClause<
+ omp::clause::Linear>([&](const omp::clause::Linear &clause,
+ const parser::CharBlock &) {
+ auto &objects = std::get<omp::ObjectList>(clause.t);
+ for (const omp::Object &object : objects) {
+ semantics::Symbol *sym = object.sym();
+ const mlir::Value variable = converter.getSymbolAddress(*sym);
+ result.linearVars.push_back(variable);
+ }
+ if (objects.size()) {
+ if (auto &mod =
+ std::get<std::optional<omp::clause::Linear::StepComplexModifier>>(
+ clause.t)) {
+ mlir::Value operand =
+ fir::getBase(converter.genExprValue(toEvExpr(*mod), stmtCtx));
+ result.linearStepVars.append(objects.size(), operand);
+ } else if (std::get<std::optional<omp::clause::Linear::LinearModifier>>(
+ clause.t)) {
+ mlir::Location currentLocation = converter.getCurrentLocation();
+ TODO(currentLocation, "Linear modifiers not yet implemented");
+ } else {
+ // If nothing is present, add the default step of 1.
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+ mlir::Location currentLocation = converter.getCurrentLocation();
+ mlir::Value operand = firOpBuilder.createIntegerConstant(
+ currentLocation, firOpBuilder.getI32Type(), 1);
+ result.linearStepVars.append(objects.size(), operand);
+ }
+ }
+ });
+}
+
bool ClauseProcessor::processLink(
llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
return findRepeatableClause<omp::clause::Link>(
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 7857ba3fd0845..0ec41bdd33256 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -122,6 +122,7 @@ class ClauseProcessor {
bool processIsDevicePtr(
mlir::omp::IsDevicePtrClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
+ bool processLinear(mlir::omp::LinearClauseOps &result) const;
bool
processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
index 7eec598645eac..2a1c94407e1c8 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
@@ -213,14 +213,15 @@ void DataSharingProcessor::collectSymbolsForPrivatization() {
// so, we won't need to explicitely handle block objects (or forget to do
// so).
for (auto *sym : explicitlyPrivatizedSymbols)
- allPrivatizedSymbols.insert(sym);
+ if (!sym->test(Fortran::semantics::Symbol::Flag::OmpLinear))
+ allPrivatizedSymbols.insert(sym);
}
bool DataSharingProcessor::needBarrier() {
// Emit implicit barrier to synchronize threads and avoid data races on
// initialization of firstprivate variables and post-update of lastprivate
// variables.
- // Emit implicit barrier for linear clause. Maybe on somewhere else.
+ // Emit implicit barrier for linear clause in the OpenMPIRBuilder.
for (const semantics::Symbol *sym : allPrivatizedSymbols) {
if (sym->test(semantics::Symbol::Flag::OmpLastPrivate) &&
(sym->test(semantics::Symbol::Flag::OmpFirstPrivate) ||
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 54560729eb4af..6fa915b4364f9 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1841,13 +1841,13 @@ static void genWsloopClauses(
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processNowait(clauseOps);
+ cp.processLinear(clauseOps);
cp.processOrder(clauseOps);
cp.processOrdered(clauseOps);
cp.processReduction(loc, clauseOps, reductionSyms);
cp.processSchedule(stmtCtx, clauseOps);
- cp.processTODO<clause::Allocate, clause::Linear>(
- loc, llvm::omp::Directive::OMPD_do);
+ cp.processTODO<clause::Allocate>(loc, llvm::omp::Directive::OMPD_do);
}
//===----------------------------------------------------------------------===//
diff --git a/flang/test/Lower/OpenMP/wsloop-linear.f90 b/flang/test/Lower/OpenMP/wsloop-linear.f90
new file mode 100644
index 0000000000000..b99677108be2f
--- /dev/null
+++ b/flang/test/Lower/OpenMP/wsloop-linear.f90
@@ -0,0 +1,57 @@
+! This test checks lowering of OpenMP DO Directive (Worksharing)
+! with linear clause
+
+! RUN: %flang_fc1 -fopenmp -emit-hlfir %s -o - 2>&1 | FileCheck %s
+
+!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFsimple_linearEx"}
+!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFsimple_linearEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[const:.*]] = arith.constant 1 : i32
+subroutine simple_linear
+ implicit none
+ integer :: x, y, i
+ !CHECK: omp.wsloop linear(%[[X]]#0 = %[[const]] : !fir.ref<i32>) {{.*}}
+ !$omp do linear(x)
+ !CHECK: %[[LOAD:.*]] = fir.load %[[X]]#0 : !fir.ref<i32>
+ !CHECK: %[[const:.*]] = arith.constant 2 : i32
+ !CHECK: %[[RESULT:.*]] = arith.addi %[[LOAD]], %[[const]] : i32
+ do i = 1, 10
+ y = x + 2
+ end do
+ !$omp end do
+end subroutine
+
+
+!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFlinear_stepEx"}
+!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFlinear_stepEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+subroutine linear_step
+ implicit none
+ integer :: x, y, i
+ !CHECK: %[[const:.*]] = arith.constant 4 : i32
+ !CHECK: omp.wsloop linear(%[[X]]#0 = %[[const]] : !fir.ref<i32>) {{.*}}
+ !$omp do linear(x:4)
+ !CHECK: %[[LOAD:.*]] = fir.load %[[X]]#0 : !fir.ref<i32>
+ !CHECK: %[[const:.*]] = arith.constant 2 : i32
+ !CHECK: %[[RESULT:.*]] = arith.addi %[[LOAD]], %[[const]] : i32
+ do i = 1, 10
+ y = x + 2
+ end do
+ !$omp end do
+end subroutine
+
+!CHECK: %[[A_alloca:.*]] = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFlinear_exprEa"}
+!CHECK: %[[A:.*]]:2 = hlfir.declare %[[A_alloca]] {uniq_name = "_QFlinear_exprEa"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFlinear_exprEx"}
+!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFlinear_exprEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+subroutine linear_expr
+ implicit none
+ integer :: x, y, i, a
+ !CHECK: %[[LOAD_A:.*]] = fir.load %[[A]]#0 : !fir.ref<i32>
+ !CHECK: %[[const:.*]] = arith.constant 4 : i32
+ !CHECK: %[[LINEAR_EXPR:.*]] = arith.addi %[[LOAD_A]], %[[const]] : i32
+ !CHECK: omp.wsloop linear(%[[X]]#0 = %[[LINEAR_EXPR]] : !fir.ref<i32>) {{.*}}
+ !$omp do linear(x:a+4)
+ do i = 1, 10
+ y = x + 2
+ end do
+ !$omp end do
+end subroutine
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index ffc0fd0a0bdac..68f15d5c7d41e 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -3580,6 +3580,9 @@ class CanonicalLoopInfo {
BasicBlock *Latch = nullptr;
BasicBlock *Exit = nullptr;
+ // Hold the MLIR value for the `lastiter` of the canonical loop.
+ Value *LastIter = nullptr;
+
/// Add the control blocks of this loop to \p BBs.
///
/// This does not include any block from the body, including the one returned
@@ -3612,6 +3615,18 @@ class CanonicalLoopInfo {
void mapIndVar(llvm::function_ref<Value *(Instruction *)> Updater);
public:
+ /// Sets the last iteration variable for this loop.
+ void setLastIter(Value *IterVar) { LastIter = std::move(IterVar); }
+
+ /// Returns the last iteration variable for this loop.
+ /// Certain use-cases (like translation of linear clause) may access
+ /// this variable even after a loop transformation. Hence, do not guard
+ /// this getter function by `isValid`. It is the responsibility of the
+ /// callee to ensure this functionality is not invoked by a non-outlined
+ /// CanonicalLoopInfo object (in which case, `setLastIter` will never be
+ /// invoked and `LastIter` will be by default `nullptr`).
+ Value *getLastIter() { return LastIter; }
+
/// Returns whether this object currently represents the IR of a loop. If
/// returning false, it may have been consumed by a loop transformation or not
/// been intialized. Do not use in this case;
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index a1268ca76b2d5..991cdb7b6b416 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -4254,6 +4254,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyStaticWorkshareLoop(
Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
+ CLI->setLastIter(PLastIter);
// At the end of the preheader, prepare for calling the "init" function by
// storing the current loop bounds into the allocated space. A canonical loop
@@ -4361,6 +4362,7 @@ OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(DebugLoc DL,
Value *PUpperBound =
Builder.CreateAlloca(InternalIVTy, nullptr, "p.upperbound");
Value *PStride = Builder.CreateAlloca(InternalIVTy, nullptr, "p.stride");
+ CLI->setLastIter(PLastIter);
// Set up the source location value for the OpenMP runtime.
Builder.restoreIP(CLI->getPreheaderIP());
@@ -4844,6 +4846,7 @@ OpenMPIRBuilder::applyDynamicWorkshareLoop(DebugLoc DL, CanonicalLoopInfo *CLI,
Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
+ CLI->setLastIter(PLastIter);
// At the end of the preheader, prepare for calling the "init" function by
// storing the current loop bounds into the allocated space. A canonical loop
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 9f7b5605556e6..571505ab9b9aa 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -124,6 +124,146 @@ class PreviouslyReportedError
char PreviouslyReportedError::ID = 0;
+/*
+ * Custom class for processing linear clause for omp.wsloop
+ * and omp.simd. Linear clause translation requires setup,
+ * initialization, update, and finalization at varying
+ * basic blocks in the IR. This class helps maintain
+ * internal state to allow consistent translation in
+ * each of these stages.
+ */
+
+class LinearClauseProcessor {
+
+private:
+ SmallVector<llvm::Value *> linearPreconditionVars;
+ SmallVector<llvm::Value *> linearLoopBodyTemps;
+ SmallVector<llvm::AllocaInst *> linearOrigVars;
+ SmallVector<llvm::Value *> linearOrigVal;
+ SmallVector<llvm::Value *> linearSteps;
+ llvm::BasicBlock *linearFinalizationBB;
+ llvm::BasicBlock *linearExitBB;
+ llvm::BasicBlock *linearLastIterExitBB;
+
+public:
+ // Allocate space for linear variabes
+ void createLinearVar(llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation,
+ mlir::Value &linearVar) {
+ if (llvm::AllocaInst *linearVarAlloca = dyn_cast<llvm::AllocaInst>(
+ moduleTranslation.lookupValue(linearVar))) {
+ linearPreconditionVars.push_back(builder.CreateAlloca(
+ linearVarAlloca->getAllocatedType(), nullptr, ".linear_var"));
+ llvm::Value *linearLoopBodyTemp = builder.CreateAlloca(
+ linearVarAlloca->getAllocatedType(), nullptr, ".linear_result");
+ linearOrigVal.push_back(moduleTranslation.lookupValue(linearVar));
+ linearLoopBodyTemps.push_back(linearLoopBodyTemp);
+ linearOrigVars.push_back(linearVarAlloca);
+ }
+ }
+
+ // Initialize linear step
+ inline void initLinearStep(LLVM::ModuleTranslation &moduleTranslation,
+ mlir::Value &linearStep) {
+ linearSteps.push_back(moduleTranslation.lookupValue(linearStep));
+ }
+
+ // Emit IR for initialization of linear variables
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy
+ initLinearVar(llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation,
+ llvm::BasicBlock *loopPreHeader) {
+ builder.SetInsertPoint(loopPreHeader->getTerminator());
+ for (size_t index = 0; index < linearOrigVars.size(); index++) {
+ llvm::LoadInst *linearVarLoad = builder.CreateLoad(
+ linearOrigVars[index]->getAllocatedType(), linearOrigVars[index]);
+ builder.CreateStore(linearVarLoad, linearPreconditionVars[index]);
+ }
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
+ moduleTranslation.getOpenMPBuilder()->createBarrier(
+ builder.saveIP(), llvm::omp::OMPD_barrier);
+ return afterBarrierIP;
+ }
+
+ // Emit IR for updating Linear variables
+ void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
+ llvm::Value *loopInductionVar) {
+ builder.SetInsertPoint(loopBody->getTerminator());
+ for (size_t index = 0; index < linearPreconditionVars.size(); index++) {
+ // Emit increments for linear vars
+ llvm::LoadInst *linearVarStart =
+ builder.CreateLoad(linearOrigVars[index]->getAllocatedType(),
+
+ linearPreconditionVars[index]);
+ auto mulInst = builder.CreateMul(loopInductionVar, linearSteps[index]);
+ auto addInst = builder.CreateAdd(linearVarStart, mulInst);
+ builder.CreateStore(addInst, linearLoopBodyTemps[index]);
+ }
+ }
+
+ // Linear variable finalization is conditional on the last logical iteration.
+ // Create BB splits to manage the same.
+ void outlineLinearFinalizationBB(llvm::IRBuilderBase &builder,
+ llvm::BasicBlock *loopExit) {
+ linearFinalizationBB = loopExit->splitBasicBlock(
+ loopExit->getTerminator(), "omp_loop.linear_finalization");
+ linearExitBB = linearFinalizationBB->splitBasicBlock(
+ linearFinalizationBB->getTerminator(), "omp_loop.linear_exit");
+ linearLastIterExitBB = linearFinalizationBB->splitBasicBlock(
+ linearFinalizationBB->getTerminator(), "omp_loop.linear_lastiter_exit");
+ }
+
+ // Finalize the linear vars
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy
+ finalizeLinearVar(llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation,
+ llvm::Value *lastIter) {
+ // Emit condition to check whether last logical iteration is being executed
+ builder.SetInsertPoint(linearFinalizationBB->getTerminator());
+ llvm::Value *loopLastIterLoad = builder.CreateLoad(
+ llvm::Type::getInt32Ty(builder.getContext()), lastIter);
+ llvm::Value *isLast =
+ builder.CreateCmp(llvm::CmpInst::ICMP_NE, loopLastIterLoad,
+ llvm::ConstantInt::get(
+ llvm::Type::getInt32Ty(builder.getContext()), 0));
+ // Store the linear variable values to original variables.
+ builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
+ for (size_t index = 0; index < linearOrigVars.size(); index++) {
+ llvm::LoadInst *linearVarTemp =
+ builder.CreateLoad(linearOrigVars[index]->getAllocatedType(),
+ linearLoopBodyTemps[index]);
+ builder.CreateStore(linearVarTemp, linearOrigVars[index]);
+ }
+
+ // Create conditional branch such that the linear variable
+ // values are stored to original variables only at the
+ // last logical iteration
+ builder.SetInsertPoint(linearFinalizationBB->getTerminator());
+ builder.CreateCondBr(isLast, linearLastIterExitBB, linearExitBB);
+ linearFinalizationBB->getTerminator()->eraseFromParent();
+ // Emit barrier
+ builder.SetInsertPoint(linearExitBB->getTerminator());
+ return moduleTranslation.getOpenMPBuilder()->createBarrier(
+ builder.saveIP(), llvm::omp::OMPD_barrier);
+ }
+
+ // Rewrite all uses of the original variable in `BBName`
+ // with the linear variable in-place
+ void rewriteInPlace(llvm::IRBuilderBase &builder, std::string BBName,
+ size_t varIndex) {
+ llvm::SmallVector<llvm::User *> users;
+ for (llvm::User *user : linearOrigVal[varIndex]->users())
+ users.push_back(user);
+ for (auto *user : users) {
+ if (auto *userInst = dyn_cast<llvm::Instruction>(user)) {
+ if (userInst->getParent()->getName().str() == BBName)
+ user->replaceUsesOfWith(linearOrigVal[varIndex],
+ linearLoopBodyTemps[varIndex]);
+ }
+ }
+ }
+};
+
} // namespace
/// Looks up from the operation from and returns the PrivateClauseOp with
@@ -292,7 +432,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
})
.Case([&](omp::WsloopOp op) {
checkAllocate(op, result);
- checkLinear(op, result);
checkOrder(op, result);
checkReduction(op, result);
})
@@ -2423,15 +2562,40 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
llvm::omp::Directive::OMPD_for);
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+
+ // Initialize linear variables and linear step
+ LinearClauseProcessor linearClauseProcessor;
+ if (wsloopOp.getLinearVars().size()) {
+ for (mlir::Value linearVar : wsloopOp.getLinearVars())
+ linearClauseProcessor.createLinearVar(builder, moduleTranslation,
+ linearVar);
+ for (mlir::Value linearStep : wsloopOp.getLinearStepVars())
+ linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
+ }
+
llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
wsloopOp.getRegion(), "omp.wsloop.region", builder, moduleTranslation);
if (failed(handleError(regionBlock, opInst)))
return failure();
- builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
+ // Emit Initialization and Update IR for linear variables
+ if (wsloopOp.getLinearVars().size()) {
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
+ linearClauseProcessor.initLinearVar(builder, moduleTranslation,
+ loopInfo->getPreheader());
+ if (failed(handleError(afterBarrierIP, *loopOp)))
+ return failure();
+ builder.restoreIP(*afterBarrierIP);
+ linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
+ loopInfo->getIndVar());
+ linearClauseProcessor.outlineLinearFinalizationBB(builder,
+ loopInfo->getExit());
+ }
+
+ builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
ompBuilder->applyWorkshareLoop(
ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
@@ -2443,6 +2607,23 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
if (failed(handleError(wsloopIP, opInst)))
return failure();
+ // Emit finalization and in-place rewrites for linear vars.
+ if (wsloopOp.getLinearVars().size()) {
+ llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
+ assert(loopInfo->getLastIter() &&
+ "`lastiter` in CanonicalLoopInfo is nullptr"...
[truncated]
|
@llvm/pr-subscribers-flang-fir-hlfir Author: None (NimishMishra) ChangesThis patch adds support for LLVM translation of linear clause on omp.wsloop (except for linear modifiers). Patch is 25.69 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/139386.diff 10 Files Affected:
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 79b5087e4da68..8ba2f604df80a 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -1060,6 +1060,40 @@ bool ClauseProcessor::processIsDevicePtr(
});
}
+bool ClauseProcessor::processLinear(mlir::omp::LinearClauseOps &result) const {
+ lower::StatementContext stmtCtx;
+ return findRepeatableClause<
+ omp::clause::Linear>([&](const omp::clause::Linear &clause,
+ const parser::CharBlock &) {
+ auto &objects = std::get<omp::ObjectList>(clause.t);
+ for (const omp::Object &object : objects) {
+ semantics::Symbol *sym = object.sym();
+ const mlir::Value variable = converter.getSymbolAddress(*sym);
+ result.linearVars.push_back(variable);
+ }
+ if (objects.size()) {
+ if (auto &mod =
+ std::get<std::optional<omp::clause::Linear::StepComplexModifier>>(
+ clause.t)) {
+ mlir::Value operand =
+ fir::getBase(converter.genExprValue(toEvExpr(*mod), stmtCtx));
+ result.linearStepVars.append(objects.size(), operand);
+ } else if (std::get<std::optional<omp::clause::Linear::LinearModifier>>(
+ clause.t)) {
+ mlir::Location currentLocation = converter.getCurrentLocation();
+ TODO(currentLocation, "Linear modifiers not yet implemented");
+ } else {
+ // If nothing is present, add the default step of 1.
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+ mlir::Location currentLocation = converter.getCurrentLocation();
+ mlir::Value operand = firOpBuilder.createIntegerConstant(
+ currentLocation, firOpBuilder.getI32Type(), 1);
+ result.linearStepVars.append(objects.size(), operand);
+ }
+ }
+ });
+}
+
bool ClauseProcessor::processLink(
llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
return findRepeatableClause<omp::clause::Link>(
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 7857ba3fd0845..0ec41bdd33256 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -122,6 +122,7 @@ class ClauseProcessor {
bool processIsDevicePtr(
mlir::omp::IsDevicePtrClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
+ bool processLinear(mlir::omp::LinearClauseOps &result) const;
bool
processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
index 7eec598645eac..2a1c94407e1c8 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
@@ -213,14 +213,15 @@ void DataSharingProcessor::collectSymbolsForPrivatization() {
// so, we won't need to explicitely handle block objects (or forget to do
// so).
for (auto *sym : explicitlyPrivatizedSymbols)
- allPrivatizedSymbols.insert(sym);
+ if (!sym->test(Fortran::semantics::Symbol::Flag::OmpLinear))
+ allPrivatizedSymbols.insert(sym);
}
bool DataSharingProcessor::needBarrier() {
// Emit implicit barrier to synchronize threads and avoid data races on
// initialization of firstprivate variables and post-update of lastprivate
// variables.
- // Emit implicit barrier for linear clause. Maybe on somewhere else.
+ // Emit implicit barrier for linear clause in the OpenMPIRBuilder.
for (const semantics::Symbol *sym : allPrivatizedSymbols) {
if (sym->test(semantics::Symbol::Flag::OmpLastPrivate) &&
(sym->test(semantics::Symbol::Flag::OmpFirstPrivate) ||
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 54560729eb4af..6fa915b4364f9 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1841,13 +1841,13 @@ static void genWsloopClauses(
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processNowait(clauseOps);
+ cp.processLinear(clauseOps);
cp.processOrder(clauseOps);
cp.processOrdered(clauseOps);
cp.processReduction(loc, clauseOps, reductionSyms);
cp.processSchedule(stmtCtx, clauseOps);
- cp.processTODO<clause::Allocate, clause::Linear>(
- loc, llvm::omp::Directive::OMPD_do);
+ cp.processTODO<clause::Allocate>(loc, llvm::omp::Directive::OMPD_do);
}
//===----------------------------------------------------------------------===//
diff --git a/flang/test/Lower/OpenMP/wsloop-linear.f90 b/flang/test/Lower/OpenMP/wsloop-linear.f90
new file mode 100644
index 0000000000000..b99677108be2f
--- /dev/null
+++ b/flang/test/Lower/OpenMP/wsloop-linear.f90
@@ -0,0 +1,57 @@
+! This test checks lowering of OpenMP DO Directive (Worksharing)
+! with linear clause
+
+! RUN: %flang_fc1 -fopenmp -emit-hlfir %s -o - 2>&1 | FileCheck %s
+
+!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFsimple_linearEx"}
+!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFsimple_linearEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[const:.*]] = arith.constant 1 : i32
+subroutine simple_linear
+ implicit none
+ integer :: x, y, i
+ !CHECK: omp.wsloop linear(%[[X]]#0 = %[[const]] : !fir.ref<i32>) {{.*}}
+ !$omp do linear(x)
+ !CHECK: %[[LOAD:.*]] = fir.load %[[X]]#0 : !fir.ref<i32>
+ !CHECK: %[[const:.*]] = arith.constant 2 : i32
+ !CHECK: %[[RESULT:.*]] = arith.addi %[[LOAD]], %[[const]] : i32
+ do i = 1, 10
+ y = x + 2
+ end do
+ !$omp end do
+end subroutine
+
+
+!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFlinear_stepEx"}
+!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFlinear_stepEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+subroutine linear_step
+ implicit none
+ integer :: x, y, i
+ !CHECK: %[[const:.*]] = arith.constant 4 : i32
+ !CHECK: omp.wsloop linear(%[[X]]#0 = %[[const]] : !fir.ref<i32>) {{.*}}
+ !$omp do linear(x:4)
+ !CHECK: %[[LOAD:.*]] = fir.load %[[X]]#0 : !fir.ref<i32>
+ !CHECK: %[[const:.*]] = arith.constant 2 : i32
+ !CHECK: %[[RESULT:.*]] = arith.addi %[[LOAD]], %[[const]] : i32
+ do i = 1, 10
+ y = x + 2
+ end do
+ !$omp end do
+end subroutine
+
+!CHECK: %[[A_alloca:.*]] = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFlinear_exprEa"}
+!CHECK: %[[A:.*]]:2 = hlfir.declare %[[A_alloca]] {uniq_name = "_QFlinear_exprEa"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFlinear_exprEx"}
+!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFlinear_exprEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+subroutine linear_expr
+ implicit none
+ integer :: x, y, i, a
+ !CHECK: %[[LOAD_A:.*]] = fir.load %[[A]]#0 : !fir.ref<i32>
+ !CHECK: %[[const:.*]] = arith.constant 4 : i32
+ !CHECK: %[[LINEAR_EXPR:.*]] = arith.addi %[[LOAD_A]], %[[const]] : i32
+ !CHECK: omp.wsloop linear(%[[X]]#0 = %[[LINEAR_EXPR]] : !fir.ref<i32>) {{.*}}
+ !$omp do linear(x:a+4)
+ do i = 1, 10
+ y = x + 2
+ end do
+ !$omp end do
+end subroutine
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index ffc0fd0a0bdac..68f15d5c7d41e 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -3580,6 +3580,9 @@ class CanonicalLoopInfo {
BasicBlock *Latch = nullptr;
BasicBlock *Exit = nullptr;
+ // Hold the MLIR value for the `lastiter` of the canonical loop.
+ Value *LastIter = nullptr;
+
/// Add the control blocks of this loop to \p BBs.
///
/// This does not include any block from the body, including the one returned
@@ -3612,6 +3615,18 @@ class CanonicalLoopInfo {
void mapIndVar(llvm::function_ref<Value *(Instruction *)> Updater);
public:
+ /// Sets the last iteration variable for this loop.
+ void setLastIter(Value *IterVar) { LastIter = std::move(IterVar); }
+
+ /// Returns the last iteration variable for this loop.
+ /// Certain use-cases (like translation of linear clause) may access
+ /// this variable even after a loop transformation. Hence, do not guard
+ /// this getter function by `isValid`. It is the responsibility of the
+ /// callee to ensure this functionality is not invoked by a non-outlined
+ /// CanonicalLoopInfo object (in which case, `setLastIter` will never be
+ /// invoked and `LastIter` will be by default `nullptr`).
+ Value *getLastIter() { return LastIter; }
+
/// Returns whether this object currently represents the IR of a loop. If
/// returning false, it may have been consumed by a loop transformation or not
/// been intialized. Do not use in this case;
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index a1268ca76b2d5..991cdb7b6b416 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -4254,6 +4254,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyStaticWorkshareLoop(
Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
+ CLI->setLastIter(PLastIter);
// At the end of the preheader, prepare for calling the "init" function by
// storing the current loop bounds into the allocated space. A canonical loop
@@ -4361,6 +4362,7 @@ OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(DebugLoc DL,
Value *PUpperBound =
Builder.CreateAlloca(InternalIVTy, nullptr, "p.upperbound");
Value *PStride = Builder.CreateAlloca(InternalIVTy, nullptr, "p.stride");
+ CLI->setLastIter(PLastIter);
// Set up the source location value for the OpenMP runtime.
Builder.restoreIP(CLI->getPreheaderIP());
@@ -4844,6 +4846,7 @@ OpenMPIRBuilder::applyDynamicWorkshareLoop(DebugLoc DL, CanonicalLoopInfo *CLI,
Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
+ CLI->setLastIter(PLastIter);
// At the end of the preheader, prepare for calling the "init" function by
// storing the current loop bounds into the allocated space. A canonical loop
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 9f7b5605556e6..571505ab9b9aa 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -124,6 +124,146 @@ class PreviouslyReportedError
char PreviouslyReportedError::ID = 0;
+/*
+ * Custom class for processing linear clause for omp.wsloop
+ * and omp.simd. Linear clause translation requires setup,
+ * initialization, update, and finalization at varying
+ * basic blocks in the IR. This class helps maintain
+ * internal state to allow consistent translation in
+ * each of these stages.
+ */
+
+class LinearClauseProcessor {
+
+private:
+ SmallVector<llvm::Value *> linearPreconditionVars;
+ SmallVector<llvm::Value *> linearLoopBodyTemps;
+ SmallVector<llvm::AllocaInst *> linearOrigVars;
+ SmallVector<llvm::Value *> linearOrigVal;
+ SmallVector<llvm::Value *> linearSteps;
+ llvm::BasicBlock *linearFinalizationBB;
+ llvm::BasicBlock *linearExitBB;
+ llvm::BasicBlock *linearLastIterExitBB;
+
+public:
+ // Allocate space for linear variabes
+ void createLinearVar(llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation,
+ mlir::Value &linearVar) {
+ if (llvm::AllocaInst *linearVarAlloca = dyn_cast<llvm::AllocaInst>(
+ moduleTranslation.lookupValue(linearVar))) {
+ linearPreconditionVars.push_back(builder.CreateAlloca(
+ linearVarAlloca->getAllocatedType(), nullptr, ".linear_var"));
+ llvm::Value *linearLoopBodyTemp = builder.CreateAlloca(
+ linearVarAlloca->getAllocatedType(), nullptr, ".linear_result");
+ linearOrigVal.push_back(moduleTranslation.lookupValue(linearVar));
+ linearLoopBodyTemps.push_back(linearLoopBodyTemp);
+ linearOrigVars.push_back(linearVarAlloca);
+ }
+ }
+
+ // Initialize linear step
+ inline void initLinearStep(LLVM::ModuleTranslation &moduleTranslation,
+ mlir::Value &linearStep) {
+ linearSteps.push_back(moduleTranslation.lookupValue(linearStep));
+ }
+
+ // Emit IR for initialization of linear variables
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy
+ initLinearVar(llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation,
+ llvm::BasicBlock *loopPreHeader) {
+ builder.SetInsertPoint(loopPreHeader->getTerminator());
+ for (size_t index = 0; index < linearOrigVars.size(); index++) {
+ llvm::LoadInst *linearVarLoad = builder.CreateLoad(
+ linearOrigVars[index]->getAllocatedType(), linearOrigVars[index]);
+ builder.CreateStore(linearVarLoad, linearPreconditionVars[index]);
+ }
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
+ moduleTranslation.getOpenMPBuilder()->createBarrier(
+ builder.saveIP(), llvm::omp::OMPD_barrier);
+ return afterBarrierIP;
+ }
+
+ // Emit IR for updating Linear variables
+ void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
+ llvm::Value *loopInductionVar) {
+ builder.SetInsertPoint(loopBody->getTerminator());
+ for (size_t index = 0; index < linearPreconditionVars.size(); index++) {
+ // Emit increments for linear vars
+ llvm::LoadInst *linearVarStart =
+ builder.CreateLoad(linearOrigVars[index]->getAllocatedType(),
+
+ linearPreconditionVars[index]);
+ auto mulInst = builder.CreateMul(loopInductionVar, linearSteps[index]);
+ auto addInst = builder.CreateAdd(linearVarStart, mulInst);
+ builder.CreateStore(addInst, linearLoopBodyTemps[index]);
+ }
+ }
+
+ // Linear variable finalization is conditional on the last logical iteration.
+ // Create BB splits to manage the same.
+ void outlineLinearFinalizationBB(llvm::IRBuilderBase &builder,
+ llvm::BasicBlock *loopExit) {
+ linearFinalizationBB = loopExit->splitBasicBlock(
+ loopExit->getTerminator(), "omp_loop.linear_finalization");
+ linearExitBB = linearFinalizationBB->splitBasicBlock(
+ linearFinalizationBB->getTerminator(), "omp_loop.linear_exit");
+ linearLastIterExitBB = linearFinalizationBB->splitBasicBlock(
+ linearFinalizationBB->getTerminator(), "omp_loop.linear_lastiter_exit");
+ }
+
+ // Finalize the linear vars
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy
+ finalizeLinearVar(llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation,
+ llvm::Value *lastIter) {
+ // Emit condition to check whether last logical iteration is being executed
+ builder.SetInsertPoint(linearFinalizationBB->getTerminator());
+ llvm::Value *loopLastIterLoad = builder.CreateLoad(
+ llvm::Type::getInt32Ty(builder.getContext()), lastIter);
+ llvm::Value *isLast =
+ builder.CreateCmp(llvm::CmpInst::ICMP_NE, loopLastIterLoad,
+ llvm::ConstantInt::get(
+ llvm::Type::getInt32Ty(builder.getContext()), 0));
+ // Store the linear variable values to original variables.
+ builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
+ for (size_t index = 0; index < linearOrigVars.size(); index++) {
+ llvm::LoadInst *linearVarTemp =
+ builder.CreateLoad(linearOrigVars[index]->getAllocatedType(),
+ linearLoopBodyTemps[index]);
+ builder.CreateStore(linearVarTemp, linearOrigVars[index]);
+ }
+
+ // Create conditional branch such that the linear variable
+ // values are stored to original variables only at the
+ // last logical iteration
+ builder.SetInsertPoint(linearFinalizationBB->getTerminator());
+ builder.CreateCondBr(isLast, linearLastIterExitBB, linearExitBB);
+ linearFinalizationBB->getTerminator()->eraseFromParent();
+ // Emit barrier
+ builder.SetInsertPoint(linearExitBB->getTerminator());
+ return moduleTranslation.getOpenMPBuilder()->createBarrier(
+ builder.saveIP(), llvm::omp::OMPD_barrier);
+ }
+
+ // Rewrite all uses of the original variable in `BBName`
+ // with the linear variable in-place
+ void rewriteInPlace(llvm::IRBuilderBase &builder, std::string BBName,
+ size_t varIndex) {
+ llvm::SmallVector<llvm::User *> users;
+ for (llvm::User *user : linearOrigVal[varIndex]->users())
+ users.push_back(user);
+ for (auto *user : users) {
+ if (auto *userInst = dyn_cast<llvm::Instruction>(user)) {
+ if (userInst->getParent()->getName().str() == BBName)
+ user->replaceUsesOfWith(linearOrigVal[varIndex],
+ linearLoopBodyTemps[varIndex]);
+ }
+ }
+ }
+};
+
} // namespace
/// Looks up from the operation from and returns the PrivateClauseOp with
@@ -292,7 +432,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
})
.Case([&](omp::WsloopOp op) {
checkAllocate(op, result);
- checkLinear(op, result);
checkOrder(op, result);
checkReduction(op, result);
})
@@ -2423,15 +2562,40 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
llvm::omp::Directive::OMPD_for);
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+
+ // Initialize linear variables and linear step
+ LinearClauseProcessor linearClauseProcessor;
+ if (wsloopOp.getLinearVars().size()) {
+ for (mlir::Value linearVar : wsloopOp.getLinearVars())
+ linearClauseProcessor.createLinearVar(builder, moduleTranslation,
+ linearVar);
+ for (mlir::Value linearStep : wsloopOp.getLinearStepVars())
+ linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
+ }
+
llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
wsloopOp.getRegion(), "omp.wsloop.region", builder, moduleTranslation);
if (failed(handleError(regionBlock, opInst)))
return failure();
- builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
+ // Emit Initialization and Update IR for linear variables
+ if (wsloopOp.getLinearVars().size()) {
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
+ linearClauseProcessor.initLinearVar(builder, moduleTranslation,
+ loopInfo->getPreheader());
+ if (failed(handleError(afterBarrierIP, *loopOp)))
+ return failure();
+ builder.restoreIP(*afterBarrierIP);
+ linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(),
+ loopInfo->getIndVar());
+ linearClauseProcessor.outlineLinearFinalizationBB(builder,
+ loopInfo->getExit());
+ }
+
+ builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
ompBuilder->applyWorkshareLoop(
ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
@@ -2443,6 +2607,23 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
if (failed(handleError(wsloopIP, opInst)))
return failure();
+ // Emit finalization and in-place rewrites for linear vars.
+ if (wsloopOp.getLinearVars().size()) {
+ llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
+ assert(loopInfo->getLastIter() &&
+ "`lastiter` in CanonicalLoopInfo is nullptr"...
[truncated]
|
The following test case from OpenMP examples document compiles successfully with a combination of PR #139385 and this PR:
Flang output: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the progress so far
// Emit condition to check whether last logical iteration is being executed | ||
builder.SetInsertPoint(linearFinalizationBB->getTerminator()); | ||
llvm::Value *loopLastIterLoad = builder.CreateLoad( | ||
llvm::Type::getInt32Ty(builder.getContext()), lastIter); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this always i32?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am thinking of this from the perspective of checking whether the current iteration is the last iteration of the loop or not. Now that you mention it, I am not sure whether p.lastiter
in canonical loop bodygen is a bool
(i.e. a flag denoting whether this is the last iteration or not) or an integer (i.e. holding end
- 1, where end
is the loop end bound). It is better to have this load match the datatype of its counterpart in canonical loop bodygen
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked the workshare loop allocas, %p.lastiter is alloca i32, align 4
. So we can use a load i32
too. Would that be ok?
|
||
// Initialize linear variables and linear step | ||
LinearClauseProcessor linearClauseProcessor; | ||
if (wsloopOp.getLinearVars().size()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: This if is unnecessary. The for loops will not execute if there are no linear vars.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
||
// Linear variable finalization is conditional on the last logical iteration. | ||
// Create BB splits to manage the same. | ||
void outlineLinearFinalizationBB(llvm::IRBuilderBase &builder, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For me at least, this isn't what I expect from the term "outline". For me, "outline" would mean moving blocks into a different function. Perhaps a better name would be splitLinearFiniBlock
.
Feel free to ignore if others disagree.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
split
is fine; I will change. Thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @tblah for the review and apologies for taking a while to circle back on this.
I will wait for your response on the floating point support of linear steps; rest changes look fine to me, I'll address them.
|
||
// Linear variable finalization is conditional on the last logical iteration. | ||
// Create BB splits to manage the same. | ||
void outlineLinearFinalizationBB(llvm::IRBuilderBase &builder, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
split
is fine; I will change. Thanks
// Emit condition to check whether last logical iteration is being executed | ||
builder.SetInsertPoint(linearFinalizationBB->getTerminator()); | ||
llvm::Value *loopLastIterLoad = builder.CreateLoad( | ||
llvm::Type::getInt32Ty(builder.getContext()), lastIter); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am thinking of this from the perspective of checking whether the current iteration is the last iteration of the loop or not. Now that you mention it, I am not sure whether p.lastiter
in canonical loop bodygen is a bool
(i.e. a flag denoting whether this is the last iteration or not) or an integer (i.e. holding end
- 1, where end
is the loop end bound). It is better to have this load match the datatype of its counterpart in canonical loop bodygen
b783aa2
to
39ca840
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @tblah for the comments. I have changed the implementation to address your comments as well as support linear clause on simd.
Please note that I am planning to raise another PR for implicit linearization in simd. So I am not planning to merge this PR unless the that other PR gets reviewed and accepted.
|
||
// Linear variable finalization is conditional on the last logical iteration. | ||
// Create BB splits to manage the same. | ||
void outlineLinearFinalizationBB(llvm::IRBuilderBase &builder, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
// Emit condition to check whether last logical iteration is being executed | ||
builder.SetInsertPoint(linearFinalizationBB->getTerminator()); | ||
llvm::Value *loopLastIterLoad = builder.CreateLoad( | ||
llvm::Type::getInt32Ty(builder.getContext()), lastIter); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked the workshare loop allocas, %p.lastiter is alloca i32, align 4
. So we can use a load i32
too. Would that be ok?
|
||
// Initialize linear variables and linear step | ||
LinearClauseProcessor linearClauseProcessor; | ||
if (wsloopOp.getLinearVars().size()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -2608,7 +2608,7 @@ void SimdOp::build(OpBuilder &builder, OperationState &state, | |||
// TODO Store clauses in op: linearVars, linearStepVars |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// TODO Store clauses in op: linearVars, linearStepVars |
@@ -1816,8 +1816,7 @@ static void genSimdClauses( | |||
cp.processReduction(loc, clauseOps, reductionSyms); | |||
cp.processSafelen(clauseOps); | |||
cp.processSimdlen(clauseOps); | |||
|
|||
cp.processTODO<clause::Linear>(loc, llvm::omp::Directive::OMPD_simd); | |||
cp.processLinear(clauseOps); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation of processLinear
does not seem to guarantee that the variable comes from an alloca. For example what if it is a function argument?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ran some tests, and the semantic checks were able to catch those cases. For instance, if the linear variable is anything other than an integer
type, we get a semantic error noting that a linear variable with the REF modifier needs to be of integer type.
Do you think we should nevertheless add a check during FIR gen? I added a check during the translation but skipped adding it during FIR gen because of this reason.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
subroutine test(lin, lb, ub, step)
integer :: lin, lb, ub, step, i
!$omp do linear(lin)
do i = lb,ub,step
l = l + 2
enddo
endsubroutine
%8 = llvm.mlir.constant(2 : i32) : i32 | ||
%9 = llvm.mlir.constant(10 : i32) : i32 | ||
%10 = llvm.mlir.constant(1 : i32) : i32 | ||
%11 = llvm.mlir.constant(1 : i64) : i64 | ||
%12 = llvm.mlir.constant(1 : i64) : i64 | ||
%13 = llvm.mlir.constant(1 : i64) : i64 | ||
%14 = llvm.mlir.constant(1 : i64) : i64 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: these could be trivially cleaned up with something like mlir-opt --canonicalize test_file.mlir
// CHECK: omp_loop.exit: | ||
// CHECK: call void @__kmpc_for_static_fini(ptr @2, i32 %omp_global_thread_num4) | ||
// CHECK: %omp_global_thread_num5 = call i32 @__kmpc_global_thread_num(ptr @2) | ||
// CHECK: call void @__kmpc_barrier(ptr @3, i32 %omp_global_thread_num5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need both this barrier and the one after processing the linear variables.
// CHECK: br label %omp_loop.header | ||
|
||
// CHECK: omp_loop.body: | ||
// CHECK: %[[LINEAR_LOAD:.*]] = load i32, ptr %[[LINEAR_VAR]], align 4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please could you test that these loads and stores do get the llvm access group metadata
This patch adds support for LLVM translation of linear clause on omp.wsloop (except for linear modifiers).