Skip to content

Commit 12c8312

Browse files
authored
Merge pull request #849 from 0x12CC/cooperative_kernels
Add cooperative kernels experimental feature
2 parents bcf2b2a + bb542f3 commit 12c8312

File tree

14 files changed

+1388
-0
lines changed

14 files changed

+1388
-0
lines changed

include/ur.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,8 @@ class ur_function_v(IntEnum):
198198
COMMAND_BUFFER_APPEND_MEM_BUFFER_WRITE_RECT_EXP = 190 ## Enumerator for ::urCommandBufferAppendMemBufferWriteRectExp
199199
COMMAND_BUFFER_APPEND_MEM_BUFFER_READ_RECT_EXP = 191## Enumerator for ::urCommandBufferAppendMemBufferReadRectExp
200200
COMMAND_BUFFER_APPEND_MEM_BUFFER_FILL_EXP = 192 ## Enumerator for ::urCommandBufferAppendMemBufferFillExp
201+
ENQUEUE_COOPERATIVE_KERNEL_LAUNCH_EXP = 193 ## Enumerator for ::urEnqueueCooperativeKernelLaunchExp
202+
KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT_EXP = 194## Enumerator for ::urKernelSuggestMaxCooperativeGroupCountExp
201203

202204
class ur_function_t(c_int):
203205
def __str__(self):
@@ -2272,6 +2274,11 @@ class ur_exp_command_buffer_sync_point_t(c_ulong):
22722274
class ur_exp_command_buffer_handle_t(c_void_p):
22732275
pass
22742276

2277+
###############################################################################
2278+
## @brief The extension string which defines support for cooperative-kernels
2279+
## which is returned when querying device extensions.
2280+
UR_COOPERATIVE_KERNELS_EXTENSION_STRING_EXP = "ur_exp_cooperative_kernels"
2281+
22752282
###############################################################################
22762283
## @brief Supported peer info
22772284
class ur_exp_peer_info_v(IntEnum):
@@ -2715,6 +2722,21 @@ class ur_kernel_dditable_t(Structure):
27152722
("pfnSetSpecializationConstants", c_void_p) ## _urKernelSetSpecializationConstants_t
27162723
]
27172724

2725+
###############################################################################
2726+
## @brief Function-pointer for urKernelSuggestMaxCooperativeGroupCountExp
2727+
if __use_win_types:
2728+
_urKernelSuggestMaxCooperativeGroupCountExp_t = WINFUNCTYPE( ur_result_t, ur_kernel_handle_t, POINTER(c_ulong) )
2729+
else:
2730+
_urKernelSuggestMaxCooperativeGroupCountExp_t = CFUNCTYPE( ur_result_t, ur_kernel_handle_t, POINTER(c_ulong) )
2731+
2732+
2733+
###############################################################################
2734+
## @brief Table of KernelExp functions pointers
2735+
class ur_kernel_exp_dditable_t(Structure):
2736+
_fields_ = [
2737+
("pfnSuggestMaxCooperativeGroupCountExp", c_void_p) ## _urKernelSuggestMaxCooperativeGroupCountExp_t
2738+
]
2739+
27182740
###############################################################################
27192741
## @brief Function-pointer for urSamplerCreate
27202742
if __use_win_types:
@@ -3142,6 +3164,21 @@ class ur_enqueue_dditable_t(Structure):
31423164
("pfnWriteHostPipe", c_void_p) ## _urEnqueueWriteHostPipe_t
31433165
]
31443166

3167+
###############################################################################
3168+
## @brief Function-pointer for urEnqueueCooperativeKernelLaunchExp
3169+
if __use_win_types:
3170+
_urEnqueueCooperativeKernelLaunchExp_t = WINFUNCTYPE( ur_result_t, ur_queue_handle_t, ur_kernel_handle_t, c_ulong, POINTER(c_size_t), POINTER(c_size_t), POINTER(c_size_t), c_ulong, POINTER(ur_event_handle_t), POINTER(ur_event_handle_t) )
3171+
else:
3172+
_urEnqueueCooperativeKernelLaunchExp_t = CFUNCTYPE( ur_result_t, ur_queue_handle_t, ur_kernel_handle_t, c_ulong, POINTER(c_size_t), POINTER(c_size_t), POINTER(c_size_t), c_ulong, POINTER(ur_event_handle_t), POINTER(ur_event_handle_t) )
3173+
3174+
3175+
###############################################################################
3176+
## @brief Table of EnqueueExp functions pointers
3177+
class ur_enqueue_exp_dditable_t(Structure):
3178+
_fields_ = [
3179+
("pfnCooperativeKernelLaunchExp", c_void_p) ## _urEnqueueCooperativeKernelLaunchExp_t
3180+
]
3181+
31453182
###############################################################################
31463183
## @brief Function-pointer for urQueueGetInfo
31473184
if __use_win_types:
@@ -3774,11 +3811,13 @@ class ur_dditable_t(Structure):
37743811
("Event", ur_event_dditable_t),
37753812
("Program", ur_program_dditable_t),
37763813
("Kernel", ur_kernel_dditable_t),
3814+
("KernelExp", ur_kernel_exp_dditable_t),
37773815
("Sampler", ur_sampler_dditable_t),
37783816
("Mem", ur_mem_dditable_t),
37793817
("PhysicalMem", ur_physical_mem_dditable_t),
37803818
("Global", ur_global_dditable_t),
37813819
("Enqueue", ur_enqueue_dditable_t),
3820+
("EnqueueExp", ur_enqueue_exp_dditable_t),
37823821
("Queue", ur_queue_dditable_t),
37833822
("BindlessImagesExp", ur_bindless_images_exp_dditable_t),
37843823
("USM", ur_usm_dditable_t),
@@ -3899,6 +3938,16 @@ def __init__(self, version : ur_api_version_t):
38993938
self.urKernelSetArgMemObj = _urKernelSetArgMemObj_t(self.__dditable.Kernel.pfnSetArgMemObj)
39003939
self.urKernelSetSpecializationConstants = _urKernelSetSpecializationConstants_t(self.__dditable.Kernel.pfnSetSpecializationConstants)
39013940

3941+
# call driver to get function pointers
3942+
KernelExp = ur_kernel_exp_dditable_t()
3943+
r = ur_result_v(self.__dll.urGetKernelExpProcAddrTable(version, byref(KernelExp)))
3944+
if r != ur_result_v.SUCCESS:
3945+
raise Exception(r)
3946+
self.__dditable.KernelExp = KernelExp
3947+
3948+
# attach function interface to function address
3949+
self.urKernelSuggestMaxCooperativeGroupCountExp = _urKernelSuggestMaxCooperativeGroupCountExp_t(self.__dditable.KernelExp.pfnSuggestMaxCooperativeGroupCountExp)
3950+
39023951
# call driver to get function pointers
39033952
Sampler = ur_sampler_dditable_t()
39043953
r = ur_result_v(self.__dll.urGetSamplerProcAddrTable(version, byref(Sampler)))
@@ -3993,6 +4042,16 @@ def __init__(self, version : ur_api_version_t):
39934042
self.urEnqueueReadHostPipe = _urEnqueueReadHostPipe_t(self.__dditable.Enqueue.pfnReadHostPipe)
39944043
self.urEnqueueWriteHostPipe = _urEnqueueWriteHostPipe_t(self.__dditable.Enqueue.pfnWriteHostPipe)
39954044

4045+
# call driver to get function pointers
4046+
EnqueueExp = ur_enqueue_exp_dditable_t()
4047+
r = ur_result_v(self.__dll.urGetEnqueueExpProcAddrTable(version, byref(EnqueueExp)))
4048+
if r != ur_result_v.SUCCESS:
4049+
raise Exception(r)
4050+
self.__dditable.EnqueueExp = EnqueueExp
4051+
4052+
# attach function interface to function address
4053+
self.urEnqueueCooperativeKernelLaunchExp = _urEnqueueCooperativeKernelLaunchExp_t(self.__dditable.EnqueueExp.pfnCooperativeKernelLaunchExp)
4054+
39964055
# call driver to get function pointers
39974056
Queue = ur_queue_dditable_t()
39984057
r = ur_result_v(self.__dll.urGetQueueProcAddrTable(version, byref(Queue)))

include/ur_api.h

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,8 @@ typedef enum ur_function_t {
207207
UR_FUNCTION_COMMAND_BUFFER_APPEND_MEM_BUFFER_WRITE_RECT_EXP = 190, ///< Enumerator for ::urCommandBufferAppendMemBufferWriteRectExp
208208
UR_FUNCTION_COMMAND_BUFFER_APPEND_MEM_BUFFER_READ_RECT_EXP = 191, ///< Enumerator for ::urCommandBufferAppendMemBufferReadRectExp
209209
UR_FUNCTION_COMMAND_BUFFER_APPEND_MEM_BUFFER_FILL_EXP = 192, ///< Enumerator for ::urCommandBufferAppendMemBufferFillExp
210+
UR_FUNCTION_ENQUEUE_COOPERATIVE_KERNEL_LAUNCH_EXP = 193, ///< Enumerator for ::urEnqueueCooperativeKernelLaunchExp
211+
UR_FUNCTION_KERNEL_SUGGEST_MAX_COOPERATIVE_GROUP_COUNT_EXP = 194, ///< Enumerator for ::urKernelSuggestMaxCooperativeGroupCountExp
210212
/// @cond
211213
UR_FUNCTION_FORCE_UINT32 = 0x7fffffff
212214
/// @endcond
@@ -8171,6 +8173,90 @@ urCommandBufferEnqueueExp(
81718173
///< command-buffer execution instance.
81728174
);
81738175

8176+
#if !defined(__GNUC__)
8177+
#pragma endregion
8178+
#endif
8179+
// Intel 'oneAPI' Unified Runtime Experimental APIs for Cooperative Kernels
8180+
#if !defined(__GNUC__)
8181+
#pragma region cooperative kernels(experimental)
8182+
#endif
8183+
///////////////////////////////////////////////////////////////////////////////
8184+
#ifndef UR_COOPERATIVE_KERNELS_EXTENSION_STRING_EXP
8185+
/// @brief The extension string which defines support for cooperative-kernels
8186+
/// which is returned when querying device extensions.
8187+
#define UR_COOPERATIVE_KERNELS_EXTENSION_STRING_EXP "ur_exp_cooperative_kernels"
8188+
#endif // UR_COOPERATIVE_KERNELS_EXTENSION_STRING_EXP
8189+
8190+
///////////////////////////////////////////////////////////////////////////////
8191+
/// @brief Enqueue a command to execute a cooperative kernel
8192+
///
8193+
/// @returns
8194+
/// - ::UR_RESULT_SUCCESS
8195+
/// - ::UR_RESULT_ERROR_UNINITIALIZED
8196+
/// - ::UR_RESULT_ERROR_DEVICE_LOST
8197+
/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC
8198+
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
8199+
/// + `NULL == hQueue`
8200+
/// + `NULL == hKernel`
8201+
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
8202+
/// + `NULL == pGlobalWorkOffset`
8203+
/// + `NULL == pGlobalWorkSize`
8204+
/// - ::UR_RESULT_ERROR_INVALID_QUEUE
8205+
/// - ::UR_RESULT_ERROR_INVALID_KERNEL
8206+
/// - ::UR_RESULT_ERROR_INVALID_EVENT
8207+
/// - ::UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST
8208+
/// + `phEventWaitList == NULL && numEventsInWaitList > 0`
8209+
/// + `phEventWaitList != NULL && numEventsInWaitList == 0`
8210+
/// + If event objects in phEventWaitList are not valid events.
8211+
/// - ::UR_RESULT_ERROR_INVALID_WORK_DIMENSION
8212+
/// - ::UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE
8213+
/// - ::UR_RESULT_ERROR_INVALID_VALUE
8214+
/// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY
8215+
/// - ::UR_RESULT_ERROR_OUT_OF_RESOURCES
8216+
UR_APIEXPORT ur_result_t UR_APICALL
8217+
urEnqueueCooperativeKernelLaunchExp(
8218+
ur_queue_handle_t hQueue, ///< [in] handle of the queue object
8219+
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
8220+
uint32_t workDim, ///< [in] number of dimensions, from 1 to 3, to specify the global and
8221+
///< work-group work-items
8222+
const size_t *pGlobalWorkOffset, ///< [in] pointer to an array of workDim unsigned values that specify the
8223+
///< offset used to calculate the global ID of a work-item
8224+
const size_t *pGlobalWorkSize, ///< [in] pointer to an array of workDim unsigned values that specify the
8225+
///< number of global work-items in workDim that will execute the kernel
8226+
///< function
8227+
const size_t *pLocalWorkSize, ///< [in][optional] pointer to an array of workDim unsigned values that
8228+
///< specify the number of local work-items forming a work-group that will
8229+
///< execute the kernel function.
8230+
///< If nullptr, the runtime implementation will choose the work-group
8231+
///< size.
8232+
uint32_t numEventsInWaitList, ///< [in] size of the event wait list
8233+
const ur_event_handle_t *phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of
8234+
///< events that must be complete before the kernel execution.
8235+
///< If nullptr, the numEventsInWaitList must be 0, indicating that no wait
8236+
///< event.
8237+
ur_event_handle_t *phEvent ///< [out][optional] return an event object that identifies this particular
8238+
///< kernel execution instance.
8239+
);
8240+
8241+
///////////////////////////////////////////////////////////////////////////////
8242+
/// @brief Query the maximum number of work groups for a cooperative kernel
8243+
///
8244+
/// @returns
8245+
/// - ::UR_RESULT_SUCCESS
8246+
/// - ::UR_RESULT_ERROR_UNINITIALIZED
8247+
/// - ::UR_RESULT_ERROR_DEVICE_LOST
8248+
/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC
8249+
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
8250+
/// + `NULL == hKernel`
8251+
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
8252+
/// + `NULL == pGroupCountRet`
8253+
/// - ::UR_RESULT_ERROR_INVALID_KERNEL
8254+
UR_APIEXPORT ur_result_t UR_APICALL
8255+
urKernelSuggestMaxCooperativeGroupCountExp(
8256+
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
8257+
uint32_t *pGroupCountRet ///< [out] pointer to maximum number of groups
8258+
);
8259+
81748260
#if !defined(__GNUC__)
81758261
#pragma endregion
81768262
#endif
@@ -8939,6 +9025,15 @@ typedef struct ur_kernel_set_specialization_constants_params_t {
89399025
const ur_specialization_constant_info_t **ppSpecConstants;
89409026
} ur_kernel_set_specialization_constants_params_t;
89419027

9028+
///////////////////////////////////////////////////////////////////////////////
9029+
/// @brief Function parameters for urKernelSuggestMaxCooperativeGroupCountExp
9030+
/// @details Each entry is a pointer to the parameter passed to the function;
9031+
/// allowing the callback the ability to modify the parameter's value
9032+
typedef struct ur_kernel_suggest_max_cooperative_group_count_exp_params_t {
9033+
ur_kernel_handle_t *phKernel;
9034+
uint32_t **ppGroupCountRet;
9035+
} ur_kernel_suggest_max_cooperative_group_count_exp_params_t;
9036+
89429037
///////////////////////////////////////////////////////////////////////////////
89439038
/// @brief Function parameters for urSamplerCreate
89449039
/// @details Each entry is a pointer to the parameter passed to the function;
@@ -9586,6 +9681,22 @@ typedef struct ur_enqueue_write_host_pipe_params_t {
95869681
ur_event_handle_t **pphEvent;
95879682
} ur_enqueue_write_host_pipe_params_t;
95889683

9684+
///////////////////////////////////////////////////////////////////////////////
9685+
/// @brief Function parameters for urEnqueueCooperativeKernelLaunchExp
9686+
/// @details Each entry is a pointer to the parameter passed to the function;
9687+
/// allowing the callback the ability to modify the parameter's value
9688+
typedef struct ur_enqueue_cooperative_kernel_launch_exp_params_t {
9689+
ur_queue_handle_t *phQueue;
9690+
ur_kernel_handle_t *phKernel;
9691+
uint32_t *pworkDim;
9692+
const size_t **ppGlobalWorkOffset;
9693+
const size_t **ppGlobalWorkSize;
9694+
const size_t **ppLocalWorkSize;
9695+
uint32_t *pnumEventsInWaitList;
9696+
const ur_event_handle_t **pphEventWaitList;
9697+
ur_event_handle_t **pphEvent;
9698+
} ur_enqueue_cooperative_kernel_launch_exp_params_t;
9699+
95899700
///////////////////////////////////////////////////////////////////////////////
95909701
/// @brief Function parameters for urQueueGetInfo
95919702
/// @details Each entry is a pointer to the parameter passed to the function;

include/ur_ddi.h

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,39 @@ typedef ur_result_t(UR_APICALL *ur_pfnGetKernelProcAddrTable_t)(
567567
ur_api_version_t,
568568
ur_kernel_dditable_t *);
569569

570+
///////////////////////////////////////////////////////////////////////////////
571+
/// @brief Function-pointer for urKernelSuggestMaxCooperativeGroupCountExp
572+
typedef ur_result_t(UR_APICALL *ur_pfnKernelSuggestMaxCooperativeGroupCountExp_t)(
573+
ur_kernel_handle_t,
574+
uint32_t *);
575+
576+
///////////////////////////////////////////////////////////////////////////////
577+
/// @brief Table of KernelExp functions pointers
578+
typedef struct ur_kernel_exp_dditable_t {
579+
ur_pfnKernelSuggestMaxCooperativeGroupCountExp_t pfnSuggestMaxCooperativeGroupCountExp;
580+
} ur_kernel_exp_dditable_t;
581+
582+
///////////////////////////////////////////////////////////////////////////////
583+
/// @brief Exported function for filling application's KernelExp table
584+
/// with current process' addresses
585+
///
586+
/// @returns
587+
/// - ::UR_RESULT_SUCCESS
588+
/// - ::UR_RESULT_ERROR_UNINITIALIZED
589+
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
590+
/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION
591+
UR_DLLEXPORT ur_result_t UR_APICALL
592+
urGetKernelExpProcAddrTable(
593+
ur_api_version_t version, ///< [in] API version requested
594+
ur_kernel_exp_dditable_t *pDdiTable ///< [in,out] pointer to table of DDI function pointers
595+
);
596+
597+
///////////////////////////////////////////////////////////////////////////////
598+
/// @brief Function-pointer for urGetKernelExpProcAddrTable
599+
typedef ur_result_t(UR_APICALL *ur_pfnGetKernelExpProcAddrTable_t)(
600+
ur_api_version_t,
601+
ur_kernel_exp_dditable_t *);
602+
570603
///////////////////////////////////////////////////////////////////////////////
571604
/// @brief Function-pointer for urSamplerCreate
572605
typedef ur_result_t(UR_APICALL *ur_pfnSamplerCreate_t)(
@@ -1246,6 +1279,46 @@ typedef ur_result_t(UR_APICALL *ur_pfnGetEnqueueProcAddrTable_t)(
12461279
ur_api_version_t,
12471280
ur_enqueue_dditable_t *);
12481281

1282+
///////////////////////////////////////////////////////////////////////////////
1283+
/// @brief Function-pointer for urEnqueueCooperativeKernelLaunchExp
1284+
typedef ur_result_t(UR_APICALL *ur_pfnEnqueueCooperativeKernelLaunchExp_t)(
1285+
ur_queue_handle_t,
1286+
ur_kernel_handle_t,
1287+
uint32_t,
1288+
const size_t *,
1289+
const size_t *,
1290+
const size_t *,
1291+
uint32_t,
1292+
const ur_event_handle_t *,
1293+
ur_event_handle_t *);
1294+
1295+
///////////////////////////////////////////////////////////////////////////////
1296+
/// @brief Table of EnqueueExp functions pointers
1297+
typedef struct ur_enqueue_exp_dditable_t {
1298+
ur_pfnEnqueueCooperativeKernelLaunchExp_t pfnCooperativeKernelLaunchExp;
1299+
} ur_enqueue_exp_dditable_t;
1300+
1301+
///////////////////////////////////////////////////////////////////////////////
1302+
/// @brief Exported function for filling application's EnqueueExp table
1303+
/// with current process' addresses
1304+
///
1305+
/// @returns
1306+
/// - ::UR_RESULT_SUCCESS
1307+
/// - ::UR_RESULT_ERROR_UNINITIALIZED
1308+
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
1309+
/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION
1310+
UR_DLLEXPORT ur_result_t UR_APICALL
1311+
urGetEnqueueExpProcAddrTable(
1312+
ur_api_version_t version, ///< [in] API version requested
1313+
ur_enqueue_exp_dditable_t *pDdiTable ///< [in,out] pointer to table of DDI function pointers
1314+
);
1315+
1316+
///////////////////////////////////////////////////////////////////////////////
1317+
/// @brief Function-pointer for urGetEnqueueExpProcAddrTable
1318+
typedef ur_result_t(UR_APICALL *ur_pfnGetEnqueueExpProcAddrTable_t)(
1319+
ur_api_version_t,
1320+
ur_enqueue_exp_dditable_t *);
1321+
12491322
///////////////////////////////////////////////////////////////////////////////
12501323
/// @brief Function-pointer for urQueueGetInfo
12511324
typedef ur_result_t(UR_APICALL *ur_pfnQueueGetInfo_t)(
@@ -2154,11 +2227,13 @@ typedef struct ur_dditable_t {
21542227
ur_event_dditable_t Event;
21552228
ur_program_dditable_t Program;
21562229
ur_kernel_dditable_t Kernel;
2230+
ur_kernel_exp_dditable_t KernelExp;
21572231
ur_sampler_dditable_t Sampler;
21582232
ur_mem_dditable_t Mem;
21592233
ur_physical_mem_dditable_t PhysicalMem;
21602234
ur_global_dditable_t Global;
21612235
ur_enqueue_dditable_t Enqueue;
2236+
ur_enqueue_exp_dditable_t EnqueueExp;
21622237
ur_queue_dditable_t Queue;
21632238
ur_bindless_images_exp_dditable_t BindlessImagesExp;
21642239
ur_usm_dditable_t USM;

0 commit comments

Comments
 (0)