Skip to content

Commit 3a6ed49

Browse files
committed
Use reference counting on factories
Previously the factories used by ur_ldrddi (used when there are multiple backends) would add newly created objects to a map, but never release them. This patch adds reference counting semantics to the allocation, retention and release methods. No tests were added; this should be covered by tests for urThingRetain. Closes: #1784 .
1 parent e26bba5 commit 3a6ed49

File tree

3 files changed

+119
-9
lines changed

3 files changed

+119
-9
lines changed

scripts/templates/ldrddi.cpp.mako

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -273,11 +273,17 @@ namespace ur_loader
273273
274274
%endif
275275
%endif
276-
## Before we can re-enable the releases we will need ref-counted object_t.
277-
## See unified-runtime github issue #1784
278-
##%if item['release']:
279-
##// release loader handle
280-
##${item['factory']}.release( ${item['name']} );
276+
## Possibly handle release/retain ref counting - there are no ur_exp-image factories
277+
%if 'factory' in item and '_exp_image_' not in item['factory']:
278+
%if item['release']:
279+
// release loader handle
280+
context->factories.${item['factory']}.release( ${item['name']} );
281+
%endif
282+
%if item['retain']:
283+
// increment refcount of handle
284+
context->factories.${item['factory']}.retain( ${item['name']} );
285+
%endif
286+
%endif
281287
%if not item['release'] and not item['retain'] and not '_native_object_' in item['obj'] or th.make_func_name(n, tags, obj) == 'urPlatformCreateWithNativeHandle':
282288
try
283289
{

source/common/ur_singleton.hpp

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,26 @@
1111
#ifndef UR_SINGLETON_H
1212
#define UR_SINGLETON_H 1
1313

14+
#include <cassert>
1415
#include <memory>
1516
#include <mutex>
1617
#include <unordered_map>
1718

1819
//////////////////////////////////////////////////////////////////////////
1920
/// a abstract factory for creation of singleton objects
2021
template <typename singleton_tn, typename key_tn> class singleton_factory_t {
22+
struct entry_t {
23+
std::unique_ptr<singleton_tn> ptr;
24+
size_t ref_count;
25+
};
26+
2127
protected:
2228
using singleton_t = singleton_tn;
2329
using key_t = typename std::conditional<std::is_pointer<key_tn>::value,
2430
size_t, key_tn>::type;
2531

2632
using ptr_t = std::unique_ptr<singleton_t>;
27-
using map_t = std::unordered_map<key_t, ptr_t>;
33+
using map_t = std::unordered_map<key_t, entry_t>;
2834

2935
std::mutex mut; ///< lock for thread-safety
3036
map_t map; ///< single instance of singleton for each unique key
@@ -60,16 +66,29 @@ template <typename singleton_tn, typename key_tn> class singleton_factory_t {
6066
if (map.end() == iter) {
6167
auto ptr =
6268
std::make_unique<singleton_t>(std::forward<Ts>(params)...);
63-
iter = map.emplace(key, std::move(ptr)).first;
69+
iter = map.emplace(key, entry_t{std::move(ptr), 0}).first;
6470
}
65-
return iter->second.get();
71+
return iter->second.ptr.get();
72+
}
73+
74+
void retain(key_tn key) {
75+
std::lock_guard<std::mutex> lk(mut);
76+
auto iter = map.find(getKey(key));
77+
assert(iter != map.end());
78+
iter->second.ref_count++;
6679
}
6780

6881
//////////////////////////////////////////////////////////////////////////
6982
/// once the key is no longer valid, release the singleton
7083
void release(key_tn key) {
7184
std::lock_guard<std::mutex> lk(mut);
72-
map.erase(getKey(key));
85+
auto iter = map.find(getKey(key));
86+
assert(iter != map.end());
87+
if (iter->second.ref_count == 0) {
88+
map.erase(iter);
89+
} else {
90+
iter->second.ref_count--;
91+
}
7392
}
7493

7594
void clear() {

source/loader/ur_ldrddi.cpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRelease(
8585
// forward to device-platform
8686
result = pfnAdapterRelease(hAdapter);
8787

88+
// release loader handle
89+
context->factories.ur_adapter_factory.release(hAdapter);
90+
8891
return result;
8992
}
9093

@@ -110,6 +113,9 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRetain(
110113
// forward to device-platform
111114
result = pfnAdapterRetain(hAdapter);
112115

116+
// increment refcount of handle
117+
context->factories.ur_adapter_factory.retain(hAdapter);
118+
113119
return result;
114120
}
115121

@@ -614,6 +620,9 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRetain(
614620
// forward to device-platform
615621
result = pfnRetain(hDevice);
616622

623+
// increment refcount of handle
624+
context->factories.ur_device_factory.retain(hDevice);
625+
617626
return result;
618627
}
619628

@@ -640,6 +649,9 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRelease(
640649
// forward to device-platform
641650
result = pfnRelease(hDevice);
642651

652+
// release loader handle
653+
context->factories.ur_device_factory.release(hDevice);
654+
643655
return result;
644656
}
645657

@@ -910,6 +922,9 @@ __urdlllocal ur_result_t UR_APICALL urContextRetain(
910922
// forward to device-platform
911923
result = pfnRetain(hContext);
912924

925+
// increment refcount of handle
926+
context->factories.ur_context_factory.retain(hContext);
927+
913928
return result;
914929
}
915930

@@ -936,6 +951,9 @@ __urdlllocal ur_result_t UR_APICALL urContextRelease(
936951
// forward to device-platform
937952
result = pfnRelease(hContext);
938953

954+
// release loader handle
955+
context->factories.ur_context_factory.release(hContext);
956+
939957
return result;
940958
}
941959

@@ -1238,6 +1256,9 @@ __urdlllocal ur_result_t UR_APICALL urMemRetain(
12381256
// forward to device-platform
12391257
result = pfnRetain(hMem);
12401258

1259+
// increment refcount of handle
1260+
context->factories.ur_mem_factory.retain(hMem);
1261+
12411262
return result;
12421263
}
12431264

@@ -1264,6 +1285,9 @@ __urdlllocal ur_result_t UR_APICALL urMemRelease(
12641285
// forward to device-platform
12651286
result = pfnRelease(hMem);
12661287

1288+
// release loader handle
1289+
context->factories.ur_mem_factory.release(hMem);
1290+
12671291
return result;
12681292
}
12691293

@@ -1615,6 +1639,9 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRetain(
16151639
// forward to device-platform
16161640
result = pfnRetain(hSampler);
16171641

1642+
// increment refcount of handle
1643+
context->factories.ur_sampler_factory.retain(hSampler);
1644+
16181645
return result;
16191646
}
16201647

@@ -1641,6 +1668,9 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRelease(
16411668
// forward to device-platform
16421669
result = pfnRelease(hSampler);
16431670

1671+
// release loader handle
1672+
context->factories.ur_sampler_factory.release(hSampler);
1673+
16441674
return result;
16451675
}
16461676

@@ -2074,6 +2104,9 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRetain(
20742104
// forward to device-platform
20752105
result = pfnPoolRetain(pPool);
20762106

2107+
// increment refcount of handle
2108+
context->factories.ur_usm_pool_factory.retain(pPool);
2109+
20772110
return result;
20782111
}
20792112

@@ -2099,6 +2132,9 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRelease(
20992132
// forward to device-platform
21002133
result = pfnPoolRelease(pPool);
21012134

2135+
// release loader handle
2136+
context->factories.ur_usm_pool_factory.release(pPool);
2137+
21022138
return result;
21032139
}
21042140

@@ -2484,6 +2520,9 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRetain(
24842520
// forward to device-platform
24852521
result = pfnRetain(hPhysicalMem);
24862522

2523+
// increment refcount of handle
2524+
context->factories.ur_physical_mem_factory.retain(hPhysicalMem);
2525+
24872526
return result;
24882527
}
24892528

@@ -2512,6 +2551,9 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRelease(
25122551
// forward to device-platform
25132552
result = pfnRelease(hPhysicalMem);
25142553

2554+
// release loader handle
2555+
context->factories.ur_physical_mem_factory.release(hPhysicalMem);
2556+
25152557
return result;
25162558
}
25172559

@@ -2749,6 +2791,9 @@ __urdlllocal ur_result_t UR_APICALL urProgramRetain(
27492791
// forward to device-platform
27502792
result = pfnRetain(hProgram);
27512793

2794+
// increment refcount of handle
2795+
context->factories.ur_program_factory.retain(hProgram);
2796+
27522797
return result;
27532798
}
27542799

@@ -2775,6 +2820,9 @@ __urdlllocal ur_result_t UR_APICALL urProgramRelease(
27752820
// forward to device-platform
27762821
result = pfnRelease(hProgram);
27772822

2823+
// release loader handle
2824+
context->factories.ur_program_factory.release(hProgram);
2825+
27782826
return result;
27792827
}
27802828

@@ -3372,6 +3420,9 @@ __urdlllocal ur_result_t UR_APICALL urKernelRetain(
33723420
// forward to device-platform
33733421
result = pfnRetain(hKernel);
33743422

3423+
// increment refcount of handle
3424+
context->factories.ur_kernel_factory.retain(hKernel);
3425+
33753426
return result;
33763427
}
33773428

@@ -3398,6 +3449,9 @@ __urdlllocal ur_result_t UR_APICALL urKernelRelease(
33983449
// forward to device-platform
33993450
result = pfnRelease(hKernel);
34003451

3452+
// release loader handle
3453+
context->factories.ur_kernel_factory.release(hKernel);
3454+
34013455
return result;
34023456
}
34033457

@@ -3848,6 +3902,9 @@ __urdlllocal ur_result_t UR_APICALL urQueueRetain(
38483902
// forward to device-platform
38493903
result = pfnRetain(hQueue);
38503904

3905+
// increment refcount of handle
3906+
context->factories.ur_queue_factory.retain(hQueue);
3907+
38513908
return result;
38523909
}
38533910

@@ -3874,6 +3931,9 @@ __urdlllocal ur_result_t UR_APICALL urQueueRelease(
38743931
// forward to device-platform
38753932
result = pfnRelease(hQueue);
38763933

3934+
// release loader handle
3935+
context->factories.ur_queue_factory.release(hQueue);
3936+
38773937
return result;
38783938
}
38793939

@@ -4178,6 +4238,9 @@ __urdlllocal ur_result_t UR_APICALL urEventRetain(
41784238
// forward to device-platform
41794239
result = pfnRetain(hEvent);
41804240

4241+
// increment refcount of handle
4242+
context->factories.ur_event_factory.retain(hEvent);
4243+
41814244
return result;
41824245
}
41834246

@@ -4203,6 +4266,9 @@ __urdlllocal ur_result_t UR_APICALL urEventRelease(
42034266
// forward to device-platform
42044267
result = pfnRelease(hEvent);
42054268

4269+
// release loader handle
4270+
context->factories.ur_event_factory.release(hEvent);
4271+
42064272
return result;
42074273
}
42084274

@@ -6715,6 +6781,9 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesReleaseExternalMemoryExp(
67156781
// forward to device-platform
67166782
result = pfnReleaseExternalMemoryExp(hContext, hDevice, hExternalMem);
67176783

6784+
// release loader handle
6785+
context->factories.ur_exp_external_mem_factory.release(hExternalMem);
6786+
67186787
return result;
67196788
}
67206789

@@ -6805,6 +6874,10 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesReleaseExternalSemaphoreExp(
68056874
result =
68066875
pfnReleaseExternalSemaphoreExp(hContext, hDevice, hExternalSemaphore);
68076876

6877+
// release loader handle
6878+
context->factories.ur_exp_external_semaphore_factory.release(
6879+
hExternalSemaphore);
6880+
68086881
return result;
68096882
}
68106883

@@ -7030,6 +7103,9 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainExp(
70307103
// forward to device-platform
70317104
result = pfnRetainExp(hCommandBuffer);
70327105

7106+
// increment refcount of handle
7107+
context->factories.ur_exp_command_buffer_factory.retain(hCommandBuffer);
7108+
70337109
return result;
70347110
}
70357111

@@ -7060,6 +7136,9 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferReleaseExp(
70607136
// forward to device-platform
70617137
result = pfnReleaseExp(hCommandBuffer);
70627138

7139+
// release loader handle
7140+
context->factories.ur_exp_command_buffer_factory.release(hCommandBuffer);
7141+
70637142
return result;
70647143
}
70657144

@@ -7808,6 +7887,9 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainCommandExp(
78087887
// forward to device-platform
78097888
result = pfnRetainCommandExp(hCommand);
78107889

7890+
// increment refcount of handle
7891+
context->factories.ur_exp_command_buffer_command_factory.retain(hCommand);
7892+
78117893
return result;
78127894
}
78137895

@@ -7839,6 +7921,9 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferReleaseCommandExp(
78397921
// forward to device-platform
78407922
result = pfnReleaseCommandExp(hCommand);
78417923

7924+
// release loader handle
7925+
context->factories.ur_exp_command_buffer_command_factory.release(hCommand);
7926+
78427927
return result;
78437928
}
78447929

0 commit comments

Comments
 (0)