Skip to content

Commit f2992aa

Browse files
committed
[MLIR][OpenMP] Improve translation of omp.distribute and omp.wsloop
This patch refactors the translation of `omp.wsloop` and `omp.distribute` operations to LLVM IR, makes some simplifications and minimizes edge cases in their handling. Some comments are added to help follow the non-trivial parts of the logic. The `generateOMPWorkshareLoop` helper function is inlined into the translation functions for both operations because that allows simplifying the translation of `omp.distribute`, which would only run part of that function in some cases. The resulting code duplication is limited to the translation of the operation's region (which is already done for every other operation) and a single call to `applyWorkshareLoop`.
1 parent 27d68a8 commit f2992aa

File tree

1 file changed

+69
-88
lines changed

1 file changed

+69
-88
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 69 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -2120,38 +2120,6 @@ convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder,
21202120
return success();
21212121
}
21222122

2123-
static LogicalResult generateOMPWorkshareLoop(
2124-
Operation &opInst, llvm::IRBuilderBase &builder,
2125-
LLVM::ModuleTranslation &moduleTranslation, llvm::Value *chunk,
2126-
bool isOrdered, bool isSimd, omp::ClauseScheduleKind &schedule,
2127-
std::optional<omp::ScheduleModifier> &scheduleMod, bool loopNeedsBarrier,
2128-
llvm::omp::WorksharingLoopType workshareLoopType) {
2129-
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2130-
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2131-
2132-
llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
2133-
opInst.getRegion(0), "omp.wsloop.region", builder, moduleTranslation);
2134-
2135-
if (failed(handleError(regionBlock, opInst)))
2136-
return failure();
2137-
2138-
builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
2139-
2140-
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2141-
findAllocaInsertPoint(builder, moduleTranslation);
2142-
llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
2143-
2144-
llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
2145-
ompBuilder->applyWorkshareLoop(
2146-
ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
2147-
convertToScheduleKind(schedule), chunk, isSimd,
2148-
scheduleMod == omp::ScheduleModifier::monotonic,
2149-
scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
2150-
workshareLoopType);
2151-
2152-
return handleError(wsloopIP, opInst);
2153-
}
2154-
21552123
/// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
21562124
static LogicalResult
21572125
convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
@@ -2240,22 +2208,36 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
22402208
bool isOrdered = wsloopOp.getOrdered().has_value();
22412209
std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
22422210
bool isSimd = wsloopOp.getScheduleSimd();
2243-
auto distributeParentOp = dyn_cast<omp::DistributeOp>(opInst.getParentOp());
2244-
// bool distributeCodeGen = opInst.getParentOfType<omp::DistributeOp>();
2211+
2212+
// The only legal way for the direct parent to be omp.distribute is that this
2213+
// represents 'distribute parallel do'. Otherwise, this is a regular
2214+
// worksharing loop.
22452215
llvm::omp::WorksharingLoopType workshareLoopType =
2246-
llvm::omp::WorksharingLoopType::ForStaticLoop;
2247-
if (distributeParentOp) {
2248-
if (isa<omp::ParallelOp>(distributeParentOp->getParentOp()))
2249-
workshareLoopType =
2250-
llvm::omp::WorksharingLoopType::DistributeForStaticLoop;
2251-
else
2252-
workshareLoopType = llvm::omp::WorksharingLoopType::DistributeStaticLoop;
2253-
}
2216+
llvm::isa_and_present<omp::DistributeOp>(opInst.getParentOp())
2217+
? llvm::omp::WorksharingLoopType::DistributeForStaticLoop
2218+
: llvm::omp::WorksharingLoopType::ForStaticLoop;
22542219

22552220
bool loopNeedsBarrier = !wsloopOp.getNowait();
2256-
if (failed(generateOMPWorkshareLoop(opInst, builder, moduleTranslation, chunk,
2257-
isOrdered, isSimd, schedule, scheduleMod,
2258-
loopNeedsBarrier, workshareLoopType)))
2221+
2222+
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2223+
llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
2224+
wsloopOp.getRegion(), "omp.wsloop.region", builder, moduleTranslation);
2225+
2226+
if (failed(handleError(regionBlock, opInst)))
2227+
return failure();
2228+
2229+
builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
2230+
llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
2231+
2232+
llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
2233+
moduleTranslation.getOpenMPBuilder()->applyWorkshareLoop(
2234+
ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
2235+
convertToScheduleKind(schedule), chunk, isSimd,
2236+
scheduleMod == omp::ScheduleModifier::monotonic,
2237+
scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
2238+
workshareLoopType);
2239+
2240+
if (failed(handleError(wsloopIP, opInst)))
22592241
return failure();
22602242

22612243
// Process the reductions if required.
@@ -4169,7 +4151,6 @@ static LogicalResult
41694151
convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
41704152
LLVM::ModuleTranslation &moduleTranslation) {
41714153
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4172-
// FIXME: This ignores any other nested wrappers (e.g. omp.wsloop, omp.simd).
41734154
auto distributeOp = cast<omp::DistributeOp>(opInst);
41744155
if (failed(checkImplementationStatus(opInst)))
41754156
return failure();
@@ -4216,46 +4197,45 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
42164197
// DistributeOp has only one region associated with it.
42174198
builder.restoreIP(codeGenIP);
42184199

4219-
if (!distributeOp.isComposite() ||
4220-
isa<omp::SimdOp>(distributeOp.getNestedWrapper())) {
4221-
// Convert a standalone DISTRIBUTE construct.
4222-
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4223-
bool isGPU = ompBuilder->Config.isGPU();
4224-
// TODO: Unify host and target lowering for standalone DISTRIBUTE
4225-
if (!isGPU) {
4226-
llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
4227-
distributeOp.getRegion(), "omp.distribute.region", builder,
4228-
moduleTranslation);
4229-
if (!regionBlock)
4230-
return regionBlock.takeError();
4231-
builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
4232-
return llvm::Error::success();
4233-
}
4234-
// TODO: Add support for clauses which are valid for DISTRIBUTE construct
4235-
// Static schedule is the default.
4236-
auto schedule = omp::ClauseScheduleKind::Static;
4237-
bool isOrdered = false;
4238-
std::optional<omp::ScheduleModifier> scheduleMod;
4239-
bool isSimd = false;
4240-
llvm::omp::WorksharingLoopType workshareLoopType =
4241-
llvm::omp::WorksharingLoopType::DistributeStaticLoop;
4242-
bool loopNeedsBarier = true;
4243-
llvm::Value *chunk = nullptr;
4244-
auto loopNestConversionResult = generateOMPWorkshareLoop(
4245-
opInst, builder, moduleTranslation, chunk, isOrdered, isSimd,
4246-
schedule, scheduleMod, loopNeedsBarier, workshareLoopType);
4247-
4248-
if (failed(loopNestConversionResult))
4249-
return llvm::make_error<PreviouslyReportedError>();
4250-
} else {
4251-
// Convert a DISTRIBUTE leaf as part of a composite construct.
4252-
mlir::Region &reg = distributeOp.getRegion();
4253-
llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
4254-
reg, "omp.distribute.region", builder, moduleTranslation);
4255-
if (!regionBlock)
4256-
return regionBlock.takeError();
4257-
builder.SetInsertPoint((*regionBlock)->getTerminator());
4258-
}
4200+
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4201+
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4202+
llvm::Expected<llvm::BasicBlock *> regionBlock =
4203+
convertOmpOpRegions(distributeOp.getRegion(), "omp.distribute.region",
4204+
builder, moduleTranslation);
4205+
if (!regionBlock)
4206+
return regionBlock.takeError();
4207+
builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
4208+
4209+
// Skip applying a workshare loop below when translating 'distribute
4210+
// parallel do' (it's been already handled by this point while translating
4211+
// the nested omp.wsloop) and when not targeting a GPU.
4212+
if (isa_and_present<omp::WsloopOp>(distributeOp.getNestedWrapper()) ||
4213+
!ompBuilder->Config.isGPU())
4214+
return llvm::Error::success();
4215+
4216+
// TODO: Add support for clauses which are valid for DISTRIBUTE construct
4217+
// Static schedule is the default.
4218+
auto schedule = omp::ClauseScheduleKind::Static;
4219+
bool isOrdered = false;
4220+
std::optional<omp::ScheduleModifier> scheduleMod;
4221+
bool isSimd = false;
4222+
llvm::omp::WorksharingLoopType workshareLoopType =
4223+
llvm::omp::WorksharingLoopType::DistributeStaticLoop;
4224+
bool loopNeedsBarier = true;
4225+
llvm::Value *chunk = nullptr;
4226+
4227+
llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
4228+
llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
4229+
ompBuilder->applyWorkshareLoop(
4230+
ompLoc.DL, loopInfo, allocaIP, loopNeedsBarier,
4231+
convertToScheduleKind(schedule), chunk, isSimd,
4232+
scheduleMod == omp::ScheduleModifier::monotonic,
4233+
scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
4234+
workshareLoopType);
4235+
4236+
if (!wsloopIP)
4237+
return wsloopIP.takeError();
4238+
42594239
return llvm::Error::success();
42604240
};
42614241

@@ -4265,8 +4245,9 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
42654245
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
42664246
ompBuilder->createDistribute(ompLoc, allocaIP, bodyGenCB);
42674247

4268-
if (!afterIP)
4269-
return opInst.emitError(llvm::toString(afterIP.takeError()));
4248+
if (failed(handleError(afterIP, opInst)))
4249+
return failure();
4250+
42704251
builder.restoreIP(*afterIP);
42714252

42724253
if (doDistributeReduction) {

0 commit comments

Comments
 (0)