Skip to content

Commit 6083ba0

Browse files
committed
Use reference counting on factories
Previously the factories used by ur_ldrddi (used when there are multiple backends) would add newly created objects to a map, but never release them. This patch adds reference counting semantics to the allocation, retention and release methods. A lot of changes were also made to fix use-after-free issues, specifically: * The `release` functions now no longer use the handle after freeing it. * `urDeviceTest` no longer frees devices that it dosen't own. * Some tests for reference counting now explicitly retain an extra copy before releasing them. No tests were added; this should be covered by tests for urThingRetain. Closes: #1784 .
1 parent a07352d commit 6083ba0

File tree

3 files changed

+121
-9
lines changed

3 files changed

+121
-9
lines changed

scripts/templates/ldrddi.cpp.mako

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -273,11 +273,17 @@ namespace ur_loader
273273
274274
%endif
275275
%endif
276-
## Before we can re-enable the releases we will need ref-counted object_t.
277-
## See unified-runtime github issue #1784
278-
##%if item['release']:
279-
##// release loader handle
280-
##${item['factory']}.release( ${item['name']} );
276+
## Possibly handle release/retain ref counting - there are no ur_exp-image factories
277+
%if 'factory' in item and '_exp_image_' not in item['factory']:
278+
%if item['release']:
279+
// release loader handle
280+
context->factories.${item['factory']}.release( ${item['name']} );
281+
%endif
282+
%if item['retain']:
283+
// increment refcount of handle
284+
context->factories.${item['factory']}.retain( ${item['name']} );
285+
%endif
286+
%endif
281287
%if not item['release'] and not item['retain'] and not '_native_object_' in item['obj'] or th.make_func_name(n, tags, obj) == 'urPlatformCreateWithNativeHandle':
282288
try
283289
{

source/common/ur_singleton.hpp

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,26 @@
1111
#ifndef UR_SINGLETON_H
1212
#define UR_SINGLETON_H 1
1313

14+
#include <cassert>
1415
#include <memory>
1516
#include <mutex>
1617
#include <unordered_map>
1718

1819
//////////////////////////////////////////////////////////////////////////
1920
/// a abstract factory for creation of singleton objects
2021
template <typename singleton_tn, typename key_tn> class singleton_factory_t {
22+
struct entry_t {
23+
std::unique_ptr<singleton_tn> ptr;
24+
size_t ref_count;
25+
};
26+
2127
protected:
2228
using singleton_t = singleton_tn;
2329
using key_t = typename std::conditional<std::is_pointer<key_tn>::value,
2430
size_t, key_tn>::type;
2531

2632
using ptr_t = std::unique_ptr<singleton_t>;
27-
using map_t = std::unordered_map<key_t, ptr_t>;
33+
using map_t = std::unordered_map<key_t, entry_t>;
2834

2935
std::mutex mut; ///< lock for thread-safety
3036
map_t map; ///< single instance of singleton for each unique key
@@ -60,16 +66,31 @@ template <typename singleton_tn, typename key_tn> class singleton_factory_t {
6066
if (map.end() == iter) {
6167
auto ptr =
6268
std::make_unique<singleton_t>(std::forward<Ts>(params)...);
63-
iter = map.emplace(key, std::move(ptr)).first;
69+
iter = map.emplace(key, entry_t{std::move(ptr), 0}).first;
70+
} else {
71+
iter->second.ref_count++;
6472
}
65-
return iter->second.get();
73+
return iter->second.ptr.get();
74+
}
75+
76+
void retain(key_tn key) {
77+
std::lock_guard<std::mutex> lk(mut);
78+
auto iter = map.find(getKey(key));
79+
assert(iter != map.end());
80+
iter->second.ref_count++;
6681
}
6782

6883
//////////////////////////////////////////////////////////////////////////
6984
/// once the key is no longer valid, release the singleton
7085
void release(key_tn key) {
7186
std::lock_guard<std::mutex> lk(mut);
72-
map.erase(getKey(key));
87+
auto iter = map.find(getKey(key));
88+
assert(iter != map.end());
89+
if (iter->second.ref_count == 0) {
90+
map.erase(iter);
91+
} else {
92+
iter->second.ref_count--;
93+
}
7394
}
7495

7596
void clear() {

source/loader/ur_ldrddi.cpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRelease(
8585
// forward to device-platform
8686
result = pfnAdapterRelease(hAdapter);
8787

88+
// release loader handle
89+
context->factories.ur_adapter_factory.release(hAdapter);
90+
8891
return result;
8992
}
9093

@@ -110,6 +113,9 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRetain(
110113
// forward to device-platform
111114
result = pfnAdapterRetain(hAdapter);
112115

116+
// increment refcount of handle
117+
context->factories.ur_adapter_factory.retain(hAdapter);
118+
113119
return result;
114120
}
115121

@@ -614,6 +620,9 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRetain(
614620
// forward to device-platform
615621
result = pfnRetain(hDevice);
616622

623+
// increment refcount of handle
624+
context->factories.ur_device_factory.retain(hDevice);
625+
617626
return result;
618627
}
619628

@@ -640,6 +649,9 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRelease(
640649
// forward to device-platform
641650
result = pfnRelease(hDevice);
642651

652+
// release loader handle
653+
context->factories.ur_device_factory.release(hDevice);
654+
643655
return result;
644656
}
645657

@@ -910,6 +922,9 @@ __urdlllocal ur_result_t UR_APICALL urContextRetain(
910922
// forward to device-platform
911923
result = pfnRetain(hContext);
912924

925+
// increment refcount of handle
926+
context->factories.ur_context_factory.retain(hContext);
927+
913928
return result;
914929
}
915930

@@ -936,6 +951,9 @@ __urdlllocal ur_result_t UR_APICALL urContextRelease(
936951
// forward to device-platform
937952
result = pfnRelease(hContext);
938953

954+
// release loader handle
955+
context->factories.ur_context_factory.release(hContext);
956+
939957
return result;
940958
}
941959

@@ -1238,6 +1256,9 @@ __urdlllocal ur_result_t UR_APICALL urMemRetain(
12381256
// forward to device-platform
12391257
result = pfnRetain(hMem);
12401258

1259+
// increment refcount of handle
1260+
context->factories.ur_mem_factory.retain(hMem);
1261+
12411262
return result;
12421263
}
12431264

@@ -1264,6 +1285,9 @@ __urdlllocal ur_result_t UR_APICALL urMemRelease(
12641285
// forward to device-platform
12651286
result = pfnRelease(hMem);
12661287

1288+
// release loader handle
1289+
context->factories.ur_mem_factory.release(hMem);
1290+
12671291
return result;
12681292
}
12691293

@@ -1615,6 +1639,9 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRetain(
16151639
// forward to device-platform
16161640
result = pfnRetain(hSampler);
16171641

1642+
// increment refcount of handle
1643+
context->factories.ur_sampler_factory.retain(hSampler);
1644+
16181645
return result;
16191646
}
16201647

@@ -1641,6 +1668,9 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRelease(
16411668
// forward to device-platform
16421669
result = pfnRelease(hSampler);
16431670

1671+
// release loader handle
1672+
context->factories.ur_sampler_factory.release(hSampler);
1673+
16441674
return result;
16451675
}
16461676

@@ -2074,6 +2104,9 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRetain(
20742104
// forward to device-platform
20752105
result = pfnPoolRetain(pPool);
20762106

2107+
// increment refcount of handle
2108+
context->factories.ur_usm_pool_factory.retain(pPool);
2109+
20772110
return result;
20782111
}
20792112

@@ -2099,6 +2132,9 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRelease(
20992132
// forward to device-platform
21002133
result = pfnPoolRelease(pPool);
21012134

2135+
// release loader handle
2136+
context->factories.ur_usm_pool_factory.release(pPool);
2137+
21022138
return result;
21032139
}
21042140

@@ -2484,6 +2520,9 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRetain(
24842520
// forward to device-platform
24852521
result = pfnRetain(hPhysicalMem);
24862522

2523+
// increment refcount of handle
2524+
context->factories.ur_physical_mem_factory.retain(hPhysicalMem);
2525+
24872526
return result;
24882527
}
24892528

@@ -2512,6 +2551,9 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRelease(
25122551
// forward to device-platform
25132552
result = pfnRelease(hPhysicalMem);
25142553

2554+
// release loader handle
2555+
context->factories.ur_physical_mem_factory.release(hPhysicalMem);
2556+
25152557
return result;
25162558
}
25172559

@@ -2759,6 +2801,9 @@ __urdlllocal ur_result_t UR_APICALL urProgramRetain(
27592801
// forward to device-platform
27602802
result = pfnRetain(hProgram);
27612803

2804+
// increment refcount of handle
2805+
context->factories.ur_program_factory.retain(hProgram);
2806+
27622807
return result;
27632808
}
27642809

@@ -2785,6 +2830,9 @@ __urdlllocal ur_result_t UR_APICALL urProgramRelease(
27852830
// forward to device-platform
27862831
result = pfnRelease(hProgram);
27872832

2833+
// release loader handle
2834+
context->factories.ur_program_factory.release(hProgram);
2835+
27882836
return result;
27892837
}
27902838

@@ -3382,6 +3430,9 @@ __urdlllocal ur_result_t UR_APICALL urKernelRetain(
33823430
// forward to device-platform
33833431
result = pfnRetain(hKernel);
33843432

3433+
// increment refcount of handle
3434+
context->factories.ur_kernel_factory.retain(hKernel);
3435+
33853436
return result;
33863437
}
33873438

@@ -3408,6 +3459,9 @@ __urdlllocal ur_result_t UR_APICALL urKernelRelease(
34083459
// forward to device-platform
34093460
result = pfnRelease(hKernel);
34103461

3462+
// release loader handle
3463+
context->factories.ur_kernel_factory.release(hKernel);
3464+
34113465
return result;
34123466
}
34133467

@@ -3858,6 +3912,9 @@ __urdlllocal ur_result_t UR_APICALL urQueueRetain(
38583912
// forward to device-platform
38593913
result = pfnRetain(hQueue);
38603914

3915+
// increment refcount of handle
3916+
context->factories.ur_queue_factory.retain(hQueue);
3917+
38613918
return result;
38623919
}
38633920

@@ -3884,6 +3941,9 @@ __urdlllocal ur_result_t UR_APICALL urQueueRelease(
38843941
// forward to device-platform
38853942
result = pfnRelease(hQueue);
38863943

3944+
// release loader handle
3945+
context->factories.ur_queue_factory.release(hQueue);
3946+
38873947
return result;
38883948
}
38893949

@@ -4188,6 +4248,9 @@ __urdlllocal ur_result_t UR_APICALL urEventRetain(
41884248
// forward to device-platform
41894249
result = pfnRetain(hEvent);
41904250

4251+
// increment refcount of handle
4252+
context->factories.ur_event_factory.retain(hEvent);
4253+
41914254
return result;
41924255
}
41934256

@@ -4213,6 +4276,9 @@ __urdlllocal ur_result_t UR_APICALL urEventRelease(
42134276
// forward to device-platform
42144277
result = pfnRelease(hEvent);
42154278

4279+
// release loader handle
4280+
context->factories.ur_event_factory.release(hEvent);
4281+
42164282
return result;
42174283
}
42184284

@@ -6745,6 +6811,9 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesReleaseExternalMemoryExp(
67456811
// forward to device-platform
67466812
result = pfnReleaseExternalMemoryExp(hContext, hDevice, hExternalMem);
67476813

6814+
// release loader handle
6815+
context->factories.ur_exp_external_mem_factory.release(hExternalMem);
6816+
67486817
return result;
67496818
}
67506819

@@ -6835,6 +6904,10 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesReleaseExternalSemaphoreExp(
68356904
result =
68366905
pfnReleaseExternalSemaphoreExp(hContext, hDevice, hExternalSemaphore);
68376906

6907+
// release loader handle
6908+
context->factories.ur_exp_external_semaphore_factory.release(
6909+
hExternalSemaphore);
6910+
68386911
return result;
68396912
}
68406913

@@ -7062,6 +7135,9 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainExp(
70627135
// forward to device-platform
70637136
result = pfnRetainExp(hCommandBuffer);
70647137

7138+
// increment refcount of handle
7139+
context->factories.ur_exp_command_buffer_factory.retain(hCommandBuffer);
7140+
70657141
return result;
70667142
}
70677143

@@ -7092,6 +7168,9 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferReleaseExp(
70927168
// forward to device-platform
70937169
result = pfnReleaseExp(hCommandBuffer);
70947170

7171+
// release loader handle
7172+
context->factories.ur_exp_command_buffer_factory.release(hCommandBuffer);
7173+
70957174
return result;
70967175
}
70977176

@@ -8408,6 +8487,9 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainCommandExp(
84088487
// forward to device-platform
84098488
result = pfnRetainCommandExp(hCommand);
84108489

8490+
// increment refcount of handle
8491+
context->factories.ur_exp_command_buffer_command_factory.retain(hCommand);
8492+
84118493
return result;
84128494
}
84138495

@@ -8439,6 +8521,9 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferReleaseCommandExp(
84398521
// forward to device-platform
84408522
result = pfnReleaseCommandExp(hCommand);
84418523

8524+
// release loader handle
8525+
context->factories.ur_exp_command_buffer_command_factory.release(hCommand);
8526+
84428527
return result;
84438528
}
84448529

0 commit comments

Comments
 (0)