Skip to content

Commit 4b1dd79

Browse files
committed
Merge branch 'aaron/ClCallbackEntrypoints' into aaron/clCTSFixMegaBranch
2 parents 5ad3f0a + 77b705d commit 4b1dd79

File tree

2 files changed

+117
-9
lines changed

2 files changed

+117
-9
lines changed

source/adapters/opencl/context.cpp

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010

1111
#include "context.hpp"
1212

13+
#include <mutex>
14+
#include <set>
15+
#include <unordered_map>
16+
1317
ur_result_t cl_adapter::getDevicesFromContext(
1418
ur_context_handle_t hContext,
1519
std::unique_ptr<std::vector<cl_device_id>> &DevicesInCtx) {
@@ -130,8 +134,53 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextCreateWithNativeHandle(
130134
}
131135

132136
UR_APIEXPORT ur_result_t UR_APICALL urContextSetExtendedDeleter(
133-
[[maybe_unused]] ur_context_handle_t hContext,
134-
[[maybe_unused]] ur_context_extended_deleter_t pfnDeleter,
135-
[[maybe_unused]] void *pUserData) {
136-
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
137+
ur_context_handle_t hContext, ur_context_extended_deleter_t pfnDeleter,
138+
void *pUserData) {
139+
static std::unordered_map<ur_context_handle_t,
140+
std::set<ur_context_extended_deleter_t>>
141+
ContextCallbackMap;
142+
static std::mutex ContextCallbackMutex;
143+
144+
{
145+
std::lock_guard<std::mutex> Lock(ContextCallbackMutex);
146+
// Callbacks can only be registered once and we need to avoid double
147+
// allocating.
148+
if (ContextCallbackMap.count(hContext) &&
149+
ContextCallbackMap[hContext].count(pfnDeleter)) {
150+
return UR_RESULT_SUCCESS;
151+
}
152+
153+
ContextCallbackMap[hContext].insert(pfnDeleter);
154+
}
155+
156+
struct ContextCallback {
157+
void execute() {
158+
pfnDeleter(pUserData);
159+
{
160+
std::lock_guard<std::mutex> Lock(*CallbackMutex);
161+
(*CallbackMap)[hContext].erase(pfnDeleter);
162+
if ((*CallbackMap)[hContext].empty()) {
163+
CallbackMap->erase(hContext);
164+
}
165+
}
166+
delete this;
167+
}
168+
ur_context_handle_t hContext;
169+
ur_context_extended_deleter_t pfnDeleter;
170+
void *pUserData;
171+
std::unordered_map<ur_context_handle_t,
172+
std::set<ur_context_extended_deleter_t>> *CallbackMap;
173+
std::mutex *CallbackMutex;
174+
};
175+
auto Callback =
176+
new ContextCallback({hContext, pfnDeleter, pUserData, &ContextCallbackMap,
177+
&ContextCallbackMutex});
178+
auto ClCallback = [](cl_context, void *pUserData) {
179+
auto *C = static_cast<ContextCallback *>(pUserData);
180+
C->execute();
181+
};
182+
CL_RETURN_ON_FAILURE(clSetContextDestructorCallback(
183+
cl_adapter::cast<cl_context>(hContext), ClCallback, Callback));
184+
185+
return UR_RESULT_SUCCESS;
137186
}

source/adapters/opencl/event.cpp

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010

1111
#include "common.hpp"
1212

13+
#include <mutex>
14+
#include <set>
15+
#include <unordered_map>
16+
1317
cl_event_info convertUREventInfoToCL(const ur_event_info_t PropName) {
1418
switch (PropName) {
1519
case UR_EVENT_INFO_COMMAND_QUEUE:
@@ -128,9 +132,64 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventGetProfilingInfo(
128132
UR_APIEXPORT ur_result_t UR_APICALL
129133
urEventSetCallback(ur_event_handle_t hEvent, ur_execution_info_t execStatus,
130134
ur_event_callback_t pfnNotify, void *pUserData) {
131-
std::ignore = hEvent;
132-
std::ignore = execStatus;
133-
std::ignore = pfnNotify;
134-
std::ignore = pUserData;
135-
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
135+
static std::unordered_map<ur_event_handle_t, std::set<ur_event_callback_t>>
136+
EventCallbackMap;
137+
static std::mutex EventCallbackMutex;
138+
139+
{
140+
std::lock_guard<std::mutex> Lock(EventCallbackMutex);
141+
// Callbacks can only be registered once and we need to avoid double
142+
// allocating.
143+
if (EventCallbackMap.count(hEvent) &&
144+
EventCallbackMap[hEvent].count(pfnNotify)) {
145+
return UR_RESULT_SUCCESS;
146+
}
147+
148+
EventCallbackMap[hEvent].insert(pfnNotify);
149+
}
150+
151+
cl_int CallbackType = 0;
152+
switch (execStatus) {
153+
case UR_EXECUTION_INFO_EXECUTION_INFO_SUBMITTED:
154+
CallbackType = CL_SUBMITTED;
155+
break;
156+
case UR_EXECUTION_INFO_EXECUTION_INFO_RUNNING:
157+
CallbackType = CL_RUNNING;
158+
break;
159+
case UR_EXECUTION_INFO_EXECUTION_INFO_COMPLETE:
160+
CallbackType = CL_COMPLETE;
161+
break;
162+
default:
163+
return UR_RESULT_ERROR_INVALID_ENUMERATION;
164+
}
165+
166+
struct EventCallback {
167+
void execute() {
168+
pfnNotify(hEvent, execStatus, pUserData);
169+
{
170+
std::lock_guard<std::mutex> Lock(*CallbackMutex);
171+
(*CallbackMap)[hEvent].erase(pfnNotify);
172+
if ((*CallbackMap)[hEvent].empty()) {
173+
CallbackMap->erase(hEvent);
174+
}
175+
}
176+
delete this;
177+
}
178+
ur_event_handle_t hEvent;
179+
ur_execution_info_t execStatus;
180+
ur_event_callback_t pfnNotify;
181+
void *pUserData;
182+
std::unordered_map<ur_event_handle_t, std::set<ur_event_callback_t>>
183+
*CallbackMap;
184+
std::mutex *CallbackMutex;
185+
};
186+
auto Callback = new EventCallback({hEvent, execStatus, pfnNotify, pUserData,
187+
&EventCallbackMap, &EventCallbackMutex});
188+
auto ClCallback = [](cl_event, cl_int, void *pUserData) {
189+
auto *C = static_cast<EventCallback *>(pUserData);
190+
C->execute();
191+
};
192+
CL_RETURN_ON_FAILURE(clSetEventCallback(cl_adapter::cast<cl_event>(hEvent),
193+
CallbackType, ClCallback, Callback));
194+
return UR_RESULT_SUCCESS;
136195
}

0 commit comments

Comments
 (0)