Skip to content

Commit 93f82f2

Browse files
committed
refactor loader lifetime management
This patch implements an atomic singleton class for managing the lifecycle of the context objects inside of the loader. This class ensures that the contexts always exist and lets the loader manually destroy them on user request (during teardown). Thanks to this change, the loader no longer relies on the order of library constructors and destructors. It also gets us 90% towards allowing the loader to be statically linked with the application.
1 parent 167ddf9 commit 93f82f2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+4380
-3595
lines changed

scripts/templates/ldrddi.cpp.mako

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,6 @@ from templates import helper as th
2323

2424
namespace ur_loader
2525
{
26-
///////////////////////////////////////////////////////////////////////////////
27-
%for obj in th.get_adapter_handles(specs):
28-
%if 'class' in obj:
29-
<%
30-
_handle_t = th.subt(n, tags, obj['name'])
31-
_factory_t = re.sub(r"(\w+)_handle_t", r"\1_factory_t", _handle_t)
32-
_factory = re.sub(r"(\w+)_handle_t", r"\1_factory", _handle_t)
33-
%>${th.append_ws(_factory_t, 35)} ${_factory};
34-
%endif
35-
%endfor
36-
3726
%for obj in th.get_adapter_functions(specs):
3827
///////////////////////////////////////////////////////////////////////////////
3928
/// @brief Intercept function for ${th.make_func_name(n, tags, obj)}
@@ -51,6 +40,7 @@ namespace ur_loader
5140
add_local = False
5241
%>${th.get_initial_null_set(obj)}
5342

43+
[[maybe_unused]] auto context = getContext();
5444
%if re.match(r"\w+AdapterGet$", th.make_func_name(n, tags, obj)):
5545

5646
size_t adapterIndex = 0;
@@ -63,7 +53,7 @@ namespace ur_loader
6353
platform.dditable.${n}.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)}( 1, &${obj['params'][1]['name']}[adapterIndex], nullptr );
6454
try
6555
{
66-
${obj['params'][1]['name']}[adapterIndex] = reinterpret_cast<${n}_adapter_handle_t>(${n}_adapter_factory.getInstance(
56+
${obj['params'][1]['name']}[adapterIndex] = reinterpret_cast<${n}_adapter_handle_t>(context->factories.${n}_adapter_factory.getInstance(
6757
${obj['params'][1]['name']}[adapterIndex], &platform.dditable
6858
));
6959
}
@@ -114,7 +104,7 @@ namespace ur_loader
114104
for( uint32_t i = 0; i < library_platform_handle_count; ++i ) {
115105
uint32_t platform_index = total_platform_handle_count + i;
116106
${obj['params'][3]['name']}[ platform_index ] = reinterpret_cast<${n}_platform_handle_t>(
117-
${n}_platform_factory.getInstance( ${obj['params'][3]['name']}[ platform_index ], dditable ) );
107+
context->factories.${n}_platform_factory.getInstance( ${obj['params'][3]['name']}[ platform_index ], dditable ) );
118108
}
119109
}
120110
catch( std::bad_alloc& )
@@ -294,7 +284,7 @@ namespace ur_loader
294284
for (size_t i = 0; i < nelements; ++i) {
295285
if (handles[i] != nullptr) {
296286
handles[i] = reinterpret_cast<${etor['type']}>(
297-
${etor['factory']}.getInstance( handles[i], dditable ) );
287+
context->factories.${etor['factory']}.getInstance( handles[i], dditable ) );
298288
}
299289
}
300290
} break;
@@ -306,16 +296,16 @@ namespace ur_loader
306296
// convert platform handles to loader handles
307297
for( size_t i = ${item['range'][0]}; ( nullptr != ${item['name']} ) && ( i < ${item['range'][1]} ); ++i )
308298
${item['name']}[ i ] = reinterpret_cast<${item['type']}>(
309-
${item['factory']}.getInstance( ${item['name']}[ i ], dditable ) );
299+
context->factories.${item['factory']}.getInstance( ${item['name']}[ i ], dditable ) );
310300
%else:
311301
// convert platform handle to loader handle
312302
%if item['optional'] or th.always_wrap_outputs(obj):
313303
if( nullptr != ${item['name']} )
314304
*${item['name']} = reinterpret_cast<${item['type']}>(
315-
${item['factory']}.getInstance( *${item['name']}, dditable ) );
305+
context->factories.${item['factory']}.getInstance( *${item['name']}, dditable ) );
316306
%else:
317307
*${item['name']} = reinterpret_cast<${item['type']}>(
318-
${item['factory']}.getInstance( *${item['name']}, dditable ) );
308+
context->factories.${item['factory']}.getInstance( *${item['name']}, dditable ) );
319309
%endif
320310
%endif
321311
}
@@ -360,13 +350,13 @@ ${tbl['export']['name']}(
360350
if( nullptr == pDdiTable )
361351
return ${X}_RESULT_ERROR_INVALID_NULL_POINTER;
362352
363-
if( ur_loader::context->version < version )
353+
if( ur_loader::getContext()->version < version )
364354
return ${X}_RESULT_ERROR_UNSUPPORTED_VERSION;
365355
366356
${x}_result_t result = ${X}_RESULT_SUCCESS;
367357
368358
// Load the device-platform DDI tables
369-
for( auto& platform : ur_loader::context->platforms )
359+
for( auto& platform : ur_loader::getContext()->platforms )
370360
{
371361
if(platform.initStatus != ${X}_RESULT_SUCCESS)
372362
continue;
@@ -379,7 +369,7 @@ ${tbl['export']['name']}(
379369
380370
if( ${X}_RESULT_SUCCESS == result )
381371
{
382-
if( ur_loader::context->platforms.size() != 1 || ur_loader::context->forceIntercept )
372+
if( ur_loader::getContext()->platforms.size() != 1 || ur_loader::getContext()->forceIntercept )
383373
{
384374
// return pointers to loader's DDIs
385375
%for obj in tbl['functions']:
@@ -397,7 +387,7 @@ ${tbl['export']['name']}(
397387
else
398388
{
399389
// return pointers directly to platform's DDIs
400-
*pDdiTable = ur_loader::context->platforms.front().dditable.${n}.${tbl['name']};
390+
*pDdiTable = ur_loader::getContext()->platforms.front().dditable.${n}.${tbl['name']};
401391
}
402392
}
403393

scripts/templates/ldrddi.hpp.mako

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,29 @@ from templates import helper as th
2727
namespace ur_loader
2828
{
2929
///////////////////////////////////////////////////////////////////////////////
30+
<%
31+
factories = []
32+
%>
3033
%for obj in th.get_adapter_handles(specs):
3134
%if 'class' in obj:
3235
<%
3336
_handle_t = th.subt(n, tags, obj['name'])
3437
_object_t = re.sub(r"(\w+)_handle_t", r"\1_object_t", _handle_t)
3538
_factory_t = re.sub(r"(\w+)_handle_t", r"\1_factory_t", _handle_t)
39+
_factory = re.sub(r"(\w+)_handle_t", r"\1_factory", _handle_t)
40+
factories.append((_factory_t, _factory))
3641
%>using ${th.append_ws(_object_t, 35)} = object_t < ${_handle_t} >;
3742
using ${th.append_ws(_factory_t, 35)} = singleton_factory_t < ${_object_t}, ${_handle_t} >;
3843

3944
%endif
4045
%endfor
46+
47+
struct handle_factories {
48+
%for (f_t, f) in factories:
49+
${f_t} ${f};
50+
%endfor
51+
};
52+
4153
}
4254

4355
#endif /* UR_LOADER_LDRDDI_H */

scripts/templates/libapi.cpp.mako

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -56,28 +56,10 @@ ${th.make_func_name(n, tags, obj)}(
5656
%endfor
5757
)
5858
try {
59-
%if re.match("Init", obj['name']):
60-
<%
61-
param_checks=th.make_param_checks(n, tags, obj, meta=meta).items()
62-
%>
63-
%for key, values in param_checks:
64-
%for val in values:
65-
if( ${val} )
66-
return ${key};
67-
68-
%endfor
69-
%endfor
70-
71-
static ${x}_result_t result = ${X}_RESULT_SUCCESS;
72-
std::call_once(${x}_lib::context->initOnce, [device_flags, hLoaderConfig]() {
73-
result = ${x}_lib::context->Init(device_flags, hLoaderConfig);
74-
});
75-
76-
return result;
77-
%elif th.obj_traits.is_loader_only(obj):
59+
%if th.obj_traits.is_loader_only(obj):
7860
return ur_lib::${th.make_func_name(n, tags, obj)}(${", ".join(th.make_param_lines(n, tags, obj, format=["name"]))} );
7961
%else:
80-
${th.get_initial_null_set(obj)}auto ${th.make_pfn_name(n, tags, obj)} = ${x}_lib::context->${n}DdiTable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};
62+
${th.get_initial_null_set(obj)}auto ${th.make_pfn_name(n, tags, obj)} = ${x}_lib::getContext()->${n}DdiTable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};
8163
if( nullptr == ${th.make_pfn_name(n, tags, obj)} )
8264
return ${X}_RESULT_ERROR_UNINITIALIZED;
8365

scripts/templates/libddi.cpp.mako

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ namespace ${x}_lib
2828
///////////////////////////////////////////////////////////////////////////////
2929

3030

31-
__${x}dlllocal ${x}_result_t context_t::${n}LoaderInit()
31+
__${x}dlllocal ${x}_result_t context_t::ddiInit()
3232
{
3333
${x}_result_t result = ${X}_RESULT_SUCCESS;
3434

scripts/templates/trcddi.cpp.mako

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,23 +37,23 @@ namespace ur_tracing_layer
3737
%endfor
3838
)
3939
{${th.get_initial_null_set(obj)}
40-
auto ${th.make_pfn_name(n, tags, obj)} = context.${n}DdiTable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};
40+
auto ${th.make_pfn_name(n, tags, obj)} = getContext()->${n}DdiTable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};
4141

4242
if( nullptr == ${th.make_pfn_name(n, tags, obj)} )
4343
return ${X}_RESULT_ERROR_UNSUPPORTED_FEATURE;
4444

4545
${th.make_pfncb_param_type(n, tags, obj)} params = { &${",&".join(th.make_param_lines(n, tags, obj, format=["name"]))} };
46-
uint64_t instance = context.notify_begin(${th.make_func_etor(n, tags, obj)}, "${th.make_func_name(n, tags, obj)}", &params);
46+
uint64_t instance = getContext()->notify_begin(${th.make_func_etor(n, tags, obj)}, "${th.make_func_name(n, tags, obj)}", &params);
4747

48-
context.logger.info("---> ${th.make_func_name(n, tags, obj)}");
48+
getContext()->logger.info("---> ${th.make_func_name(n, tags, obj)}");
4949

5050
${x}_result_t result = ${th.make_pfn_name(n, tags, obj)}( ${", ".join(th.make_param_lines(n, tags, obj, format=["name"]))} );
5151

52-
context.notify_end(${th.make_func_etor(n, tags, obj)}, "${th.make_func_name(n, tags, obj)}", &params, &result, instance);
52+
getContext()->notify_end(${th.make_func_etor(n, tags, obj)}, "${th.make_func_name(n, tags, obj)}", &params, &result, instance);
5353

5454
std::ostringstream args_str;
5555
ur::extras::printFunctionParams(args_str, ${th.make_func_etor(n, tags, obj)}, &params);
56-
context.logger.info("({}) -> {};\n", args_str.str(), result);
56+
getContext()->logger.info("({}) -> {};\n", args_str.str(), result);
5757

5858
return result;
5959
}
@@ -79,13 +79,13 @@ namespace ur_tracing_layer
7979
%endfor
8080
)
8181
{
82-
auto& dditable = ur_tracing_layer::context.${n}DdiTable.${tbl['name']};
82+
auto& dditable = ur_tracing_layer::getContext()->${n}DdiTable.${tbl['name']};
8383

8484
if( nullptr == pDdiTable )
8585
return ${X}_RESULT_ERROR_INVALID_NULL_POINTER;
8686

87-
if (UR_MAJOR_VERSION(ur_tracing_layer::context.version) != UR_MAJOR_VERSION(version) ||
88-
UR_MINOR_VERSION(ur_tracing_layer::context.version) > UR_MINOR_VERSION(version))
87+
if (UR_MAJOR_VERSION(ur_tracing_layer::getContext()->version) != UR_MAJOR_VERSION(version) ||
88+
UR_MINOR_VERSION(ur_tracing_layer::getContext()->version) > UR_MINOR_VERSION(version))
8989
return ${X}_RESULT_ERROR_UNSUPPORTED_VERSION;
9090

9191
${x}_result_t result = ${X}_RESULT_SUCCESS;
@@ -122,7 +122,7 @@ namespace ur_tracing_layer
122122
// program launch and the call to `urLoaderInit`
123123
logger = logger::create_logger("tracing", true, true);
124124

125-
ur_tracing_layer::context.codelocData = codelocData;
125+
ur_tracing_layer::getContext()->codelocData = codelocData;
126126

127127
%for tbl in th.get_pfntables(specs, meta, n, tags):
128128
if( ${X}_RESULT_SUCCESS == result )

scripts/templates/ur_api.hpp.mako

Whitespace-only changes.

scripts/templates/valddi.cpp.mako

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,13 @@ namespace ur_validation_layer
4747
%endfor
4848
)
4949
{${th.get_initial_null_set(obj)}
50-
auto ${th.make_pfn_name(n, tags, obj)} = context.${n}DdiTable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};
50+
auto ${th.make_pfn_name(n, tags, obj)} = getContext()->${n}DdiTable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};
5151

5252
if( nullptr == ${th.make_pfn_name(n, tags, obj)} ) {
5353
return ${X}_RESULT_ERROR_UNINITIALIZED;
5454
}
5555

56-
if( context.enableParameterValidation )
56+
if( getContext()->enableParameterValidation )
5757
{
5858
%for key, values in sorted_param_checks:
5959
%for val in values:
@@ -80,8 +80,8 @@ namespace ur_validation_layer
8080
is_related_create_get_retain_release_func = any(func_name in funcs for funcs in tp_input_handle_funcs.values())
8181
%>
8282
%if tp_input_handle_funcs and not is_related_create_get_retain_release_func:
83-
if (context.enableLifetimeValidation && !refCountContext.isReferenceValid(${tp['name']})) {
84-
refCountContext.logInvalidReference(${tp['name']});
83+
if (getContext()->enableLifetimeValidation && !getContext()->refCountContext->isReferenceValid(${tp['name']})) {
84+
getContext()->refCountContext->logInvalidReference(${tp['name']});
8585
}
8686
%endif
8787
%endfor
@@ -94,26 +94,26 @@ namespace ur_validation_layer
9494
is_handle_to_adapter = ("_adapter_handle_t" in tp['type'])
9595
%>
9696
%if func_name in tp_handle_funcs['create']:
97-
if( context.enableLeakChecking && result == UR_RESULT_SUCCESS )
97+
if( getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS )
9898
{
99-
refCountContext.createRefCount(*${tp['name']});
99+
getContext()->refCountContext->createRefCount(*${tp['name']});
100100
}
101101
%elif func_name in tp_handle_funcs['get']:
102-
if( context.enableLeakChecking && ${tp['name']} && result == UR_RESULT_SUCCESS )
102+
if( getContext()->enableLeakChecking && ${tp['name']} && result == UR_RESULT_SUCCESS )
103103
{
104104
for (uint32_t i = ${th.param_traits.range_start(tp)}; i < ${th.param_traits.range_end(tp)}; i++) {
105-
refCountContext.createOrIncrementRefCount(${tp['name']}[i], ${str(is_handle_to_adapter).lower()});
105+
getContext()->refCountContext->createOrIncrementRefCount(${tp['name']}[i], ${str(is_handle_to_adapter).lower()});
106106
}
107107
}
108108
%elif func_name in tp_handle_funcs['retain']:
109-
if( context.enableLeakChecking && result == UR_RESULT_SUCCESS )
109+
if( getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS )
110110
{
111-
refCountContext.incrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()});
111+
getContext()->refCountContext->incrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()});
112112
}
113113
%elif func_name in tp_handle_funcs['release']:
114-
if( context.enableLeakChecking && result == UR_RESULT_SUCCESS )
114+
if( getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS )
115115
{
116-
refCountContext.decrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()});
116+
getContext()->refCountContext->decrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()});
117117
}
118118
%endif
119119
%endfor
@@ -141,13 +141,13 @@ namespace ur_validation_layer
141141
%endfor
142142
)
143143
{
144-
auto& dditable = ur_validation_layer::context.${n}DdiTable.${tbl['name']};
144+
auto& dditable = ur_validation_layer::getContext()->${n}DdiTable.${tbl['name']};
145145

146146
if( nullptr == pDdiTable )
147147
return ${X}_RESULT_ERROR_INVALID_NULL_POINTER;
148148

149-
if (UR_MAJOR_VERSION(ur_validation_layer::context.version) != UR_MAJOR_VERSION(version) ||
150-
UR_MINOR_VERSION(ur_validation_layer::context.version) > UR_MINOR_VERSION(version))
149+
if (UR_MAJOR_VERSION(ur_validation_layer::getContext()->version) != UR_MAJOR_VERSION(version) ||
150+
UR_MINOR_VERSION(ur_validation_layer::getContext()->version) > UR_MINOR_VERSION(version))
151151
return ${X}_RESULT_ERROR_UNSUPPORTED_VERSION;
152152

153153
${x}_result_t result = ${X}_RESULT_SUCCESS;
@@ -212,8 +212,8 @@ namespace ur_validation_layer
212212
${x}_result_t result = ${X}_RESULT_SUCCESS;
213213

214214
if (enableLeakChecking) {
215-
refCountContext.logInvalidReferences();
216-
refCountContext.clear();
215+
getContext()->refCountContext->logInvalidReferences();
216+
getContext()->refCountContext->clear();
217217
}
218218
return result;
219219
}

source/common/ur_singleton.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ template <typename singleton_tn, typename key_tn> class singleton_factory_t {
7171
std::lock_guard<std::mutex> lk(mut);
7272
map.erase(getKey(key));
7373
}
74+
75+
void clear() {
76+
std::lock_guard<std::mutex> lk(mut);
77+
map.clear();
78+
}
7479
};
7580

7681
#endif /* UR_SINGLETON_H */

0 commit comments

Comments
 (0)