@@ -271,13 +271,264 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
271
271
}
272
272
273
273
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp (
274
- ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
275
- const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
276
- const size_t *pLocalWorkSize, uint32_t numEventsInWaitList,
277
- const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
278
- return urEnqueueKernelLaunch (hQueue, hKernel, workDim, pGlobalWorkOffset,
279
- pGlobalWorkSize, pLocalWorkSize,
280
- numEventsInWaitList, phEventWaitList, phEvent);
274
+ ur_queue_handle_t Queue, // /< [in] handle of the queue object
275
+ ur_kernel_handle_t Kernel, // /< [in] handle of the kernel object
276
+ uint32_t WorkDim, // /< [in] number of dimensions, from 1 to 3, to specify
277
+ // /< the global and work-group work-items
278
+ const size_t
279
+ *GlobalWorkOffset, // /< [in] pointer to an array of workDim unsigned
280
+ // /< values that specify the offset used to
281
+ // /< calculate the global ID of a work-item
282
+ const size_t *GlobalWorkSize, // /< [in] pointer to an array of workDim
283
+ // /< unsigned values that specify the number
284
+ // /< of global work-items in workDim that
285
+ // /< will execute the kernel function
286
+ const size_t
287
+ *LocalWorkSize, // /< [in][optional] pointer to an array of workDim
288
+ // /< unsigned values that specify the number of local
289
+ // /< work-items forming a work-group that will execute
290
+ // /< the kernel function. If nullptr, the runtime
291
+ // /< implementation will choose the work-group size.
292
+ uint32_t NumEventsInWaitList, // /< [in] size of the event wait list
293
+ const ur_event_handle_t
294
+ *EventWaitList, // /< [in][optional][range(0, numEventsInWaitList)]
295
+ // /< pointer to a list of events that must be complete
296
+ // /< before the kernel execution. If nullptr, the
297
+ // /< numEventsInWaitList must be 0, indicating that no
298
+ // /< wait event.
299
+ ur_event_handle_t
300
+ *OutEvent // /< [in,out][optional] return an event object that identifies
301
+ // /< this particular kernel execution instance.
302
+ ) {
303
+ auto ZeDevice = Queue->Device ->ZeDevice ;
304
+
305
+ ze_kernel_handle_t ZeKernel{};
306
+ if (Kernel->ZeKernelMap .empty ()) {
307
+ ZeKernel = Kernel->ZeKernel ;
308
+ } else {
309
+ auto It = Kernel->ZeKernelMap .find (ZeDevice);
310
+ if (It == Kernel->ZeKernelMap .end ()) {
311
+ /* kernel and queue don't match */
312
+ return UR_RESULT_ERROR_INVALID_QUEUE;
313
+ }
314
+ ZeKernel = It->second ;
315
+ }
316
+ // Lock automatically releases when this goes out of scope.
317
+ std::scoped_lock<ur_shared_mutex, ur_shared_mutex, ur_shared_mutex> Lock (
318
+ Queue->Mutex , Kernel->Mutex , Kernel->Program ->Mutex );
319
+ if (GlobalWorkOffset != NULL ) {
320
+ if (!Queue->Device ->Platform ->ZeDriverGlobalOffsetExtensionFound ) {
321
+ logger::error (" No global offset extension found on this driver" );
322
+ return UR_RESULT_ERROR_INVALID_VALUE;
323
+ }
324
+
325
+ ZE2UR_CALL (zeKernelSetGlobalOffsetExp,
326
+ (ZeKernel, GlobalWorkOffset[0 ], GlobalWorkOffset[1 ],
327
+ GlobalWorkOffset[2 ]));
328
+ }
329
+
330
+ // If there are any pending arguments set them now.
331
+ for (auto &Arg : Kernel->PendingArguments ) {
332
+ // The ArgValue may be a NULL pointer in which case a NULL value is used for
333
+ // the kernel argument declared as a pointer to global or constant memory.
334
+ char **ZeHandlePtr = nullptr ;
335
+ if (Arg.Value ) {
336
+ UR_CALL (Arg.Value ->getZeHandlePtr (ZeHandlePtr, Arg.AccessMode ,
337
+ Queue->Device ));
338
+ }
339
+ ZE2UR_CALL (zeKernelSetArgumentValue,
340
+ (ZeKernel, Arg.Index , Arg.Size , ZeHandlePtr));
341
+ }
342
+ Kernel->PendingArguments .clear ();
343
+
344
+ ze_group_count_t ZeThreadGroupDimensions{1 , 1 , 1 };
345
+ uint32_t WG[3 ]{};
346
+
347
+ // New variable needed because GlobalWorkSize parameter might not be of size 3
348
+ size_t GlobalWorkSize3D[3 ]{1 , 1 , 1 };
349
+ std::copy (GlobalWorkSize, GlobalWorkSize + WorkDim, GlobalWorkSize3D);
350
+
351
+ if (LocalWorkSize) {
352
+ // L0
353
+ UR_ASSERT (LocalWorkSize[0 ] < (std::numeric_limits<uint32_t >::max)(),
354
+ UR_RESULT_ERROR_INVALID_VALUE);
355
+ UR_ASSERT (LocalWorkSize[1 ] < (std::numeric_limits<uint32_t >::max)(),
356
+ UR_RESULT_ERROR_INVALID_VALUE);
357
+ UR_ASSERT (LocalWorkSize[2 ] < (std::numeric_limits<uint32_t >::max)(),
358
+ UR_RESULT_ERROR_INVALID_VALUE);
359
+ WG[0 ] = static_cast <uint32_t >(LocalWorkSize[0 ]);
360
+ WG[1 ] = static_cast <uint32_t >(LocalWorkSize[1 ]);
361
+ WG[2 ] = static_cast <uint32_t >(LocalWorkSize[2 ]);
362
+ } else {
363
+ // We can't call to zeKernelSuggestGroupSize if 64-bit GlobalWorkSize
364
+ // values do not fit to 32-bit that the API only supports currently.
365
+ bool SuggestGroupSize = true ;
366
+ for (int I : {0 , 1 , 2 }) {
367
+ if (GlobalWorkSize3D[I] > UINT32_MAX) {
368
+ SuggestGroupSize = false ;
369
+ }
370
+ }
371
+ if (SuggestGroupSize) {
372
+ ZE2UR_CALL (zeKernelSuggestGroupSize,
373
+ (ZeKernel, GlobalWorkSize3D[0 ], GlobalWorkSize3D[1 ],
374
+ GlobalWorkSize3D[2 ], &WG[0 ], &WG[1 ], &WG[2 ]));
375
+ } else {
376
+ for (int I : {0 , 1 , 2 }) {
377
+ // Try to find a I-dimension WG size that the GlobalWorkSize[I] is
378
+ // fully divisable with. Start with the max possible size in
379
+ // each dimension.
380
+ uint32_t GroupSize[] = {
381
+ Queue->Device ->ZeDeviceComputeProperties ->maxGroupSizeX ,
382
+ Queue->Device ->ZeDeviceComputeProperties ->maxGroupSizeY ,
383
+ Queue->Device ->ZeDeviceComputeProperties ->maxGroupSizeZ };
384
+ GroupSize[I] = (std::min)(size_t (GroupSize[I]), GlobalWorkSize3D[I]);
385
+ while (GlobalWorkSize3D[I] % GroupSize[I]) {
386
+ --GroupSize[I];
387
+ }
388
+
389
+ if (GlobalWorkSize3D[I] / GroupSize[I] > UINT32_MAX) {
390
+ logger::error (
391
+ " urEnqueueCooperativeKernelLaunchExp: can't find a WG size "
392
+ " suitable for global work size > UINT32_MAX" );
393
+ return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
394
+ }
395
+ WG[I] = GroupSize[I];
396
+ }
397
+ logger::debug (" urEnqueueCooperativeKernelLaunchExp: using computed WG "
398
+ " size = {{{}, {}, {}}}" ,
399
+ WG[0 ], WG[1 ], WG[2 ]);
400
+ }
401
+ }
402
+
403
+ // TODO: assert if sizes do not fit into 32-bit?
404
+
405
+ switch (WorkDim) {
406
+ case 3 :
407
+ ZeThreadGroupDimensions.groupCountX =
408
+ static_cast <uint32_t >(GlobalWorkSize3D[0 ] / WG[0 ]);
409
+ ZeThreadGroupDimensions.groupCountY =
410
+ static_cast <uint32_t >(GlobalWorkSize3D[1 ] / WG[1 ]);
411
+ ZeThreadGroupDimensions.groupCountZ =
412
+ static_cast <uint32_t >(GlobalWorkSize3D[2 ] / WG[2 ]);
413
+ break ;
414
+ case 2 :
415
+ ZeThreadGroupDimensions.groupCountX =
416
+ static_cast <uint32_t >(GlobalWorkSize3D[0 ] / WG[0 ]);
417
+ ZeThreadGroupDimensions.groupCountY =
418
+ static_cast <uint32_t >(GlobalWorkSize3D[1 ] / WG[1 ]);
419
+ WG[2 ] = 1 ;
420
+ break ;
421
+ case 1 :
422
+ ZeThreadGroupDimensions.groupCountX =
423
+ static_cast <uint32_t >(GlobalWorkSize3D[0 ] / WG[0 ]);
424
+ WG[1 ] = WG[2 ] = 1 ;
425
+ break ;
426
+
427
+ default :
428
+ logger::error (" urEnqueueCooperativeKernelLaunchExp: unsupported work_dim" );
429
+ return UR_RESULT_ERROR_INVALID_VALUE;
430
+ }
431
+
432
+ // Error handling for non-uniform group size case
433
+ if (GlobalWorkSize3D[0 ] !=
434
+ size_t (ZeThreadGroupDimensions.groupCountX ) * WG[0 ]) {
435
+ logger::error (" urEnqueueCooperativeKernelLaunchExp: invalid work_dim. The "
436
+ " range is not a "
437
+ " multiple of the group size in the 1st dimension" );
438
+ return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
439
+ }
440
+ if (GlobalWorkSize3D[1 ] !=
441
+ size_t (ZeThreadGroupDimensions.groupCountY ) * WG[1 ]) {
442
+ logger::error (" urEnqueueCooperativeKernelLaunchExp: invalid work_dim. The "
443
+ " range is not a "
444
+ " multiple of the group size in the 2nd dimension" );
445
+ return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
446
+ }
447
+ if (GlobalWorkSize3D[2 ] !=
448
+ size_t (ZeThreadGroupDimensions.groupCountZ ) * WG[2 ]) {
449
+ logger::debug (" urEnqueueCooperativeKernelLaunchExp: invalid work_dim. The "
450
+ " range is not a "
451
+ " multiple of the group size in the 3rd dimension" );
452
+ return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
453
+ }
454
+
455
+ ZE2UR_CALL (zeKernelSetGroupSize, (ZeKernel, WG[0 ], WG[1 ], WG[2 ]));
456
+
457
+ bool UseCopyEngine = false ;
458
+ _ur_ze_event_list_t TmpWaitList;
459
+ UR_CALL (TmpWaitList.createAndRetainUrZeEventList (
460
+ NumEventsInWaitList, EventWaitList, Queue, UseCopyEngine));
461
+
462
+ // Get a new command list to be used on this call
463
+ ur_command_list_ptr_t CommandList{};
464
+ UR_CALL (Queue->Context ->getAvailableCommandList (
465
+ Queue, CommandList, UseCopyEngine, NumEventsInWaitList, EventWaitList,
466
+ true /* AllowBatching */ ));
467
+
468
+ ze_event_handle_t ZeEvent = nullptr ;
469
+ ur_event_handle_t InternalEvent{};
470
+ bool IsInternal = OutEvent == nullptr ;
471
+ ur_event_handle_t *Event = OutEvent ? OutEvent : &InternalEvent;
472
+
473
+ UR_CALL (createEventAndAssociateQueue (Queue, Event, UR_COMMAND_KERNEL_LAUNCH,
474
+ CommandList, IsInternal, false ));
475
+ UR_CALL (setSignalEvent (Queue, UseCopyEngine, &ZeEvent, Event,
476
+ NumEventsInWaitList, EventWaitList,
477
+ CommandList->second .ZeQueue ));
478
+ (*Event)->WaitList = TmpWaitList;
479
+
480
+ // Save the kernel in the event, so that when the event is signalled
481
+ // the code can do a urKernelRelease on this kernel.
482
+ (*Event)->CommandData = (void *)Kernel;
483
+
484
+ // Increment the reference count of the Kernel and indicate that the Kernel
485
+ // is in use. Once the event has been signalled, the code in
486
+ // CleanupCompletedEvent(Event) will do a urKernelRelease to update the
487
+ // reference count on the kernel, using the kernel saved in CommandData.
488
+ UR_CALL (urKernelRetain (Kernel));
489
+
490
+ // Add to list of kernels to be submitted
491
+ if (IndirectAccessTrackingEnabled)
492
+ Queue->KernelsToBeSubmitted .push_back (Kernel);
493
+
494
+ if (Queue->UsingImmCmdLists && IndirectAccessTrackingEnabled) {
495
+ // If using immediate commandlists then gathering of indirect
496
+ // references and appending to the queue (which means submission)
497
+ // must be done together.
498
+ std::unique_lock<ur_shared_mutex> ContextsLock (
499
+ Queue->Device ->Platform ->ContextsMutex , std::defer_lock);
500
+ // We are going to submit kernels for execution. If indirect access flag is
501
+ // set for a kernel then we need to make a snapshot of existing memory
502
+ // allocations in all contexts in the platform. We need to lock the mutex
503
+ // guarding the list of contexts in the platform to prevent creation of new
504
+ // memory alocations in any context before we submit the kernel for
505
+ // execution.
506
+ ContextsLock.lock ();
507
+ Queue->CaptureIndirectAccesses ();
508
+ // Add the command to the command list, which implies submission.
509
+ ZE2UR_CALL (zeCommandListAppendLaunchCooperativeKernel,
510
+ (CommandList->first , ZeKernel, &ZeThreadGroupDimensions, ZeEvent,
511
+ (*Event)->WaitList .Length , (*Event)->WaitList .ZeEventList ));
512
+ } else {
513
+ // Add the command to the command list for later submission.
514
+ // No lock is needed here, unlike the immediate commandlist case above,
515
+ // because the kernels are not actually submitted yet. Kernels will be
516
+ // submitted only when the comamndlist is closed. Then, a lock is held.
517
+ ZE2UR_CALL (zeCommandListAppendLaunchCooperativeKernel,
518
+ (CommandList->first , ZeKernel, &ZeThreadGroupDimensions, ZeEvent,
519
+ (*Event)->WaitList .Length , (*Event)->WaitList .ZeEventList ));
520
+ }
521
+
522
+ logger::debug (" calling zeCommandListAppendLaunchCooperativeKernel() with"
523
+ " ZeEvent {}" ,
524
+ ur_cast<std::uintptr_t >(ZeEvent));
525
+ printZeEventList ((*Event)->WaitList );
526
+
527
+ // Execute command list asynchronously, as the event will be used
528
+ // to track down its completion.
529
+ UR_CALL (Queue->executeCommandList (CommandList, false , true ));
530
+
531
+ return UR_RESULT_SUCCESS;
281
532
}
282
533
283
534
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableWrite (
@@ -829,10 +1080,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetNativeHandle(
829
1080
UR_APIEXPORT ur_result_t UR_APICALL urKernelSuggestMaxCooperativeGroupCountExp (
830
1081
ur_kernel_handle_t hKernel, size_t localWorkSize,
831
1082
size_t dynamicSharedMemorySize, uint32_t *pGroupCountRet) {
832
- (void )hKernel;
833
1083
(void )localWorkSize;
834
1084
(void )dynamicSharedMemorySize;
835
- *pGroupCountRet = 1 ;
1085
+ std::shared_lock<ur_shared_mutex> Guard (hKernel->Mutex );
1086
+ uint32_t TotalGroupCount = 0 ;
1087
+ ZE2UR_CALL (zeKernelSuggestMaxCooperativeGroupCount,
1088
+ (hKernel->ZeKernel , &TotalGroupCount));
1089
+ *pGroupCountRet = TotalGroupCount;
836
1090
return UR_RESULT_SUCCESS;
837
1091
}
838
1092
0 commit comments