Skip to content

Commit 0eb08b6

Browse files
authored
Merge pull request #2506 from AllanZyne/review/yang/fix_kernel_native
[DeviceASAN] Fix urKernelCreateWithNativeHandle segfault
2 parents 46bbad2 + b4c5f1f commit 0eb08b6

File tree

6 files changed

+85
-137
lines changed

6 files changed

+85
-137
lines changed

source/loader/layers/sanitizer/asan/asan_ddi.cpp

Lines changed: 17 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1335,28 +1335,6 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemUnmap(
13351335
return UR_RESULT_SUCCESS;
13361336
}
13371337

1338-
///////////////////////////////////////////////////////////////////////////////
1339-
/// @brief Intercept function for urKernelCreate
1340-
__urdlllocal ur_result_t UR_APICALL urKernelCreate(
1341-
ur_program_handle_t hProgram, ///< [in] handle of the program instance
1342-
const char *pKernelName, ///< [in] pointer to null-terminated string.
1343-
ur_kernel_handle_t
1344-
*phKernel ///< [out] pointer to handle of kernel object created.
1345-
) {
1346-
auto pfnCreate = getContext()->urDdiTable.Kernel.pfnCreate;
1347-
1348-
if (nullptr == pfnCreate) {
1349-
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
1350-
}
1351-
1352-
getContext()->logger.debug("==== urKernelCreate");
1353-
1354-
UR_CALL(pfnCreate(hProgram, pKernelName, phKernel));
1355-
UR_CALL(getAsanInterceptor()->insertKernel(*phKernel));
1356-
1357-
return UR_RESULT_SUCCESS;
1358-
}
1359-
13601338
///////////////////////////////////////////////////////////////////////////////
13611339
/// @brief Intercept function for urKernelRetain
13621340
__urdlllocal ur_result_t UR_APICALL urKernelRetain(
@@ -1372,8 +1350,8 @@ __urdlllocal ur_result_t UR_APICALL urKernelRetain(
13721350

13731351
UR_CALL(pfnRetain(hKernel));
13741352

1375-
auto KernelInfo = getAsanInterceptor()->getKernelInfo(hKernel);
1376-
KernelInfo->RefCount++;
1353+
auto &KernelInfo = getAsanInterceptor()->getOrCreateKernelInfo(hKernel);
1354+
KernelInfo.RefCount++;
13771355

13781356
return UR_RESULT_SUCCESS;
13791357
}
@@ -1392,9 +1370,9 @@ __urdlllocal ur_result_t urKernelRelease(
13921370
getContext()->logger.debug("==== urKernelRelease");
13931371
UR_CALL(pfnRelease(hKernel));
13941372

1395-
auto KernelInfo = getAsanInterceptor()->getKernelInfo(hKernel);
1396-
if (--KernelInfo->RefCount == 0) {
1397-
UR_CALL(getAsanInterceptor()->eraseKernel(hKernel));
1373+
auto &KernelInfo = getAsanInterceptor()->getOrCreateKernelInfo(hKernel);
1374+
if (--KernelInfo.RefCount == 0) {
1375+
UR_CALL(getAsanInterceptor()->eraseKernelInfo(hKernel));
13981376
}
13991377

14001378
return UR_RESULT_SUCCESS;
@@ -1423,9 +1401,9 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgValue(
14231401
if (argSize == sizeof(ur_mem_handle_t) &&
14241402
(MemBuffer = getAsanInterceptor()->getMemBuffer(
14251403
*ur_cast<const ur_mem_handle_t *>(pArgValue)))) {
1426-
auto KernelInfo = getAsanInterceptor()->getKernelInfo(hKernel);
1427-
std::scoped_lock<ur_shared_mutex> Guard(KernelInfo->Mutex);
1428-
KernelInfo->BufferArgs[argIndex] = std::move(MemBuffer);
1404+
auto &KernelInfo = getAsanInterceptor()->getOrCreateKernelInfo(hKernel);
1405+
std::scoped_lock<ur_shared_mutex> Guard(KernelInfo.Mutex);
1406+
KernelInfo.BufferArgs[argIndex] = std::move(MemBuffer);
14291407
} else {
14301408
UR_CALL(
14311409
pfnSetArgValue(hKernel, argIndex, argSize, pProperties, pArgValue));
@@ -1453,9 +1431,9 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgMemObj(
14531431

14541432
std::shared_ptr<MemBuffer> MemBuffer;
14551433
if ((MemBuffer = getAsanInterceptor()->getMemBuffer(hArgValue))) {
1456-
auto KernelInfo = getAsanInterceptor()->getKernelInfo(hKernel);
1457-
std::scoped_lock<ur_shared_mutex> Guard(KernelInfo->Mutex);
1458-
KernelInfo->BufferArgs[argIndex] = std::move(MemBuffer);
1434+
auto &KernelInfo = getAsanInterceptor()->getOrCreateKernelInfo(hKernel);
1435+
std::scoped_lock<ur_shared_mutex> Guard(KernelInfo.Mutex);
1436+
KernelInfo.BufferArgs[argIndex] = std::move(MemBuffer);
14591437
} else {
14601438
UR_CALL(pfnSetArgMemObj(hKernel, argIndex, pProperties, hArgValue));
14611439
}
@@ -1484,12 +1462,12 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgLocal(
14841462
argSize);
14851463

14861464
{
1487-
auto KI = getAsanInterceptor()->getKernelInfo(hKernel);
1488-
std::scoped_lock<ur_shared_mutex> Guard(KI->Mutex);
1465+
auto &KI = getAsanInterceptor()->getOrCreateKernelInfo(hKernel);
1466+
std::scoped_lock<ur_shared_mutex> Guard(KI.Mutex);
14891467
// TODO: get local variable alignment
14901468
auto argSizeWithRZ = GetSizeAndRedzoneSizeForLocal(
14911469
argSize, ASAN_SHADOW_GRANULARITY, ASAN_SHADOW_GRANULARITY);
1492-
KI->LocalArgs[argIndex] = LocalArgsInfo{argSize, argSizeWithRZ};
1470+
KI.LocalArgs[argIndex] = LocalArgsInfo{argSize, argSizeWithRZ};
14931471
argSize = argSizeWithRZ;
14941472
}
14951473

@@ -1522,9 +1500,9 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgPointer(
15221500

15231501
std::shared_ptr<KernelInfo> KI;
15241502
if (getAsanInterceptor()->getOptions().DetectKernelArguments) {
1525-
auto KI = getAsanInterceptor()->getKernelInfo(hKernel);
1526-
std::scoped_lock<ur_shared_mutex> Guard(KI->Mutex);
1527-
KI->PointerArgs[argIndex] = {pArgValue, GetCurrentBacktrace()};
1503+
auto &KI = getAsanInterceptor()->getOrCreateKernelInfo(hKernel);
1504+
std::scoped_lock<ur_shared_mutex> Guard(KI.Mutex);
1505+
KI.PointerArgs[argIndex] = {pArgValue, GetCurrentBacktrace()};
15281506
}
15291507

15301508
ur_result_t result =
@@ -1708,7 +1686,6 @@ __urdlllocal ur_result_t UR_APICALL urGetKernelProcAddrTable(
17081686

17091687
ur_result_t result = UR_RESULT_SUCCESS;
17101688

1711-
pDdiTable->pfnCreate = ur_sanitizer_layer::asan::urKernelCreate;
17121689
pDdiTable->pfnRetain = ur_sanitizer_layer::asan::urKernelRetain;
17131690
pDdiTable->pfnRelease = ur_sanitizer_layer::asan::urKernelRelease;
17141691
pDdiTable->pfnSetArgValue = ur_sanitizer_layer::asan::urKernelSetArgValue;

source/loader/layers/sanitizer/asan/asan_interceptor.cpp

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -639,22 +639,26 @@ ur_result_t AsanInterceptor::eraseProgram(ur_program_handle_t Program) {
639639
return UR_RESULT_SUCCESS;
640640
}
641641

642-
ur_result_t AsanInterceptor::insertKernel(ur_kernel_handle_t Kernel) {
643-
std::scoped_lock<ur_shared_mutex> Guard(m_KernelMapMutex);
644-
if (m_KernelMap.find(Kernel) != m_KernelMap.end()) {
645-
return UR_RESULT_SUCCESS;
642+
KernelInfo &AsanInterceptor::getOrCreateKernelInfo(ur_kernel_handle_t Kernel) {
643+
{
644+
std::shared_lock<ur_shared_mutex> Guard(m_KernelMapMutex);
645+
if (m_KernelMap.find(Kernel) != m_KernelMap.end()) {
646+
return *m_KernelMap[Kernel].get();
647+
}
646648
}
647649

648-
auto hProgram = GetProgram(Kernel);
649-
auto PI = getAsanInterceptor()->getProgramInfo(hProgram);
650+
// Create new KernelInfo
651+
auto Program = GetProgram(Kernel);
652+
auto PI = getProgramInfo(Program);
650653
bool IsInstrumented = PI->isKernelInstrumented(Kernel);
651654

655+
std::scoped_lock<ur_shared_mutex> Guard(m_KernelMapMutex);
652656
m_KernelMap.emplace(Kernel,
653-
std::make_shared<KernelInfo>(Kernel, IsInstrumented));
654-
return UR_RESULT_SUCCESS;
657+
std::make_unique<KernelInfo>(Kernel, IsInstrumented));
658+
return *m_KernelMap[Kernel].get();
655659
}
656660

657-
ur_result_t AsanInterceptor::eraseKernel(ur_kernel_handle_t Kernel) {
661+
ur_result_t AsanInterceptor::eraseKernelInfo(ur_kernel_handle_t Kernel) {
658662
std::scoped_lock<ur_shared_mutex> Guard(m_KernelMapMutex);
659663
assert(m_KernelMap.find(Kernel) != m_KernelMap.end());
660664
m_KernelMap.erase(Kernel);
@@ -691,7 +695,8 @@ ur_result_t AsanInterceptor::prepareLaunch(
691695
std::shared_ptr<ContextInfo> &ContextInfo,
692696
std::shared_ptr<DeviceInfo> &DeviceInfo, ur_queue_handle_t Queue,
693697
ur_kernel_handle_t Kernel, LaunchInfo &LaunchInfo) {
694-
auto KernelInfo = getKernelInfo(Kernel);
698+
auto &KernelInfo = getOrCreateKernelInfo(Kernel);
699+
std::shared_lock<ur_shared_mutex> Guard(KernelInfo.Mutex);
695700

696701
auto ArgNums = GetKernelNumArgs(Kernel);
697702
auto LocalMemoryUsage =
@@ -703,11 +708,11 @@ ur_result_t AsanInterceptor::prepareLaunch(
703708
"KernelInfo {} (Name={}, ArgNums={}, IsInstrumented={}, "
704709
"LocalMemory={}, PrivateMemory={})",
705710
(void *)Kernel, GetKernelName(Kernel), ArgNums,
706-
KernelInfo->IsInstrumented, LocalMemoryUsage, PrivateMemoryUsage);
711+
KernelInfo.IsInstrumented, LocalMemoryUsage, PrivateMemoryUsage);
707712

708713
// Validate pointer arguments
709714
if (getOptions().DetectKernelArguments) {
710-
for (const auto &[ArgIndex, PtrPair] : KernelInfo->PointerArgs) {
715+
for (const auto &[ArgIndex, PtrPair] : KernelInfo.PointerArgs) {
711716
auto Ptr = PtrPair.first;
712717
if (Ptr == nullptr) {
713718
continue;
@@ -722,7 +727,7 @@ ur_result_t AsanInterceptor::prepareLaunch(
722727
}
723728

724729
// Set membuffer arguments
725-
for (const auto &[ArgIndex, MemBuffer] : KernelInfo->BufferArgs) {
730+
for (const auto &[ArgIndex, MemBuffer] : KernelInfo.BufferArgs) {
726731
char *ArgPointer = nullptr;
727732
UR_CALL(MemBuffer->getHandle(DeviceInfo->Handle, ArgPointer));
728733
ur_result_t URes = getContext()->urDdiTable.Kernel.pfnSetArgPointer(
@@ -735,7 +740,7 @@ ur_result_t AsanInterceptor::prepareLaunch(
735740
}
736741
}
737742

738-
if (!KernelInfo->IsInstrumented) {
743+
if (!KernelInfo.IsInstrumented) {
739744
return UR_RESULT_SUCCESS;
740745
}
741746

@@ -830,9 +835,9 @@ ur_result_t AsanInterceptor::prepareLaunch(
830835
}
831836

832837
// Write local arguments info
833-
if (!KernelInfo->LocalArgs.empty()) {
838+
if (!KernelInfo.LocalArgs.empty()) {
834839
std::vector<LocalArgsInfo> LocalArgsInfo;
835-
for (auto [ArgIndex, ArgInfo] : KernelInfo->LocalArgs) {
840+
for (auto [ArgIndex, ArgInfo] : KernelInfo.LocalArgs) {
836841
LocalArgsInfo.push_back(ArgInfo);
837842
getContext()->logger.debug(
838843
"local_args (argIndex={}, size={}, sizeWithRZ={})", ArgIndex,

source/loader/layers/sanitizer/asan/asan_interceptor.hpp

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -308,9 +308,6 @@ class AsanInterceptor {
308308
ur_result_t insertProgram(ur_program_handle_t Program);
309309
ur_result_t eraseProgram(ur_program_handle_t Program);
310310

311-
ur_result_t insertKernel(ur_kernel_handle_t Kernel);
312-
ur_result_t eraseKernel(ur_kernel_handle_t Kernel);
313-
314311
ur_result_t insertMemBuffer(std::shared_ptr<MemBuffer> MemBuffer);
315312
ur_result_t eraseMemBuffer(ur_mem_handle_t MemHandle);
316313
std::shared_ptr<MemBuffer> getMemBuffer(ur_mem_handle_t MemHandle);
@@ -350,11 +347,8 @@ class AsanInterceptor {
350347
return nullptr;
351348
}
352349

353-
std::shared_ptr<KernelInfo> getKernelInfo(ur_kernel_handle_t Kernel) {
354-
std::shared_lock<ur_shared_mutex> Guard(m_KernelMapMutex);
355-
assert(m_KernelMap.find(Kernel) != m_KernelMap.end());
356-
return m_KernelMap[Kernel];
357-
}
350+
KernelInfo &getOrCreateKernelInfo(ur_kernel_handle_t Kernel);
351+
ur_result_t eraseKernelInfo(ur_kernel_handle_t Kernel);
358352

359353
const AsanOptions &getOptions() { return m_Options; }
360354

@@ -401,7 +395,7 @@ class AsanInterceptor {
401395
m_ProgramMap;
402396
ur_shared_mutex m_ProgramMapMutex;
403397

404-
std::unordered_map<ur_kernel_handle_t, std::shared_ptr<KernelInfo>>
398+
std::unordered_map<ur_kernel_handle_t, std::unique_ptr<KernelInfo>>
405399
m_KernelMap;
406400
ur_shared_mutex m_KernelMapMutex;
407401

source/loader/layers/sanitizer/msan/msan_ddi.cpp

Lines changed: 13 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,6 @@ ur_result_t setupContext(ur_context_handle_t Context, uint32_t numDevices,
5050
return UR_RESULT_SUCCESS;
5151
}
5252

53-
bool isInstrumentedKernel(ur_kernel_handle_t hKernel) {
54-
auto hProgram = GetProgram(hKernel);
55-
auto PI = getMsanInterceptor()->getProgramInfo(hProgram);
56-
return PI->isKernelInstrumented(hKernel);
57-
}
58-
5953
} // namespace
6054

6155
///////////////////////////////////////////////////////////////////////////////
@@ -354,12 +348,6 @@ ur_result_t urEnqueueKernelLaunch(
354348

355349
getContext()->logger.debug("==== urEnqueueKernelLaunch");
356350

357-
if (!isInstrumentedKernel(hKernel)) {
358-
return pfnKernelLaunch(hQueue, hKernel, workDim, pGlobalWorkOffset,
359-
pGlobalWorkSize, pLocalWorkSize,
360-
numEventsInWaitList, phEventWaitList, phEvent);
361-
}
362-
363351
USMLaunchInfo LaunchInfo(GetContext(hQueue), GetDevice(hQueue),
364352
pGlobalWorkSize, pLocalWorkSize, pGlobalWorkOffset,
365353
workDim);
@@ -1155,26 +1143,6 @@ ur_result_t urEnqueueMemUnmap(
11551143
return UR_RESULT_SUCCESS;
11561144
}
11571145

1158-
///////////////////////////////////////////////////////////////////////////////
1159-
/// @brief Intercept function for urKernelCreate
1160-
ur_result_t urKernelCreate(
1161-
ur_program_handle_t hProgram, ///< [in] handle of the program instance
1162-
const char *pKernelName, ///< [in] pointer to null-terminated string.
1163-
ur_kernel_handle_t
1164-
*phKernel ///< [out] pointer to handle of kernel object created.
1165-
) {
1166-
auto pfnCreate = getContext()->urDdiTable.Kernel.pfnCreate;
1167-
1168-
getContext()->logger.debug("==== urKernelCreate");
1169-
1170-
UR_CALL(pfnCreate(hProgram, pKernelName, phKernel));
1171-
if (isInstrumentedKernel(*phKernel)) {
1172-
UR_CALL(getMsanInterceptor()->insertKernel(*phKernel));
1173-
}
1174-
1175-
return UR_RESULT_SUCCESS;
1176-
}
1177-
11781146
///////////////////////////////////////////////////////////////////////////////
11791147
/// @brief Intercept function for urKernelRetain
11801148
ur_result_t urKernelRetain(
@@ -1186,10 +1154,8 @@ ur_result_t urKernelRetain(
11861154

11871155
UR_CALL(pfnRetain(hKernel));
11881156

1189-
auto KernelInfo = getMsanInterceptor()->getKernelInfo(hKernel);
1190-
if (KernelInfo) {
1191-
KernelInfo->RefCount++;
1192-
}
1157+
auto &KernelInfo = getMsanInterceptor()->getOrCreateKernelInfo(hKernel);
1158+
KernelInfo.RefCount++;
11931159

11941160
return UR_RESULT_SUCCESS;
11951161
}
@@ -1204,11 +1170,9 @@ ur_result_t urKernelRelease(
12041170
getContext()->logger.debug("==== urKernelRelease");
12051171
UR_CALL(pfnRelease(hKernel));
12061172

1207-
auto KernelInfo = getMsanInterceptor()->getKernelInfo(hKernel);
1208-
if (KernelInfo) {
1209-
if (--KernelInfo->RefCount == 0) {
1210-
UR_CALL(getMsanInterceptor()->eraseKernel(hKernel));
1211-
}
1173+
auto &KernelInfo = getMsanInterceptor()->getOrCreateKernelInfo(hKernel);
1174+
if (--KernelInfo.RefCount == 0) {
1175+
UR_CALL(getMsanInterceptor()->eraseKernelInfo(hKernel));
12121176
}
12131177

12141178
return UR_RESULT_SUCCESS;
@@ -1230,13 +1194,12 @@ ur_result_t urKernelSetArgValue(
12301194
getContext()->logger.debug("==== urKernelSetArgValue");
12311195

12321196
std::shared_ptr<MemBuffer> MemBuffer;
1233-
std::shared_ptr<KernelInfo> KernelInfo;
12341197
if (argSize == sizeof(ur_mem_handle_t) &&
12351198
(MemBuffer = getMsanInterceptor()->getMemBuffer(
1236-
*ur_cast<const ur_mem_handle_t *>(pArgValue))) &&
1237-
(KernelInfo = getMsanInterceptor()->getKernelInfo(hKernel))) {
1238-
std::scoped_lock<ur_shared_mutex> Guard(KernelInfo->Mutex);
1239-
KernelInfo->BufferArgs[argIndex] = std::move(MemBuffer);
1199+
*ur_cast<const ur_mem_handle_t *>(pArgValue)))) {
1200+
auto &KernelInfo = getMsanInterceptor()->getOrCreateKernelInfo(hKernel);
1201+
std::scoped_lock<ur_shared_mutex> Guard(KernelInfo.Mutex);
1202+
KernelInfo.BufferArgs[argIndex] = std::move(MemBuffer);
12401203
} else {
12411204
UR_CALL(
12421205
pfnSetArgValue(hKernel, argIndex, argSize, pProperties, pArgValue));
@@ -1260,10 +1223,10 @@ ur_result_t urKernelSetArgMemObj(
12601223

12611224
std::shared_ptr<MemBuffer> MemBuffer;
12621225
std::shared_ptr<KernelInfo> KernelInfo;
1263-
if ((MemBuffer = getMsanInterceptor()->getMemBuffer(hArgValue)) &&
1264-
(KernelInfo = getMsanInterceptor()->getKernelInfo(hKernel))) {
1265-
std::scoped_lock<ur_shared_mutex> Guard(KernelInfo->Mutex);
1266-
KernelInfo->BufferArgs[argIndex] = std::move(MemBuffer);
1226+
if ((MemBuffer = getMsanInterceptor()->getMemBuffer(hArgValue))) {
1227+
auto &KernelInfo = getMsanInterceptor()->getOrCreateKernelInfo(hKernel);
1228+
std::scoped_lock<ur_shared_mutex> Guard(KernelInfo.Mutex);
1229+
KernelInfo.BufferArgs[argIndex] = std::move(MemBuffer);
12671230
} else {
12681231
UR_CALL(pfnSetArgMemObj(hKernel, argIndex, pProperties, hArgValue));
12691232
}
@@ -1348,7 +1311,6 @@ ur_result_t urGetKernelProcAddrTable(
13481311
) {
13491312
ur_result_t result = UR_RESULT_SUCCESS;
13501313

1351-
pDdiTable->pfnCreate = ur_sanitizer_layer::msan::urKernelCreate;
13521314
pDdiTable->pfnRetain = ur_sanitizer_layer::msan::urKernelRetain;
13531315
pDdiTable->pfnRelease = ur_sanitizer_layer::msan::urKernelRelease;
13541316
pDdiTable->pfnSetArgValue = ur_sanitizer_layer::msan::urKernelSetArgValue;

0 commit comments

Comments
 (0)