Skip to content

Commit 25225fd

Browse files
committed
[UR] Expand lifetime validation with
... handle type checks.
1 parent dca6c88 commit 25225fd

File tree

3 files changed

+60
-23
lines changed

3 files changed

+60
-23
lines changed

source/loader/layers/validation/ur_leak_check.hpp

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "ur_validation_layer.hpp"
1010

1111
#include <mutex>
12+
#include <typeindex>
1213
#include <unordered_map>
1314
#include <utility>
1415

@@ -20,7 +21,12 @@ struct RefCountContext {
2021
private:
2122
struct RefRuntimeInfo {
2223
int64_t refCount;
24+
std::type_index type;
2325
std::vector<BacktraceLine> backtrace;
26+
27+
RefRuntimeInfo(int64_t refCount, std::type_index type,
28+
std::vector<BacktraceLine> backtrace)
29+
: refCount(refCount), type(type), backtrace(backtrace) {}
2430
};
2531

2632
enum RefCountUpdateType {
@@ -34,26 +40,32 @@ struct RefCountContext {
3440
std::unordered_map<void *, struct RefRuntimeInfo> counts;
3541
int64_t adapterCount = 0;
3642

37-
void updateRefCount(void *ptr, enum RefCountUpdateType type,
43+
template <typename T>
44+
void updateRefCount(T handle, enum RefCountUpdateType type,
3845
bool isAdapterHandle = false) {
3946
std::unique_lock<std::mutex> ulock(mutex);
4047

48+
void *ptr = static_cast<void *>(handle);
4149
auto it = counts.find(ptr);
4250

4351
switch (type) {
4452
case REFCOUNT_CREATE_OR_INCREASE:
4553
if (it == counts.end()) {
46-
counts[ptr] = {1, getCurrentBacktrace()};
54+
std::tie(it, std::ignore) = counts.emplace(
55+
ptr, RefRuntimeInfo{1, std::type_index(typeid(handle)),
56+
getCurrentBacktrace()});
4757
if (isAdapterHandle) {
4858
adapterCount++;
4959
}
5060
} else {
51-
counts[ptr].refCount++;
61+
it->second.refCount++;
5262
}
5363
break;
5464
case REFCOUNT_CREATE:
5565
if (it == counts.end()) {
56-
counts[ptr] = {1, getCurrentBacktrace()};
66+
std::tie(it, std::ignore) = counts.emplace(
67+
ptr, RefRuntimeInfo{1, std::type_index(typeid(handle)),
68+
getCurrentBacktrace()});
5769
} else {
5870
context.logger.error("Handle {} already exists", ptr);
5971
return;
@@ -65,29 +77,31 @@ struct RefCountContext {
6577
"Attempting to retain nonexistent handle {}", ptr);
6678
return;
6779
} else {
68-
counts[ptr].refCount++;
80+
it->second.refCount++;
6981
}
7082
break;
7183
case REFCOUNT_DECREASE:
7284
if (it == counts.end()) {
73-
counts[ptr] = {-1, getCurrentBacktrace()};
85+
std::tie(it, std::ignore) = counts.emplace(
86+
ptr, RefRuntimeInfo{-1, std::type_index(typeid(handle)),
87+
getCurrentBacktrace()});
7488
} else {
75-
counts[ptr].refCount--;
89+
it->second.refCount--;
7690
}
7791

78-
if (counts[ptr].refCount < 0) {
92+
if (it->second.refCount < 0) {
7993
context.logger.error(
8094
"Attempting to release nonexistent handle {}", ptr);
81-
} else if (counts[ptr].refCount == 0 && isAdapterHandle) {
95+
} else if (it->second.refCount == 0 && isAdapterHandle) {
8296
adapterCount--;
8397
}
8498
break;
8599
}
86100

87101
context.logger.debug("Reference count for handle {} changed to {}", ptr,
88-
counts[ptr].refCount);
102+
it->second.refCount);
89103

90-
if (counts[ptr].refCount == 0) {
104+
if (it->second.refCount == 0) {
91105
counts.erase(ptr);
92106
}
93107

@@ -99,23 +113,35 @@ struct RefCountContext {
99113
}
100114

101115
public:
102-
void createRefCount(void *ptr) { updateRefCount(ptr, REFCOUNT_CREATE); }
116+
template <typename T> void createRefCount(T handle) {
117+
updateRefCount<T>(handle, REFCOUNT_CREATE);
118+
}
103119

104-
void incrementRefCount(void *ptr, bool isAdapterHandle = false) {
105-
updateRefCount(ptr, REFCOUNT_INCREASE, isAdapterHandle);
120+
template <typename T>
121+
void incrementRefCount(T handle, bool isAdapterHandle = false) {
122+
updateRefCount(handle, REFCOUNT_INCREASE, isAdapterHandle);
106123
}
107124

108-
void decrementRefCount(void *ptr, bool isAdapterHandle = false) {
109-
updateRefCount(ptr, REFCOUNT_DECREASE, isAdapterHandle);
125+
template <typename T>
126+
void decrementRefCount(T handle, bool isAdapterHandle = false) {
127+
updateRefCount(handle, REFCOUNT_DECREASE, isAdapterHandle);
110128
}
111129

112-
void createOrIncrementRefCount(void *ptr, bool isAdapterHandle = false) {
113-
updateRefCount(ptr, REFCOUNT_CREATE_OR_INCREASE, isAdapterHandle);
130+
template <typename T>
131+
void createOrIncrementRefCount(T handle, bool isAdapterHandle = false) {
132+
updateRefCount(handle, REFCOUNT_CREATE_OR_INCREASE, isAdapterHandle);
114133
}
115134

116135
void clear() { counts.clear(); }
117136

118-
bool isReferenceValid(void *ptr) { return counts.count(ptr) > 0; }
137+
template <typename T> bool isReferenceValid(T handle) {
138+
auto it = counts.find(static_cast<void *>(handle));
139+
if (it == counts.end() || it->second.refCount < 1) {
140+
return false;
141+
}
142+
143+
return (it->second.type == std::type_index(typeid(handle)));
144+
}
119145

120146
void logInvalidReferences() {
121147
for (auto &[ptr, refRuntimeInfo] : counts) {

test/layers/validation/lifetime.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,19 @@ TEST_F(urTest, testUrAdapterHandleLifetimeExpectFail) {
99
size_t size = 0;
1010
ur_adapter_handle_t adapter = (ur_adapter_handle_t)0xC0FFEE;
1111
ur_adapter_info_t info_type = UR_ADAPTER_INFO_BACKEND;
12-
ASSERT_EQ(urAdapterGetInfo(adapter, info_type, 0, nullptr, &size),
13-
UR_RESULT_ERROR_INVALID_ARGUMENT);
12+
urAdapterGetInfo(adapter, info_type, 0, nullptr, &size);
1413
}
1514

1615
TEST_F(valAdapterTest, testUrAdapterHandleLifetimeExpectSuccess) {
1716
size_t size = 0;
1817
ur_adapter_info_t info_type = UR_ADAPTER_INFO_BACKEND;
19-
ASSERT_EQ(urAdapterGetInfo(adapter, info_type, 0, nullptr, &size),
20-
UR_RESULT_SUCCESS);
18+
urAdapterGetInfo(adapter, info_type, 0, nullptr, &size);
19+
}
20+
21+
TEST_F(valAdapterTest, testUrAdapterHandleTypeMismatchExpectFail) {
22+
size_t size = 0;
23+
// Use valid adapter handle with incorrect cast.
24+
ur_device_handle_t device = (ur_device_handle_t)adapter;
25+
ur_device_info_t info_type = UR_DEVICE_INFO_BACKEND_RUNTIME_VERSION;
26+
urDeviceGetInfo(device, info_type, 0, nullptr, &size);
2127
}

test/layers/validation/lifetime.out.match

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,10 @@
33
<VALIDATION>[ERROR]: There are no valid references to handle {{[0-9xa-fA-F]+}}
44
{{IGNORE}}
55
[ RUN ] valAdapterTest.testUrAdapterHandleLifetimeExpectSuccess
6+
<VALIDATION>[DEBUG]: Reference count for handle {{[0-9xa-fA-F]+}} changed to 1
67
{{^(?!.*There are no valid references to handle).*$}}
78
{{IGNORE}}
9+
[ RUN ] valAdapterTest.testUrAdapterHandleTypeMismatchExpectFail
10+
<VALIDATION>[DEBUG]: Reference count for handle {{[0-9xa-fA-F]+}} changed to 1
11+
<VALIDATION>[ERROR]: There are no valid references to handle {{[0-9xa-fA-F]+}}
12+
{{IGNORE}}

0 commit comments

Comments
 (0)