Skip to content

Commit 77b705d

Browse files
committed
[OpenCL] Implement urEventSetCallback and urContextSetExtendedDeleter.
1 parent f0de2f4 commit 77b705d

File tree

3 files changed

+121
-13
lines changed

3 files changed

+121
-13
lines changed

source/adapters/opencl/context.cpp

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010

1111
#include "context.hpp"
1212

13+
#include <mutex>
14+
#include <set>
15+
#include <unordered_map>
16+
1317
ur_result_t cl_adapter::getDevicesFromContext(
1418
ur_context_handle_t hContext,
1519
std::unique_ptr<std::vector<cl_device_id>> &DevicesInCtx) {
@@ -130,8 +134,53 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextCreateWithNativeHandle(
130134
}
131135

132136
UR_APIEXPORT ur_result_t UR_APICALL urContextSetExtendedDeleter(
133-
[[maybe_unused]] ur_context_handle_t hContext,
134-
[[maybe_unused]] ur_context_extended_deleter_t pfnDeleter,
135-
[[maybe_unused]] void *pUserData) {
136-
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
137+
ur_context_handle_t hContext, ur_context_extended_deleter_t pfnDeleter,
138+
void *pUserData) {
139+
static std::unordered_map<ur_context_handle_t,
140+
std::set<ur_context_extended_deleter_t>>
141+
ContextCallbackMap;
142+
static std::mutex ContextCallbackMutex;
143+
144+
{
145+
std::lock_guard<std::mutex> Lock(ContextCallbackMutex);
146+
// Callbacks can only be registered once and we need to avoid double
147+
// allocating.
148+
if (ContextCallbackMap.count(hContext) &&
149+
ContextCallbackMap[hContext].count(pfnDeleter)) {
150+
return UR_RESULT_SUCCESS;
151+
}
152+
153+
ContextCallbackMap[hContext].insert(pfnDeleter);
154+
}
155+
156+
struct ContextCallback {
157+
void execute() {
158+
pfnDeleter(pUserData);
159+
{
160+
std::lock_guard<std::mutex> Lock(*CallbackMutex);
161+
(*CallbackMap)[hContext].erase(pfnDeleter);
162+
if ((*CallbackMap)[hContext].empty()) {
163+
CallbackMap->erase(hContext);
164+
}
165+
}
166+
delete this;
167+
}
168+
ur_context_handle_t hContext;
169+
ur_context_extended_deleter_t pfnDeleter;
170+
void *pUserData;
171+
std::unordered_map<ur_context_handle_t,
172+
std::set<ur_context_extended_deleter_t>> *CallbackMap;
173+
std::mutex *CallbackMutex;
174+
};
175+
auto Callback =
176+
new ContextCallback({hContext, pfnDeleter, pUserData, &ContextCallbackMap,
177+
&ContextCallbackMutex});
178+
auto ClCallback = [](cl_context, void *pUserData) {
179+
auto *C = static_cast<ContextCallback *>(pUserData);
180+
C->execute();
181+
};
182+
CL_RETURN_ON_FAILURE(clSetContextDestructorCallback(
183+
cl_adapter::cast<cl_context>(hContext), ClCallback, Callback));
184+
185+
return UR_RESULT_SUCCESS;
137186
}

source/adapters/opencl/enqueue.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -350,9 +350,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueReadHostPipe(
350350
return mapCLErrorToUR(CLErr);
351351
}
352352

353-
clEnqueueReadHostPipeINTEL_fn FuncPtr = nullptr;
353+
cl_ext::clEnqueueReadHostPipeINTEL_fn FuncPtr = nullptr;
354354
ur_result_t RetVal =
355-
cl_ext::getExtFuncFromContext<clEnqueueReadHostPipeINTEL_fn>(
355+
cl_ext::getExtFuncFromContext<cl_ext::clEnqueueReadHostPipeINTEL_fn>(
356356
CLContext, cl_ext::ExtFuncPtrCache->clEnqueueReadHostPipeINTELCache,
357357
cl_ext::EnqueueReadHostPipeName, &FuncPtr);
358358

@@ -382,9 +382,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueWriteHostPipe(
382382
return mapCLErrorToUR(CLErr);
383383
}
384384

385-
clEnqueueWriteHostPipeINTEL_fn FuncPtr = nullptr;
385+
cl_ext::clEnqueueWriteHostPipeINTEL_fn FuncPtr = nullptr;
386386
ur_result_t RetVal =
387-
cl_ext::getExtFuncFromContext<clEnqueueWriteHostPipeINTEL_fn>(
387+
cl_ext::getExtFuncFromContext<cl_ext::clEnqueueWriteHostPipeINTEL_fn>(
388388
CLContext, cl_ext::ExtFuncPtrCache->clEnqueueWriteHostPipeINTELCache,
389389
cl_ext::EnqueueWriteHostPipeName, &FuncPtr);
390390

source/adapters/opencl/event.cpp

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010

1111
#include "common.hpp"
1212

13+
#include <mutex>
14+
#include <set>
15+
#include <unordered_map>
16+
1317
cl_event_info convertUREventInfoToCL(const ur_event_info_t PropName) {
1418
switch (PropName) {
1519
case UR_EVENT_INFO_COMMAND_QUEUE:
@@ -128,9 +132,64 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventGetProfilingInfo(
128132
UR_APIEXPORT ur_result_t UR_APICALL
129133
urEventSetCallback(ur_event_handle_t hEvent, ur_execution_info_t execStatus,
130134
ur_event_callback_t pfnNotify, void *pUserData) {
131-
std::ignore = hEvent;
132-
std::ignore = execStatus;
133-
std::ignore = pfnNotify;
134-
std::ignore = pUserData;
135-
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
135+
static std::unordered_map<ur_event_handle_t, std::set<ur_event_callback_t>>
136+
EventCallbackMap;
137+
static std::mutex EventCallbackMutex;
138+
139+
{
140+
std::lock_guard<std::mutex> Lock(EventCallbackMutex);
141+
// Callbacks can only be registered once and we need to avoid double
142+
// allocating.
143+
if (EventCallbackMap.count(hEvent) &&
144+
EventCallbackMap[hEvent].count(pfnNotify)) {
145+
return UR_RESULT_SUCCESS;
146+
}
147+
148+
EventCallbackMap[hEvent].insert(pfnNotify);
149+
}
150+
151+
cl_int CallbackType = 0;
152+
switch (execStatus) {
153+
case UR_EXECUTION_INFO_EXECUTION_INFO_SUBMITTED:
154+
CallbackType = CL_SUBMITTED;
155+
break;
156+
case UR_EXECUTION_INFO_EXECUTION_INFO_RUNNING:
157+
CallbackType = CL_RUNNING;
158+
break;
159+
case UR_EXECUTION_INFO_EXECUTION_INFO_COMPLETE:
160+
CallbackType = CL_COMPLETE;
161+
break;
162+
default:
163+
return UR_RESULT_ERROR_INVALID_ENUMERATION;
164+
}
165+
166+
struct EventCallback {
167+
void execute() {
168+
pfnNotify(hEvent, execStatus, pUserData);
169+
{
170+
std::lock_guard<std::mutex> Lock(*CallbackMutex);
171+
(*CallbackMap)[hEvent].erase(pfnNotify);
172+
if ((*CallbackMap)[hEvent].empty()) {
173+
CallbackMap->erase(hEvent);
174+
}
175+
}
176+
delete this;
177+
}
178+
ur_event_handle_t hEvent;
179+
ur_execution_info_t execStatus;
180+
ur_event_callback_t pfnNotify;
181+
void *pUserData;
182+
std::unordered_map<ur_event_handle_t, std::set<ur_event_callback_t>>
183+
*CallbackMap;
184+
std::mutex *CallbackMutex;
185+
};
186+
auto Callback = new EventCallback({hEvent, execStatus, pfnNotify, pUserData,
187+
&EventCallbackMap, &EventCallbackMutex});
188+
auto ClCallback = [](cl_event, cl_int, void *pUserData) {
189+
auto *C = static_cast<EventCallback *>(pUserData);
190+
C->execute();
191+
};
192+
CL_RETURN_ON_FAILURE(clSetEventCallback(cl_adapter::cast<cl_event>(hEvent),
193+
CallbackType, ClCallback, Callback));
194+
return UR_RESULT_SUCCESS;
136195
}

0 commit comments

Comments
 (0)