Skip to content

Commit beedc09

Browse files
authored
[DeviceSanitizer] Add support for urEnqueueCooperativeKernelLaunchExp for ASAN and MSAN (#18198)
Device ASAN and Device MSAN have not yet supported that API, this patch addresses that problem. Some sycl e2e tests failed with device sanitizers enabled because of the lack of this API support. --------- Signed-off-by: Wu, Yingcong <yingcong.wu@intel.com>
1 parent e365f42 commit beedc09

File tree

2 files changed

+163
-30
lines changed

2 files changed

+163
-30
lines changed

unified-runtime/source/loader/layers/sanitizer/asan/asan_ddi.cpp

Lines changed: 81 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -513,22 +513,14 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueKernelLaunch(
513513

514514
UR_CALL(getAsanInterceptor()->preLaunchKernel(hKernel, hQueue, LaunchInfo));
515515

516-
ur_event_handle_t hEvent{};
517-
ur_result_t result =
518-
pfnKernelLaunch(hQueue, hKernel, workDim, pGlobalWorkOffset,
519-
pGlobalWorkSize, LaunchInfo.LocalWorkSize.data(),
520-
numEventsInWaitList, phEventWaitList, &hEvent);
521-
522-
if (result == UR_RESULT_SUCCESS) {
523-
UR_CALL(
524-
getAsanInterceptor()->postLaunchKernel(hKernel, hQueue, LaunchInfo));
525-
}
516+
UR_CALL(getContext()->urDdiTable.Enqueue.pfnKernelLaunch(
517+
hQueue, hKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize,
518+
LaunchInfo.LocalWorkSize.data(), numEventsInWaitList, phEventWaitList,
519+
phEvent));
526520

527-
if (phEvent) {
528-
*phEvent = hEvent;
529-
}
521+
UR_CALL(getAsanInterceptor()->postLaunchKernel(hKernel, hQueue, LaunchInfo));
530522

531-
return result;
523+
return UR_RESULT_SUCCESS;
532524
}
533525

534526
///////////////////////////////////////////////////////////////////////////////
@@ -1410,6 +1402,57 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemUnmap(
14101402
return UR_RESULT_SUCCESS;
14111403
}
14121404

1405+
ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
1406+
/// [in] handle of the queue object
1407+
ur_queue_handle_t hQueue,
1408+
/// [in] handle of the kernel object
1409+
ur_kernel_handle_t hKernel,
1410+
/// [in] number of dimensions, from 1 to 3, to specify the global and
1411+
/// work-group work-items
1412+
uint32_t workDim,
1413+
/// [in] pointer to an array of workDim unsigned values that specify the
1414+
/// offset used to calculate the global ID of a work-item
1415+
const size_t *pGlobalWorkOffset,
1416+
/// [in] pointer to an array of workDim unsigned values that specify the
1417+
/// number of global work-items in workDim that will execute the kernel
1418+
/// function
1419+
const size_t *pGlobalWorkSize,
1420+
/// [in][optional] pointer to an array of workDim unsigned values that
1421+
/// specify the number of local work-items forming a work-group that will
1422+
/// execute the kernel function.
1423+
/// If nullptr, the runtime implementation will choose the work-group size.
1424+
const size_t *pLocalWorkSize,
1425+
/// [in] size of the event wait list
1426+
uint32_t numEventsInWaitList,
1427+
/// [in][optional][range(0, numEventsInWaitList)] pointer to a list of
1428+
/// events that must be complete before the kernel execution.
1429+
/// If nullptr, the numEventsInWaitList must be 0, indicating that no wait
1430+
/// event.
1431+
const ur_event_handle_t *phEventWaitList,
1432+
/// [out][optional][alloc] return an event object that identifies this
1433+
/// particular kernel execution instance. If phEventWaitList and phEvent
1434+
/// are not NULL, phEvent must not refer to an element of the
1435+
/// phEventWaitList array.
1436+
ur_event_handle_t *phEvent) {
1437+
1438+
getContext()->logger.debug("==== urEnqueueCooperativeKernelLaunchExp");
1439+
1440+
LaunchInfo LaunchInfo(GetContext(hQueue), GetDevice(hQueue), pGlobalWorkSize,
1441+
pLocalWorkSize, pGlobalWorkOffset, workDim);
1442+
UR_CALL(LaunchInfo.Data.syncToDevice(hQueue));
1443+
1444+
UR_CALL(getAsanInterceptor()->preLaunchKernel(hKernel, hQueue, LaunchInfo));
1445+
1446+
UR_CALL(getContext()->urDdiTable.EnqueueExp.pfnCooperativeKernelLaunchExp(
1447+
hQueue, hKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize,
1448+
LaunchInfo.LocalWorkSize.data(), numEventsInWaitList, phEventWaitList,
1449+
phEvent));
1450+
1451+
UR_CALL(getAsanInterceptor()->postLaunchKernel(hKernel, hQueue, LaunchInfo));
1452+
1453+
return UR_RESULT_SUCCESS;
1454+
}
1455+
14131456
///////////////////////////////////////////////////////////////////////////////
14141457
/// @brief Intercept function for urKernelRetain
14151458
__urdlllocal ur_result_t UR_APICALL urKernelRetain(
@@ -1952,6 +1995,25 @@ __urdlllocal ur_result_t UR_APICALL urGetDeviceProcAddrTable(
19521995
return result;
19531996
}
19541997

1998+
///////////////////////////////////////////////////////////////////////////////
1999+
/// @brief Exported function for filling application's EnqueueExp table
2000+
/// with current process' addresses
2001+
///
2002+
/// @returns
2003+
/// - ::UR_RESULT_SUCCESS
2004+
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
2005+
__urdlllocal ur_result_t UR_APICALL urGetEnqueueExpProcAddrTable(
2006+
/// [in,out] pointer to table of DDI function pointers
2007+
ur_enqueue_exp_dditable_t *pDdiTable) {
2008+
if (nullptr == pDdiTable) {
2009+
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
2010+
}
2011+
2012+
pDdiTable->pfnCooperativeKernelLaunchExp =
2013+
ur_sanitizer_layer::asan::urEnqueueCooperativeKernelLaunchExp;
2014+
return UR_RESULT_SUCCESS;
2015+
}
2016+
19552017
template <class A, class B> struct NotSupportedApi;
19562018

19572019
template <class MsgType, class R, class... A>
@@ -2147,6 +2209,11 @@ ur_result_t initAsanDDITable(ur_dditable_t *dditable) {
21472209
UR_API_VERSION_CURRENT, &dditable->VirtualMem);
21482210
}
21492211

2212+
if (UR_RESULT_SUCCESS == result) {
2213+
result = ur_sanitizer_layer::asan::urGetEnqueueExpProcAddrTable(
2214+
&dditable->EnqueueExp);
2215+
}
2216+
21502217
if (result != UR_RESULT_SUCCESS) {
21512218
getContext()->logger.error("Initialize ASAN DDI table failed: {}", result);
21522219
}

unified-runtime/source/loader/layers/sanitizer/msan/msan_ddi.cpp

Lines changed: 82 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -420,8 +420,6 @@ ur_result_t urEnqueueKernelLaunch(
420420
/// [out][optional] return an event object that identifies this
421421
/// particular kernel execution instance.
422422
ur_event_handle_t *phEvent) {
423-
auto pfnKernelLaunch = getContext()->urDdiTable.Enqueue.pfnKernelLaunch;
424-
425423
getContext()->logger.debug("==== urEnqueueKernelLaunch");
426424

427425
USMLaunchInfo LaunchInfo(GetContext(hQueue), GetDevice(hQueue),
@@ -431,22 +429,14 @@ ur_result_t urEnqueueKernelLaunch(
431429

432430
UR_CALL(getMsanInterceptor()->preLaunchKernel(hKernel, hQueue, LaunchInfo));
433431

434-
ur_event_handle_t hEvent{};
435-
ur_result_t result =
436-
pfnKernelLaunch(hQueue, hKernel, workDim, pGlobalWorkOffset,
437-
pGlobalWorkSize, LaunchInfo.LocalWorkSize.data(),
438-
numEventsInWaitList, phEventWaitList, &hEvent);
439-
440-
if (result == UR_RESULT_SUCCESS) {
441-
UR_CALL(
442-
getMsanInterceptor()->postLaunchKernel(hKernel, hQueue, LaunchInfo));
443-
}
432+
UR_CALL(getContext()->urDdiTable.Enqueue.pfnKernelLaunch(
433+
hQueue, hKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize,
434+
LaunchInfo.LocalWorkSize.data(), numEventsInWaitList, phEventWaitList,
435+
phEvent));
444436

445-
if (phEvent) {
446-
*phEvent = hEvent;
447-
}
437+
UR_CALL(getMsanInterceptor()->postLaunchKernel(hKernel, hQueue, LaunchInfo));
448438

449-
return result;
439+
return UR_RESULT_SUCCESS;
450440
}
451441

452442
///////////////////////////////////////////////////////////////////////////////
@@ -1323,6 +1313,58 @@ ur_result_t urEnqueueMemUnmap(
13231313
return UR_RESULT_SUCCESS;
13241314
}
13251315

1316+
ur_result_t UR_APICALL urEnqueueCooperativeKernelLaunchExp(
1317+
/// [in] handle of the queue object
1318+
ur_queue_handle_t hQueue,
1319+
/// [in] handle of the kernel object
1320+
ur_kernel_handle_t hKernel,
1321+
/// [in] number of dimensions, from 1 to 3, to specify the global and
1322+
/// work-group work-items
1323+
uint32_t workDim,
1324+
/// [in] pointer to an array of workDim unsigned values that specify the
1325+
/// offset used to calculate the global ID of a work-item
1326+
const size_t *pGlobalWorkOffset,
1327+
/// [in] pointer to an array of workDim unsigned values that specify the
1328+
/// number of global work-items in workDim that will execute the kernel
1329+
/// function
1330+
const size_t *pGlobalWorkSize,
1331+
/// [in][optional] pointer to an array of workDim unsigned values that
1332+
/// specify the number of local work-items forming a work-group that will
1333+
/// execute the kernel function.
1334+
/// If nullptr, the runtime implementation will choose the work-group size.
1335+
const size_t *pLocalWorkSize,
1336+
/// [in] size of the event wait list
1337+
uint32_t numEventsInWaitList,
1338+
/// [in][optional][range(0, numEventsInWaitList)] pointer to a list of
1339+
/// events that must be complete before the kernel execution.
1340+
/// If nullptr, the numEventsInWaitList must be 0, indicating that no wait
1341+
/// event.
1342+
const ur_event_handle_t *phEventWaitList,
1343+
/// [out][optional][alloc] return an event object that identifies this
1344+
/// particular kernel execution instance. If phEventWaitList and phEvent
1345+
/// are not NULL, phEvent must not refer to an element of the
1346+
/// phEventWaitList array.
1347+
ur_event_handle_t *phEvent) {
1348+
1349+
getContext()->logger.debug("==== urEnqueueCooperativeKernelLaunchExp");
1350+
1351+
USMLaunchInfo LaunchInfo(GetContext(hQueue), GetDevice(hQueue),
1352+
pGlobalWorkSize, pLocalWorkSize, pGlobalWorkOffset,
1353+
workDim);
1354+
UR_CALL(LaunchInfo.initialize());
1355+
1356+
UR_CALL(getMsanInterceptor()->preLaunchKernel(hKernel, hQueue, LaunchInfo));
1357+
1358+
UR_CALL(getContext()->urDdiTable.EnqueueExp.pfnCooperativeKernelLaunchExp(
1359+
hQueue, hKernel, workDim, pGlobalWorkOffset, pGlobalWorkSize,
1360+
LaunchInfo.LocalWorkSize.data(), numEventsInWaitList, phEventWaitList,
1361+
phEvent));
1362+
1363+
UR_CALL(getMsanInterceptor()->postLaunchKernel(hKernel, hQueue, LaunchInfo));
1364+
1365+
return UR_RESULT_SUCCESS;
1366+
}
1367+
13261368
///////////////////////////////////////////////////////////////////////////////
13271369
/// @brief Intercept function for urKernelRetain
13281370
ur_result_t urKernelRetain(
@@ -1912,6 +1954,25 @@ ur_result_t urCheckVersion(ur_api_version_t version) {
19121954
return UR_RESULT_SUCCESS;
19131955
}
19141956

1957+
///////////////////////////////////////////////////////////////////////////////
1958+
/// @brief Exported function for filling application's EnqueueExp table
1959+
/// with current process' addresses
1960+
///
1961+
/// @returns
1962+
/// - ::UR_RESULT_SUCCESS
1963+
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
1964+
__urdlllocal ur_result_t UR_APICALL urGetEnqueueExpProcAddrTable(
1965+
/// [in,out] pointer to table of DDI function pointers
1966+
ur_enqueue_exp_dditable_t *pDdiTable) {
1967+
if (nullptr == pDdiTable) {
1968+
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
1969+
}
1970+
1971+
pDdiTable->pfnCooperativeKernelLaunchExp =
1972+
ur_sanitizer_layer::msan::urEnqueueCooperativeKernelLaunchExp;
1973+
return UR_RESULT_SUCCESS;
1974+
}
1975+
19151976
} // namespace msan
19161977

19171978
ur_result_t initMsanDDITable(ur_dditable_t *dditable) {
@@ -1966,6 +2027,11 @@ ur_result_t initMsanDDITable(ur_dditable_t *dditable) {
19662027
result = ur_sanitizer_layer::msan::urGetUSMProcAddrTable(&dditable->USM);
19672028
}
19682029

2030+
if (UR_RESULT_SUCCESS == result) {
2031+
result = ur_sanitizer_layer::msan::urGetEnqueueExpProcAddrTable(
2032+
&dditable->EnqueueExp);
2033+
}
2034+
19692035
if (result != UR_RESULT_SUCCESS) {
19702036
getContext()->logger.error("Initialize MSAN DDI table failed: {}", result);
19712037
}

0 commit comments

Comments
 (0)