Skip to content

Commit 72e80a4

Browse files
authored
Merge pull request #2316 from 0x12CC/coop_kernel_query
Change `urSuggestMaxCooperativeGroupCountExp` to accept ND size parameter
2 parents 6e5d0e6 + 9c7e56c commit 72e80a4

File tree

16 files changed

+105
-41
lines changed

16 files changed

+105
-41
lines changed

include/ur_api.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9486,13 +9486,17 @@ urEnqueueCooperativeKernelLaunchExp(
94869486
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
94879487
/// + `NULL == hKernel`
94889488
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
9489+
/// + `NULL == pLocalWorkSize`
94899490
/// + `NULL == pGroupCountRet`
94909491
/// - ::UR_RESULT_ERROR_INVALID_KERNEL
94919492
UR_APIEXPORT ur_result_t UR_APICALL
94929493
urKernelSuggestMaxCooperativeGroupCountExp(
94939494
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
9494-
size_t localWorkSize, ///< [in] number of local work-items that will form a work-group when the
9495-
///< kernel is launched
9495+
uint32_t workDim, ///< [in] number of dimensions, from 1 to 3, to specify the work-group
9496+
///< work-items
9497+
const size_t *pLocalWorkSize, ///< [in] pointer to an array of workDim unsigned values that specify the
9498+
///< number of local work-items forming a work-group that will execute the
9499+
///< kernel function.
94969500
size_t dynamicSharedMemorySize, ///< [in] size of dynamic shared memory, for each work-group, in bytes,
94979501
///< that will be used when the kernel is launched
94989502
uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups
@@ -11028,7 +11032,8 @@ typedef struct ur_kernel_set_specialization_constants_params_t {
1102811032
/// allowing the callback the ability to modify the parameter's value
1102911033
typedef struct ur_kernel_suggest_max_cooperative_group_count_exp_params_t {
1103011034
ur_kernel_handle_t *phKernel;
11031-
size_t *plocalWorkSize;
11035+
uint32_t *pworkDim;
11036+
const size_t **ppLocalWorkSize;
1103211037
size_t *pdynamicSharedMemorySize;
1103311038
uint32_t **ppGroupCountRet;
1103411039
} ur_kernel_suggest_max_cooperative_group_count_exp_params_t;

include/ur_ddi.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,8 @@ typedef ur_result_t(UR_APICALL *ur_pfnGetKernelProcAddrTable_t)(
651651
/// @brief Function-pointer for urKernelSuggestMaxCooperativeGroupCountExp
652652
typedef ur_result_t(UR_APICALL *ur_pfnKernelSuggestMaxCooperativeGroupCountExp_t)(
653653
ur_kernel_handle_t,
654-
size_t,
654+
uint32_t,
655+
const size_t *,
655656
size_t,
656657
uint32_t *);
657658

include/ur_print.hpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13074,9 +13074,15 @@ inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct
1307413074
*(params->phKernel));
1307513075

1307613076
os << ", ";
13077-
os << ".localWorkSize = ";
13077+
os << ".workDim = ";
13078+
13079+
os << *(params->pworkDim);
13080+
13081+
os << ", ";
13082+
os << ".pLocalWorkSize = ";
1307813083

13079-
os << *(params->plocalWorkSize);
13084+
ur::details::printPtr(os,
13085+
*(params->ppLocalWorkSize));
1308013086

1308113087
os << ", ";
1308213088
os << ".dynamicSharedMemorySize = ";

scripts/core/exp-cooperative-kernels.yml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,13 @@ params:
7878
- type: $x_kernel_handle_t
7979
name: hKernel
8080
desc: "[in] handle of the kernel object"
81-
- type: size_t
82-
name: localWorkSize
83-
desc: "[in] number of local work-items that will form a work-group when the kernel is launched"
81+
- type: uint32_t
82+
name: workDim
83+
desc: "[in] number of dimensions, from 1 to 3, to specify the work-group work-items"
84+
- type: "const size_t*"
85+
name: pLocalWorkSize
86+
desc: |
87+
[in] pointer to an array of workDim unsigned values that specify the number of local work-items forming a work-group that will execute the kernel function.
8488
- type: size_t
8589
name: dynamicSharedMemorySize
8690
desc: "[in] size of dynamic shared memory, for each work-group, in bytes, that will be used when the kernel is launched"

source/adapters/cuda/kernel.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,10 +190,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetNativeHandle(
190190
}
191191

192192
UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
193-
ur_kernel_handle_t hKernel, size_t localWorkSize,
193+
ur_kernel_handle_t hKernel, uint32_t workDim, const size_t *pLocalWorkSize,
194194
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) {
195195
UR_ASSERT(hKernel, UR_RESULT_ERROR_INVALID_KERNEL);
196196

197+
size_t localWorkSize = pLocalWorkSize[0];
198+
localWorkSize *= (workDim >= 2 ? pLocalWorkSize[1] : 1);
199+
localWorkSize *= (workDim == 3 ? pLocalWorkSize[2] : 1);
200+
197201
// We need to set the active current device for this kernel explicitly here,
198202
// because the occupancy querying API does not take device parameter.
199203
ur_device_handle_t Device = hKernel->getProgram()->getDevice();

source/adapters/hip/kernel.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,10 +169,11 @@ urKernelGetNativeHandle(ur_kernel_handle_t, ur_native_handle_t *) {
169169
}
170170

171171
UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
172-
ur_kernel_handle_t hKernel, size_t localWorkSize,
172+
ur_kernel_handle_t hKernel, uint32_t workDim, const size_t *pLocalWorkSize,
173173
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) {
174174
std::ignore = hKernel;
175-
std::ignore = localWorkSize;
175+
std::ignore = workDim;
176+
std::ignore = pLocalWorkSize;
176177
std::ignore = dynamicSharedMemorySize;
177178
std::ignore = pGroupCountRet;
178179
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;

source/adapters/level_zero/kernel.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,11 +1054,17 @@ ur_result_t urKernelGetNativeHandle(
10541054
}
10551055

10561056
ur_result_t urKernelSuggestMaxCooperativeGroupCountExp(
1057-
ur_kernel_handle_t hKernel, size_t localWorkSize,
1057+
ur_kernel_handle_t hKernel, uint32_t workDim, const size_t *pLocalWorkSize,
10581058
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) {
1059-
(void)localWorkSize;
10601059
(void)dynamicSharedMemorySize;
10611060
std::shared_lock<ur_shared_mutex> Guard(hKernel->Mutex);
1061+
1062+
uint32_t WG[3];
1063+
WG[0] = ur_cast<uint32_t>(pLocalWorkSize[0]);
1064+
WG[1] = workDim >= 2 ? ur_cast<uint32_t>(pLocalWorkSize[1]) : 1;
1065+
WG[2] = workDim == 3 ? ur_cast<uint32_t>(pLocalWorkSize[2]) : 1;
1066+
ZE2UR_CALL(zeKernelSetGroupSize, (hKernel->ZeKernel, WG[0], WG[1], WG[2]));
1067+
10621068
uint32_t TotalGroupCount = 0;
10631069
ZE2UR_CALL(zeKernelSuggestMaxCooperativeGroupCount,
10641070
(hKernel->ZeKernel, &TotalGroupCount));

source/adapters/level_zero/ur_interface_loader.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,7 @@ ur_result_t urEnqueueCooperativeKernelLaunchExp(
687687
const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
688688
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent);
689689
ur_result_t urKernelSuggestMaxCooperativeGroupCountExp(
690-
ur_kernel_handle_t hKernel, size_t localWorkSize,
690+
ur_kernel_handle_t hKernel, uint32_t workDim, const size_t *pLocalWorkSize,
691691
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet);
692692
ur_result_t urEnqueueTimestampRecordingExp(
693693
ur_queue_handle_t hQueue, bool blocking, uint32_t numEventsInWaitList,

source/adapters/level_zero/v2/api.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,7 @@ ur_result_t urCommandBufferCommandGetInfoExp(
568568
}
569569

570570
ur_result_t urKernelSuggestMaxCooperativeGroupCountExp(
571-
ur_kernel_handle_t hKernel, size_t localWorkSize,
571+
ur_kernel_handle_t hKernel, uint32_t workDim, const size_t *pLocalWorkSize,
572572
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) {
573573
logger::error("{} function not implemented!", __FUNCTION__);
574574
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;

source/adapters/mock/ur_mockddi.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10003,9 +10003,13 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
1000310003
/// @brief Intercept function for urKernelSuggestMaxCooperativeGroupCountExp
1000410004
__urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
1000510005
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
10006-
size_t
10007-
localWorkSize, ///< [in] number of local work-items that will form a work-group when the
10008-
///< kernel is launched
10006+
uint32_t
10007+
workDim, ///< [in] number of dimensions, from 1 to 3, to specify the work-group
10008+
///< work-items
10009+
const size_t *
10010+
pLocalWorkSize, ///< [in] pointer to an array of workDim unsigned values that specify the
10011+
///< number of local work-items forming a work-group that will execute the
10012+
///< kernel function.
1000910013
size_t
1001010014
dynamicSharedMemorySize, ///< [in] size of dynamic shared memory, for each work-group, in bytes,
1001110015
///< that will be used when the kernel is launched
@@ -10014,7 +10018,8 @@ __urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
1001410018
ur_result_t result = UR_RESULT_SUCCESS;
1001510019

1001610020
ur_kernel_suggest_max_cooperative_group_count_exp_params_t params = {
10017-
&hKernel, &localWorkSize, &dynamicSharedMemorySize, &pGroupCountRet};
10021+
&hKernel, &workDim, &pLocalWorkSize, &dynamicSharedMemorySize,
10022+
&pGroupCountRet};
1001810023

1001910024
auto beforeCallback = reinterpret_cast<ur_mock_callback_t>(
1002010025
mock::getCallbacks().get_before_callback(

0 commit comments

Comments
 (0)