|
10 | 10 |
|
11 | 11 | #include "common.hpp"
|
12 | 12 |
|
| 13 | +#include <mutex> |
| 14 | +#include <set> |
| 15 | +#include <unordered_map> |
| 16 | + |
13 | 17 | cl_event_info convertUREventInfoToCL(const ur_event_info_t PropName) {
|
14 | 18 | switch (PropName) {
|
15 | 19 | case UR_EVENT_INFO_COMMAND_QUEUE:
|
@@ -128,9 +132,64 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventGetProfilingInfo(
|
128 | 132 | UR_APIEXPORT ur_result_t UR_APICALL
|
129 | 133 | urEventSetCallback(ur_event_handle_t hEvent, ur_execution_info_t execStatus,
|
130 | 134 | ur_event_callback_t pfnNotify, void *pUserData) {
|
131 |
| - std::ignore = hEvent; |
132 |
| - std::ignore = execStatus; |
133 |
| - std::ignore = pfnNotify; |
134 |
| - std::ignore = pUserData; |
135 |
| - return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; |
| 135 | + static std::unordered_map<ur_event_handle_t, std::set<ur_event_callback_t>> |
| 136 | + EventCallbackMap; |
| 137 | + static std::mutex EventCallbackMutex; |
| 138 | + |
| 139 | + { |
| 140 | + std::lock_guard<std::mutex> Lock(EventCallbackMutex); |
| 141 | + // Callbacks can only be registered once and we need to avoid double |
| 142 | + // allocating. |
| 143 | + if (EventCallbackMap.count(hEvent) && |
| 144 | + EventCallbackMap[hEvent].count(pfnNotify)) { |
| 145 | + return UR_RESULT_SUCCESS; |
| 146 | + } |
| 147 | + |
| 148 | + EventCallbackMap[hEvent].insert(pfnNotify); |
| 149 | + } |
| 150 | + |
| 151 | + cl_int CallbackType = 0; |
| 152 | + switch (execStatus) { |
| 153 | + case UR_EXECUTION_INFO_EXECUTION_INFO_SUBMITTED: |
| 154 | + CallbackType = CL_SUBMITTED; |
| 155 | + break; |
| 156 | + case UR_EXECUTION_INFO_EXECUTION_INFO_RUNNING: |
| 157 | + CallbackType = CL_RUNNING; |
| 158 | + break; |
| 159 | + case UR_EXECUTION_INFO_EXECUTION_INFO_COMPLETE: |
| 160 | + CallbackType = CL_COMPLETE; |
| 161 | + break; |
| 162 | + default: |
| 163 | + return UR_RESULT_ERROR_INVALID_ENUMERATION; |
| 164 | + } |
| 165 | + |
| 166 | + struct EventCallback { |
| 167 | + void execute() { |
| 168 | + pfnNotify(hEvent, execStatus, pUserData); |
| 169 | + { |
| 170 | + std::lock_guard<std::mutex> Lock(*CallbackMutex); |
| 171 | + (*CallbackMap)[hEvent].erase(pfnNotify); |
| 172 | + if ((*CallbackMap)[hEvent].empty()) { |
| 173 | + CallbackMap->erase(hEvent); |
| 174 | + } |
| 175 | + } |
| 176 | + delete this; |
| 177 | + } |
| 178 | + ur_event_handle_t hEvent; |
| 179 | + ur_execution_info_t execStatus; |
| 180 | + ur_event_callback_t pfnNotify; |
| 181 | + void *pUserData; |
| 182 | + std::unordered_map<ur_event_handle_t, std::set<ur_event_callback_t>> |
| 183 | + *CallbackMap; |
| 184 | + std::mutex *CallbackMutex; |
| 185 | + }; |
| 186 | + auto Callback = new EventCallback({hEvent, execStatus, pfnNotify, pUserData, |
| 187 | + &EventCallbackMap, &EventCallbackMutex}); |
| 188 | + auto ClCallback = [](cl_event, cl_int, void *pUserData) { |
| 189 | + auto *C = static_cast<EventCallback *>(pUserData); |
| 190 | + C->execute(); |
| 191 | + }; |
| 192 | + CL_RETURN_ON_FAILURE(clSetEventCallback(cl_adapter::cast<cl_event>(hEvent), |
| 193 | + CallbackType, ClCallback, Callback)); |
| 194 | + return UR_RESULT_SUCCESS; |
136 | 195 | }
|
0 commit comments