diff --git a/scripts/templates/ldrddi.cpp.mako b/scripts/templates/ldrddi.cpp.mako index 1b7d19fa67..c2bc9968d2 100644 --- a/scripts/templates/ldrddi.cpp.mako +++ b/scripts/templates/ldrddi.cpp.mako @@ -24,13 +24,37 @@ from templates import helper as th namespace ur_loader { %for obj in th.get_adapter_functions(specs): + <% + func_name = th.make_func_name(n, tags, obj) + if func_name.startswith(x): + func_basename = func_name[len(x):] + else: + func_basename = func_name + %> + %if func_basename == "EventSetCallback": + namespace { + struct event_callback_wrapper_data_t { + ${x}_event_callback_t fn; + ${x}_event_handle_t event; + void *userData; + }; + + void event_callback_wrapper([[maybe_unused]] ${x}_event_handle_t hEvent, + ${x}_execution_info_t execStatus, void *pUserData) { + auto *wrapper = + reinterpret_cast(pUserData); + (wrapper->fn)(wrapper->event, execStatus, wrapper->userData); + delete wrapper; + } + } + + %endif /////////////////////////////////////////////////////////////////////////////// - /// @brief Intercept function for ${th.make_func_name(n, tags, obj)} + /// @brief Intercept function for ${func_name} %if 'condition' in obj: #if ${th.subt(n, tags, obj['condition'])} %endif - __${x}dlllocal ${x}_result_t ${X}_APICALL - ${th.make_func_name(n, tags, obj)}( + __${x}dlllocal ${x}_result_t ${X}_APICALL ${func_name}( %for line in th.make_param_lines(n, tags, obj): ${line} %endfor @@ -41,7 +65,16 @@ namespace ur_loader %>${th.get_initial_null_set(obj)} [[maybe_unused]] auto context = getContext(); - %if re.match(r"\w+AdapterGet$", th.make_func_name(n, tags, obj)): + %if func_basename == "EventSetCallback": + + // Replace the callback with a wrapper function that gives the callback the loader event rather than a + // backend-specific event + auto wrapper_data = + new event_callback_wrapper_data_t{pfnNotify, hEvent, pUserData}; + pUserData = wrapper_data; + pfnNotify = event_callback_wrapper; + %endif + %if func_basename == "AdapterGet": size_t adapterIndex = 0; if( nullptr != ${obj['params'][1]['name']} && ${obj['params'][0]['name']} !=0) @@ -74,7 +107,7 @@ namespace ur_loader *${obj['params'][2]['name']} = static_cast(context->platforms.size()); } - %elif re.match(r"\w+PlatformGet$", th.make_func_name(n, tags, obj)): + %elif func_basename == "PlatformGet": uint32_t total_platform_handle_count = 0; for( uint32_t adapter_index = 0; adapter_index < ${obj['params'][1]['name']}; adapter_index++) @@ -263,7 +296,7 @@ namespace ur_loader %for i, item in enumerate(epilogue): %if 0 == i and not item['release'] and not item['retain'] and not th.always_wrap_outputs(obj): ## TODO: Remove once we have a concrete way for submitting warnings in place. - %if re.match(r"urEnqueue\w+", th.make_func_name(n, tags, obj)): + %if re.match(r"Enqueue\w+", func_basename): // In the event of ERROR_ADAPTER_SPECIFIC we should still attempt to wrap any output handles below. if( ${X}_RESULT_SUCCESS != result && ${X}_RESULT_ERROR_ADAPTER_SPECIFIC != result ) return result; diff --git a/source/loader/ur_ldrddi.cpp b/source/loader/ur_ldrddi.cpp index 08e3a1b2a7..82d9fbe1c0 100644 --- a/source/loader/ur_ldrddi.cpp +++ b/source/loader/ur_ldrddi.cpp @@ -13,6 +13,7 @@ #include "ur_loader.hpp" namespace ur_loader { + /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urAdapterGet __urdlllocal ur_result_t UR_APICALL urAdapterGet( @@ -4476,6 +4477,22 @@ __urdlllocal ur_result_t UR_APICALL urEventCreateWithNativeHandle( return result; } +namespace { +struct event_callback_wrapper_data_t { + ur_event_callback_t fn; + ur_event_handle_t event; + void *userData; +}; + +void event_callback_wrapper([[maybe_unused]] ur_event_handle_t hEvent, + ur_execution_info_t execStatus, void *pUserData) { + auto *wrapper = + reinterpret_cast(pUserData); + (wrapper->fn)(wrapper->event, execStatus, wrapper->userData); + delete wrapper; +} +} // namespace + /////////////////////////////////////////////////////////////////////////////// /// @brief Intercept function for urEventSetCallback __urdlllocal ur_result_t UR_APICALL urEventSetCallback( @@ -4489,6 +4506,13 @@ __urdlllocal ur_result_t UR_APICALL urEventSetCallback( [[maybe_unused]] auto context = getContext(); + // Replace the callback with a wrapper function that gives the callback the loader event rather than a + // backend-specific event + auto wrapper_data = + new event_callback_wrapper_data_t{pfnNotify, hEvent, pUserData}; + pUserData = wrapper_data; + pfnNotify = event_callback_wrapper; + // extract platform's function pointer table auto dditable = reinterpret_cast(hEvent)->dditable; auto pfnSetCallback = dditable->ur.Event.pfnSetCallback; diff --git a/test/conformance/event/urEventSetCallback.cpp b/test/conformance/event/urEventSetCallback.cpp index ffcf70aff5..2279a6b2df 100644 --- a/test/conformance/event/urEventSetCallback.cpp +++ b/test/conformance/event/urEventSetCallback.cpp @@ -41,7 +41,7 @@ TEST_P(urEventSetCallbackTest, Success) { */ TEST_P(urEventSetCallbackTest, ValidateParameters) { UUR_KNOWN_FAILURE_ON(uur::CUDA{}, uur::HIP{}, uur::LevelZero{}, - uur::LevelZeroV2{}, uur::OpenCL{}, uur::NativeCPU{}); + uur::LevelZeroV2{}, uur::NativeCPU{}); struct CallbackParameters { ur_event_handle_t event;