@@ -24,13 +24,37 @@ from templates import helper as th
24
24
namespace ur_loader
25
25
{
26
26
% for obj in th.get_adapter_functions(specs):
27
+ <%
28
+ func_name = th.make_func_name(n, tags, obj)
29
+ if func_name.startswith(x):
30
+ func_basename = func_name[len (x):]
31
+ else :
32
+ func_basename = func_name
33
+ %>
34
+ % if func_basename == " EventSetCallback" :
35
+ namespace {
36
+ struct event_callback_wrapper_data_t {
37
+ ${ x} _event_callback_t fn;
38
+ ${ x} _event_handle_t event;
39
+ void *userData;
40
+ };
41
+
42
+ void event_callback_wrapper([[maybe_unused]] ${ x} _event_handle_t hEvent,
43
+ ${ x} _execution_info_t execStatus, void *pUserData) {
44
+ auto *wrapper =
45
+ reinterpret_cast<event _callback_wrapper_data_t * >(pUserData);
46
+ (wrapper->fn)(wrapper->event, execStatus, wrapper->userData);
47
+ delete wrapper;
48
+ }
49
+ }
50
+
51
+ %endif
27
52
///////////////////////////////////////////////////////////////////////////////
28
- /// @brief Intercept function for ${ th.make_func_name(n, tags, obj) }
53
+ /// @brief Intercept function for ${ func_name }
29
54
% if ' condition' in obj:
30
55
#if ${ th.subt(n, tags, obj[' condition' ])}
31
56
%endif
32
- __${ x} dlllocal ${ x} _result_t ${ X} _APICALL
33
- ${ th.make_func_name(n, tags, obj)} (
57
+ __${ x} dlllocal ${ x} _result_t ${ X} _APICALL ${ func_name} (
34
58
% for line in th.make_param_lines(n, tags, obj):
35
59
${ line}
36
60
%endfor
@@ -41,7 +65,16 @@ namespace ur_loader
41
65
%> ${ th.get_initial_null_set(obj)}
42
66
43
67
[[maybe_unused]] auto context = getContext();
44
- % if re.match(r " \w + AdapterGet$ " , th.make_func_name(n, tags, obj)):
68
+ % if func_basename == " EventSetCallback" :
69
+
70
+ // Replace the callback with a wrapper function that gives the callback the loader event rather than a
71
+ // backend-specific event
72
+ auto wrapper_data =
73
+ new event_callback_wrapper_data_t{pfnNotify, hEvent, pUserData};
74
+ pUserData = wrapper_data;
75
+ pfnNotify = event_callback_wrapper;
76
+ %endif
77
+ % if func_basename == " AdapterGet" :
45
78
46
79
size_t adapterIndex = 0;
47
80
if( nullptr != ${ obj[' params' ][1 ][' name' ]} && ${ obj[' params' ][0 ][' name' ]} !=0)
@@ -74,7 +107,7 @@ namespace ur_loader
74
107
*${ obj[' params' ][2 ][' name' ]} = static_cast<uint32 _t >(context->platforms.size());
75
108
}
76
109
77
- % elif re.match( r " \w + PlatformGet$ " , th.make_func_name(n, tags, obj)) :
110
+ % elif func_basename == " PlatformGet" :
78
111
uint32_t total_platform_handle_count = 0;
79
112
80
113
for( uint32_t adapter_index = 0; adapter_index < ${ obj[' params' ][1 ][' name' ]} ; adapter_index++)
@@ -263,7 +296,7 @@ namespace ur_loader
263
296
% for i, item in enumerate (epilogue):
264
297
% if 0 == i and not item[' release' ] and not item[' retain' ] and not th.always_wrap_outputs(obj):
265
298
## TODO: Remove once we have a concrete way for submitting warnings in place.
266
- % if re.match(r " urEnqueue \w + " , th.make_func_name(n, tags, obj) ):
299
+ % if re.match(r " Enqueue \w + " , func_basename ):
267
300
// In the event of ERROR_ADAPTER_SPECIFIC we should still attempt to wrap any output handles below.
268
301
if( ${ X} _RESULT_SUCCESS != result && ${ X} _RESULT_ERROR_ADAPTER_SPECIFIC != result )
269
302
return result;
0 commit comments