Skip to content

Commit 55425b3

Browse files
committed
Wrap urEventSetCallback when ran through loader
The events returned by the loader do not match the events returned by openCL (or other backends), which causes issues when adding a callback handler. This adds an intermediate wrapper to replace the event with the loader event.
1 parent 3472b5b commit 55425b3

File tree

3 files changed

+64
-10
lines changed

3 files changed

+64
-10
lines changed

scripts/templates/ldrddi.cpp.mako

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,37 @@ from templates import helper as th
2424
namespace ur_loader
2525
{
2626
%for obj in th.get_adapter_functions(specs):
27+
<%
28+
func_name = th.make_func_name(n, tags, obj)
29+
if func_name.startswith(x):
30+
func_basename = func_name[len(x):]
31+
else:
32+
func_basename = func_name
33+
%>
34+
%if func_basename == "EventSetCallback":
35+
namespace {
36+
struct event_callback_wrapper_data_t {
37+
${x}_event_callback_t fn;
38+
${x}_event_handle_t event;
39+
void *userData;
40+
};
41+
42+
void event_callback_wrapper([[maybe_unused]] ${x}_event_handle_t hEvent,
43+
${x}_execution_info_t execStatus, void *pUserData) {
44+
auto *wrapper =
45+
reinterpret_cast<event_callback_wrapper_data_t *>(pUserData);
46+
(wrapper->fn)(wrapper->event, execStatus, wrapper->userData);
47+
delete wrapper;
48+
}
49+
}
50+
51+
%endif
2752
///////////////////////////////////////////////////////////////////////////////
28-
/// @brief Intercept function for ${th.make_func_name(n, tags, obj)}
53+
/// @brief Intercept function for ${func_name}
2954
%if 'condition' in obj:
3055
#if ${th.subt(n, tags, obj['condition'])}
3156
%endif
32-
__${x}dlllocal ${x}_result_t ${X}_APICALL
33-
${th.make_func_name(n, tags, obj)}(
57+
__${x}dlllocal ${x}_result_t ${X}_APICALL ${func_name}(
3458
%for line in th.make_param_lines(n, tags, obj):
3559
${line}
3660
%endfor
@@ -41,7 +65,16 @@ namespace ur_loader
4165
%>${th.get_initial_null_set(obj)}
4266

4367
[[maybe_unused]] auto context = getContext();
44-
%if re.match(r"\w+AdapterGet$", th.make_func_name(n, tags, obj)):
68+
%if func_basename == "EventSetCallback":
69+
70+
// Replace the callback with a wrapper function that gives the callback the loader event rather than a
71+
// backend-specific event
72+
auto wrapper_data =
73+
new event_callback_wrapper_data_t{pfnNotify, hEvent, pUserData};
74+
pUserData = wrapper_data;
75+
pfnNotify = event_callback_wrapper;
76+
%endif
77+
%if func_basename == "AdapterGet":
4578

4679
size_t adapterIndex = 0;
4780
if( nullptr != ${obj['params'][1]['name']} && ${obj['params'][0]['name']} !=0)
@@ -74,7 +107,7 @@ namespace ur_loader
74107
*${obj['params'][2]['name']} = static_cast<uint32_t>(context->platforms.size());
75108
}
76109

77-
%elif re.match(r"\w+PlatformGet$", th.make_func_name(n, tags, obj)):
110+
%elif func_basename == "PlatformGet":
78111
uint32_t total_platform_handle_count = 0;
79112

80113
for( uint32_t adapter_index = 0; adapter_index < ${obj['params'][1]['name']}; adapter_index++)
@@ -263,7 +296,7 @@ namespace ur_loader
263296
%for i, item in enumerate(epilogue):
264297
%if 0 == i and not item['release'] and not item['retain'] and not th.always_wrap_outputs(obj):
265298
## TODO: Remove once we have a concrete way for submitting warnings in place.
266-
%if re.match(r"urEnqueue\w+", th.make_func_name(n, tags, obj)):
299+
%if re.match(r"Enqueue\w+", func_basename):
267300
// In the event of ERROR_ADAPTER_SPECIFIC we should still attempt to wrap any output handles below.
268301
if( ${X}_RESULT_SUCCESS != result && ${X}_RESULT_ERROR_ADAPTER_SPECIFIC != result )
269302
return result;
@@ -278,7 +311,7 @@ namespace ur_loader
278311
##%if item['release']:
279312
##// release loader handle
280313
##${item['factory']}.release( ${item['name']} );
281-
%if not item['release'] and not item['retain'] and not '_native_object_' in item['obj'] or th.make_func_name(n, tags, obj) == 'urPlatformCreateWithNativeHandle':
314+
%if not item['release'] and not item['retain'] and not '_native_object_' in item['obj'] or func_name == 'urPlatformCreateWithNativeHandle':
282315
try
283316
{
284317
%if 'typename' in item:

source/loader/ur_ldrddi.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "ur_loader.hpp"
1414

1515
namespace ur_loader {
16+
1617
///////////////////////////////////////////////////////////////////////////////
1718
/// @brief Intercept function for urAdapterGet
1819
__urdlllocal ur_result_t UR_APICALL urAdapterGet(
@@ -4410,6 +4411,22 @@ __urdlllocal ur_result_t UR_APICALL urEventCreateWithNativeHandle(
44104411
return result;
44114412
}
44124413

4414+
namespace {
4415+
struct event_callback_wrapper_data_t {
4416+
ur_event_callback_t fn;
4417+
ur_event_handle_t event;
4418+
void *userData;
4419+
};
4420+
4421+
void event_callback_wrapper([[maybe_unused]] ur_event_handle_t hEvent,
4422+
ur_execution_info_t execStatus, void *pUserData) {
4423+
auto *wrapper =
4424+
reinterpret_cast<event_callback_wrapper_data_t *>(pUserData);
4425+
(wrapper->fn)(wrapper->event, execStatus, wrapper->userData);
4426+
delete wrapper;
4427+
}
4428+
} // namespace
4429+
44134430
///////////////////////////////////////////////////////////////////////////////
44144431
/// @brief Intercept function for urEventSetCallback
44154432
__urdlllocal ur_result_t UR_APICALL urEventSetCallback(
@@ -4423,6 +4440,13 @@ __urdlllocal ur_result_t UR_APICALL urEventSetCallback(
44234440

44244441
[[maybe_unused]] auto context = getContext();
44254442

4443+
// Replace the callback with a wrapper function that gives the callback the loader event rather than a
4444+
// backend-specific event
4445+
auto wrapper_data =
4446+
new event_callback_wrapper_data_t{pfnNotify, hEvent, pUserData};
4447+
pUserData = wrapper_data;
4448+
pfnNotify = event_callback_wrapper;
4449+
44264450
// extract platform's function pointer table
44274451
auto dditable = reinterpret_cast<ur_event_object_t *>(hEvent)->dditable;
44284452
auto pfnSetCallback = dditable->ur.Event.pfnSetCallback;

test/conformance/event/event_adapter_opencl.match

Lines changed: 0 additions & 3 deletions
This file was deleted.

0 commit comments

Comments
 (0)