@@ -420,8 +420,6 @@ ur_result_t urEnqueueKernelLaunch(
420
420
// / [out][optional] return an event object that identifies this
421
421
// / particular kernel execution instance.
422
422
ur_event_handle_t *phEvent) {
423
- auto pfnKernelLaunch = getContext ()->urDdiTable .Enqueue .pfnKernelLaunch ;
424
-
425
423
getContext ()->logger .debug (" ==== urEnqueueKernelLaunch" );
426
424
427
425
USMLaunchInfo LaunchInfo (GetContext (hQueue), GetDevice (hQueue),
@@ -431,22 +429,14 @@ ur_result_t urEnqueueKernelLaunch(
431
429
432
430
UR_CALL (getMsanInterceptor ()->preLaunchKernel (hKernel, hQueue, LaunchInfo));
433
431
434
- ur_event_handle_t hEvent{};
435
- ur_result_t result =
436
- pfnKernelLaunch (hQueue, hKernel, workDim, pGlobalWorkOffset,
437
- pGlobalWorkSize, LaunchInfo.LocalWorkSize .data (),
438
- numEventsInWaitList, phEventWaitList, &hEvent);
439
-
440
- if (result == UR_RESULT_SUCCESS) {
441
- UR_CALL (
442
- getMsanInterceptor ()->postLaunchKernel (hKernel, hQueue, LaunchInfo));
443
- }
432
+ UR_CALL (getContext ()->urDdiTable .Enqueue .pfnKernelLaunch (
433
+ hQueue, hKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize,
434
+ LaunchInfo.LocalWorkSize .data (), numEventsInWaitList, phEventWaitList,
435
+ phEvent));
444
436
445
- if (phEvent) {
446
- *phEvent = hEvent;
447
- }
437
+ UR_CALL (getMsanInterceptor ()->postLaunchKernel (hKernel, hQueue, LaunchInfo));
448
438
449
- return result ;
439
+ return UR_RESULT_SUCCESS ;
450
440
}
451
441
452
442
// /////////////////////////////////////////////////////////////////////////////
@@ -1323,6 +1313,58 @@ ur_result_t urEnqueueMemUnmap(
1323
1313
return UR_RESULT_SUCCESS;
1324
1314
}
1325
1315
1316
+ ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp (
1317
+ // / [in] handle of the queue object
1318
+ ur_queue_handle_t hQueue,
1319
+ // / [in] handle of the kernel object
1320
+ ur_kernel_handle_t hKernel,
1321
+ // / [in] number of dimensions, from 1 to 3, to specify the global and
1322
+ // / work-group work-items
1323
+ uint32_t workDim,
1324
+ // / [in] pointer to an array of workDim unsigned values that specify the
1325
+ // / offset used to calculate the global ID of a work-item
1326
+ const size_t *pGlobalWorkOffset,
1327
+ // / [in] pointer to an array of workDim unsigned values that specify the
1328
+ // / number of global work-items in workDim that will execute the kernel
1329
+ // / function
1330
+ const size_t *pGlobalWorkSize,
1331
+ // / [in][optional] pointer to an array of workDim unsigned values that
1332
+ // / specify the number of local work-items forming a work-group that will
1333
+ // / execute the kernel function.
1334
+ // / If nullptr, the runtime implementation will choose the work-group size.
1335
+ const size_t *pLocalWorkSize,
1336
+ // / [in] size of the event wait list
1337
+ uint32_t numEventsInWaitList,
1338
+ // / [in][optional][range(0, numEventsInWaitList)] pointer to a list of
1339
+ // / events that must be complete before the kernel execution.
1340
+ // / If nullptr, the numEventsInWaitList must be 0, indicating that no wait
1341
+ // / event.
1342
+ const ur_event_handle_t *phEventWaitList,
1343
+ // / [out][optional][alloc] return an event object that identifies this
1344
+ // / particular kernel execution instance. If phEventWaitList and phEvent
1345
+ // / are not NULL, phEvent must not refer to an element of the
1346
+ // / phEventWaitList array.
1347
+ ur_event_handle_t *phEvent) {
1348
+
1349
+ getContext ()->logger .debug (" ==== urEnqueueCooperativeKernelLaunchExp" );
1350
+
1351
+ USMLaunchInfo LaunchInfo (GetContext (hQueue), GetDevice (hQueue),
1352
+ pGlobalWorkSize, pLocalWorkSize, pGlobalWorkOffset,
1353
+ workDim);
1354
+ UR_CALL (LaunchInfo.initialize ());
1355
+
1356
+ UR_CALL (getMsanInterceptor ()->preLaunchKernel (hKernel, hQueue, LaunchInfo));
1357
+
1358
+ UR_CALL (getContext ()->urDdiTable .EnqueueExp .pfnCooperativeKernelLaunchExp (
1359
+ hQueue, hKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize,
1360
+ LaunchInfo.LocalWorkSize .data (), numEventsInWaitList, phEventWaitList,
1361
+ phEvent));
1362
+
1363
+ UR_CALL (getMsanInterceptor ()->postLaunchKernel (hKernel, hQueue, LaunchInfo));
1364
+
1365
+ return UR_RESULT_SUCCESS;
1366
+ }
1367
+
1326
1368
// /////////////////////////////////////////////////////////////////////////////
1327
1369
// / @brief Intercept function for urKernelRetain
1328
1370
ur_result_t urKernelRetain (
@@ -1912,6 +1954,25 @@ ur_result_t urCheckVersion(ur_api_version_t version) {
1912
1954
return UR_RESULT_SUCCESS;
1913
1955
}
1914
1956
1957
+ // /////////////////////////////////////////////////////////////////////////////
1958
+ // / @brief Exported function for filling application's EnqueueExp table
1959
+ // / with current process' addresses
1960
+ // /
1961
+ // / @returns
1962
+ // / - ::UR_RESULT_SUCCESS
1963
+ // / - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
1964
+ __urdlllocal ur_result_t UR_APICALL urGetEnqueueExpProcAddrTable (
1965
+ // / [in,out] pointer to table of DDI function pointers
1966
+ ur_enqueue_exp_dditable_t *pDdiTable) {
1967
+ if (nullptr == pDdiTable) {
1968
+ return UR_RESULT_ERROR_INVALID_NULL_POINTER;
1969
+ }
1970
+
1971
+ pDdiTable->pfnCooperativeKernelLaunchExp =
1972
+ ur_sanitizer_layer::msan::urEnqueueCooperativeKernelLaunchExp;
1973
+ return UR_RESULT_SUCCESS;
1974
+ }
1975
+
1915
1976
} // namespace msan
1916
1977
1917
1978
ur_result_t initMsanDDITable (ur_dditable_t *dditable) {
@@ -1966,6 +2027,11 @@ ur_result_t initMsanDDITable(ur_dditable_t *dditable) {
1966
2027
result = ur_sanitizer_layer::msan::urGetUSMProcAddrTable (&dditable->USM );
1967
2028
}
1968
2029
2030
+ if (UR_RESULT_SUCCESS == result) {
2031
+ result = ur_sanitizer_layer::msan::urGetEnqueueExpProcAddrTable (
2032
+ &dditable->EnqueueExp );
2033
+ }
2034
+
1969
2035
if (result != UR_RESULT_SUCCESS) {
1970
2036
getContext ()->logger .error (" Initialize MSAN DDI table failed: {}" , result);
1971
2037
}
0 commit comments