9
9
#include " ur_validation_layer.hpp"
10
10
11
11
#include < mutex>
12
+ #include < typeindex>
12
13
#include < unordered_map>
13
14
#include < utility>
14
15
@@ -20,7 +21,12 @@ struct RefCountContext {
20
21
private:
21
22
struct RefRuntimeInfo {
22
23
int64_t refCount;
24
+ std::type_index type;
23
25
std::vector<BacktraceLine> backtrace;
26
+
27
+ RefRuntimeInfo (int64_t refCount, std::type_index type,
28
+ std::vector<BacktraceLine> backtrace)
29
+ : refCount(refCount), type(type), backtrace(backtrace) {}
24
30
};
25
31
26
32
enum RefCountUpdateType {
@@ -34,26 +40,32 @@ struct RefCountContext {
34
40
std::unordered_map<void *, struct RefRuntimeInfo > counts;
35
41
int64_t adapterCount = 0 ;
36
42
37
- void updateRefCount (void *ptr, enum RefCountUpdateType type,
43
+ template <typename T>
44
+ void updateRefCount (T handle, enum RefCountUpdateType type,
38
45
bool isAdapterHandle = false ) {
39
46
std::unique_lock<std::mutex> ulock (mutex);
40
47
48
+ void *ptr = static_cast <void *>(handle);
41
49
auto it = counts.find (ptr);
42
50
43
51
switch (type) {
44
52
case REFCOUNT_CREATE_OR_INCREASE:
45
53
if (it == counts.end ()) {
46
- counts[ptr] = {1 , getCurrentBacktrace ()};
54
+ std::tie (it, std::ignore) = counts.emplace (
55
+ ptr, RefRuntimeInfo{1 , std::type_index (typeid (handle)),
56
+ getCurrentBacktrace ()});
47
57
if (isAdapterHandle) {
48
58
adapterCount++;
49
59
}
50
60
} else {
51
- counts[ptr] .refCount ++;
61
+ it-> second .refCount ++;
52
62
}
53
63
break ;
54
64
case REFCOUNT_CREATE:
55
65
if (it == counts.end ()) {
56
- counts[ptr] = {1 , getCurrentBacktrace ()};
66
+ std::tie (it, std::ignore) = counts.emplace (
67
+ ptr, RefRuntimeInfo{1 , std::type_index (typeid (handle)),
68
+ getCurrentBacktrace ()});
57
69
} else {
58
70
context.logger .error (" Handle {} already exists" , ptr);
59
71
return ;
@@ -65,29 +77,31 @@ struct RefCountContext {
65
77
" Attempting to retain nonexistent handle {}" , ptr);
66
78
return ;
67
79
} else {
68
- counts[ptr] .refCount ++;
80
+ it-> second .refCount ++;
69
81
}
70
82
break ;
71
83
case REFCOUNT_DECREASE:
72
84
if (it == counts.end ()) {
73
- counts[ptr] = {-1 , getCurrentBacktrace ()};
85
+ std::tie (it, std::ignore) = counts.emplace (
86
+ ptr, RefRuntimeInfo{-1 , std::type_index (typeid (handle)),
87
+ getCurrentBacktrace ()});
74
88
} else {
75
- counts[ptr] .refCount --;
89
+ it-> second .refCount --;
76
90
}
77
91
78
- if (counts[ptr] .refCount < 0 ) {
92
+ if (it-> second .refCount < 0 ) {
79
93
context.logger .error (
80
94
" Attempting to release nonexistent handle {}" , ptr);
81
- } else if (counts[ptr] .refCount == 0 && isAdapterHandle) {
95
+ } else if (it-> second .refCount == 0 && isAdapterHandle) {
82
96
adapterCount--;
83
97
}
84
98
break ;
85
99
}
86
100
87
101
context.logger .debug (" Reference count for handle {} changed to {}" , ptr,
88
- counts[ptr] .refCount );
102
+ it-> second .refCount );
89
103
90
- if (counts[ptr] .refCount == 0 ) {
104
+ if (it-> second .refCount == 0 ) {
91
105
counts.erase (ptr);
92
106
}
93
107
@@ -99,23 +113,35 @@ struct RefCountContext {
99
113
}
100
114
101
115
public:
102
- void createRefCount (void *ptr) { updateRefCount (ptr, REFCOUNT_CREATE); }
116
+ template <typename T> void createRefCount (T handle) {
117
+ updateRefCount<T>(handle, REFCOUNT_CREATE);
118
+ }
103
119
104
- void incrementRefCount (void *ptr, bool isAdapterHandle = false ) {
105
- updateRefCount (ptr, REFCOUNT_INCREASE, isAdapterHandle);
120
+ template <typename T>
121
+ void incrementRefCount (T handle, bool isAdapterHandle = false ) {
122
+ updateRefCount (handle, REFCOUNT_INCREASE, isAdapterHandle);
106
123
}
107
124
108
- void decrementRefCount (void *ptr, bool isAdapterHandle = false ) {
109
- updateRefCount (ptr, REFCOUNT_DECREASE, isAdapterHandle);
125
+ template <typename T>
126
+ void decrementRefCount (T handle, bool isAdapterHandle = false ) {
127
+ updateRefCount (handle, REFCOUNT_DECREASE, isAdapterHandle);
110
128
}
111
129
112
- void createOrIncrementRefCount (void *ptr, bool isAdapterHandle = false ) {
113
- updateRefCount (ptr, REFCOUNT_CREATE_OR_INCREASE, isAdapterHandle);
130
+ template <typename T>
131
+ void createOrIncrementRefCount (T handle, bool isAdapterHandle = false ) {
132
+ updateRefCount (handle, REFCOUNT_CREATE_OR_INCREASE, isAdapterHandle);
114
133
}
115
134
116
135
void clear () { counts.clear (); }
117
136
118
- bool isReferenceValid (void *ptr) { return counts.count (ptr) > 0 ; }
137
+ template <typename T> bool isReferenceValid (T handle) {
138
+ auto it = counts.find (static_cast <void *>(handle));
139
+ if (it == counts.end () || it->second .refCount < 1 ) {
140
+ return false ;
141
+ }
142
+
143
+ return (it->second .type == std::type_index (typeid (handle)));
144
+ }
119
145
120
146
void logInvalidReferences () {
121
147
for (auto &[ptr, refRuntimeInfo] : counts) {
0 commit comments