@@ -257,24 +257,39 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
257
257
&CopyEvent));
258
258
259
259
struct DeleteCallbackInfo {
260
+ DeleteCallbackInfo (clMemBlockingFreeINTEL_fn USMFree, cl_context CLContext,
261
+ void *HostBuffer)
262
+ : USMFree(USMFree), CLContext(CLContext), HostBuffer(HostBuffer) {
263
+ clRetainContext (CLContext);
264
+ }
265
+ ~DeleteCallbackInfo () {
266
+ USMFree (CLContext, HostBuffer);
267
+ clReleaseContext (CLContext);
268
+ }
269
+ DeleteCallbackInfo (const DeleteCallbackInfo &) = delete ;
270
+ DeleteCallbackInfo &operator =(const DeleteCallbackInfo &) = delete ;
271
+
260
272
clMemBlockingFreeINTEL_fn USMFree;
261
273
cl_context CLContext;
262
274
void *HostBuffer;
263
- void execute () {
264
- USMFree (CLContext, HostBuffer);
265
- delete this ;
266
- }
267
275
};
268
276
269
- auto Info = new DeleteCallbackInfo{ USMFree, CLContext, HostBuffer} ;
277
+ auto Info = new DeleteCallbackInfo ( USMFree, CLContext, HostBuffer) ;
270
278
271
279
auto DeleteCallback = [](cl_event, cl_int, void *pUserData) {
272
- static_cast <DeleteCallbackInfo *>(pUserData)->execute ();
280
+ auto Info = static_cast <DeleteCallbackInfo *>(pUserData);
281
+ delete Info;
273
282
};
274
283
275
- CL_RETURN_ON_FAILURE (
276
- clSetEventCallback (CopyEvent, CL_COMPLETE, DeleteCallback, Info));
277
-
284
+ ClErr = clSetEventCallback (CopyEvent, CL_COMPLETE, DeleteCallback, Info);
285
+ if (ClErr != CL_SUCCESS) {
286
+ // We can attempt to recover gracefully by attempting to wait for the copy
287
+ // to finish and deleting the info struct here.
288
+ clWaitForEvents (1 , &CopyEvent);
289
+ delete Info;
290
+ clReleaseEvent (CopyEvent);
291
+ CL_RETURN_ON_FAILURE (ClErr);
292
+ }
278
293
if (phEvent) {
279
294
*phEvent = cl_adapter::cast<ur_event_handle_t >(CopyEvent);
280
295
} else {
@@ -426,7 +441,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
426
441
return RetVal;
427
442
}
428
443
429
- std::vector<cl_event> Events;
444
+ std::vector<cl_event> Events (height) ;
430
445
for (size_t HeightIndex = 0 ; HeightIndex < height; HeightIndex++) {
431
446
cl_event Event = nullptr ;
432
447
auto ClResult =
@@ -435,7 +450,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
435
450
static_cast <const uint8_t *>(pSrc) + srcPitch * HeightIndex,
436
451
width, numEventsInWaitList,
437
452
cl_adapter::cast<const cl_event *>(phEventWaitList), &Event);
438
- Events. push_back ( Event) ;
453
+ Events[HeightIndex] = Event;
439
454
if (ClResult != CL_SUCCESS) {
440
455
for (const auto &E : Events) {
441
456
clReleaseEvent (E);
@@ -453,7 +468,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
453
468
Events.data (), cl_adapter::cast<cl_event *>(phEvent));
454
469
}
455
470
for (const auto &E : Events) {
456
- clReleaseEvent (E);
471
+ CL_RETURN_ON_FAILURE ( clReleaseEvent (E) );
457
472
}
458
473
CL_RETURN_ON_FAILURE (ClResult)
459
474
return UR_RESULT_SUCCESS;
0 commit comments