Skip to content

Commit 4211b4d

Browse files
committed
[MLIR][OpenMP] Use same helper to translate all loops, NFC
This patch makes the translation of workshare loops also use the `convertLoopNestHelper` function to handle the `omp.loop_nest` wrapped by the corresponding `omp.wsloop` or `omp.distribute` operation.
1 parent b8130f8 commit 4211b4d

File tree

1 file changed

+13
-91
lines changed

1 file changed

+13
-91
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 13 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -625,16 +625,14 @@ static void getSinkableAllocas(LLVM::ModuleTranslation &moduleTranslation,
625625

626626
// TODO: Make this a top-level conversion function (i.e. part of the switch
627627
// statement in `convertHostOrTargetOperation`) independent from parent
628-
// worksharing operations and update `convertOmpWsloop` to rely on this rather
629-
// than replicating the same logic.
628+
// worksharing operations.
630629
static std::optional<
631630
std::tuple<llvm::OpenMPIRBuilder::LocationDescription,
632631
llvm::IRBuilderBase::InsertPoint, llvm::CanonicalLoopInfo *>>
633-
convertLoopNestHelper(Operation &opInst, llvm::IRBuilderBase &builder,
632+
convertLoopNestHelper(omp::LoopNestOp loopOp, llvm::IRBuilderBase &builder,
634633
LLVM::ModuleTranslation &moduleTranslation,
635634
StringRef blockName) {
636635
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
637-
auto loopOp = cast<omp::LoopNestOp>(opInst);
638636

639637
// Set up the source location value for OpenMP runtime.
640638
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
@@ -643,8 +641,6 @@ convertLoopNestHelper(Operation &opInst, llvm::IRBuilderBase &builder,
643641
getSinkableAllocas(moduleTranslation, loopOp.getRegion(), allocasToSink);
644642

645643
// Generator of the canonical loop body.
646-
// TODO: support error propagation in OpenMPIRBuilder and use it instead of
647-
// relying on captured variables.
648644
SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
649645
SmallVector<llvm::OpenMPIRBuilder::InsertPointTy> bodyInsertPoints;
650646
auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
@@ -668,11 +664,14 @@ convertLoopNestHelper(Operation &opInst, llvm::IRBuilderBase &builder,
668664
unsigned size = alloca->getAllocatedType()->getPrimitiveSizeInBits() / 8;
669665
builder.CreateLifetimeStart(alloca, builder.getInt64(size));
670666
}
667+
671668
llvm::Expected<llvm::BasicBlock *> cont = convertOmpOpRegions(
672669
loopOp.getRegion(), blockName, builder, moduleTranslation);
673670
if (!cont)
674671
return cont.takeError();
672+
675673
builder.SetInsertPoint(*cont, (*cont)->begin());
674+
676675
for (auto *alloca : allocasToSink) {
677676
unsigned size = alloca->getAllocatedType()->getPrimitiveSizeInBits() / 8;
678677
builder.CreateLifetimeEnd(alloca, builder.getInt64(size));
@@ -706,7 +705,7 @@ convertLoopNestHelper(Operation &opInst, llvm::IRBuilderBase &builder,
706705
llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
707706
ompBuilder->createCanonicalLoop(
708707
loc, bodyGen, lowerBound, upperBound, step,
709-
/*IsSigned=*/true, /*InclusiveStop=*/true, computeIP);
708+
/*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
710709

711710
if (failed(handleError(loopResult, *loopOp)))
712711
return std::nullopt;
@@ -2121,90 +2120,13 @@ static LogicalResult generateOMPWorkshareLoop(
21212120
std::optional<omp::ScheduleModifier> &scheduleMod, bool loopNeedsBarier,
21222121
llvm::omp::WorksharingLoopType workshareLoopType) {
21232122
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2124-
// Set up the source location value for OpenMP runtime.
2125-
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2126-
2127-
SetVector<llvm::AllocaInst *> allocasToSink;
2128-
getSinkableAllocas(moduleTranslation, loopOp.getRegion(), allocasToSink);
2129-
2130-
// Generator of the canonical loop body.
2131-
SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
2132-
SmallVector<llvm::OpenMPIRBuilder::InsertPointTy> bodyInsertPoints;
2133-
auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
2134-
llvm::Value *iv) -> llvm::Error {
2135-
// Make sure further conversions know about the induction variable.
2136-
moduleTranslation.mapValue(
2137-
loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
2138-
2139-
// Capture the body insertion point for use in nested loops. BodyIP of the
2140-
// CanonicalLoopInfo always points to the beginning of the entry block of
2141-
// the body.
2142-
bodyInsertPoints.push_back(ip);
2143-
2144-
if (loopInfos.size() != loopOp.getNumLoops() - 1)
2145-
return llvm::Error::success();
2146-
2147-
// Convert the body of the loop, adding lifetime markers to allocations that
2148-
// can be sunk into the new block.
2149-
builder.restoreIP(ip);
2150-
for (auto *alloca : allocasToSink) {
2151-
unsigned size = alloca->getAllocatedType()->getPrimitiveSizeInBits() / 8;
2152-
builder.CreateLifetimeStart(alloca, builder.getInt64(size));
2153-
}
2154-
2155-
llvm::Expected<llvm::BasicBlock *> cont = convertOmpOpRegions(
2156-
loopOp.getRegion(), "omp.wsloop.region", builder, moduleTranslation);
2157-
if (!cont)
2158-
return cont.takeError();
2159-
2160-
builder.SetInsertPoint(*cont, (*cont)->begin());
2161-
2162-
for (auto *alloca : allocasToSink) {
2163-
unsigned size = alloca->getAllocatedType()->getPrimitiveSizeInBits() / 8;
2164-
builder.CreateLifetimeEnd(alloca, builder.getInt64(size));
2165-
}
2166-
return llvm::Error::success();
2167-
};
2168-
2169-
// Delegate actual loop construction to the OpenMP IRBuilder.
2170-
// TODO: this currently assumes omp.loop_nest is semantically similar to SCF
2171-
// loop, i.e. it has a positive step, uses signed integer semantics.
2172-
// Reconsider this code when the nested loop operation clearly supports more
2173-
// cases.
2174-
for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
2175-
llvm::Value *lowerBound =
2176-
moduleTranslation.lookupValue(loopOp.getLoopLowerBounds()[i]);
2177-
llvm::Value *upperBound =
2178-
moduleTranslation.lookupValue(loopOp.getLoopUpperBounds()[i]);
2179-
llvm::Value *step = moduleTranslation.lookupValue(loopOp.getLoopSteps()[i]);
2180-
2181-
// Make sure loop trip count are emitted in the preheader of the outermost
2182-
// loop at the latest so that they are all available for the new collapsed
2183-
// loop will be created below.
2184-
llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
2185-
llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
2186-
if (i != 0) {
2187-
loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back());
2188-
computeIP = loopInfos.front()->getPreheaderIP();
2189-
}
2190-
2191-
llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
2192-
ompBuilder->createCanonicalLoop(
2193-
loc, bodyGen, lowerBound, upperBound, step,
2194-
/*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
2195-
2196-
if (failed(handleError(loopResult, *loopOp)))
2197-
return failure();
2198-
2199-
loopInfos.push_back(*loopResult);
2200-
}
22012123

2202-
// Collapse loops. Store the insertion point because LoopInfos may get
2203-
// invalidated.
2204-
llvm::IRBuilderBase::InsertPoint afterIP = loopInfos.front()->getAfterIP();
2205-
llvm::CanonicalLoopInfo *loopInfo =
2206-
ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
2124+
auto loopNestConversionResult = convertLoopNestHelper(
2125+
loopOp, builder, moduleTranslation, "omp.wsloop.region");
2126+
if (!loopNestConversionResult)
2127+
return failure();
22072128

2129+
auto [ompLoc, afterIP, loopInfo] = *loopNestConversionResult;
22082130
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
22092131
findAllocaInsertPoint(builder, moduleTranslation);
22102132

@@ -2597,7 +2519,7 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
25972519
return failure();
25982520

25992521
auto loopNestConversionResult = convertLoopNestHelper(
2600-
*loopOp, builder, moduleTranslation, "omp.simd.region");
2522+
loopOp, builder, moduleTranslation, "omp.simd.region");
26012523
if (!loopNestConversionResult)
26022524
return failure();
26032525

@@ -4302,7 +4224,7 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
43024224
// TODO: Unify host and target lowering for standalone DISTRIBUTE
43034225
if (!isGPU) {
43044226
auto loopNestConversionResult = convertLoopNestHelper(
4305-
*loopOp, builder, moduleTranslation, "omp.distribute.region");
4227+
loopOp, builder, moduleTranslation, "omp.distribute.region");
43064228
if (!loopNestConversionResult)
43074229
return llvm::make_error<PreviouslyReportedError>();
43084230

0 commit comments

Comments
 (0)