Skip to content

Commit 2af159d

Browse files
authored
Merge pull request #1808 from zhaomaosu/detect-memory-leak
[DeviceSanitizer] Support detecting memory leaks of USM
2 parents e8182b8 + 5653b30 commit 2af159d

File tree

5 files changed

+90
-17
lines changed

5 files changed

+90
-17
lines changed

source/loader/layers/sanitizer/asan_interceptor.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -947,6 +947,34 @@ SanitizerInterceptor::findAllocInfoByAddress(uptr Address) {
947947
return It;
948948
}
949949

950+
std::vector<AllocationIterator>
951+
SanitizerInterceptor::findAllocInfoByContext(ur_context_handle_t Context) {
952+
std::shared_lock<ur_shared_mutex> Guard(m_AllocationMapMutex);
953+
std::vector<AllocationIterator> AllocInfos;
954+
for (auto It = m_AllocationMap.begin(); It != m_AllocationMap.end(); It++) {
955+
const auto &[_, AI] = *It;
956+
if (AI->Context == Context) {
957+
AllocInfos.emplace_back(It);
958+
}
959+
}
960+
return AllocInfos;
961+
}
962+
963+
ContextInfo::~ContextInfo() {
964+
[[maybe_unused]] auto Result =
965+
getContext()->urDdiTable.Context.pfnRelease(Handle);
966+
assert(Result == UR_RESULT_SUCCESS);
967+
968+
std::vector<AllocationIterator> AllocInfos =
969+
getContext()->interceptor->findAllocInfoByContext(Handle);
970+
for (const auto &It : AllocInfos) {
971+
const auto &[_, AI] = *It;
972+
if (!AI->IsReleased) {
973+
ReportMemoryLeak(AI);
974+
}
975+
}
976+
}
977+
950978
ur_result_t USMLaunchInfo::initialize() {
951979
UR_CALL(getContext()->urDdiTable.Context.pfnRetain(Context));
952980
UR_CALL(getContext()->urDdiTable.Device.pfnRetain(Device));

source/loader/layers/sanitizer/asan_interceptor.hpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ struct DeviceInfo {
4545
// Device features
4646
bool IsSupportSharedSystemUSM = false;
4747

48+
// lock this mutex if following fields are accessed
4849
ur_mutex Mutex;
4950
std::queue<std::shared_ptr<AllocInfo>> Quarantine;
5051
size_t QuarantineSize = 0;
@@ -59,6 +60,7 @@ struct DeviceInfo {
5960
struct QueueInfo {
6061
ur_queue_handle_t Handle;
6162

63+
// lock this mutex if following fields are accessed
6264
ur_shared_mutex Mutex;
6365
ur_event_handle_t LastEvent;
6466

@@ -78,8 +80,10 @@ struct QueueInfo {
7880

7981
struct KernelInfo {
8082
ur_kernel_handle_t Handle;
81-
ur_shared_mutex Mutex;
8283
std::atomic<int32_t> RefCount = 1;
84+
85+
// lock this mutex if following fields are accessed
86+
ur_shared_mutex Mutex;
8387
std::unordered_map<uint32_t, std::shared_ptr<MemBuffer>> BufferArgs;
8488
std::unordered_map<uint32_t, std::pair<const void *, StackTrace>>
8589
PointerArgs;
@@ -102,6 +106,7 @@ struct KernelInfo {
102106

103107
struct ContextInfo {
104108
ur_context_handle_t Handle;
109+
std::atomic<int32_t> RefCount = 1;
105110

106111
std::vector<ur_device_handle_t> DeviceList;
107112
std::unordered_map<ur_device_handle_t, AllocInfoList> AllocInfosMap;
@@ -112,11 +117,7 @@ struct ContextInfo {
112117
assert(Result == UR_RESULT_SUCCESS);
113118
}
114119

115-
~ContextInfo() {
116-
[[maybe_unused]] auto Result =
117-
getContext()->urDdiTable.Context.pfnRelease(Handle);
118-
assert(Result == UR_RESULT_SUCCESS);
119-
}
120+
~ContextInfo();
120121

121122
void insertAllocInfo(const std::vector<ur_device_handle_t> &Devices,
122123
std::shared_ptr<AllocInfo> &AI) {
@@ -211,6 +212,9 @@ class SanitizerInterceptor {
211212

212213
std::optional<AllocationIterator> findAllocInfoByAddress(uptr Address);
213214

215+
std::vector<AllocationIterator>
216+
findAllocInfoByContext(ur_context_handle_t Context);
217+
214218
std::shared_ptr<ContextInfo> getContextInfo(ur_context_handle_t Context) {
215219
std::shared_lock<ur_shared_mutex> Guard(m_ContextMapMutex);
216220
assert(m_ContextMap.find(Context) != m_ContextMap.end());

source/loader/layers/sanitizer/asan_report.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,16 @@ void ReportDoubleFree(uptr Addr, const StackTrace &Stack,
7979
AI->AllocStack.print();
8080
}
8181

82+
void ReportMemoryLeak(const std::shared_ptr<AllocInfo> &AI) {
83+
getContext()->logger.always(
84+
"\n====ERROR: DeviceSanitizer: detected memory leaks of {}",
85+
ToString(AI->Type));
86+
getContext()->logger.always(
87+
"Direct leak of {} byte(s) at {} allocated from:",
88+
AI->UserEnd - AI->UserBegin, (void *)AI->UserBegin);
89+
AI->AllocStack.print();
90+
}
91+
8292
void ReportFatalError(const DeviceSanitizerReport &Report) {
8393
getContext()->logger.always("\n====ERROR: DeviceSanitizer: {}",
8494
ToString(Report.ErrorType));

source/loader/layers/sanitizer/asan_report.hpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,10 @@ void ReportBadContext(uptr Addr, const StackTrace &stack,
3232
void ReportDoubleFree(uptr Addr, const StackTrace &Stack,
3333
const std::shared_ptr<AllocInfo> &AllocInfo);
3434

35-
// This type of error is usually unexpected mistake and doesn't have enough debug information
35+
void ReportMemoryLeak(const std::shared_ptr<AllocInfo> &AI);
36+
37+
// This type of error is usually unexpected mistake and doesn't have enough
38+
// debug information
3639
void ReportFatalError(const DeviceSanitizerReport &Report);
3740

3841
void ReportGenericError(const DeviceSanitizerReport &Report,

source/loader/layers/sanitizer/ur_sanddi.cpp

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,29 @@ __urdlllocal ur_result_t UR_APICALL urContextCreateWithNativeHandle(
411411
return result;
412412
}
413413

414+
///////////////////////////////////////////////////////////////////////////////
415+
/// @brief Intercept function for urContextRetain
416+
__urdlllocal ur_result_t UR_APICALL urContextRetain(
417+
ur_context_handle_t
418+
hContext ///< [in] handle of the context to get a reference of.
419+
) {
420+
auto pfnRetain = getContext()->urDdiTable.Context.pfnRetain;
421+
422+
if (nullptr == pfnRetain) {
423+
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
424+
}
425+
426+
getContext()->logger.debug("==== urContextRetain");
427+
428+
UR_CALL(pfnRetain(hContext));
429+
430+
auto ContextInfo = getContext()->interceptor->getContextInfo(hContext);
431+
UR_ASSERT(ContextInfo != nullptr, UR_RESULT_ERROR_INVALID_VALUE);
432+
ContextInfo->RefCount++;
433+
434+
return UR_RESULT_SUCCESS;
435+
}
436+
414437
///////////////////////////////////////////////////////////////////////////////
415438
/// @brief Intercept function for urContextRelease
416439
__urdlllocal ur_result_t UR_APICALL urContextRelease(
@@ -424,10 +447,15 @@ __urdlllocal ur_result_t UR_APICALL urContextRelease(
424447

425448
getContext()->logger.debug("==== urContextRelease");
426449

427-
UR_CALL(getContext()->interceptor->eraseContext(hContext));
428-
ur_result_t result = pfnRelease(hContext);
450+
UR_CALL(pfnRelease(hContext));
429451

430-
return result;
452+
auto ContextInfo = getContext()->interceptor->getContextInfo(hContext);
453+
UR_ASSERT(ContextInfo != nullptr, UR_RESULT_ERROR_INVALID_VALUE);
454+
if (--ContextInfo->RefCount == 0) {
455+
UR_CALL(getContext()->interceptor->eraseContext(hContext));
456+
}
457+
458+
return UR_RESULT_SUCCESS;
431459
}
432460

433461
///////////////////////////////////////////////////////////////////////////////
@@ -1207,9 +1235,9 @@ __urdlllocal ur_result_t UR_APICALL urKernelRetain(
12071235

12081236
UR_CALL(pfnRetain(hKernel));
12091237

1210-
if (auto KernelInfo = getContext()->interceptor->getKernelInfo(hKernel)) {
1211-
KernelInfo->RefCount++;
1212-
}
1238+
auto KernelInfo = getContext()->interceptor->getKernelInfo(hKernel);
1239+
UR_ASSERT(KernelInfo != nullptr, UR_RESULT_ERROR_INVALID_VALUE);
1240+
KernelInfo->RefCount++;
12131241

12141242
return UR_RESULT_SUCCESS;
12151243
}
@@ -1228,10 +1256,9 @@ __urdlllocal ur_result_t urKernelRelease(
12281256
getContext()->logger.debug("==== urKernelRelease");
12291257
UR_CALL(pfnRelease(hKernel));
12301258

1231-
if (auto KernelInfo = getContext()->interceptor->getKernelInfo(hKernel)) {
1232-
if (--KernelInfo->RefCount != 0) {
1233-
return UR_RESULT_SUCCESS;
1234-
}
1259+
auto KernelInfo = getContext()->interceptor->getKernelInfo(hKernel);
1260+
UR_ASSERT(KernelInfo != nullptr, UR_RESULT_ERROR_INVALID_VALUE);
1261+
if (--KernelInfo->RefCount == 0) {
12351262
UR_CALL(getContext()->interceptor->eraseKernel(hKernel));
12361263
}
12371264

@@ -1426,6 +1453,7 @@ __urdlllocal ur_result_t UR_APICALL urGetContextProcAddrTable(
14261453
ur_result_t result = UR_RESULT_SUCCESS;
14271454

14281455
pDdiTable->pfnCreate = ur_sanitizer_layer::urContextCreate;
1456+
pDdiTable->pfnRetain = ur_sanitizer_layer::urContextRetain;
14291457
pDdiTable->pfnRelease = ur_sanitizer_layer::urContextRelease;
14301458

14311459
pDdiTable->pfnCreateWithNativeHandle =

0 commit comments

Comments
 (0)