-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[Flang] Implement !$omp unroll using omp.unroll_heuristic #144785
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: users/meinersbur/flang_canonical-loop_ops-lowering
Are you sure you want to change the base?
Changes from all commits
fed2aa7
30a533e
a12dca3
4577333
2415e33
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -64,6 +64,28 @@ static void processHostEvalClauses(lower::AbstractConverter &converter, | |
lower::pft::Evaluation &eval, | ||
mlir::Location loc); | ||
|
||
static llvm::omp::Directive | ||
getOpenMPDirectiveEnum(const parser::OmpLoopDirective &beginStatment) { | ||
return beginStatment.v; | ||
} | ||
|
||
static llvm::omp::Directive getOpenMPDirectiveEnum( | ||
const parser::OmpBeginLoopDirective &beginLoopDirective) { | ||
return getOpenMPDirectiveEnum( | ||
std::get<parser::OmpLoopDirective>(beginLoopDirective.t)); | ||
} | ||
|
||
static llvm::omp::Directive | ||
getOpenMPDirectiveEnum(const parser::OpenMPLoopConstruct &ompLoopConstruct) { | ||
return getOpenMPDirectiveEnum( | ||
std::get<parser::OmpBeginLoopDirective>(ompLoopConstruct.t)); | ||
} | ||
|
||
static llvm::omp::Directive getOpenMPDirectiveEnum( | ||
const common::Indirection<parser::OpenMPLoopConstruct> &ompLoopConstruct) { | ||
return getOpenMPDirectiveEnum(ompLoopConstruct.value()); | ||
} | ||
|
||
namespace { | ||
/// Structure holding information that is needed to pass host-evaluated | ||
/// information to later lowering stages. | ||
|
@@ -2154,6 +2176,163 @@ genLoopOp(lower::AbstractConverter &converter, lower::SymMap &symTable, | |
return loopOp; | ||
} | ||
|
||
static mlir::omp::CanonicalLoopOp | ||
genCanonicalLoopOp(lower::AbstractConverter &converter, lower::SymMap &symTable, | ||
semantics::SemanticsContext &semaCtx, | ||
lower::pft::Evaluation &eval, mlir::Location loc, | ||
const ConstructQueue &queue, | ||
ConstructQueue::const_iterator item, | ||
llvm::ArrayRef<const semantics::Symbol *> ivs, | ||
llvm::omp::Directive directive, DataSharingProcessor &dsp) { | ||
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); | ||
|
||
assert(ivs.size() == 1 && "Nested loops not yet implemented"); | ||
const semantics::Symbol *iv = ivs[0]; | ||
|
||
auto &nestedEval = eval.getFirstNestedEvaluation(); | ||
if (nestedEval.getIf<parser::DoConstruct>()->IsDoConcurrent()) { | ||
// OpenMP specifies DO CONCURRENT only with the `!omp loop` construct. Will | ||
// need to add special cases for this combination. | ||
TODO(loc, "DO CONCURRENT as canonical loop not supported"); | ||
} | ||
|
||
// Get the loop bounds (and increment) | ||
auto &doLoopEval = nestedEval.getFirstNestedEvaluation(); | ||
auto *doStmt = doLoopEval.getIf<parser::NonLabelDoStmt>(); | ||
assert(doStmt && "Expected do loop to be in the nested evaluation"); | ||
auto &loopControl = std::get<std::optional<parser::LoopControl>>(doStmt->t); | ||
assert(loopControl.has_value()); | ||
auto *bounds = std::get_if<parser::LoopControl::Bounds>(&loopControl->u); | ||
assert(bounds && "Expected bounds for canonical loop"); | ||
lower::StatementContext stmtCtx; | ||
mlir::Value loopLBVar = fir::getBase( | ||
converter.genExprValue(*semantics::GetExpr(bounds->lower), stmtCtx)); | ||
mlir::Value loopUBVar = fir::getBase( | ||
converter.genExprValue(*semantics::GetExpr(bounds->upper), stmtCtx)); | ||
mlir::Value loopStepVar = [&]() { | ||
if (bounds->step) { | ||
return fir::getBase( | ||
converter.genExprValue(*semantics::GetExpr(bounds->step), stmtCtx)); | ||
} | ||
|
||
// If `step` is not present, assume it is `1`. | ||
return firOpBuilder.createIntegerConstant(loc, firOpBuilder.getI32Type(), | ||
1); | ||
}(); | ||
|
||
// Get the integer kind for the loop variable and cast the loop bounds | ||
size_t loopVarTypeSize = bounds->name.thing.symbol->GetUltimate().size(); | ||
mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize); | ||
loopLBVar = firOpBuilder.createConvert(loc, loopVarType, loopLBVar); | ||
loopUBVar = firOpBuilder.createConvert(loc, loopVarType, loopUBVar); | ||
loopStepVar = firOpBuilder.createConvert(loc, loopVarType, loopStepVar); | ||
|
||
// Start lowering | ||
mlir::Value zero = firOpBuilder.createIntegerConstant(loc, loopVarType, 0); | ||
mlir::Value one = firOpBuilder.createIntegerConstant(loc, loopVarType, 1); | ||
mlir::Value isDownwards = firOpBuilder.create<mlir::arith::CmpIOp>( | ||
loc, mlir::arith::CmpIPredicate::slt, loopStepVar, zero); | ||
|
||
// Ensure we are counting upwards. If not, negate step and swap lb and ub. | ||
mlir::Value negStep = | ||
firOpBuilder.create<mlir::arith::SubIOp>(loc, zero, loopStepVar); | ||
mlir::Value incr = firOpBuilder.create<mlir::arith::SelectOp>( | ||
loc, isDownwards, negStep, loopStepVar); | ||
mlir::Value lb = firOpBuilder.create<mlir::arith::SelectOp>( | ||
loc, isDownwards, loopUBVar, loopLBVar); | ||
mlir::Value ub = firOpBuilder.create<mlir::arith::SelectOp>( | ||
loc, isDownwards, loopLBVar, loopUBVar); | ||
|
||
// Compute the trip count assuming lb <= ub. This guarantees that the result | ||
// is non-negative and we can use unsigned arithmetic. | ||
mlir::Value span = firOpBuilder.create<mlir::arith::SubIOp>( | ||
loc, ub, lb, ::mlir::arith::IntegerOverflowFlags::nuw); | ||
mlir::Value tcMinusOne = | ||
firOpBuilder.create<mlir::arith::DivUIOp>(loc, span, incr); | ||
mlir::Value tcIfLooping = firOpBuilder.create<mlir::arith::AddIOp>( | ||
loc, tcMinusOne, one, ::mlir::arith::IntegerOverflowFlags::nuw); | ||
|
||
// Fall back to 0 if lb > ub | ||
mlir::Value isZeroTC = firOpBuilder.create<mlir::arith::CmpIOp>( | ||
loc, mlir::arith::CmpIPredicate::slt, ub, lb); | ||
mlir::Value tripcount = firOpBuilder.create<mlir::arith::SelectOp>( | ||
loc, isZeroTC, zero, tcIfLooping); | ||
|
||
// Create the CLI handle. | ||
auto newcli = firOpBuilder.create<mlir::omp::NewCliOp>(loc); | ||
mlir::Value cli = newcli.getResult(); | ||
|
||
auto ivCallback = [&](mlir::Operation *op) | ||
-> llvm::SmallVector<const Fortran::semantics::Symbol *> { | ||
mlir::Region ®ion = op->getRegion(0); | ||
|
||
// Create the op's region skeleton (BB taking the iv as argument) | ||
firOpBuilder.createBlock(®ion, {}, {loopVarType}, {loc}); | ||
|
||
// Compute the value of the loop variable from the logical iteration number. | ||
mlir::Value natIterNum = fir::getBase(region.front().getArgument(0)); | ||
mlir::Value scaled = | ||
firOpBuilder.create<mlir::arith::MulIOp>(loc, natIterNum, loopStepVar); | ||
mlir::Value userVal = | ||
firOpBuilder.create<mlir::arith::AddIOp>(loc, loopLBVar, scaled); | ||
|
||
// The argument is not currently in memory, so make a temporary for the | ||
// argument, and store it there, then bind that location to the argument. | ||
mlir::Operation *storeOp = | ||
createAndSetPrivatizedLoopVar(converter, loc, userVal, iv); | ||
|
||
firOpBuilder.setInsertionPointAfter(storeOp); | ||
return {iv}; | ||
}; | ||
|
||
// Create the omp.canonical_loop operation | ||
auto canonLoop = genOpWithBody<mlir::omp::CanonicalLoopOp>( | ||
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, nestedEval, | ||
directive) | ||
.setClauses(&item->clauses) | ||
.setDataSharingProcessor(&dsp) | ||
.setGenRegionEntryCb(ivCallback), | ||
queue, item, tripcount, cli); | ||
|
||
firOpBuilder.setInsertionPointAfter(canonLoop); | ||
return canonLoop; | ||
} | ||
|
||
static void genUnrollOp(Fortran::lower::AbstractConverter &converter, | ||
Fortran::lower::SymMap &symTable, | ||
lower::StatementContext &stmtCtx, | ||
Fortran::semantics::SemanticsContext &semaCtx, | ||
Fortran::lower::pft::Evaluation &eval, | ||
mlir::Location loc, const ConstructQueue &queue, | ||
ConstructQueue::const_iterator item) { | ||
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); | ||
|
||
mlir::omp::LoopRelatedClauseOps loopInfo; | ||
llvm::SmallVector<const semantics::Symbol *> iv; | ||
collectLoopRelatedInfo(converter, loc, eval, item->clauses, loopInfo, iv); | ||
Comment on lines
+2310
to
+2312
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be moved to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The number of loops to collect info for depends on the construct (e.g. unroll: 1, tile: number of elements in the |
||
|
||
// Clauses for unrolling not yet implemnted | ||
ClauseProcessor cp(converter, semaCtx, item->clauses); | ||
cp.processTODO<clause::Partial, clause::Full>( | ||
loc, llvm::omp::Directive::OMPD_unroll); | ||
|
||
// Even though unroll does not support data-sharing clauses, but this is | ||
// required to fill the symbol table. | ||
DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval, | ||
/*shouldCollectPreDeterminedSymbols=*/true, | ||
/*useDelayedPrivatization=*/false, symTable); | ||
dsp.processStep1(); | ||
|
||
// Emit the associated loop | ||
auto canonLoop = | ||
genCanonicalLoopOp(converter, symTable, semaCtx, eval, loc, queue, item, | ||
iv, llvm::omp::Directive::OMPD_unroll, dsp); | ||
|
||
// Apply unrolling to it | ||
auto cli = canonLoop.getCli(); | ||
firOpBuilder.create<mlir::omp::UnrollHeuristicOp>(loc, cli); | ||
} | ||
|
||
static mlir::omp::MaskedOp | ||
genMaskedOp(lower::AbstractConverter &converter, lower::SymMap &symTable, | ||
lower::StatementContext &stmtCtx, | ||
|
@@ -3334,12 +3513,14 @@ static void genOMPDispatch(lower::AbstractConverter &converter, | |
newOp = genTeamsOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, | ||
item); | ||
break; | ||
case llvm::omp::Directive::OMPD_tile: | ||
case llvm::omp::Directive::OMPD_unroll: { | ||
case llvm::omp::Directive::OMPD_tile: { | ||
unsigned version = semaCtx.langOptions().OpenMPVersion; | ||
TODO(loc, "Unhandled loop directive (" + | ||
llvm::omp::getOpenMPDirectiveName(dir, version) + ")"); | ||
} | ||
case llvm::omp::Directive::OMPD_unroll: | ||
genUnrollOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item); | ||
break; | ||
tblah marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// case llvm::omp::Directive::OMPD_workdistribute: | ||
case llvm::omp::Directive::OMPD_workshare: | ||
newOp = genWorkshareOp(converter, symTable, stmtCtx, semaCtx, eval, loc, | ||
|
@@ -3775,12 +3956,25 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable, | |
if (auto *ompNestedLoopCons{ | ||
std::get_if<common::Indirection<parser::OpenMPLoopConstruct>>( | ||
&*optLoopCons)}) { | ||
genOMP(converter, symTable, semaCtx, eval, ompNestedLoopCons->value()); | ||
llvm::omp::Directive nestedDirective = | ||
getOpenMPDirectiveEnum(*ompNestedLoopCons); | ||
switch (nestedDirective) { | ||
case llvm::omp::Directive::OMPD_tile: | ||
// Emit the omp.loop_nest with annotation for tiling | ||
genOMP(converter, symTable, semaCtx, eval, ompNestedLoopCons->value()); | ||
break; | ||
default: { | ||
unsigned version = semaCtx.langOptions().OpenMPVersion; | ||
TODO(currentLocation, | ||
"Applying a loop-associated on the loop generated by the " + | ||
llvm::omp::getOpenMPDirectiveName(nestedDirective, version) + | ||
" construct"); | ||
} | ||
} | ||
} | ||
} | ||
|
||
llvm::omp::Directive directive = | ||
std::get<parser::OmpLoopDirective>(beginLoopDirective.t).v; | ||
llvm::omp::Directive directive = getOpenMPDirectiveEnum(beginLoopDirective); | ||
const parser::CharBlock &source = | ||
std::get<parser::OmpLoopDirective>(beginLoopDirective.t).source; | ||
ConstructQueue queue{ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=51 -o - %s 2>&1 | FileCheck %s | ||
|
||
|
||
subroutine omp_unroll_heuristic01(lb, ub, inc) | ||
integer res, i, lb, ub, inc | ||
|
||
!$omp unroll | ||
do i = lb, ub, inc | ||
res = i | ||
end do | ||
!$omp end unroll | ||
|
||
end subroutine omp_unroll_heuristic01 | ||
|
||
|
||
!CHECK-LABEL: func.func @_QPomp_unroll_heuristic01( | ||
!CHECK: %c0_i32 = arith.constant 0 : i32 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you capture SSA value names using descriptive names (e.g. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not possible in MLIR: SSA value names are always automatically generated and cannot be user-provided. In this case there is a special handler for |
||
!CHECK-NEXT: %c1_i32 = arith.constant 1 : i32 | ||
!CHECK-NEXT: %13 = arith.cmpi slt, %12, %c0_i32 : i32 | ||
!CHECK-NEXT: %14 = arith.subi %c0_i32, %12 : i32 | ||
!CHECK-NEXT: %15 = arith.select %13, %14, %12 : i32 | ||
!CHECK-NEXT: %16 = arith.select %13, %11, %10 : i32 | ||
!CHECK-NEXT: %17 = arith.select %13, %10, %11 : i32 | ||
!CHECK-NEXT: %18 = arith.subi %17, %16 overflow<nuw> : i32 | ||
!CHECK-NEXT: %19 = arith.divui %18, %15 : i32 | ||
!CHECK-NEXT: %20 = arith.addi %19, %c1_i32 overflow<nuw> : i32 | ||
!CHECK-NEXT: %21 = arith.cmpi slt, %17, %16 : i32 | ||
!CHECK-NEXT: %22 = arith.select %21, %c0_i32, %20 : i32 | ||
!CHECK-NEXT: %canonloop_s0 = omp.new_cli | ||
!CHECK-NEXT: omp.canonical_loop(%canonloop_s0) %iv : i32 in range(%22) { | ||
!CHECK-NEXT: %23 = arith.muli %iv, %12 : i32 | ||
!CHECK-NEXT: %24 = arith.addi %10, %23 : i32 | ||
!CHECK-NEXT: hlfir.assign %24 to %9#0 : i32, !fir.ref<i32> | ||
!CHECK-NEXT: %25 = fir.load %9#0 : !fir.ref<i32> | ||
!CHECK-NEXT: hlfir.assign %25 to %6#0 : i32, !fir.ref<i32> | ||
!CHECK-NEXT: omp.terminator | ||
!CHECK-NEXT: } | ||
!CHECK-NEXT: omp.unroll_heuristic(%canonloop_s0) | ||
!CHECK-NEXT: return |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=51 -o - %s 2>&1 | FileCheck %s | ||
|
||
|
||
subroutine omp_unroll_heuristic_nested02(outer_lb, outer_ub, outer_inc, inner_lb, inner_ub, inner_inc) | ||
integer res, i, j, inner_lb, inner_ub, inner_inc, outer_lb, outer_ub, outer_inc | ||
|
||
!$omp unroll | ||
do i = outer_lb, outer_ub, outer_inc | ||
!$omp unroll | ||
do j = inner_lb, inner_ub, inner_inc | ||
res = i + j | ||
end do | ||
!$omp end unroll | ||
end do | ||
!$omp end unroll | ||
|
||
end subroutine omp_unroll_heuristic_nested02 | ||
|
||
|
||
!CHECK-LABEL: func.func @_QPomp_unroll_heuristic_nested02(%arg0: !fir.ref<i32> {fir.bindc_name = "outer_lb"}, %arg1: !fir.ref<i32> {fir.bindc_name = "outer_ub"}, %arg2: !fir.ref<i32> {fir.bindc_name = "outer_inc"}, %arg3: !fir.ref<i32> {fir.bindc_name = "inner_lb"}, %arg4: !fir.ref<i32> {fir.bindc_name = "inner_ub"}, %arg5: !fir.ref<i32> {fir.bindc_name = "inner_inc"}) { | ||
!CHECK: %c0_i32 = arith.constant 0 : i32 | ||
!CHECK-NEXT: %c1_i32 = arith.constant 1 : i32 | ||
!CHECK-NEXT: %18 = arith.cmpi slt, %17, %c0_i32 : i32 | ||
!CHECK-NEXT: %19 = arith.subi %c0_i32, %17 : i32 | ||
!CHECK-NEXT: %20 = arith.select %18, %19, %17 : i32 | ||
!CHECK-NEXT: %21 = arith.select %18, %16, %15 : i32 | ||
!CHECK-NEXT: %22 = arith.select %18, %15, %16 : i32 | ||
!CHECK-NEXT: %23 = arith.subi %22, %21 overflow<nuw> : i32 | ||
!CHECK-NEXT: %24 = arith.divui %23, %20 : i32 | ||
!CHECK-NEXT: %25 = arith.addi %24, %c1_i32 overflow<nuw> : i32 | ||
!CHECK-NEXT: %26 = arith.cmpi slt, %22, %21 : i32 | ||
!CHECK-NEXT: %27 = arith.select %26, %c0_i32, %25 : i32 | ||
!CHECK-NEXT: %canonloop_s0 = omp.new_cli | ||
!CHECK-NEXT: omp.canonical_loop(%canonloop_s0) %iv : i32 in range(%27) { | ||
!CHECK-NEXT: %28 = arith.muli %iv, %17 : i32 | ||
!CHECK-NEXT: %29 = arith.addi %15, %28 : i32 | ||
!CHECK-NEXT: hlfir.assign %29 to %14#0 : i32, !fir.ref<i32> | ||
!CHECK-NEXT: %30 = fir.alloca i32 {bindc_name = "j", pinned, uniq_name = "_QFomp_unroll_heuristic_nested02Ej"} | ||
!CHECK-NEXT: %31:2 = hlfir.declare %30 {uniq_name = "_QFomp_unroll_heuristic_nested02Ej"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>) | ||
!CHECK-NEXT: %32 = fir.load %4#0 : !fir.ref<i32> | ||
!CHECK-NEXT: %33 = fir.load %5#0 : !fir.ref<i32> | ||
!CHECK-NEXT: %34 = fir.load %3#0 : !fir.ref<i32> | ||
!CHECK-NEXT: %c0_i32_0 = arith.constant 0 : i32 | ||
!CHECK-NEXT: %c1_i32_1 = arith.constant 1 : i32 | ||
!CHECK-NEXT: %35 = arith.cmpi slt, %34, %c0_i32_0 : i32 | ||
!CHECK-NEXT: %36 = arith.subi %c0_i32_0, %34 : i32 | ||
!CHECK-NEXT: %37 = arith.select %35, %36, %34 : i32 | ||
!CHECK-NEXT: %38 = arith.select %35, %33, %32 : i32 | ||
!CHECK-NEXT: %39 = arith.select %35, %32, %33 : i32 | ||
!CHECK-NEXT: %40 = arith.subi %39, %38 overflow<nuw> : i32 | ||
!CHECK-NEXT: %41 = arith.divui %40, %37 : i32 | ||
!CHECK-NEXT: %42 = arith.addi %41, %c1_i32_1 overflow<nuw> : i32 | ||
!CHECK-NEXT: %43 = arith.cmpi slt, %39, %38 : i32 | ||
!CHECK-NEXT: %44 = arith.select %43, %c0_i32_0, %42 : i32 | ||
!CHECK-NEXT: %canonloop_s0_s0 = omp.new_cli | ||
!CHECK-NEXT: omp.canonical_loop(%canonloop_s0_s0) %iv_2 : i32 in range(%44) { | ||
!CHECK-NEXT: %45 = arith.muli %iv_2, %34 : i32 | ||
!CHECK-NEXT: %46 = arith.addi %32, %45 : i32 | ||
!CHECK-NEXT: hlfir.assign %46 to %31#0 : i32, !fir.ref<i32> | ||
!CHECK-NEXT: %47 = fir.load %14#0 : !fir.ref<i32> | ||
!CHECK-NEXT: %48 = fir.load %31#0 : !fir.ref<i32> | ||
!CHECK-NEXT: %49 = arith.addi %47, %48 : i32 | ||
!CHECK-NEXT: hlfir.assign %49 to %12#0 : i32, !fir.ref<i32> | ||
!CHECK-NEXT: omp.terminator | ||
!CHECK-NEXT: } | ||
!CHECK-NEXT: omp.unroll_heuristic(%canonloop_s0_s0) | ||
!CHECK-NEXT: omp.terminator | ||
!CHECK-NEXT: } | ||
!CHECK-NEXT: omp.unroll_heuristic(%canonloop_s0) | ||
!CHECK-NEXT: return |
Uh oh!
There was an error while loading. Please reload this page.