Skip to content

Commit d3d3f6e

Browse files
committed
Implement L0 cooperative kernel functions
Defines `urKernelSuggestMaxCooperativeGroupCountExp` and `urEnqueueCooperativeKernelLaunchExp` to enable cooperative kernels with more than one work group. Signed-off-by: Michael Aziz <michael.aziz@intel.com>
1 parent fcd3693 commit d3d3f6e

File tree

1 file changed

+263
-9
lines changed

1 file changed

+263
-9
lines changed

source/adapters/level_zero/kernel.cpp

Lines changed: 263 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -271,13 +271,264 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
271271
}
272272

273273
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
274-
ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
275-
const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
276-
const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
277-
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
278-
return urEnqueueKernelLaunch(hQueue, hKernel, workDim, pGlobalWorkOffset,
279-
pGlobalWorkSize, pLocalWorkSize,
280-
numEventsInWaitList, phEventWaitList, phEvent);
274+
ur_queue_handle_t Queue, ///< [in] handle of the queue object
275+
ur_kernel_handle_t Kernel, ///< [in] handle of the kernel object
276+
uint32_t WorkDim, ///< [in] number of dimensions, from 1 to 3, to specify
277+
///< the global and work-group work-items
278+
const size_t
279+
*GlobalWorkOffset, ///< [in] pointer to an array of workDim unsigned
280+
///< values that specify the offset used to
281+
///< calculate the global ID of a work-item
282+
const size_t *GlobalWorkSize, ///< [in] pointer to an array of workDim
283+
///< unsigned values that specify the number
284+
///< of global work-items in workDim that
285+
///< will execute the kernel function
286+
const size_t
287+
*LocalWorkSize, ///< [in][optional] pointer to an array of workDim
288+
///< unsigned values that specify the number of local
289+
///< work-items forming a work-group that will execute
290+
///< the kernel function. If nullptr, the runtime
291+
///< implementation will choose the work-group size.
292+
uint32_t NumEventsInWaitList, ///< [in] size of the event wait list
293+
const ur_event_handle_t
294+
*EventWaitList, ///< [in][optional][range(0, numEventsInWaitList)]
295+
///< pointer to a list of events that must be complete
296+
///< before the kernel execution. If nullptr, the
297+
///< numEventsInWaitList must be 0, indicating that no
298+
///< wait event.
299+
ur_event_handle_t
300+
*OutEvent ///< [in,out][optional] return an event object that identifies
301+
///< this particular kernel execution instance.
302+
) {
303+
auto ZeDevice = Queue->Device->ZeDevice;
304+
305+
ze_kernel_handle_t ZeKernel{};
306+
if (Kernel->ZeKernelMap.empty()) {
307+
ZeKernel = Kernel->ZeKernel;
308+
} else {
309+
auto It = Kernel->ZeKernelMap.find(ZeDevice);
310+
if (It == Kernel->ZeKernelMap.end()) {
311+
/* kernel and queue don't match */
312+
return UR_RESULT_ERROR_INVALID_QUEUE;
313+
}
314+
ZeKernel = It->second;
315+
}
316+
// Lock automatically releases when this goes out of scope.
317+
std::scoped_lock<ur_shared_mutex, ur_shared_mutex, ur_shared_mutex> Lock(
318+
Queue->Mutex, Kernel->Mutex, Kernel->Program->Mutex);
319+
if (GlobalWorkOffset != NULL) {
320+
if (!Queue->Device->Platform->ZeDriverGlobalOffsetExtensionFound) {
321+
logger::error("No global offset extension found on this driver");
322+
return UR_RESULT_ERROR_INVALID_VALUE;
323+
}
324+
325+
ZE2UR_CALL(zeKernelSetGlobalOffsetExp,
326+
(ZeKernel, GlobalWorkOffset[0], GlobalWorkOffset[1],
327+
GlobalWorkOffset[2]));
328+
}
329+
330+
// If there are any pending arguments set them now.
331+
for (auto &Arg : Kernel->PendingArguments) {
332+
// The ArgValue may be a NULL pointer in which case a NULL value is used for
333+
// the kernel argument declared as a pointer to global or constant memory.
334+
char **ZeHandlePtr = nullptr;
335+
if (Arg.Value) {
336+
UR_CALL(Arg.Value->getZeHandlePtr(ZeHandlePtr, Arg.AccessMode,
337+
Queue->Device));
338+
}
339+
ZE2UR_CALL(zeKernelSetArgumentValue,
340+
(ZeKernel, Arg.Index, Arg.Size, ZeHandlePtr));
341+
}
342+
Kernel->PendingArguments.clear();
343+
344+
ze_group_count_t ZeThreadGroupDimensions{1, 1, 1};
345+
uint32_t WG[3]{};
346+
347+
// New variable needed because GlobalWorkSize parameter might not be of size 3
348+
size_t GlobalWorkSize3D[3]{1, 1, 1};
349+
std::copy(GlobalWorkSize, GlobalWorkSize + WorkDim, GlobalWorkSize3D);
350+
351+
if (LocalWorkSize) {
352+
// L0
353+
UR_ASSERT(LocalWorkSize[0] < (std::numeric_limits<uint32_t>::max)(),
354+
UR_RESULT_ERROR_INVALID_VALUE);
355+
UR_ASSERT(LocalWorkSize[1] < (std::numeric_limits<uint32_t>::max)(),
356+
UR_RESULT_ERROR_INVALID_VALUE);
357+
UR_ASSERT(LocalWorkSize[2] < (std::numeric_limits<uint32_t>::max)(),
358+
UR_RESULT_ERROR_INVALID_VALUE);
359+
WG[0] = static_cast<uint32_t>(LocalWorkSize[0]);
360+
WG[1] = static_cast<uint32_t>(LocalWorkSize[1]);
361+
WG[2] = static_cast<uint32_t>(LocalWorkSize[2]);
362+
} else {
363+
// We can't call to zeKernelSuggestGroupSize if 64-bit GlobalWorkSize
364+
// values do not fit to 32-bit that the API only supports currently.
365+
bool SuggestGroupSize = true;
366+
for (int I : {0, 1, 2}) {
367+
if (GlobalWorkSize3D[I] > UINT32_MAX) {
368+
SuggestGroupSize = false;
369+
}
370+
}
371+
if (SuggestGroupSize) {
372+
ZE2UR_CALL(zeKernelSuggestGroupSize,
373+
(ZeKernel, GlobalWorkSize3D[0], GlobalWorkSize3D[1],
374+
GlobalWorkSize3D[2], &WG[0], &WG[1], &WG[2]));
375+
} else {
376+
for (int I : {0, 1, 2}) {
377+
// Try to find a I-dimension WG size that the GlobalWorkSize[I] is
378+
// fully divisable with. Start with the max possible size in
379+
// each dimension.
380+
uint32_t GroupSize[] = {
381+
Queue->Device->ZeDeviceComputeProperties->maxGroupSizeX,
382+
Queue->Device->ZeDeviceComputeProperties->maxGroupSizeY,
383+
Queue->Device->ZeDeviceComputeProperties->maxGroupSizeZ};
384+
GroupSize[I] = (std::min)(size_t(GroupSize[I]), GlobalWorkSize3D[I]);
385+
while (GlobalWorkSize3D[I] % GroupSize[I]) {
386+
--GroupSize[I];
387+
}
388+
389+
if (GlobalWorkSize3D[I] / GroupSize[I] > UINT32_MAX) {
390+
logger::error(
391+
"urEnqueueCooperativeKernelLaunchExp: can't find a WG size "
392+
"suitable for global work size > UINT32_MAX");
393+
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
394+
}
395+
WG[I] = GroupSize[I];
396+
}
397+
logger::debug("urEnqueueCooperativeKernelLaunchExp: using computed WG "
398+
"size = {{{}, {}, {}}}",
399+
WG[0], WG[1], WG[2]);
400+
}
401+
}
402+
403+
// TODO: assert if sizes do not fit into 32-bit?
404+
405+
switch (WorkDim) {
406+
case 3:
407+
ZeThreadGroupDimensions.groupCountX =
408+
static_cast<uint32_t>(GlobalWorkSize3D[0] / WG[0]);
409+
ZeThreadGroupDimensions.groupCountY =
410+
static_cast<uint32_t>(GlobalWorkSize3D[1] / WG[1]);
411+
ZeThreadGroupDimensions.groupCountZ =
412+
static_cast<uint32_t>(GlobalWorkSize3D[2] / WG[2]);
413+
break;
414+
case 2:
415+
ZeThreadGroupDimensions.groupCountX =
416+
static_cast<uint32_t>(GlobalWorkSize3D[0] / WG[0]);
417+
ZeThreadGroupDimensions.groupCountY =
418+
static_cast<uint32_t>(GlobalWorkSize3D[1] / WG[1]);
419+
WG[2] = 1;
420+
break;
421+
case 1:
422+
ZeThreadGroupDimensions.groupCountX =
423+
static_cast<uint32_t>(GlobalWorkSize3D[0] / WG[0]);
424+
WG[1] = WG[2] = 1;
425+
break;
426+
427+
default:
428+
logger::error("urEnqueueCooperativeKernelLaunchExp: unsupported work_dim");
429+
return UR_RESULT_ERROR_INVALID_VALUE;
430+
}
431+
432+
// Error handling for non-uniform group size case
433+
if (GlobalWorkSize3D[0] !=
434+
size_t(ZeThreadGroupDimensions.groupCountX) * WG[0]) {
435+
logger::error("urEnqueueCooperativeKernelLaunchExp: invalid work_dim. The "
436+
"range is not a "
437+
"multiple of the group size in the 1st dimension");
438+
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
439+
}
440+
if (GlobalWorkSize3D[1] !=
441+
size_t(ZeThreadGroupDimensions.groupCountY) * WG[1]) {
442+
logger::error("urEnqueueCooperativeKernelLaunchExp: invalid work_dim. The "
443+
"range is not a "
444+
"multiple of the group size in the 2nd dimension");
445+
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
446+
}
447+
if (GlobalWorkSize3D[2] !=
448+
size_t(ZeThreadGroupDimensions.groupCountZ) * WG[2]) {
449+
logger::debug("urEnqueueCooperativeKernelLaunchExp: invalid work_dim. The "
450+
"range is not a "
451+
"multiple of the group size in the 3rd dimension");
452+
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
453+
}
454+
455+
ZE2UR_CALL(zeKernelSetGroupSize, (ZeKernel, WG[0], WG[1], WG[2]));
456+
457+
bool UseCopyEngine = false;
458+
_ur_ze_event_list_t TmpWaitList;
459+
UR_CALL(TmpWaitList.createAndRetainUrZeEventList(
460+
NumEventsInWaitList, EventWaitList, Queue, UseCopyEngine));
461+
462+
// Get a new command list to be used on this call
463+
ur_command_list_ptr_t CommandList{};
464+
UR_CALL(Queue->Context->getAvailableCommandList(
465+
Queue, CommandList, UseCopyEngine, NumEventsInWaitList, EventWaitList,
466+
true /* AllowBatching */));
467+
468+
ze_event_handle_t ZeEvent = nullptr;
469+
ur_event_handle_t InternalEvent{};
470+
bool IsInternal = OutEvent == nullptr;
471+
ur_event_handle_t *Event = OutEvent ? OutEvent : &InternalEvent;
472+
473+
UR_CALL(createEventAndAssociateQueue(Queue, Event, UR_COMMAND_KERNEL_LAUNCH,
474+
CommandList, IsInternal, false));
475+
UR_CALL(setSignalEvent(Queue, UseCopyEngine, &ZeEvent, Event,
476+
NumEventsInWaitList, EventWaitList,
477+
CommandList->second.ZeQueue));
478+
(*Event)->WaitList = TmpWaitList;
479+
480+
// Save the kernel in the event, so that when the event is signalled
481+
// the code can do a urKernelRelease on this kernel.
482+
(*Event)->CommandData = (void *)Kernel;
483+
484+
// Increment the reference count of the Kernel and indicate that the Kernel
485+
// is in use. Once the event has been signalled, the code in
486+
// CleanupCompletedEvent(Event) will do a urKernelRelease to update the
487+
// reference count on the kernel, using the kernel saved in CommandData.
488+
UR_CALL(urKernelRetain(Kernel));
489+
490+
// Add to list of kernels to be submitted
491+
if (IndirectAccessTrackingEnabled)
492+
Queue->KernelsToBeSubmitted.push_back(Kernel);
493+
494+
if (Queue->UsingImmCmdLists && IndirectAccessTrackingEnabled) {
495+
// If using immediate commandlists then gathering of indirect
496+
// references and appending to the queue (which means submission)
497+
// must be done together.
498+
std::unique_lock<ur_shared_mutex> ContextsLock(
499+
Queue->Device->Platform->ContextsMutex, std::defer_lock);
500+
// We are going to submit kernels for execution. If indirect access flag is
501+
// set for a kernel then we need to make a snapshot of existing memory
502+
// allocations in all contexts in the platform. We need to lock the mutex
503+
// guarding the list of contexts in the platform to prevent creation of new
504+
// memory alocations in any context before we submit the kernel for
505+
// execution.
506+
ContextsLock.lock();
507+
Queue->CaptureIndirectAccesses();
508+
// Add the command to the command list, which implies submission.
509+
ZE2UR_CALL(zeCommandListAppendLaunchCooperativeKernel,
510+
(CommandList->first, ZeKernel, &ZeThreadGroupDimensions, ZeEvent,
511+
(*Event)->WaitList.Length, (*Event)->WaitList.ZeEventList));
512+
} else {
513+
// Add the command to the command list for later submission.
514+
// No lock is needed here, unlike the immediate commandlist case above,
515+
// because the kernels are not actually submitted yet. Kernels will be
516+
// submitted only when the comamndlist is closed. Then, a lock is held.
517+
ZE2UR_CALL(zeCommandListAppendLaunchCooperativeKernel,
518+
(CommandList->first, ZeKernel, &ZeThreadGroupDimensions, ZeEvent,
519+
(*Event)->WaitList.Length, (*Event)->WaitList.ZeEventList));
520+
}
521+
522+
logger::debug("calling zeCommandListAppendLaunchCooperativeKernel() with"
523+
" ZeEvent {}",
524+
ur_cast<std::uintptr_t>(ZeEvent));
525+
printZeEventList((*Event)->WaitList);
526+
527+
// Execute command list asynchronously, as the event will be used
528+
// to track down its completion.
529+
UR_CALL(Queue->executeCommandList(CommandList, false, true));
530+
531+
return UR_RESULT_SUCCESS;
281532
}
282533

283534
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableWrite(
@@ -818,10 +1069,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetNativeHandle(
8181069
UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp(
8191070
ur_kernel_handle_t hKernel, size_t localWorkSize,
8201071
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) {
821-
(void)hKernel;
8221072
(void)localWorkSize;
8231073
(void)dynamicSharedMemorySize;
824-
*pGroupCountRet = 1;
1074+
std::shared_lock<ur_shared_mutex> Guard(hKernel->Mutex);
1075+
uint32_t TotalGroupCount = 0;
1076+
ZE2UR_CALL(zeKernelSuggestMaxCooperativeGroupCount,
1077+
(hKernel->ZeKernel, &TotalGroupCount));
1078+
*pGroupCountRet = TotalGroupCount;
8251079
return UR_RESULT_SUCCESS;
8261080
}
8271081

0 commit comments

Comments
 (0)