Skip to content

Commit 72b8744

Browse files
authored
[MLIR][OpenMP] Reduce overhead of target compilation (#130945)
This patch avoids calling `TargetOp::getInnermostCapturedOmpOp` multiple times during initialization of default and runtime target attributes in MLIR to LLVM IR translation of `omp.target` operations. This is a potentially expensive operation, so this change should help keep compile times lower.
1 parent f2541ce commit 72b8744

File tree

3 files changed

+30
-15
lines changed

3 files changed

+30
-15
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1342,7 +1342,11 @@ def TargetOp : OpenMP_Op<"target", traits = [
13421342

13431343
/// Infers the kernel type (Generic, SPMD or Generic-SPMD) based on the
13441344
/// contents of the target region.
1345-
llvm::omp::OMPTgtExecModeFlags getKernelExecFlags();
1345+
///
1346+
/// \param capturedOp result of a still valid (no modifications made to any
1347+
/// nested operations) previous call to `getInnermostCapturedOmpOp()`.
1348+
static llvm::omp::OMPTgtExecModeFlags
1349+
getKernelExecFlags(Operation *capturedOp);
13461350
}] # clausesExtraClassDeclaration;
13471351

13481352
let assemblyFormat = clausesAssemblyFormat # [{

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1905,7 +1905,8 @@ LogicalResult TargetOp::verifyRegions() {
19051905
return emitError("target containing multiple 'omp.teams' nested ops");
19061906

19071907
// Check that host_eval values are only used in legal ways.
1908-
llvm::omp::OMPTgtExecModeFlags execFlags = getKernelExecFlags();
1908+
llvm::omp::OMPTgtExecModeFlags execFlags =
1909+
getKernelExecFlags(getInnermostCapturedOmpOp());
19091910
for (Value hostEvalArg :
19101911
cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
19111912
for (Operation *user : hostEvalArg.getUsers()) {
@@ -2025,12 +2026,20 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
20252026
return capturedOp;
20262027
}
20272028

2028-
llvm::omp::OMPTgtExecModeFlags TargetOp::getKernelExecFlags() {
2029+
llvm::omp::OMPTgtExecModeFlags
2030+
TargetOp::getKernelExecFlags(Operation *capturedOp) {
20292031
using namespace llvm::omp;
20302032

2033+
// A non-null captured op is only valid if it resides inside of a TargetOp
2034+
// and is the result of calling getInnermostCapturedOmpOp() on it.
2035+
TargetOp targetOp =
2036+
capturedOp ? capturedOp->getParentOfType<TargetOp>() : nullptr;
2037+
assert((!capturedOp ||
2038+
(targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
2039+
"unexpected captured op");
2040+
20312041
// Make sure this region is capturing a loop. Otherwise, it's a generic
20322042
// kernel.
2033-
Operation *capturedOp = getInnermostCapturedOmpOp();
20342043
if (!isa_and_present<LoopNestOp>(capturedOp))
20352044
return OMP_TGT_EXEC_MODE_GENERIC;
20362045

@@ -2054,7 +2063,7 @@ llvm::omp::OMPTgtExecModeFlags TargetOp::getKernelExecFlags() {
20542063
if (!isa_and_present<TeamsOp>(teamsOp))
20552064
return OMP_TGT_EXEC_MODE_GENERIC;
20562065

2057-
if (teamsOp->getParentOp() == *this)
2066+
if (teamsOp->getParentOp() == targetOp.getOperation())
20582067
return OMP_TGT_EXEC_MODE_GENERIC_SPMD;
20592068
}
20602069

@@ -2075,7 +2084,7 @@ llvm::omp::OMPTgtExecModeFlags TargetOp::getKernelExecFlags() {
20752084
if (!isa_and_present<TeamsOp>(teamsOp))
20762085
return OMP_TGT_EXEC_MODE_GENERIC;
20772086

2078-
if (teamsOp->getParentOp() == *this)
2087+
if (teamsOp->getParentOp() == targetOp.getOperation())
20792088
return OMP_TGT_EXEC_MODE_SPMD;
20802089
}
20812090

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4557,11 +4557,10 @@ static std::optional<int64_t> extractConstInteger(Value value) {
45574557
/// function for the target region, so that they can be used to initialize the
45584558
/// corresponding global `ConfigurationEnvironmentTy` structure.
45594559
static void
4560-
initTargetDefaultAttrs(omp::TargetOp targetOp,
4560+
initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
45614561
llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
45624562
bool isTargetDevice) {
45634563
// TODO: Handle constant 'if' clauses.
4564-
Operation *capturedOp = targetOp.getInnermostCapturedOmpOp();
45654564

45664565
Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
45674566
if (!isTargetDevice) {
@@ -4643,7 +4642,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp,
46434642
combinedMaxThreadsVal = maxThreadsVal;
46444643

46454644
// Update kernel bounds structure for the `OpenMPIRBuilder` to use.
4646-
attrs.ExecFlags = targetOp.getKernelExecFlags();
4645+
attrs.ExecFlags = targetOp.getKernelExecFlags(capturedOp);
46474646
attrs.MinTeams = minTeamsVal;
46484647
attrs.MaxTeams.front() = maxTeamsVal;
46494648
attrs.MinThreads = 1;
@@ -4659,10 +4658,9 @@ initTargetDefaultAttrs(omp::TargetOp targetOp,
46594658
static void
46604659
initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
46614660
LLVM::ModuleTranslation &moduleTranslation,
4662-
omp::TargetOp targetOp,
4661+
omp::TargetOp targetOp, Operation *capturedOp,
46634662
llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
4664-
omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(
4665-
targetOp.getInnermostCapturedOmpOp());
4663+
omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(capturedOp);
46664664
unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
46674665

46684666
Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
@@ -4689,7 +4687,8 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
46894687
if (numThreads)
46904688
attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
46914689

4692-
if (targetOp.getKernelExecFlags() != llvm::omp::OMP_TGT_EXEC_MODE_GENERIC) {
4690+
if (targetOp.getKernelExecFlags(capturedOp) !=
4691+
llvm::omp::OMP_TGT_EXEC_MODE_GENERIC) {
46934692
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
46944693
attrs.LoopTripCount = nullptr;
46954694

@@ -4938,12 +4937,15 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
49384937

49394938
llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
49404939
llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
4941-
initTargetDefaultAttrs(targetOp, defaultAttrs, isTargetDevice);
4940+
Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
4941+
initTargetDefaultAttrs(targetOp, targetCapturedOp, defaultAttrs,
4942+
isTargetDevice);
49424943

49434944
// Collect host-evaluated values needed to properly launch the kernel from the
49444945
// host.
49454946
if (!isTargetDevice)
4946-
initTargetRuntimeAttrs(builder, moduleTranslation, targetOp, runtimeAttrs);
4947+
initTargetRuntimeAttrs(builder, moduleTranslation, targetOp,
4948+
targetCapturedOp, runtimeAttrs);
49474949

49484950
// Pass host-evaluated values as parameters to the kernel / host fallback,
49494951
// except if they are constants. In any case, map the MLIR block argument to

0 commit comments

Comments
 (0)