Skip to content

Commit a14e712

Browse files
committed
Wrap urEventSetCallback when ran through loader
The events returned by the loader do not match the events returned by openCL (or other backends), which causes issues when adding a callback handler. This adds an intermediate wrapper to replace the event with the loader event.
1 parent 778085f commit a14e712

File tree

3 files changed

+64
-7
lines changed

3 files changed

+64
-7
lines changed

scripts/templates/ldrddi.cpp.mako

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,37 @@ from templates import helper as th
2424
namespace ur_loader
2525
{
2626
%for obj in th.get_adapter_functions(specs):
27+
<%
28+
func_name = th.make_func_name(n, tags, obj)
29+
if func_name.startswith(x):
30+
func_basename = func_name[len(x):]
31+
else:
32+
func_basename = func_name
33+
%>
34+
%if func_basename == "EventSetCallback":
35+
namespace {
36+
struct event_callback_wrapper_data_t {
37+
${x}_event_callback_t fn;
38+
${x}_event_handle_t event;
39+
void *userData;
40+
};
41+
42+
void event_callback_wrapper([[maybe_unused]] ${x}_event_handle_t hEvent,
43+
${x}_execution_info_t execStatus, void *pUserData) {
44+
auto *wrapper =
45+
reinterpret_cast<event_callback_wrapper_data_t *>(pUserData);
46+
(wrapper->fn)(wrapper->event, execStatus, wrapper->userData);
47+
delete wrapper;
48+
}
49+
}
50+
51+
%endif
2752
///////////////////////////////////////////////////////////////////////////////
28-
/// @brief Intercept function for ${th.make_func_name(n, tags, obj)}
53+
/// @brief Intercept function for ${func_name}
2954
%if 'condition' in obj:
3055
#if ${th.subt(n, tags, obj['condition'])}
3156
%endif
32-
__${x}dlllocal ${x}_result_t ${X}_APICALL
33-
${th.make_func_name(n, tags, obj)}(
57+
__${x}dlllocal ${x}_result_t ${X}_APICALL ${func_name}(
3458
%for line in th.make_param_lines(n, tags, obj):
3559
${line}
3660
%endfor
@@ -41,7 +65,16 @@ namespace ur_loader
4165
%>${th.get_initial_null_set(obj)}
4266

4367
[[maybe_unused]] auto context = getContext();
44-
%if re.match(r"\w+AdapterGet$", th.make_func_name(n, tags, obj)):
68+
%if func_basename == "EventSetCallback":
69+
70+
// Replace the callback with a wrapper function that gives the callback the loader event rather than a
71+
// backend-specific event
72+
auto wrapper_data =
73+
new event_callback_wrapper_data_t{pfnNotify, hEvent, pUserData};
74+
pUserData = wrapper_data;
75+
pfnNotify = event_callback_wrapper;
76+
%endif
77+
%if func_basename == "AdapterGet":
4578

4679
size_t adapterIndex = 0;
4780
if( nullptr != ${obj['params'][1]['name']} && ${obj['params'][0]['name']} !=0)
@@ -74,7 +107,7 @@ namespace ur_loader
74107
*${obj['params'][2]['name']} = static_cast<uint32_t>(context->platforms.size());
75108
}
76109

77-
%elif re.match(r"\w+PlatformGet$", th.make_func_name(n, tags, obj)):
110+
%elif func_basename == "PlatformGet":
78111
uint32_t total_platform_handle_count = 0;
79112

80113
for( uint32_t adapter_index = 0; adapter_index < ${obj['params'][1]['name']}; adapter_index++)
@@ -263,7 +296,7 @@ namespace ur_loader
263296
%for i, item in enumerate(epilogue):
264297
%if 0 == i and not item['release'] and not item['retain'] and not th.always_wrap_outputs(obj):
265298
## TODO: Remove once we have a concrete way for submitting warnings in place.
266-
%if re.match(r"urEnqueue\w+", th.make_func_name(n, tags, obj)):
299+
%if re.match(r"Enqueue\w+", func_basename):
267300
// In the event of ERROR_ADAPTER_SPECIFIC we should still attempt to wrap any output handles below.
268301
if( ${X}_RESULT_SUCCESS != result && ${X}_RESULT_ERROR_ADAPTER_SPECIFIC != result )
269302
return result;

source/loader/ur_ldrddi.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "ur_loader.hpp"
1414

1515
namespace ur_loader {
16+
1617
///////////////////////////////////////////////////////////////////////////////
1718
/// @brief Intercept function for urAdapterGet
1819
__urdlllocal ur_result_t UR_APICALL urAdapterGet(
@@ -4476,6 +4477,22 @@ __urdlllocal ur_result_t UR_APICALL urEventCreateWithNativeHandle(
44764477
return result;
44774478
}
44784479

4480+
namespace {
4481+
struct event_callback_wrapper_data_t {
4482+
ur_event_callback_t fn;
4483+
ur_event_handle_t event;
4484+
void *userData;
4485+
};
4486+
4487+
void event_callback_wrapper([[maybe_unused]] ur_event_handle_t hEvent,
4488+
ur_execution_info_t execStatus, void *pUserData) {
4489+
auto *wrapper =
4490+
reinterpret_cast<event_callback_wrapper_data_t *>(pUserData);
4491+
(wrapper->fn)(wrapper->event, execStatus, wrapper->userData);
4492+
delete wrapper;
4493+
}
4494+
} // namespace
4495+
44794496
///////////////////////////////////////////////////////////////////////////////
44804497
/// @brief Intercept function for urEventSetCallback
44814498
__urdlllocal ur_result_t UR_APICALL urEventSetCallback(
@@ -4489,6 +4506,13 @@ __urdlllocal ur_result_t UR_APICALL urEventSetCallback(
44894506

44904507
[[maybe_unused]] auto context = getContext();
44914508

4509+
// Replace the callback with a wrapper function that gives the callback the loader event rather than a
4510+
// backend-specific event
4511+
auto wrapper_data =
4512+
new event_callback_wrapper_data_t{pfnNotify, hEvent, pUserData};
4513+
pUserData = wrapper_data;
4514+
pfnNotify = event_callback_wrapper;
4515+
44924516
// extract platform's function pointer table
44934517
auto dditable = reinterpret_cast<ur_event_object_t *>(hEvent)->dditable;
44944518
auto pfnSetCallback = dditable->ur.Event.pfnSetCallback;

test/conformance/event/urEventSetCallback.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ TEST_P(urEventSetCallbackTest, Success) {
4141
*/
4242
TEST_P(urEventSetCallbackTest, ValidateParameters) {
4343
UUR_KNOWN_FAILURE_ON(uur::CUDA{}, uur::HIP{}, uur::LevelZero{},
44-
uur::LevelZeroV2{}, uur::OpenCL{}, uur::NativeCPU{});
44+
uur::LevelZeroV2{}, uur::NativeCPU{});
4545

4646
struct CallbackParameters {
4747
ur_event_handle_t event;

0 commit comments

Comments
 (0)