Skip to content

Commit b5cefa4

Browse files
committed
Update query in adapters
Signed-off-by: Michael Aziz <michael.aziz@intel.com>
1 parent 1a10dab commit b5cefa4

File tree

4 files changed

+17
-6
lines changed

4 files changed

+17
-6
lines changed

source/adapters/cuda/kernel.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,10 +190,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetNativeHandle(
190190
}
191191

192192
UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
193-
ur_kernel_handle_t hKernel, size_t localWorkSize,
193+
ur_kernel_handle_t hKernel, uint32_t workDim, const size_t *pLocalWorkSize,
194194
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) {
195195
UR_ASSERT(hKernel, UR_RESULT_ERROR_INVALID_KERNEL);
196196

197+
size_t localWorkSize = pLocalWorkSize[0];
198+
localWorkSize *= (workDim >= 2 ? pLocalWorkSize[1] : 1);
199+
localWorkSize *= (workDim == 3 ? pLocalWorkSize[2] : 1);
200+
197201
// We need to set the active current device for this kernel explicitly here,
198202
// because the occupancy querying API does not take device parameter.
199203
ur_device_handle_t Device = hKernel->getProgram()->getDevice();

source/adapters/hip/kernel.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,10 +169,11 @@ urKernelGetNativeHandle(ur_kernel_handle_t, ur_native_handle_t *) {
169169
}
170170

171171
UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
172-
ur_kernel_handle_t hKernel, size_t localWorkSize,
172+
ur_kernel_handle_t hKernel, uint32_t workDim, const size_t *pLocalWorkSize,
173173
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) {
174174
std::ignore = hKernel;
175-
std::ignore = localWorkSize;
175+
std::ignore = workDim;
176+
std::ignore = pLocalWorkSize;
176177
std::ignore = dynamicSharedMemorySize;
177178
std::ignore = pGroupCountRet;
178179
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;

source/adapters/level_zero/kernel.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,11 +1051,17 @@ ur_result_t urKernelGetNativeHandle(
10511051
}
10521052

10531053
ur_result_t urKernelSuggestMaxCooperativeGroupCountExp(
1054-
ur_kernel_handle_t hKernel, size_t localWorkSize,
1054+
ur_kernel_handle_t hKernel, uint32_t workDim, const size_t *pLocalWorkSize,
10551055
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) {
1056-
(void)localWorkSize;
10571056
(void)dynamicSharedMemorySize;
10581057
std::shared_lock<ur_shared_mutex> Guard(hKernel->Mutex);
1058+
1059+
uint32_t WG[3];
1060+
WG[0] = ur_cast<uint32_t>(pLocalWorkSize[0]);
1061+
WG[1] = workDim >= 2 ? ur_cast<uint32_t>(pLocalWorkSize[1]) : 1;
1062+
WG[2] = workDim == 3 ? ur_cast<uint32_t>(pLocalWorkSize[2]) : 1;
1063+
ZE2UR_CALL(zeKernelSetGroupSize, (hKernel->ZeKernel, WG[0], WG[1], WG[2]));
1064+
10591065
uint32_t TotalGroupCount = 0;
10601066
ZE2UR_CALL(zeKernelSuggestMaxCooperativeGroupCount,
10611067
(hKernel->ZeKernel, &TotalGroupCount));

source/adapters/level_zero/ur_interface_loader.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,7 @@ ur_result_t urEnqueueCooperativeKernelLaunchExp(
687687
const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
688688
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent);
689689
ur_result_t urKernelSuggestMaxCooperativeGroupCountExp(
690-
ur_kernel_handle_t hKernel, size_t localWorkSize,
690+
ur_kernel_handle_t hKernel, uint32_t workDim, const size_t *pLocalWorkSize,
691691
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet);
692692
ur_result_t urEnqueueTimestampRecordingExp(
693693
ur_queue_handle_t hQueue, bool blocking, uint32_t numEventsInWaitList,

0 commit comments

Comments
 (0)