Skip to content

Commit 6a45873

Browse files
committed
[Flang][MLIR] Add !$omp unroll construct and omp.unroll_heuristic modeling
1 parent 383b326 commit 6a45873

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1718
-35
lines changed

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 158 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2128,6 +2128,161 @@ genLoopOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
21282128
return loopOp;
21292129
}
21302130

2131+
static mlir::omp::CanonicalLoopOp
2132+
genCanonicalLoopOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
2133+
semantics::SemanticsContext &semaCtx,
2134+
lower::pft::Evaluation &eval, mlir::Location loc,
2135+
const ConstructQueue &queue,
2136+
ConstructQueue::const_iterator item,
2137+
llvm::ArrayRef<const semantics::Symbol *> ivs,
2138+
llvm::omp::Directive directive, DataSharingProcessor &dsp) {
2139+
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
2140+
2141+
assert(ivs.size() == 1 && "Nested loops not yet implemented");
2142+
const semantics::Symbol *iv = ivs[0];
2143+
2144+
auto &nestedEval = eval.getFirstNestedEvaluation();
2145+
if (nestedEval.getIf<parser::DoConstruct>()->IsDoConcurrent()) {
2146+
TODO(loc, "Do Concurrent in unroll construct");
2147+
}
2148+
2149+
// Get the loop bounds (and increment)
2150+
auto &doLoopEval = nestedEval.getFirstNestedEvaluation();
2151+
auto *doStmt = doLoopEval.getIf<parser::NonLabelDoStmt>();
2152+
assert(doStmt && "Expected do loop to be in the nested evaluation");
2153+
auto &loopControl = std::get<std::optional<parser::LoopControl>>(doStmt->t);
2154+
assert(loopControl.has_value());
2155+
auto *bounds = std::get_if<parser::LoopControl::Bounds>(&loopControl->u);
2156+
assert(bounds && "Expected bounds for canonical loop");
2157+
lower::StatementContext stmtCtx;
2158+
mlir::Value loopLBVar = fir::getBase(
2159+
converter.genExprValue(*semantics::GetExpr(bounds->lower), stmtCtx));
2160+
mlir::Value loopUBVar = fir::getBase(
2161+
converter.genExprValue(*semantics::GetExpr(bounds->upper), stmtCtx));
2162+
mlir::Value loopStepVar = [&]() {
2163+
if (bounds->step) {
2164+
return fir::getBase(
2165+
converter.genExprValue(*semantics::GetExpr(bounds->step), stmtCtx));
2166+
} else {
2167+
// If `step` is not present, assume it is `1`.
2168+
return firOpBuilder.createIntegerConstant(loc, firOpBuilder.getI32Type(),
2169+
1);
2170+
}
2171+
}();
2172+
2173+
// Get the integer kind for the loop variable and cast the loop bounds
2174+
size_t loopVarTypeSize = bounds->name.thing.symbol->GetUltimate().size();
2175+
mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
2176+
loopLBVar = firOpBuilder.createConvert(loc, loopVarType, loopLBVar);
2177+
loopUBVar = firOpBuilder.createConvert(loc, loopVarType, loopUBVar);
2178+
loopStepVar = firOpBuilder.createConvert(loc, loopVarType, loopStepVar);
2179+
2180+
// Start lowering
2181+
mlir::Value zero = firOpBuilder.createIntegerConstant(loc, loopVarType, 0);
2182+
mlir::Value one = firOpBuilder.createIntegerConstant(loc, loopVarType, 1);
2183+
mlir::Value isDownwards = firOpBuilder.create<mlir::arith::CmpIOp>(
2184+
loc, mlir::arith::CmpIPredicate::slt, loopStepVar, zero);
2185+
2186+
// Ensure we are counting upwards. If not, negate step and swap lb and ub.
2187+
mlir::Value negStep =
2188+
firOpBuilder.create<mlir::arith::SubIOp>(loc, zero, loopStepVar);
2189+
mlir::Value incr = firOpBuilder.create<mlir::arith::SelectOp>(
2190+
loc, isDownwards, negStep, loopStepVar);
2191+
mlir::Value lb = firOpBuilder.create<mlir::arith::SelectOp>(
2192+
loc, isDownwards, loopUBVar, loopLBVar);
2193+
mlir::Value ub = firOpBuilder.create<mlir::arith::SelectOp>(
2194+
loc, isDownwards, loopLBVar, loopUBVar);
2195+
2196+
// Compute the trip count assuming lb <= ub. This guarantees that the result
2197+
// is non-negative and we can use unsigned arithmetic.
2198+
mlir::Value span = firOpBuilder.create<mlir::arith::SubIOp>(
2199+
loc, ub, lb, ::mlir::arith::IntegerOverflowFlags::nuw);
2200+
mlir::Value tcMinusOne =
2201+
firOpBuilder.create<mlir::arith::DivUIOp>(loc, span, incr);
2202+
mlir::Value tcIfLooping = firOpBuilder.create<mlir::arith::AddIOp>(
2203+
loc, tcMinusOne, one, ::mlir::arith::IntegerOverflowFlags::nuw);
2204+
2205+
// Fall back to 0 if lb > ub
2206+
mlir::Value isZeroTC = firOpBuilder.create<mlir::arith::CmpIOp>(
2207+
loc, mlir::arith::CmpIPredicate::slt, ub, lb);
2208+
mlir::Value tripcount = firOpBuilder.create<mlir::arith::SelectOp>(
2209+
loc, isZeroTC, zero, tcIfLooping);
2210+
2211+
// Create the CLI handle.
2212+
auto newcli = firOpBuilder.create<mlir::omp::NewCliOp>(loc);
2213+
mlir::Value cli = newcli.getResult();
2214+
2215+
auto ivCallback = [&](mlir::Operation *op)
2216+
-> llvm::SmallVector<const Fortran::semantics::Symbol *> {
2217+
mlir::Region &region = op->getRegion(0);
2218+
2219+
// Create the op's region skeleton (BB taking the iv as argument)
2220+
firOpBuilder.createBlock(&region, {}, {loopVarType}, {loc});
2221+
2222+
// Compute the value of the loop variable from the logical iteration number.
2223+
mlir::Value natIterNum = fir::getBase(region.front().getArgument(0));
2224+
mlir::Value scaled =
2225+
firOpBuilder.create<mlir::arith::MulIOp>(loc, natIterNum, loopStepVar);
2226+
mlir::Value userVal =
2227+
firOpBuilder.create<mlir::arith::AddIOp>(loc, loopLBVar, scaled);
2228+
2229+
// The argument is not currently in memory, so make a temporary for the
2230+
// argument, and store it there, then bind that location to the argument.
2231+
mlir::Operation *storeOp =
2232+
createAndSetPrivatizedLoopVar(converter, loc, userVal, iv);
2233+
2234+
firOpBuilder.setInsertionPointAfter(storeOp);
2235+
return {iv};
2236+
};
2237+
2238+
// Create the omp.canonical_loop operation
2239+
auto canonLoop = genOpWithBody<mlir::omp::CanonicalLoopOp>(
2240+
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, nestedEval,
2241+
directive)
2242+
.setClauses(&item->clauses)
2243+
.setDataSharingProcessor(&dsp)
2244+
.setGenRegionEntryCb(ivCallback),
2245+
queue, item, tripcount, cli);
2246+
2247+
firOpBuilder.setInsertionPointAfter(canonLoop);
2248+
return canonLoop;
2249+
}
2250+
2251+
static void genUnrollOp(Fortran::lower::AbstractConverter &converter,
2252+
Fortran::lower::SymMap &symTable,
2253+
lower::StatementContext &stmtCtx,
2254+
Fortran::semantics::SemanticsContext &semaCtx,
2255+
Fortran::lower::pft::Evaluation &eval,
2256+
mlir::Location loc, const ConstructQueue &queue,
2257+
ConstructQueue::const_iterator item) {
2258+
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
2259+
2260+
mlir::omp::LoopRelatedClauseOps loopInfo;
2261+
llvm::SmallVector<const semantics::Symbol *> iv;
2262+
collectLoopRelatedInfo(converter, loc, eval, item->clauses, loopInfo, iv);
2263+
2264+
// Clauses for unrolling not yet implemnted
2265+
ClauseProcessor cp(converter, semaCtx, item->clauses);
2266+
cp.processTODO<clause::Partial, clause::Full>(
2267+
loc, llvm::omp::Directive::OMPD_unroll);
2268+
2269+
// Even though unroll does not support data-sharing clauses, but this is
2270+
// required to fill the symbol table.
2271+
DataSharingProcessor dsp(converter, semaCtx, item->clauses, eval,
2272+
/*shouldCollectPreDeterminedSymbols=*/true,
2273+
/*useDelayedPrivatization=*/false, symTable);
2274+
dsp.processStep1();
2275+
2276+
// Emit the associated loop
2277+
auto canonLoop =
2278+
genCanonicalLoopOp(converter, symTable, semaCtx, eval, loc, queue, item,
2279+
iv, llvm::omp::Directive::OMPD_unroll, dsp);
2280+
2281+
// Apply unrolling to it
2282+
auto cli = canonLoop.getCli();
2283+
firOpBuilder.create<mlir::omp::UnrollHeuristicOp>(loc, cli);
2284+
}
2285+
21312286
static mlir::omp::MaskedOp
21322287
genMaskedOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
21332288
lower::StatementContext &stmtCtx,
@@ -3516,12 +3671,9 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
35163671
newOp = genTeamsOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue,
35173672
item);
35183673
break;
3519-
case llvm::omp::Directive::OMPD_tile:
3520-
case llvm::omp::Directive::OMPD_unroll: {
3521-
unsigned version = semaCtx.langOptions().OpenMPVersion;
3522-
TODO(loc, "Unhandled loop directive (" +
3523-
llvm::omp::getOpenMPDirectiveName(dir, version) + ")");
3524-
}
3674+
case llvm::omp::Directive::OMPD_unroll:
3675+
genUnrollOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item);
3676+
break;
35253677
// case llvm::omp::Directive::OMPD_workdistribute:
35263678
case llvm::omp::Directive::OMPD_workshare:
35273679
newOp = genWorkshareOp(converter, symTable, stmtCtx, semaCtx, eval, loc,

flang/test/Driver/Inputs/basic_cross_linux_tree/usr/bin/i386-unknown-linux-gnu-as

100755100644
File mode changed.

flang/test/Driver/Inputs/basic_cross_linux_tree/usr/bin/i386-unknown-linux-gnu-ld

100755100644
File mode changed.

flang/test/Driver/Inputs/basic_cross_linux_tree/usr/bin/i386-unknown-linux-gnu-ld.bfd

100755100644
File mode changed.

flang/test/Driver/Inputs/basic_cross_linux_tree/usr/bin/i386-unknown-linux-gnu-ld.gold

100755100644
File mode changed.

flang/test/Driver/Inputs/basic_cross_linux_tree/usr/bin/x86_64-unknown-linux-gnu-as

100755100644
File mode changed.

flang/test/Driver/Inputs/basic_cross_linux_tree/usr/bin/x86_64-unknown-linux-gnu-ld

100755100644
File mode changed.

flang/test/Driver/Inputs/basic_cross_linux_tree/usr/bin/x86_64-unknown-linux-gnu-ld.bfd

100755100644
File mode changed.

flang/test/Driver/Inputs/basic_cross_linux_tree/usr/bin/x86_64-unknown-linux-gnu-ld.gold

100755100644
File mode changed.

flang/test/Driver/Inputs/basic_cross_linux_tree/usr/i386-unknown-linux-gnu/bin/as

100755100644
File mode changed.

0 commit comments

Comments
 (0)