Skip to content

Commit c806e99

Browse files
authored
Merge pull request #1826 from pbalcer/loader-lifetime-context
refactor loader lifetime management
2 parents a4a5b08 + 93f82f2 commit c806e99

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+4380
-3595
lines changed

scripts/templates/ldrddi.cpp.mako

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,6 @@ from templates import helper as th
2323

2424
namespace ur_loader
2525
{
26-
///////////////////////////////////////////////////////////////////////////////
27-
%for obj in th.get_adapter_handles(specs):
28-
%if 'class' in obj:
29-
<%
30-
_handle_t = th.subt(n, tags, obj['name'])
31-
_factory_t = re.sub(r"(\w+)_handle_t", r"\1_factory_t", _handle_t)
32-
_factory = re.sub(r"(\w+)_handle_t", r"\1_factory", _handle_t)
33-
%>${th.append_ws(_factory_t, 35)} ${_factory};
34-
%endif
35-
%endfor
36-
3726
%for obj in th.get_adapter_functions(specs):
3827
///////////////////////////////////////////////////////////////////////////////
3928
/// @brief Intercept function for ${th.make_func_name(n, tags, obj)}
@@ -51,6 +40,7 @@ namespace ur_loader
5140
add_local = False
5241
%>${th.get_initial_null_set(obj)}
5342

43+
[[maybe_unused]] auto context = getContext();
5444
%if re.match(r"\w+AdapterGet$", th.make_func_name(n, tags, obj)):
5545

5646
size_t adapterIndex = 0;
@@ -63,7 +53,7 @@ namespace ur_loader
6353
platform.dditable.${n}.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)}( 1, &${obj['params'][1]['name']}[adapterIndex], nullptr );
6454
try
6555
{
66-
${obj['params'][1]['name']}[adapterIndex] = reinterpret_cast<${n}_adapter_handle_t>(${n}_adapter_factory.getInstance(
56+
${obj['params'][1]['name']}[adapterIndex] = reinterpret_cast<${n}_adapter_handle_t>(context->factories.${n}_adapter_factory.getInstance(
6757
${obj['params'][1]['name']}[adapterIndex], &platform.dditable
6858
));
6959
}
@@ -114,7 +104,7 @@ namespace ur_loader
114104
for( uint32_t i = 0; i < library_platform_handle_count; ++i ) {
115105
uint32_t platform_index = total_platform_handle_count + i;
116106
${obj['params'][3]['name']}[ platform_index ] = reinterpret_cast<${n}_platform_handle_t>(
117-
${n}_platform_factory.getInstance( ${obj['params'][3]['name']}[ platform_index ], dditable ) );
107+
context->factories.${n}_platform_factory.getInstance( ${obj['params'][3]['name']}[ platform_index ], dditable ) );
118108
}
119109
}
120110
catch( std::bad_alloc& )
@@ -294,7 +284,7 @@ namespace ur_loader
294284
for (size_t i = 0; i < nelements; ++i) {
295285
if (handles[i] != nullptr) {
296286
handles[i] = reinterpret_cast<${etor['type']}>(
297-
${etor['factory']}.getInstance( handles[i], dditable ) );
287+
context->factories.${etor['factory']}.getInstance( handles[i], dditable ) );
298288
}
299289
}
300290
} break;
@@ -306,16 +296,16 @@ namespace ur_loader
306296
// convert platform handles to loader handles
307297
for( size_t i = ${item['range'][0]}; ( nullptr != ${item['name']} ) && ( i < ${item['range'][1]} ); ++i )
308298
${item['name']}[ i ] = reinterpret_cast<${item['type']}>(
309-
${item['factory']}.getInstance( ${item['name']}[ i ], dditable ) );
299+
context->factories.${item['factory']}.getInstance( ${item['name']}[ i ], dditable ) );
310300
%else:
311301
// convert platform handle to loader handle
312302
%if item['optional'] or th.always_wrap_outputs(obj):
313303
if( nullptr != ${item['name']} )
314304
*${item['name']} = reinterpret_cast<${item['type']}>(
315-
${item['factory']}.getInstance( *${item['name']}, dditable ) );
305+
context->factories.${item['factory']}.getInstance( *${item['name']}, dditable ) );
316306
%else:
317307
*${item['name']} = reinterpret_cast<${item['type']}>(
318-
${item['factory']}.getInstance( *${item['name']}, dditable ) );
308+
context->factories.${item['factory']}.getInstance( *${item['name']}, dditable ) );
319309
%endif
320310
%endif
321311
}
@@ -360,13 +350,13 @@ ${tbl['export']['name']}(
360350
if( nullptr == pDdiTable )
361351
return ${X}_RESULT_ERROR_INVALID_NULL_POINTER;
362352
363-
if( ur_loader::context->version < version )
353+
if( ur_loader::getContext()->version < version )
364354
return ${X}_RESULT_ERROR_UNSUPPORTED_VERSION;
365355
366356
${x}_result_t result = ${X}_RESULT_SUCCESS;
367357
368358
// Load the device-platform DDI tables
369-
for( auto& platform : ur_loader::context->platforms )
359+
for( auto& platform : ur_loader::getContext()->platforms )
370360
{
371361
if(platform.initStatus != ${X}_RESULT_SUCCESS)
372362
continue;
@@ -379,7 +369,7 @@ ${tbl['export']['name']}(
379369
380370
if( ${X}_RESULT_SUCCESS == result )
381371
{
382-
if( ur_loader::context->platforms.size() != 1 || ur_loader::context->forceIntercept )
372+
if( ur_loader::getContext()->platforms.size() != 1 || ur_loader::getContext()->forceIntercept )
383373
{
384374
// return pointers to loader's DDIs
385375
%for obj in tbl['functions']:
@@ -397,7 +387,7 @@ ${tbl['export']['name']}(
397387
else
398388
{
399389
// return pointers directly to platform's DDIs
400-
*pDdiTable = ur_loader::context->platforms.front().dditable.${n}.${tbl['name']};
390+
*pDdiTable = ur_loader::getContext()->platforms.front().dditable.${n}.${tbl['name']};
401391
}
402392
}
403393

scripts/templates/ldrddi.hpp.mako

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,29 @@ from templates import helper as th
2727
namespace ur_loader
2828
{
2929
///////////////////////////////////////////////////////////////////////////////
30+
<%
31+
factories = []
32+
%>
3033
%for obj in th.get_adapter_handles(specs):
3134
%if 'class' in obj:
3235
<%
3336
_handle_t = th.subt(n, tags, obj['name'])
3437
_object_t = re.sub(r"(\w+)_handle_t", r"\1_object_t", _handle_t)
3538
_factory_t = re.sub(r"(\w+)_handle_t", r"\1_factory_t", _handle_t)
39+
_factory = re.sub(r"(\w+)_handle_t", r"\1_factory", _handle_t)
40+
factories.append((_factory_t, _factory))
3641
%>using ${th.append_ws(_object_t, 35)} = object_t < ${_handle_t} >;
3742
using ${th.append_ws(_factory_t, 35)} = singleton_factory_t < ${_object_t}, ${_handle_t} >;
3843

3944
%endif
4045
%endfor
46+
47+
struct handle_factories {
48+
%for (f_t, f) in factories:
49+
${f_t} ${f};
50+
%endfor
51+
};
52+
4153
}
4254

4355
#endif /* UR_LOADER_LDRDDI_H */

scripts/templates/libapi.cpp.mako

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -56,28 +56,10 @@ ${th.make_func_name(n, tags, obj)}(
5656
%endfor
5757
)
5858
try {
59-
%if re.match("Init", obj['name']):
60-
<%
61-
param_checks=th.make_param_checks(n, tags, obj, meta=meta).items()
62-
%>
63-
%for key, values in param_checks:
64-
%for val in values:
65-
if( ${val} )
66-
return ${key};
67-
68-
%endfor
69-
%endfor
70-
71-
static ${x}_result_t result = ${X}_RESULT_SUCCESS;
72-
std::call_once(${x}_lib::context->initOnce, [device_flags, hLoaderConfig]() {
73-
result = ${x}_lib::context->Init(device_flags, hLoaderConfig);
74-
});
75-
76-
return result;
77-
%elif th.obj_traits.is_loader_only(obj):
59+
%if th.obj_traits.is_loader_only(obj):
7860
return ur_lib::${th.make_func_name(n, tags, obj)}(${", ".join(th.make_param_lines(n, tags, obj, format=["name"]))} );
7961
%else:
80-
${th.get_initial_null_set(obj)}auto ${th.make_pfn_name(n, tags, obj)} = ${x}_lib::context->${n}DdiTable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};
62+
${th.get_initial_null_set(obj)}auto ${th.make_pfn_name(n, tags, obj)} = ${x}_lib::getContext()->${n}DdiTable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};
8163
if( nullptr == ${th.make_pfn_name(n, tags, obj)} )
8264
return ${X}_RESULT_ERROR_UNINITIALIZED;
8365

scripts/templates/libddi.cpp.mako

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ namespace ${x}_lib
2828
///////////////////////////////////////////////////////////////////////////////
2929

3030

31-
__${x}dlllocal ${x}_result_t context_t::${n}LoaderInit()
31+
__${x}dlllocal ${x}_result_t context_t::ddiInit()
3232
{
3333
${x}_result_t result = ${X}_RESULT_SUCCESS;
3434

scripts/templates/trcddi.cpp.mako

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,23 +37,23 @@ namespace ur_tracing_layer
3737
%endfor
3838
)
3939
{${th.get_initial_null_set(obj)}
40-
auto ${th.make_pfn_name(n, tags, obj)} = context.${n}DdiTable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};
40+
auto ${th.make_pfn_name(n, tags, obj)} = getContext()->${n}DdiTable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};
4141

4242
if( nullptr == ${th.make_pfn_name(n, tags, obj)} )
4343
return ${X}_RESULT_ERROR_UNSUPPORTED_FEATURE;
4444

4545
${th.make_pfncb_param_type(n, tags, obj)} params = { &${",&".join(th.make_param_lines(n, tags, obj, format=["name"]))} };
46-
uint64_t instance = context.notify_begin(${th.make_func_etor(n, tags, obj)}, "${th.make_func_name(n, tags, obj)}", &params);
46+
uint64_t instance = getContext()->notify_begin(${th.make_func_etor(n, tags, obj)}, "${th.make_func_name(n, tags, obj)}", &params);
4747

48-
context.logger.info("---> ${th.make_func_name(n, tags, obj)}");
48+
getContext()->logger.info("---> ${th.make_func_name(n, tags, obj)}");
4949

5050
${x}_result_t result = ${th.make_pfn_name(n, tags, obj)}( ${", ".join(th.make_param_lines(n, tags, obj, format=["name"]))} );
5151

52-
context.notify_end(${th.make_func_etor(n, tags, obj)}, "${th.make_func_name(n, tags, obj)}", &params, &result, instance);
52+
getContext()->notify_end(${th.make_func_etor(n, tags, obj)}, "${th.make_func_name(n, tags, obj)}", &params, &result, instance);
5353

5454
std::ostringstream args_str;
5555
ur::extras::printFunctionParams(args_str, ${th.make_func_etor(n, tags, obj)}, &params);
56-
context.logger.info("({}) -> {};\n", args_str.str(), result);
56+
getContext()->logger.info("({}) -> {};\n", args_str.str(), result);
5757

5858
return result;
5959
}
@@ -79,13 +79,13 @@ namespace ur_tracing_layer
7979
%endfor
8080
)
8181
{
82-
auto& dditable = ur_tracing_layer::context.${n}DdiTable.${tbl['name']};
82+
auto& dditable = ur_tracing_layer::getContext()->${n}DdiTable.${tbl['name']};
8383

8484
if( nullptr == pDdiTable )
8585
return ${X}_RESULT_ERROR_INVALID_NULL_POINTER;
8686

87-
if (UR_MAJOR_VERSION(ur_tracing_layer::context.version) != UR_MAJOR_VERSION(version) ||
88-
UR_MINOR_VERSION(ur_tracing_layer::context.version) > UR_MINOR_VERSION(version))
87+
if (UR_MAJOR_VERSION(ur_tracing_layer::getContext()->version) != UR_MAJOR_VERSION(version) ||
88+
UR_MINOR_VERSION(ur_tracing_layer::getContext()->version) > UR_MINOR_VERSION(version))
8989
return ${X}_RESULT_ERROR_UNSUPPORTED_VERSION;
9090

9191
${x}_result_t result = ${X}_RESULT_SUCCESS;
@@ -122,7 +122,7 @@ namespace ur_tracing_layer
122122
// program launch and the call to `urLoaderInit`
123123
logger = logger::create_logger("tracing", true, true);
124124

125-
ur_tracing_layer::context.codelocData = codelocData;
125+
ur_tracing_layer::getContext()->codelocData = codelocData;
126126

127127
%for tbl in th.get_pfntables(specs, meta, n, tags):
128128
if( ${X}_RESULT_SUCCESS == result )

scripts/templates/ur_api.hpp.mako

Whitespace-only changes.

scripts/templates/valddi.cpp.mako

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,13 @@ namespace ur_validation_layer
4747
%endfor
4848
)
4949
{${th.get_initial_null_set(obj)}
50-
auto ${th.make_pfn_name(n, tags, obj)} = context.${n}DdiTable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};
50+
auto ${th.make_pfn_name(n, tags, obj)} = getContext()->${n}DdiTable.${th.get_table_name(n, tags, obj)}.${th.make_pfn_name(n, tags, obj)};
5151

5252
if( nullptr == ${th.make_pfn_name(n, tags, obj)} ) {
5353
return ${X}_RESULT_ERROR_UNINITIALIZED;
5454
}
5555

56-
if( context.enableParameterValidation )
56+
if( getContext()->enableParameterValidation )
5757
{
5858
%for key, values in sorted_param_checks:
5959
%for val in values:
@@ -80,8 +80,8 @@ namespace ur_validation_layer
8080
is_related_create_get_retain_release_func = any(func_name in funcs for funcs in tp_input_handle_funcs.values())
8181
%>
8282
%if tp_input_handle_funcs and not is_related_create_get_retain_release_func:
83-
if (context.enableLifetimeValidation && !refCountContext.isReferenceValid(${tp['name']})) {
84-
refCountContext.logInvalidReference(${tp['name']});
83+
if (getContext()->enableLifetimeValidation && !getContext()->refCountContext->isReferenceValid(${tp['name']})) {
84+
getContext()->refCountContext->logInvalidReference(${tp['name']});
8585
}
8686
%endif
8787
%endfor
@@ -94,26 +94,26 @@ namespace ur_validation_layer
9494
is_handle_to_adapter = ("_adapter_handle_t" in tp['type'])
9595
%>
9696
%if func_name in tp_handle_funcs['create']:
97-
if( context.enableLeakChecking && result == UR_RESULT_SUCCESS )
97+
if( getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS )
9898
{
99-
refCountContext.createRefCount(*${tp['name']});
99+
getContext()->refCountContext->createRefCount(*${tp['name']});
100100
}
101101
%elif func_name in tp_handle_funcs['get']:
102-
if( context.enableLeakChecking && ${tp['name']} && result == UR_RESULT_SUCCESS )
102+
if( getContext()->enableLeakChecking && ${tp['name']} && result == UR_RESULT_SUCCESS )
103103
{
104104
for (uint32_t i = ${th.param_traits.range_start(tp)}; i < ${th.param_traits.range_end(tp)}; i++) {
105-
refCountContext.createOrIncrementRefCount(${tp['name']}[i], ${str(is_handle_to_adapter).lower()});
105+
getContext()->refCountContext->createOrIncrementRefCount(${tp['name']}[i], ${str(is_handle_to_adapter).lower()});
106106
}
107107
}
108108
%elif func_name in tp_handle_funcs['retain']:
109-
if( context.enableLeakChecking && result == UR_RESULT_SUCCESS )
109+
if( getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS )
110110
{
111-
refCountContext.incrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()});
111+
getContext()->refCountContext->incrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()});
112112
}
113113
%elif func_name in tp_handle_funcs['release']:
114-
if( context.enableLeakChecking && result == UR_RESULT_SUCCESS )
114+
if( getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS )
115115
{
116-
refCountContext.decrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()});
116+
getContext()->refCountContext->decrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()});
117117
}
118118
%endif
119119
%endfor
@@ -141,13 +141,13 @@ namespace ur_validation_layer
141141
%endfor
142142
)
143143
{
144-
auto& dditable = ur_validation_layer::context.${n}DdiTable.${tbl['name']};
144+
auto& dditable = ur_validation_layer::getContext()->${n}DdiTable.${tbl['name']};
145145

146146
if( nullptr == pDdiTable )
147147
return ${X}_RESULT_ERROR_INVALID_NULL_POINTER;
148148

149-
if (UR_MAJOR_VERSION(ur_validation_layer::context.version) != UR_MAJOR_VERSION(version) ||
150-
UR_MINOR_VERSION(ur_validation_layer::context.version) > UR_MINOR_VERSION(version))
149+
if (UR_MAJOR_VERSION(ur_validation_layer::getContext()->version) != UR_MAJOR_VERSION(version) ||
150+
UR_MINOR_VERSION(ur_validation_layer::getContext()->version) > UR_MINOR_VERSION(version))
151151
return ${X}_RESULT_ERROR_UNSUPPORTED_VERSION;
152152

153153
${x}_result_t result = ${X}_RESULT_SUCCESS;
@@ -212,8 +212,8 @@ namespace ur_validation_layer
212212
${x}_result_t result = ${X}_RESULT_SUCCESS;
213213

214214
if (enableLeakChecking) {
215-
refCountContext.logInvalidReferences();
216-
refCountContext.clear();
215+
getContext()->refCountContext->logInvalidReferences();
216+
getContext()->refCountContext->clear();
217217
}
218218
return result;
219219
}

source/common/ur_singleton.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ template <typename singleton_tn, typename key_tn> class singleton_factory_t {
7171
std::lock_guard<std::mutex> lk(mut);
7272
map.erase(getKey(key));
7373
}
74+
75+
void clear() {
76+
std::lock_guard<std::mutex> lk(mut);
77+
map.clear();
78+
}
7479
};
7580

7681
#endif /* UR_SINGLETON_H */

0 commit comments

Comments
 (0)