Skip to content

Commit 3fd11f1

Browse files
authored
Merge pull request #1246 from 0x12CC/cooperative_kernel_functions
[UR] Add default implementation for cooperative kernel functions
2 parents 24078c2 + 8a8d704 commit 3fd11f1

21 files changed

+176
-24
lines changed

include/ur_api.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8692,8 +8692,12 @@ urEnqueueCooperativeKernelLaunchExp(
86928692
/// - ::UR_RESULT_ERROR_INVALID_KERNEL
86938693
UR_APIEXPORT ur_result_t UR_APICALL
86948694
urKernelSuggestMaxCooperativeGroupCountExp(
8695-
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
8696-
uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups
8695+
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
8696+
size_t localWorkSize, ///< [in] number of local work-items that will form a work-group when the
8697+
///< kernel is launched
8698+
size_t dynamicSharedMemorySize, ///< [in] size of dynamic shared memory, for each work-group, in bytes,
8699+
///< that will be used when the kernel is launched
8700+
uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups
86978701
);
86988702

86998703
#if !defined(__GNUC__)
@@ -9641,6 +9645,8 @@ typedef struct ur_kernel_set_specialization_constants_params_t {
96419645
/// allowing the callback the ability to modify the parameter's value
96429646
typedef struct ur_kernel_suggest_max_cooperative_group_count_exp_params_t {
96439647
ur_kernel_handle_t *phKernel;
9648+
size_t *plocalWorkSize;
9649+
size_t *pdynamicSharedMemorySize;
96449650
uint32_t **ppGroupCountRet;
96459651
} ur_kernel_suggest_max_cooperative_group_count_exp_params_t;
96469652

include/ur_ddi.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,8 @@ typedef ur_result_t(UR_APICALL *ur_pfnGetKernelProcAddrTable_t)(
627627
/// @brief Function-pointer for urKernelSuggestMaxCooperativeGroupCountExp
628628
typedef ur_result_t(UR_APICALL *ur_pfnKernelSuggestMaxCooperativeGroupCountExp_t)(
629629
ur_kernel_handle_t,
630+
size_t,
631+
size_t,
630632
uint32_t *);
631633

632634
///////////////////////////////////////////////////////////////////////////////

include/ur_print.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11399,6 +11399,16 @@ inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct
1139911399
ur::details::printPtr(os,
1140011400
*(params->phKernel));
1140111401

11402+
os << ", ";
11403+
os << ".localWorkSize = ";
11404+
11405+
os << *(params->plocalWorkSize);
11406+
11407+
os << ", ";
11408+
os << ".dynamicSharedMemorySize = ";
11409+
11410+
os << *(params->pdynamicSharedMemorySize);
11411+
1140211412
os << ", ";
1140311413
os << ".pGroupCountRet = ";
1140411414

scripts/core/exp-cooperative-kernels.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,12 @@ params:
7878
- type: $x_kernel_handle_t
7979
name: hKernel
8080
desc: "[in] handle of the kernel object"
81+
- type: size_t
82+
name: localWorkSize
83+
desc: "[in] number of local work-items that will form a work-group when the kernel is launched"
84+
- type: size_t
85+
name: dynamicSharedMemorySize
86+
desc: "[in] size of dynamic shared memory, for each work-group, in bytes, that will be used when the kernel is launched"
8187
- type: "uint32_t*"
8288
name: "pGroupCountRet"
8389
desc: "[out] pointer to maximum number of groups"

source/adapters/cuda/enqueue.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
494494
return Result;
495495
}
496496

497+
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
498+
ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
499+
const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
500+
const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
501+
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
502+
return urEnqueueKernelLaunch(hQueue, hKernel, workDim, pGlobalWorkOffset,
503+
pGlobalWorkSize, pLocalWorkSize,
504+
numEventsInWaitList, phEventWaitList, phEvent);
505+
}
506+
497507
/// Set parameters for general 3D memory copy.
498508
/// If the source and/or destination is on the device, SrcPtr and/or DstPtr
499509
/// must be a pointer to a CUdeviceptr

source/adapters/cuda/kernel.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetNativeHandle(
169169
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
170170
}
171171

172+
UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
173+
ur_kernel_handle_t hKernel, size_t localWorkSize,
174+
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) {
175+
(void)hKernel;
176+
(void)localWorkSize;
177+
(void)dynamicSharedMemorySize;
178+
*pGroupCountRet = 1;
179+
return UR_RESULT_SUCCESS;
180+
}
181+
172182
UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgValue(
173183
ur_kernel_handle_t hKernel, uint32_t argIndex, size_t argSize,
174184
const ur_kernel_arg_value_properties_t *pProperties,

source/adapters/cuda/ur_interface_loader.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,8 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueExpProcAddrTable(
404404
return result;
405405
}
406406

407-
pDdiTable->pfnCooperativeKernelLaunchExp = nullptr;
407+
pDdiTable->pfnCooperativeKernelLaunchExp =
408+
urEnqueueCooperativeKernelLaunchExp;
408409

409410
return UR_RESULT_SUCCESS;
410411
}
@@ -416,7 +417,8 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetKernelExpProcAddrTable(
416417
return result;
417418
}
418419

419-
pDdiTable->pfnSuggestMaxCooperativeGroupCountExp = nullptr;
420+
pDdiTable->pfnSuggestMaxCooperativeGroupCountExp =
421+
urKernelSuggestMaxCooperativeGroupCountExp;
420422

421423
return UR_RESULT_SUCCESS;
422424
}

source/adapters/hip/enqueue.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
465465
return Result;
466466
}
467467

468+
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
469+
ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
470+
const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
471+
const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
472+
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
473+
return urEnqueueKernelLaunch(hQueue, hKernel, workDim, pGlobalWorkOffset,
474+
pGlobalWorkSize, pLocalWorkSize,
475+
numEventsInWaitList, phEventWaitList, phEvent);
476+
}
477+
468478
/// Enqueues a wait on the given queue for all events.
469479
/// See \ref enqueueEventWait
470480
///

source/adapters/hip/kernel.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,16 @@ urKernelGetNativeHandle(ur_kernel_handle_t, ur_native_handle_t *) {
158158
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
159159
}
160160

161+
UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
162+
ur_kernel_handle_t hKernel, size_t localWorkSize,
163+
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) {
164+
(void)hKernel;
165+
(void)localWorkSize;
166+
(void)dynamicSharedMemorySize;
167+
*pGroupCountRet = 1;
168+
return UR_RESULT_SUCCESS;
169+
}
170+
161171
UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgValue(
162172
ur_kernel_handle_t hKernel, uint32_t argIndex, size_t argSize,
163173
const ur_kernel_arg_value_properties_t *, const void *pArgValue) {

source/adapters/hip/ur_interface_loader.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,8 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueExpProcAddrTable(
374374
return result;
375375
}
376376

377-
pDdiTable->pfnCooperativeKernelLaunchExp = nullptr;
377+
pDdiTable->pfnCooperativeKernelLaunchExp =
378+
urEnqueueCooperativeKernelLaunchExp;
378379

379380
return UR_RESULT_SUCCESS;
380381
}
@@ -386,7 +387,8 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetKernelExpProcAddrTable(
386387
return result;
387388
}
388389

389-
pDdiTable->pfnSuggestMaxCooperativeGroupCountExp = nullptr;
390+
pDdiTable->pfnSuggestMaxCooperativeGroupCountExp =
391+
urKernelSuggestMaxCooperativeGroupCountExp;
390392

391393
return UR_RESULT_SUCCESS;
392394
}

0 commit comments

Comments
 (0)