Skip to content

Commit 0b783c5

Browse files
committed
save
1 parent ad288bb commit 0b783c5

8 files changed

+145
-11
lines changed

source/loader/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,8 @@ if(UR_ENABLE_SANITIZER)
179179
${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/ur_sanddi.cpp
180180
${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/ur_sanitizer_layer.cpp
181181
${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/ur_sanitizer_layer.hpp
182+
${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/ur_obj_handler.cpp
183+
${CMAKE_CURRENT_SOURCE_DIR}/layers/sanitizer/ur_obj_handler.hpp
182184
)
183185

184186
if(UR_ENABLE_SYMBOLIZER)

source/loader/layers/sanitizer/asan/asan_ddi.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,7 @@ __urdlllocal ur_result_t UR_APICALL urContextCreate(
509509
pfnCreate(numDevices, phDevices, pProperties, phContext);
510510

511511
if (result == UR_RESULT_SUCCESS) {
512+
getContext()->objectHandler.add(*phContext);
512513
UR_CALL(setupContext(*phContext, numDevices, phDevices));
513514
}
514515

@@ -543,6 +544,7 @@ __urdlllocal ur_result_t UR_APICALL urContextCreateWithNativeHandle(
543544
phDevices, pProperties, phContext);
544545

545546
if (result == UR_RESULT_SUCCESS) {
547+
getContext()->objectHandler.add(*phContext);
546548
UR_CALL(setupContext(*phContext, numDevices, phDevices));
547549
}
548550

@@ -563,7 +565,8 @@ __urdlllocal ur_result_t UR_APICALL urContextRetain(
563565

564566
getContext()->logger.debug("==== urContextRetain");
565567

566-
UR_CALL(pfnRetain(hContext));
568+
// UR_CALL(pfnRetain(hContext));
569+
UR_CALL(getContext()->objectHandler.retain(hContext));
567570

568571
auto ContextInfo = getAsanInterceptor()->getContextInfo(hContext);
569572
UR_ASSERT(ContextInfo != nullptr, UR_RESULT_ERROR_INVALID_VALUE);
@@ -585,7 +588,8 @@ __urdlllocal ur_result_t UR_APICALL urContextRelease(
585588

586589
getContext()->logger.debug("==== urContextRelease");
587590

588-
UR_CALL(pfnRelease(hContext));
591+
// UR_CALL(pfnRelease(hContext));
592+
UR_CALL(getContext()->objectHandler.release(hContext));
589593

590594
auto ContextInfo = getAsanInterceptor()->getContextInfo(hContext);
591595
UR_ASSERT(ContextInfo != nullptr, UR_RESULT_ERROR_INVALID_VALUE);
@@ -2037,6 +2041,8 @@ ur_result_t initAsanDDITable(ur_dditable_t *dditable) {
20372041

20382042
getContext()->logger.always("==== DeviceSanitizer: ASAN");
20392043

2044+
getContext()->objectHandler.installDdiTable(dditable);
2045+
20402046
if (UR_RESULT_SUCCESS == result) {
20412047
result = ur_sanitizer_layer::asan::urGetGlobalProcAddrTable(
20422048
UR_API_VERSION_CURRENT, &dditable->Global);

source/loader/layers/sanitizer/asan/asan_interceptor.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -896,7 +896,8 @@ ContextInfo::~ContextInfo() {
896896
assert(URes == UR_RESULT_SUCCESS);
897897
}
898898

899-
URes = getContext()->urDdiTable.Context.pfnRelease(Handle);
899+
// URes = getContext()->urDdiTable.Context.pfnRelease(Handle);
900+
URes = getContext()->objectHandler.release(Handle);
900901
assert(URes == UR_RESULT_SUCCESS);
901902

902903
// check memory leaks
@@ -944,7 +945,8 @@ AsanRuntimeDataWrapper::~AsanRuntimeDataWrapper() {
944945

945946
LaunchInfo::~LaunchInfo() {
946947
[[maybe_unused]] ur_result_t Result;
947-
Result = getContext()->urDdiTable.Context.pfnRelease(Context);
948+
// Result = getContext()->urDdiTable.Context.pfnRelease(Context);
949+
Result = getContext()->objectHandler.release(Context);
948950
assert(Result == UR_RESULT_SUCCESS);
949951
Result = getContext()->urDdiTable.Device.pfnRelease(Device);
950952
assert(Result == UR_RESULT_SUCCESS);

source/loader/layers/sanitizer/asan/asan_interceptor.hpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,7 @@ struct ContextInfo {
148148
AsanStatsWrapper Stats;
149149

150150
explicit ContextInfo(ur_context_handle_t Context) : Handle(Context) {
151-
[[maybe_unused]] auto Result =
152-
getContext()->urDdiTable.Context.pfnRetain(Context);
151+
[[maybe_unused]] auto Result = getContext()->objectHandler.retain(Handle);
153152
assert(Result == UR_RESULT_SUCCESS);
154153
}
155154

@@ -252,9 +251,9 @@ struct LaunchInfo {
252251
this->LocalWorkSize =
253252
std::vector<size_t>(LocalWorkSize, LocalWorkSize + WorkDim);
254253
}
255-
[[maybe_unused]] auto Result =
256-
getContext()->urDdiTable.Context.pfnRetain(Context);
254+
[[maybe_unused]] auto Result = getContext()->objectHandler.retain(Context);
257255
assert(Result == UR_RESULT_SUCCESS);
256+
258257
Result = getContext()->urDdiTable.Device.pfnRetain(Device);
259258
assert(Result == UR_RESULT_SUCCESS);
260259
}

source/loader/layers/sanitizer/asan/asan_shadow.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ namespace asan {
2222
std::shared_ptr<ShadowMemory> GetShadowMemory(ur_context_handle_t Context,
2323
ur_device_handle_t Device,
2424
DeviceType Type) {
25+
getContext()->objectHandler.use(Context);
2526
if (Type == DeviceType::CPU) {
2627
static std::shared_ptr<ShadowMemory> ShadowCPU =
2728
std::make_shared<ShadowMemoryCPU>(Context, Device);
@@ -109,6 +110,8 @@ ur_result_t ShadowMemoryGPU::Setup() {
109110
// the SVM range, so that GFX driver will automatically switch to reservation on the GPU
110111
// heap.
111112
const void *StartAddress = (void *)(0x100'0000'0000'0000ULL);
113+
114+
getContext()->objectHandler.use(Context);
112115
// TODO: Protect Bad Zone
113116
auto Result = getContext()->urDdiTable.VirtualMem.pfnReserve(
114117
Context, StartAddress, ShadowSize, (void **)&ShadowBegin);
@@ -120,7 +123,8 @@ ur_result_t ShadowMemoryGPU::Setup() {
120123
}
121124
ShadowEnd = ShadowBegin + ShadowSize;
122125
// Retain the context which reserves shadow memory
123-
getContext()->urDdiTable.Context.pfnRetain(Context);
126+
// getContext()->urDdiTable.Context.pfnRetain(Context);
127+
getContext()->objectHandler.retain(Context);
124128

125129
// Set shadow memory for null pointer
126130
// For GPU, wu use up to 1 page of shadow memory
@@ -147,6 +151,7 @@ ur_result_t ShadowMemoryGPU::Destory() {
147151
}
148152

149153
static ur_result_t Result = [this]() {
154+
getContext()->objectHandler.use(Context);
150155
const size_t PageSize = GetVirtualMemGranularity(Context, Device);
151156
for (auto [MappedPtr, PhysicalMem] : VirtualMemMaps) {
152157
UR_CALL(getContext()->urDdiTable.VirtualMem.pfnUnmap(
@@ -156,7 +161,8 @@ ur_result_t ShadowMemoryGPU::Destory() {
156161
}
157162
UR_CALL(getContext()->urDdiTable.VirtualMem.pfnFree(
158163
Context, (const void *)ShadowBegin, GetShadowSize()));
159-
UR_CALL(getContext()->urDdiTable.Context.pfnRelease(Context));
164+
// UR_CALL(getContext()->urDdiTable.Context.pfnRelease(Context));
165+
UR_CALL(getContext()->objectHandler.release(Context));
160166
return UR_RESULT_SUCCESS;
161167
}();
162168
if (!Result) {
@@ -171,7 +177,8 @@ ur_result_t ShadowMemoryGPU::Destory() {
171177
if (ShadowBegin != 0) {
172178
UR_CALL(getContext()->urDdiTable.VirtualMem.pfnFree(
173179
Context, (const void *)ShadowBegin, GetShadowSize()));
174-
UR_CALL(getContext()->urDdiTable.Context.pfnRelease(Context));
180+
// UR_CALL(getContext()->urDdiTable.Context.pfnRelease(Context));
181+
UR_CALL(getContext()->objectHandler.release(Context));
175182
ShadowBegin = ShadowEnd = 0;
176183
}
177184
return UR_RESULT_SUCCESS;

source/loader/layers/sanitizer/ur_obj_handler.cpp

Whitespace-only changes.
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
/*
2+
The UrObjectHandler is intend to provide a global maps for all UrObjects and corresponding XXXInfo objects that used with sanitizer layers.
3+
Also, it provides a checker that checks for UrObjects' status to avoid any use-after-released cases.
4+
*/
5+
6+
/**
7+
* 20250107: first we impl this as a checker to check for use-after-released cases.
8+
*/
9+
10+
#include "ur/ur.hpp"
11+
#include "ur_api.h"
12+
13+
#include <atomic>
14+
#include <cassert>
15+
#include <unordered_map>
16+
#include <variant>
17+
18+
#pragma once
19+
20+
namespace ur_sanitizer_layer {
21+
22+
typedef std::variant<ur_context_handle_t, ur_device_handle_t> UrObjectT;
23+
24+
class UrObjectHandler {
25+
public:
26+
void add(UrObjectT urObject) {
27+
std::scoped_lock<ur_shared_mutex> Guard(urObjectStatusMapMutex);
28+
// if (urObjectStatusMap.find(urObject) != urObjectStatusMap.end()) {
29+
// if (urObjectStatusMap[urObject].refCount > 0) {
30+
// assert(false && "Add of a exist object");
31+
// } else {
32+
// // remove an old object
33+
// ; // Nothing to do for now as we only do ref-counting
34+
// }
35+
// }
36+
assert(urObjectStatusMap.find(urObject) == urObjectStatusMap.end() &&
37+
"Add of a exist object");
38+
std::ignore = urObjectStatusMap[urObject];
39+
}
40+
41+
ur_result_t retain(UrObjectT urObject) {
42+
assert(ddiTableInstalled && "DdiTable is not installed");
43+
assert(urObjectStatusMap.find(urObject) != urObjectStatusMap.end() &&
44+
"Retain of a nonexistent object");
45+
urObjectStatusMap[urObject].retain();
46+
47+
if (std::holds_alternative<ur_context_handle_t>(urObject)) {
48+
return urDdiTable.Context.pfnRetain(
49+
std::get<ur_context_handle_t>(urObject));
50+
} else if (std::holds_alternative<ur_device_handle_t>(urObject)) {
51+
return urDdiTable.Device.pfnRetain(
52+
std::get<ur_device_handle_t>(urObject));
53+
}
54+
assert(false && "Abonomal object type");
55+
return UR_RESULT_SUCCESS;
56+
}
57+
58+
ur_result_t release(UrObjectT urObject) {
59+
assert(ddiTableInstalled && "DdiTable is not installed");
60+
assert(urObjectStatusMap.find(urObject) != urObjectStatusMap.end() &&
61+
"Release of a nonexistent object");
62+
urObjectStatusMap[urObject].release();
63+
if (urObjectStatusMap[urObject].refCount == 0) {
64+
std::scoped_lock<ur_shared_mutex> Guard(urObjectStatusMapMutex);
65+
urObjectStatusMap.erase(urObject);
66+
}
67+
68+
if (std::holds_alternative<ur_context_handle_t>(urObject)) {
69+
return urDdiTable.Context.pfnRelease(
70+
std::get<ur_context_handle_t>(urObject));
71+
} else if (std::holds_alternative<ur_device_handle_t>(urObject)) {
72+
return urDdiTable.Device.pfnRelease(
73+
std::get<ur_device_handle_t>(urObject));
74+
}
75+
assert(false && "Abonomal object type");
76+
return UR_RESULT_SUCCESS;
77+
}
78+
79+
bool use(UrObjectT urObject) {
80+
assert(urObjectStatusMap.find(urObject) != urObjectStatusMap.end() &&
81+
"Use of a nonexistent object");
82+
return urObjectStatusMap[urObject].use();
83+
}
84+
85+
void installDdiTable(ur_dditable_t *dditable) {
86+
urDdiTable = *dditable;
87+
ddiTableInstalled = true;
88+
}
89+
90+
~UrObjectHandler() {
91+
// Check for not released objects
92+
}
93+
94+
private:
95+
struct UrObjectInfo {
96+
UrObjectInfo() : refCount(1) {}
97+
std::atomic<int> refCount;
98+
99+
void retain() { refCount += 1; }
100+
void release() {
101+
assert(refCount > 0 && "Release of a invalid object");
102+
refCount -= 1;
103+
}
104+
bool use() {
105+
assert(refCount > 0 && "Use of a invalid object");
106+
return refCount > 0;
107+
}
108+
};
109+
110+
ur_dditable_t urDdiTable;
111+
bool ddiTableInstalled = false;
112+
ur_shared_mutex urObjectStatusMapMutex;
113+
std::unordered_map<UrObjectT, UrObjectInfo> urObjectStatusMap;
114+
};
115+
116+
} // namespace ur_sanitizer_layer

source/loader/layers/sanitizer/ur_sanitizer_layer.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "logger/ur_logger.hpp"
1616
#include "ur_proxy_layer.hpp"
17+
#include "ur_obj_handler.hpp"
1718

1819
#define SANITIZER_COMP_NAME "sanitizer layer"
1920

@@ -33,6 +34,7 @@ class __urdlllocal context_t : public proxy_layer_context_t,
3334
ur_dditable_t urDdiTable = {};
3435
logger::Logger logger;
3536
SanitizerType enabledType = SanitizerType::None;
37+
UrObjectHandler objectHandler;
3638

3739
context_t();
3840
~context_t();

0 commit comments

Comments
 (0)