Skip to content

Commit ccab45f

Browse files
committed
Use reference counting on factories
Previously the factories used by ur_ldrddi (used when there are multiple backends) would add newly created objects to a map, but never release them. This patch adds reference counting semantics to the allocation, retention and release methods. A lot of changes were also made to fix use-after-free issues, specifically: * The `release` functions now no longer use the handle after freeing it. * `urDeviceTest` no longer frees devices that it dosen't own. * Some tests for reference counting now explicitly retain an extra copy before releasing them. No tests were added; this should be covered by tests for urThingRetain. Closes: #1784 .
1 parent f9f71f1 commit ccab45f

File tree

8 files changed

+182
-60
lines changed

8 files changed

+182
-60
lines changed

scripts/templates/ldrddi.cpp.mako

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -273,11 +273,17 @@ namespace ur_loader
273273
274274
%endif
275275
%endif
276-
## Before we can re-enable the releases we will need ref-counted object_t.
277-
## See unified-runtime github issue #1784
278-
##%if item['release']:
279-
##// release loader handle
280-
##${item['factory']}.release( ${item['name']} );
276+
## Possibly handle release/retain ref counting - there are no ur_exp-image factories
277+
%if 'factory' in item and '_exp_image_' not in item['factory']:
278+
%if item['release']:
279+
// release loader handle
280+
context->factories.${item['factory']}.release( ${item['name']} );
281+
%endif
282+
%if item['retain']:
283+
// increment refcount of handle
284+
context->factories.${item['factory']}.retain( ${item['name']} );
285+
%endif
286+
%endif
281287
%if not item['release'] and not item['retain'] and not '_native_object_' in item['obj'] or th.make_func_name(n, tags, obj) == 'urPlatformCreateWithNativeHandle':
282288
try
283289
{

scripts/templates/valddi.cpp.mako

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,19 @@ namespace ur_validation_layer
9494
%endif
9595
%endfor
9696

97+
%for tp in tracked_params:
98+
<%
99+
tp_handle_funcs = next((hf for hf in handle_create_get_retain_release_funcs if th.subt(n, tags, tp['type']) in [hf['handle'], hf['handle'] + "*"]), None)
100+
is_handle_to_adapter = ("_adapter_handle_t" in tp['type'])
101+
%>
102+
%if func_name in tp_handle_funcs['release']:
103+
if( getContext()->enableLeakChecking )
104+
{
105+
getContext()->refCountContext->decrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()});
106+
}
107+
%endif
108+
%endfor
109+
97110
${x}_result_t result = ${th.make_pfn_name(n, tags, obj)}( ${", ".join(th.make_param_lines(n, tags, obj, format=["name"]))} );
98111

99112
%for tp in tracked_params:
@@ -114,15 +127,10 @@ namespace ur_validation_layer
114127
}
115128
}
116129
%elif func_name in tp_handle_funcs['retain']:
117-
if( getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS )
130+
if( getContext()->enableLeakChecking )
118131
{
119132
getContext()->refCountContext->incrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()});
120133
}
121-
%elif func_name in tp_handle_funcs['release']:
122-
if( getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS )
123-
{
124-
getContext()->refCountContext->decrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()});
125-
}
126134
%endif
127135
%endfor
128136

source/common/ur_singleton.hpp

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,26 @@
1111
#ifndef UR_SINGLETON_H
1212
#define UR_SINGLETON_H 1
1313

14+
#include <cassert>
1415
#include <memory>
1516
#include <mutex>
1617
#include <unordered_map>
1718

1819
//////////////////////////////////////////////////////////////////////////
1920
/// a abstract factory for creation of singleton objects
2021
template <typename singleton_tn, typename key_tn> class singleton_factory_t {
22+
struct entry_t {
23+
std::unique_ptr<singleton_tn> ptr;
24+
size_t ref_count;
25+
};
26+
2127
protected:
2228
using singleton_t = singleton_tn;
2329
using key_t = typename std::conditional<std::is_pointer<key_tn>::value,
2430
size_t, key_tn>::type;
2531

2632
using ptr_t = std::unique_ptr<singleton_t>;
27-
using map_t = std::unordered_map<key_t, ptr_t>;
33+
using map_t = std::unordered_map<key_t, entry_t>;
2834

2935
std::mutex mut; ///< lock for thread-safety
3036
map_t map; ///< single instance of singleton for each unique key
@@ -60,16 +66,31 @@ template <typename singleton_tn, typename key_tn> class singleton_factory_t {
6066
if (map.end() == iter) {
6167
auto ptr =
6268
std::make_unique<singleton_t>(std::forward<Ts>(params)...);
63-
iter = map.emplace(key, std::move(ptr)).first;
69+
iter = map.emplace(key, entry_t{std::move(ptr), 0}).first;
70+
} else {
71+
iter->second.ref_count++;
6472
}
65-
return iter->second.get();
73+
return iter->second.ptr.get();
74+
}
75+
76+
void retain(key_tn key) {
77+
std::lock_guard<std::mutex> lk(mut);
78+
auto iter = map.find(getKey(key));
79+
assert(iter != map.end());
80+
iter->second.ref_count++;
6681
}
6782

6883
//////////////////////////////////////////////////////////////////////////
6984
/// once the key is no longer valid, release the singleton
7085
void release(key_tn key) {
7186
std::lock_guard<std::mutex> lk(mut);
72-
map.erase(getKey(key));
87+
auto iter = map.find(getKey(key));
88+
assert(iter != map.end());
89+
if (iter->second.ref_count == 0) {
90+
map.erase(iter);
91+
} else {
92+
iter->second.ref_count--;
93+
}
7394
}
7495

7596
void clear() {

source/loader/layers/validation/ur_valddi.cpp

Lines changed: 44 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,12 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRelease(
7171
}
7272
}
7373

74-
ur_result_t result = pfnAdapterRelease(hAdapter);
75-
76-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
74+
if (getContext()->enableLeakChecking) {
7775
getContext()->refCountContext->decrementRefCount(hAdapter, true);
7876
}
7977

78+
ur_result_t result = pfnAdapterRelease(hAdapter);
79+
8080
return result;
8181
}
8282

@@ -99,7 +99,7 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRetain(
9999

100100
ur_result_t result = pfnAdapterRetain(hAdapter);
101101

102-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
102+
if (getContext()->enableLeakChecking) {
103103
getContext()->refCountContext->incrementRefCount(hAdapter, true);
104104
}
105105

@@ -558,7 +558,7 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRetain(
558558

559559
ur_result_t result = pfnRetain(hDevice);
560560

561-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
561+
if (getContext()->enableLeakChecking) {
562562
getContext()->refCountContext->incrementRefCount(hDevice, false);
563563
}
564564

@@ -583,12 +583,12 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRelease(
583583
}
584584
}
585585

586-
ur_result_t result = pfnRelease(hDevice);
587-
588-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
586+
if (getContext()->enableLeakChecking) {
589587
getContext()->refCountContext->decrementRefCount(hDevice, false);
590588
}
591589

590+
ur_result_t result = pfnRelease(hDevice);
591+
592592
return result;
593593
}
594594

@@ -861,7 +861,7 @@ __urdlllocal ur_result_t UR_APICALL urContextRetain(
861861

862862
ur_result_t result = pfnRetain(hContext);
863863

864-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
864+
if (getContext()->enableLeakChecking) {
865865
getContext()->refCountContext->incrementRefCount(hContext, false);
866866
}
867867

@@ -886,12 +886,12 @@ __urdlllocal ur_result_t UR_APICALL urContextRelease(
886886
}
887887
}
888888

889-
ur_result_t result = pfnRelease(hContext);
890-
891-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
889+
if (getContext()->enableLeakChecking) {
892890
getContext()->refCountContext->decrementRefCount(hContext, false);
893891
}
894892

893+
ur_result_t result = pfnRelease(hContext);
894+
895895
return result;
896896
}
897897

@@ -1248,7 +1248,7 @@ __urdlllocal ur_result_t UR_APICALL urMemRetain(
12481248

12491249
ur_result_t result = pfnRetain(hMem);
12501250

1251-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
1251+
if (getContext()->enableLeakChecking) {
12521252
getContext()->refCountContext->incrementRefCount(hMem, false);
12531253
}
12541254

@@ -1273,12 +1273,12 @@ __urdlllocal ur_result_t UR_APICALL urMemRelease(
12731273
}
12741274
}
12751275

1276-
ur_result_t result = pfnRelease(hMem);
1277-
1278-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
1276+
if (getContext()->enableLeakChecking) {
12791277
getContext()->refCountContext->decrementRefCount(hMem, false);
12801278
}
12811279

1280+
ur_result_t result = pfnRelease(hMem);
1281+
12821282
return result;
12831283
}
12841284

@@ -1657,7 +1657,7 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRetain(
16571657

16581658
ur_result_t result = pfnRetain(hSampler);
16591659

1660-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
1660+
if (getContext()->enableLeakChecking) {
16611661
getContext()->refCountContext->incrementRefCount(hSampler, false);
16621662
}
16631663

@@ -1682,12 +1682,12 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRelease(
16821682
}
16831683
}
16841684

1685-
ur_result_t result = pfnRelease(hSampler);
1686-
1687-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
1685+
if (getContext()->enableLeakChecking) {
16881686
getContext()->refCountContext->decrementRefCount(hSampler, false);
16891687
}
16901688

1689+
ur_result_t result = pfnRelease(hSampler);
1690+
16911691
return result;
16921692
}
16931693

@@ -2154,7 +2154,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRetain(
21542154

21552155
ur_result_t result = pfnPoolRetain(pPool);
21562156

2157-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
2157+
if (getContext()->enableLeakChecking) {
21582158
getContext()->refCountContext->incrementRefCount(pPool, false);
21592159
}
21602160

@@ -2178,12 +2178,12 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRelease(
21782178
}
21792179
}
21802180

2181-
ur_result_t result = pfnPoolRelease(pPool);
2182-
2183-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
2181+
if (getContext()->enableLeakChecking) {
21842182
getContext()->refCountContext->decrementRefCount(pPool, false);
21852183
}
21862184

2185+
ur_result_t result = pfnPoolRelease(pPool);
2186+
21872187
return result;
21882188
}
21892189

@@ -2631,7 +2631,7 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRetain(
26312631

26322632
ur_result_t result = pfnRetain(hPhysicalMem);
26332633

2634-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
2634+
if (getContext()->enableLeakChecking) {
26352635
getContext()->refCountContext->incrementRefCount(hPhysicalMem, false);
26362636
}
26372637

@@ -2656,12 +2656,12 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRelease(
26562656
}
26572657
}
26582658

2659-
ur_result_t result = pfnRelease(hPhysicalMem);
2660-
2661-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
2659+
if (getContext()->enableLeakChecking) {
26622660
getContext()->refCountContext->decrementRefCount(hPhysicalMem, false);
26632661
}
26642662

2663+
ur_result_t result = pfnRelease(hPhysicalMem);
2664+
26652665
return result;
26662666
}
26672667

@@ -2952,7 +2952,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramRetain(
29522952

29532953
ur_result_t result = pfnRetain(hProgram);
29542954

2955-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
2955+
if (getContext()->enableLeakChecking) {
29562956
getContext()->refCountContext->incrementRefCount(hProgram, false);
29572957
}
29582958

@@ -2977,12 +2977,12 @@ __urdlllocal ur_result_t UR_APICALL urProgramRelease(
29772977
}
29782978
}
29792979

2980-
ur_result_t result = pfnRelease(hProgram);
2981-
2982-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
2980+
if (getContext()->enableLeakChecking) {
29832981
getContext()->refCountContext->decrementRefCount(hProgram, false);
29842982
}
29852983

2984+
ur_result_t result = pfnRelease(hProgram);
2985+
29862986
return result;
29872987
}
29882988

@@ -3618,7 +3618,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelRetain(
36183618

36193619
ur_result_t result = pfnRetain(hKernel);
36203620

3621-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
3621+
if (getContext()->enableLeakChecking) {
36223622
getContext()->refCountContext->incrementRefCount(hKernel, false);
36233623
}
36243624

@@ -3643,12 +3643,12 @@ __urdlllocal ur_result_t UR_APICALL urKernelRelease(
36433643
}
36443644
}
36453645

3646-
ur_result_t result = pfnRelease(hKernel);
3647-
3648-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
3646+
if (getContext()->enableLeakChecking) {
36493647
getContext()->refCountContext->decrementRefCount(hKernel, false);
36503648
}
36513649

3650+
ur_result_t result = pfnRelease(hKernel);
3651+
36523652
return result;
36533653
}
36543654

@@ -4138,7 +4138,7 @@ __urdlllocal ur_result_t UR_APICALL urQueueRetain(
41384138

41394139
ur_result_t result = pfnRetain(hQueue);
41404140

4141-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
4141+
if (getContext()->enableLeakChecking) {
41424142
getContext()->refCountContext->incrementRefCount(hQueue, false);
41434143
}
41444144

@@ -4163,12 +4163,12 @@ __urdlllocal ur_result_t UR_APICALL urQueueRelease(
41634163
}
41644164
}
41654165

4166-
ur_result_t result = pfnRelease(hQueue);
4167-
4168-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
4166+
if (getContext()->enableLeakChecking) {
41694167
getContext()->refCountContext->decrementRefCount(hQueue, false);
41704168
}
41714169

4170+
ur_result_t result = pfnRelease(hQueue);
4171+
41724172
return result;
41734173
}
41744174

@@ -4454,7 +4454,7 @@ __urdlllocal ur_result_t UR_APICALL urEventRetain(
44544454

44554455
ur_result_t result = pfnRetain(hEvent);
44564456

4457-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
4457+
if (getContext()->enableLeakChecking) {
44584458
getContext()->refCountContext->incrementRefCount(hEvent, false);
44594459
}
44604460

@@ -4478,12 +4478,12 @@ __urdlllocal ur_result_t UR_APICALL urEventRelease(
44784478
}
44794479
}
44804480

4481-
ur_result_t result = pfnRelease(hEvent);
4482-
4483-
if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
4481+
if (getContext()->enableLeakChecking) {
44844482
getContext()->refCountContext->decrementRefCount(hEvent, false);
44854483
}
44864484

4485+
ur_result_t result = pfnRelease(hEvent);
4486+
44874487
return result;
44884488
}
44894489

0 commit comments

Comments
 (0)