Skip to content

Commit ea0f3a1

Browse files
authored
Merge pull request #2277 from igchor/cooperative_fix
[Spec] fix urKernelSuggestMaxCooperativeGroupCountExp
2 parents bb64b3e + 4a89e1c commit ea0f3a1

File tree

16 files changed

+69
-17
lines changed

16 files changed

+69
-17
lines changed

include/ur_api.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9543,13 +9543,15 @@ urEnqueueCooperativeKernelLaunchExp(
95439543
/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC
95449544
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
95459545
/// + `NULL == hKernel`
9546+
/// + `NULL == hDevice`
95469547
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
95479548
/// + `NULL == pLocalWorkSize`
95489549
/// + `NULL == pGroupCountRet`
95499550
/// - ::UR_RESULT_ERROR_INVALID_KERNEL
95509551
UR_APIEXPORT ur_result_t UR_APICALL
95519552
urKernelSuggestMaxCooperativeGroupCountExp(
95529553
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
9554+
ur_device_handle_t hDevice, ///< [in] handle of the device object
95539555
uint32_t workDim, ///< [in] number of dimensions, from 1 to 3, to specify the work-group
95549556
///< work-items
95559557
const size_t *pLocalWorkSize, ///< [in] pointer to an array of workDim unsigned values that specify the
@@ -11090,6 +11092,7 @@ typedef struct ur_kernel_set_specialization_constants_params_t {
1109011092
/// allowing the callback the ability to modify the parameter's value
1109111093
typedef struct ur_kernel_suggest_max_cooperative_group_count_exp_params_t {
1109211094
ur_kernel_handle_t *phKernel;
11095+
ur_device_handle_t *phDevice;
1109311096
uint32_t *pworkDim;
1109411097
const size_t **ppLocalWorkSize;
1109511098
size_t *pdynamicSharedMemorySize;

include/ur_ddi.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,7 @@ typedef ur_result_t(UR_APICALL *ur_pfnGetKernelProcAddrTable_t)(
651651
/// @brief Function-pointer for urKernelSuggestMaxCooperativeGroupCountExp
652652
typedef ur_result_t(UR_APICALL *ur_pfnKernelSuggestMaxCooperativeGroupCountExp_t)(
653653
ur_kernel_handle_t,
654+
ur_device_handle_t,
654655
uint32_t,
655656
const size_t *,
656657
size_t,

include/ur_print.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13203,6 +13203,12 @@ inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct
1320313203
ur::details::printPtr(os,
1320413204
*(params->phKernel));
1320513205

13206+
os << ", ";
13207+
os << ".hDevice = ";
13208+
13209+
ur::details::printPtr(os,
13210+
*(params->phDevice));
13211+
1320613212
os << ", ";
1320713213
os << ".workDim = ";
1320813214

scripts/core/exp-cooperative-kernels.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ params:
7878
- type: $x_kernel_handle_t
7979
name: hKernel
8080
desc: "[in] handle of the kernel object"
81+
- type: $x_device_handle_t
82+
name: hDevice
83+
desc: "[in] handle of the device object"
8184
- type: uint32_t
8285
name: workDim
8386
desc: "[in] number of dimensions, from 1 to 3, to specify the work-group work-items"

source/adapters/cuda/kernel.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,10 +190,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetNativeHandle(
190190
}
191191

192192
UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
193-
ur_kernel_handle_t hKernel, uint32_t workDim, const size_t *pLocalWorkSize,
194-
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) {
193+
ur_kernel_handle_t hKernel, ur_device_handle_t hDevice, uint32_t workDim,
194+
const size_t *pLocalWorkSize, size_t dynamicSharedMemorySize,
195+
uint32_t *pGroupCountRet) {
195196
UR_ASSERT(hKernel, UR_RESULT_ERROR_INVALID_KERNEL);
196197

198+
std::ignore = hDevice;
199+
197200
size_t localWorkSize = pLocalWorkSize[0];
198201
localWorkSize *= (workDim >= 2 ? pLocalWorkSize[1] : 1);
199202
localWorkSize *= (workDim == 3 ? pLocalWorkSize[2] : 1);

source/adapters/hip/kernel.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,11 @@ urKernelGetNativeHandle(ur_kernel_handle_t, ur_native_handle_t *) {
169169
}
170170

171171
UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
172-
ur_kernel_handle_t hKernel, uint32_t workDim, const size_t *pLocalWorkSize,
173-
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) {
172+
ur_kernel_handle_t hKernel, ur_device_handle_t hDevice, uint32_t workDim,
173+
const size_t *pLocalWorkSize, size_t dynamicSharedMemorySize,
174+
uint32_t *pGroupCountRet) {
174175
std::ignore = hKernel;
176+
std::ignore = hDevice;
175177
std::ignore = workDim;
176178
std::ignore = pLocalWorkSize;
177179
std::ignore = dynamicSharedMemorySize;

source/adapters/level_zero/kernel.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,8 +1054,9 @@ ur_result_t urKernelGetNativeHandle(
10541054
}
10551055

10561056
ur_result_t urKernelSuggestMaxCooperativeGroupCountExp(
1057-
ur_kernel_handle_t hKernel, uint32_t workDim, const size_t *pLocalWorkSize,
1058-
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) {
1057+
ur_kernel_handle_t hKernel, ur_device_handle_t hDevice, uint32_t workDim,
1058+
const size_t *pLocalWorkSize, size_t dynamicSharedMemorySize,
1059+
uint32_t *pGroupCountRet) {
10591060
(void)dynamicSharedMemorySize;
10601061
std::shared_lock<ur_shared_mutex> Guard(hKernel->Mutex);
10611062

@@ -1066,8 +1067,10 @@ ur_result_t urKernelSuggestMaxCooperativeGroupCountExp(
10661067
ZE2UR_CALL(zeKernelSetGroupSize, (hKernel->ZeKernel, WG[0], WG[1], WG[2]));
10671068

10681069
uint32_t TotalGroupCount = 0;
1070+
ze_kernel_handle_t ZeKernel;
1071+
UR_CALL(getZeKernel(hDevice->ZeDevice, hKernel, &ZeKernel));
10691072
ZE2UR_CALL(zeKernelSuggestMaxCooperativeGroupCount,
1070-
(hKernel->ZeKernel, &TotalGroupCount));
1073+
(ZeKernel, &TotalGroupCount));
10711074
*pGroupCountRet = TotalGroupCount;
10721075
return UR_RESULT_SUCCESS;
10731076
}

source/adapters/level_zero/ur_interface_loader.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -691,8 +691,9 @@ ur_result_t urEnqueueCooperativeKernelLaunchExp(
691691
const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
692692
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent);
693693
ur_result_t urKernelSuggestMaxCooperativeGroupCountExp(
694-
ur_kernel_handle_t hKernel, uint32_t workDim, const size_t *pLocalWorkSize,
695-
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet);
694+
ur_kernel_handle_t hKernel, ur_device_handle_t hDevice, uint32_t workDim,
695+
const size_t *pLocalWorkSize, size_t dynamicSharedMemorySize,
696+
uint32_t *pGroupCountRet);
696697
ur_result_t urEnqueueTimestampRecordingExp(
697698
ur_queue_handle_t hQueue, bool blocking, uint32_t numEventsInWaitList,
698699
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent);

source/adapters/level_zero/v2/api.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -485,8 +485,9 @@ ur_result_t urCommandBufferCommandGetInfoExp(
485485
}
486486

487487
ur_result_t urKernelSuggestMaxCooperativeGroupCountExp(
488-
ur_kernel_handle_t hKernel, uint32_t workDim, const size_t *pLocalWorkSize,
489-
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) {
488+
ur_kernel_handle_t hKernel, ur_device_handle_t hDevice, uint32_t workDim,
489+
const size_t *pLocalWorkSize, size_t dynamicSharedMemorySize,
490+
uint32_t *pGroupCountRet) {
490491
logger::error("{} function not implemented!", __FUNCTION__);
491492
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
492493
}

source/adapters/mock/ur_mockddi.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10057,6 +10057,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
1005710057
/// @brief Intercept function for urKernelSuggestMaxCooperativeGroupCountExp
1005810058
__urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
1005910059
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
10060+
ur_device_handle_t hDevice, ///< [in] handle of the device object
1006010061
uint32_t
1006110062
workDim, ///< [in] number of dimensions, from 1 to 3, to specify the work-group
1006210063
///< work-items
@@ -10072,7 +10073,11 @@ __urdlllocal ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
1007210073
ur_result_t result = UR_RESULT_SUCCESS;
1007310074

1007410075
ur_kernel_suggest_max_cooperative_group_count_exp_params_t params = {
10075-
&hKernel, &workDim, &pLocalWorkSize, &dynamicSharedMemorySize,
10076+
&hKernel,
10077+
&hDevice,
10078+
&workDim,
10079+
&pLocalWorkSize,
10080+
&dynamicSharedMemorySize,
1007610081
&pGroupCountRet};
1007710082

1007810083
auto beforeCallback = reinterpret_cast<ur_mock_callback_t>(

0 commit comments

Comments
 (0)