Skip to content

Wrap urEventSetCallback when ran through loader #2527

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 39 additions & 6 deletions scripts/templates/ldrddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,37 @@ from templates import helper as th
namespace ur_loader
{
%for obj in th.get_adapter_functions(specs):
<%
func_name = th.make_func_name(n, tags, obj)
if func_name.startswith(x):
func_basename = func_name[len(x):]
else:
func_basename = func_name
%>
%if func_basename == "EventSetCallback":
namespace {
struct event_callback_wrapper_data_t {
${x}_event_callback_t fn;
${x}_event_handle_t event;
void *userData;
};

void event_callback_wrapper([[maybe_unused]] ${x}_event_handle_t hEvent,
${x}_execution_info_t execStatus, void *pUserData) {
auto *wrapper =
reinterpret_cast<event_callback_wrapper_data_t *>(pUserData);
(wrapper->fn)(wrapper->event, execStatus, wrapper->userData);
delete wrapper;
}
}

%endif
///////////////////////////////////////////////////////////////////////////////
/// @brief Intercept function for ${th.make_func_name(n, tags, obj)}
/// @brief Intercept function for ${func_name}
%if 'condition' in obj:
#if ${th.subt(n, tags, obj['condition'])}
%endif
__${x}dlllocal ${x}_result_t ${X}_APICALL
${th.make_func_name(n, tags, obj)}(
__${x}dlllocal ${x}_result_t ${X}_APICALL ${func_name}(
%for line in th.make_param_lines(n, tags, obj):
${line}
%endfor
Expand All @@ -41,7 +65,16 @@ namespace ur_loader
%>${th.get_initial_null_set(obj)}

[[maybe_unused]] auto context = getContext();
%if re.match(r"\w+AdapterGet$", th.make_func_name(n, tags, obj)):
%if func_basename == "EventSetCallback":

// Replace the callback with a wrapper function that gives the callback the loader event rather than a
// backend-specific event
auto wrapper_data =
new event_callback_wrapper_data_t{pfnNotify, hEvent, pUserData};
pUserData = wrapper_data;
pfnNotify = event_callback_wrapper;
%endif
%if func_basename == "AdapterGet":

size_t adapterIndex = 0;
if( nullptr != ${obj['params'][1]['name']} && ${obj['params'][0]['name']} !=0)
Expand Down Expand Up @@ -74,7 +107,7 @@ namespace ur_loader
*${obj['params'][2]['name']} = static_cast<uint32_t>(context->platforms.size());
}

%elif re.match(r"\w+PlatformGet$", th.make_func_name(n, tags, obj)):
%elif func_basename == "PlatformGet":
uint32_t total_platform_handle_count = 0;

for( uint32_t adapter_index = 0; adapter_index < ${obj['params'][1]['name']}; adapter_index++)
Expand Down Expand Up @@ -263,7 +296,7 @@ namespace ur_loader
%for i, item in enumerate(epilogue):
%if 0 == i and not item['release'] and not item['retain'] and not th.always_wrap_outputs(obj):
## TODO: Remove once we have a concrete way for submitting warnings in place.
%if re.match(r"urEnqueue\w+", th.make_func_name(n, tags, obj)):
%if re.match(r"Enqueue\w+", func_basename):
// In the event of ERROR_ADAPTER_SPECIFIC we should still attempt to wrap any output handles below.
if( ${X}_RESULT_SUCCESS != result && ${X}_RESULT_ERROR_ADAPTER_SPECIFIC != result )
return result;
Expand Down
24 changes: 24 additions & 0 deletions source/loader/ur_ldrddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "ur_loader.hpp"

namespace ur_loader {

///////////////////////////////////////////////////////////////////////////////
/// @brief Intercept function for urAdapterGet
__urdlllocal ur_result_t UR_APICALL urAdapterGet(
Expand Down Expand Up @@ -4476,6 +4477,22 @@ __urdlllocal ur_result_t UR_APICALL urEventCreateWithNativeHandle(
return result;
}

namespace {
struct event_callback_wrapper_data_t {
ur_event_callback_t fn;
ur_event_handle_t event;
void *userData;
};

void event_callback_wrapper([[maybe_unused]] ur_event_handle_t hEvent,
ur_execution_info_t execStatus, void *pUserData) {
auto *wrapper =
reinterpret_cast<event_callback_wrapper_data_t *>(pUserData);
(wrapper->fn)(wrapper->event, execStatus, wrapper->userData);
delete wrapper;
}
} // namespace

///////////////////////////////////////////////////////////////////////////////
/// @brief Intercept function for urEventSetCallback
__urdlllocal ur_result_t UR_APICALL urEventSetCallback(
Expand All @@ -4489,6 +4506,13 @@ __urdlllocal ur_result_t UR_APICALL urEventSetCallback(

[[maybe_unused]] auto context = getContext();

// Replace the callback with a wrapper function that gives the callback the loader event rather than a
// backend-specific event
auto wrapper_data =
new event_callback_wrapper_data_t{pfnNotify, hEvent, pUserData};
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I know. Manual memory management.

But I'm not sure if there's a better C++-y way of doing this, does anyone have any ideas?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to use a smart pointer instead maybe?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. It needs to survive a roundtrip to and from a bare pointer to be used in the API.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, i don't think we can rely on raii here, unless you'd like to manually play around with constructing/destroying unique pointers, but I don't see much point.

pUserData = wrapper_data;
pfnNotify = event_callback_wrapper;

// extract platform's function pointer table
auto dditable = reinterpret_cast<ur_event_object_t *>(hEvent)->dditable;
auto pfnSetCallback = dditable->ur.Event.pfnSetCallback;
Expand Down
2 changes: 1 addition & 1 deletion test/conformance/event/urEventSetCallback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ TEST_P(urEventSetCallbackTest, Success) {
*/
TEST_P(urEventSetCallbackTest, ValidateParameters) {
UUR_KNOWN_FAILURE_ON(uur::CUDA{}, uur::HIP{}, uur::LevelZero{},
uur::LevelZeroV2{}, uur::OpenCL{}, uur::NativeCPU{});
uur::LevelZeroV2{}, uur::NativeCPU{});

struct CallbackParameters {
ur_event_handle_t event;
Expand Down
Loading