Skip to content

Commit fbc3f7b

Browse files
authored
[OpenMP][Offload][AMDGPU] Add envar for setting CU multiplier (llvm#1143)
2 parents e938489 + 482b93c commit fbc3f7b

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

offload/plugins-nextgen/amdgpu/src/rtl.cpp

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1095,11 +1095,18 @@ struct AMDGPUKernelTy : public GenericKernelTy {
10951095
// Honor OMP_NUM_TEAMS environment variable for XteamReduction kernel
10961096
// type, if possible.
10971097
int32_t NumTeamsEnvVar = GenericDevice.getOMPNumTeams();
1098+
// CU mulitiplier from envar.
1099+
uint32_t EnvarCUMultiplier = GenericDevice.getXTeamRedTeamsPerCU();
10981100

10991101
if (GenericDevice.isFastReductionEnabled()) {
11001102
// When fast reduction is enabled, the number of teams is capped by
11011103
// the MaxCUMultiplier constant.
1102-
MaxNumGroups = DeviceNumCUs * llvm::omp::xteam_red::MaxCUMultiplier;
1104+
// When envar is enabled, use it for computing MaxNumGroup.
1105+
if (EnvarCUMultiplier > 0)
1106+
MaxNumGroups = DeviceNumCUs * EnvarCUMultiplier;
1107+
else
1108+
MaxNumGroups = DeviceNumCUs * llvm::omp::xteam_red::MaxCUMultiplier;
1109+
11031110
} else {
11041111
// When fast reduction is not enabled, the number of teams is capped
11051112
// by the metadata that clang CodeGen created. The number of teams
@@ -1110,7 +1117,13 @@ struct AMDGPUKernelTy : public GenericKernelTy {
11101117
// ConstWGSize is the block size that CodeGen used.
11111118
uint32_t CUMultiplier =
11121119
llvm::omp::xteam_red::getXteamRedCUMultiplier(ConstWGSize);
1113-
MaxNumGroups = DeviceNumCUs * CUMultiplier;
1120+
1121+
if (EnvarCUMultiplier > 0) {
1122+
MaxNumGroups =
1123+
DeviceNumCUs * std::min(CUMultiplier, EnvarCUMultiplier);
1124+
} else {
1125+
MaxNumGroups = DeviceNumCUs * CUMultiplier;
1126+
}
11141127
}
11151128

11161129
// If envar OMPX_XTEAMREDUCTION_OCCUPANCY_BASED_OPT is set and no
@@ -2915,6 +2928,8 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
29152928
"LIBOMPTARGET_AMDGPU_GENERIC_SPMD_TEAMS_PER_CU", 6),
29162929
OMPX_BigJumpLoopTeamsPerCU(
29172930
"LIBOMPTARGET_AMDGPU_BIG_JUMP_LOOP_TEAMS_PER_CU", 0),
2931+
OMPX_XTeamRedTeamsPerCU("LIBOMPTARGET_AMDGPU_XTEAM_RED_TEAMS_PER_CU",
2932+
0),
29182933
OMPX_BigJumpLoopMaxTotalTeams(
29192934
"LIBOMPTARGET_AMDGPU_BIG_JUMP_LOOP_MAX_TOTAL_TEAMS", 1024 * 1024),
29202935
OMPX_LowTripCount("LIBOMPTARGET_AMDGPU_LOW_TRIPCOUNT", 9000),
@@ -2980,6 +2995,9 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
29802995
virtual uint32_t getOMPXBigJumpLoopTeamsPerCU() const override {
29812996
return OMPX_BigJumpLoopTeamsPerCU;
29822997
}
2998+
virtual uint32_t getXTeamRedTeamsPerCU() const override {
2999+
return OMPX_XTeamRedTeamsPerCU;
3000+
}
29833001
virtual uint32_t getOMPXBigJumpLoopMaxTotalTeams() const override {
29843002
return OMPX_BigJumpLoopMaxTotalTeams;
29853003
}
@@ -4427,6 +4445,12 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
44274445
/// OMPX_BigJumpLoopTeamsPerCU * #CUs.
44284446
UInt32Envar OMPX_BigJumpLoopTeamsPerCU;
44294447

4448+
/// Envar for controlling the number of teams relative to the number of
4449+
/// compute units (CUs) for cross-team-reduction kernels. 0 indicates that
4450+
/// this value is not specified. If non-zero, the number of teams =
4451+
/// OMPX_XTeamRedTeamsPerCU * #CUs.
4452+
UInt32Envar OMPX_XTeamRedTeamsPerCU;
4453+
44304454
/// Envar controlling the maximum number of teams per device for
44314455
/// Big-Jump-Loop kernels.
44324456
UInt32Envar OMPX_BigJumpLoopMaxTotalTeams;

offload/plugins-nextgen/common/include/PluginInterface.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,6 +1017,9 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
10171017
virtual uint32_t getOMPXBigJumpLoopTeamsPerCU() const {
10181018
llvm_unreachable("Unimplemented");
10191019
}
1020+
virtual uint32_t getXTeamRedTeamsPerCU() const {
1021+
llvm_unreachable("Unimplemented");
1022+
}
10201023
virtual uint32_t getOMPXBigJumpLoopMaxTotalTeams() const {
10211024
llvm_unreachable("Unimplemented");
10221025
}

0 commit comments

Comments
 (0)