@@ -789,9 +789,17 @@ int32_t __tgt_rtl_init_device(int device_id) {
789
789
DP (" Default number of teams set according to environment %d\n " ,
790
790
DeviceInfo.EnvNumTeams );
791
791
} else {
792
- DeviceInfo.NumTeams [device_id] = RTLDeviceInfoTy::DefaultNumTeams;
793
- DP (" Default number of teams set according to library's default %d\n " ,
794
- RTLDeviceInfoTy::DefaultNumTeams);
792
+ char *TeamsPerCUEnvStr = getenv (" OMP_TARGET_TEAMS_PER_PROC" );
793
+ int TeamsPerCU = 1 ; // default number of teams per CU is 1
794
+ if (TeamsPerCUEnvStr) {
795
+ TeamsPerCU = std::stoi (TeamsPerCUEnvStr);
796
+ }
797
+
798
+ DeviceInfo.NumTeams [device_id] =
799
+ TeamsPerCU * DeviceInfo.ComputeUnits [device_id];
800
+ DP (" Default number of teams = %d * number of compute units %d\n " ,
801
+ TeamsPerCU,
802
+ DeviceInfo.ComputeUnits [device_id]);
795
803
}
796
804
797
805
if (DeviceInfo.NumTeams [device_id] > DeviceInfo.GroupsPerDevice [device_id]) {
@@ -1548,11 +1556,12 @@ int32_t __tgt_rtl_data_delete(int device_id, void *tgt_ptr) {
1548
1556
// loop_tripcount.
1549
1557
void getLaunchVals (int &threadsPerGroup, int &num_groups, int ConstWGSize,
1550
1558
int ExecutionMode, int EnvTeamLimit, int EnvNumTeams,
1551
- int num_teams, int thread_limit, uint64_t loop_tripcount) {
1559
+ int num_teams, int thread_limit, uint64_t loop_tripcount,
1560
+ int32_t device_id) {
1552
1561
1553
1562
int Max_Teams = DeviceInfo.EnvMaxTeamsDefault > 0
1554
1563
? DeviceInfo.EnvMaxTeamsDefault
1555
- : DeviceInfo.Max_Teams ;
1564
+ : DeviceInfo.NumTeams [device_id] ;
1556
1565
if (Max_Teams > DeviceInfo.HardTeamLimit )
1557
1566
Max_Teams = DeviceInfo.HardTeamLimit ;
1558
1567
@@ -1752,7 +1761,8 @@ int32_t __tgt_rtl_run_target_team_region_locked(
1752
1761
DeviceInfo.EnvNumTeams ,
1753
1762
num_teams, // From run_region arg
1754
1763
thread_limit, // From run_region arg
1755
- loop_tripcount // From run_region arg
1764
+ loop_tripcount, // From run_region arg
1765
+ KernelInfo->device_id
1756
1766
);
1757
1767
1758
1768
if (print_kernel_trace == 4 )
0 commit comments