Skip to content

Commit c9bc414

Browse files
[libomptarget][amdgpu] Let default number of teams equal number of CUs
1 parent e191d31 commit c9bc414

File tree

1 file changed

+16
-6
lines changed
  • openmp/libomptarget/plugins/amdgpu/src

1 file changed

+16
-6
lines changed

openmp/libomptarget/plugins/amdgpu/src/rtl.cpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -789,9 +789,17 @@ int32_t __tgt_rtl_init_device(int device_id) {
789789
DP("Default number of teams set according to environment %d\n",
790790
DeviceInfo.EnvNumTeams);
791791
} else {
792-
DeviceInfo.NumTeams[device_id] = RTLDeviceInfoTy::DefaultNumTeams;
793-
DP("Default number of teams set according to library's default %d\n",
794-
RTLDeviceInfoTy::DefaultNumTeams);
792+
char *TeamsPerCUEnvStr = getenv("OMP_TARGET_TEAMS_PER_PROC");
793+
int TeamsPerCU = 1; // default number of teams per CU is 1
794+
if (TeamsPerCUEnvStr) {
795+
TeamsPerCU = std::stoi(TeamsPerCUEnvStr);
796+
}
797+
798+
DeviceInfo.NumTeams[device_id] =
799+
TeamsPerCU * DeviceInfo.ComputeUnits[device_id];
800+
DP("Default number of teams = %d * number of compute units %d\n",
801+
TeamsPerCU,
802+
DeviceInfo.ComputeUnits[device_id]);
795803
}
796804

797805
if (DeviceInfo.NumTeams[device_id] > DeviceInfo.GroupsPerDevice[device_id]) {
@@ -1548,11 +1556,12 @@ int32_t __tgt_rtl_data_delete(int device_id, void *tgt_ptr) {
15481556
// loop_tripcount.
15491557
void getLaunchVals(int &threadsPerGroup, int &num_groups, int ConstWGSize,
15501558
int ExecutionMode, int EnvTeamLimit, int EnvNumTeams,
1551-
int num_teams, int thread_limit, uint64_t loop_tripcount) {
1559+
int num_teams, int thread_limit, uint64_t loop_tripcount,
1560+
int32_t device_id) {
15521561

15531562
int Max_Teams = DeviceInfo.EnvMaxTeamsDefault > 0
15541563
? DeviceInfo.EnvMaxTeamsDefault
1555-
: DeviceInfo.Max_Teams;
1564+
: DeviceInfo.NumTeams[device_id];
15561565
if (Max_Teams > DeviceInfo.HardTeamLimit)
15571566
Max_Teams = DeviceInfo.HardTeamLimit;
15581567

@@ -1752,7 +1761,8 @@ int32_t __tgt_rtl_run_target_team_region_locked(
17521761
DeviceInfo.EnvNumTeams,
17531762
num_teams, // From run_region arg
17541763
thread_limit, // From run_region arg
1755-
loop_tripcount // From run_region arg
1764+
loop_tripcount, // From run_region arg
1765+
KernelInfo->device_id
17561766
);
17571767

17581768
if (print_kernel_trace == 4)

0 commit comments

Comments
 (0)