11
11
#include < ur/ur.hpp>
12
12
13
13
#include " common.hpp"
14
+ #include " usm.hpp"
15
+
16
+ template <class T >
17
+ void AllocDeleterCallback (cl_event event, cl_int, void *pUserData) {
18
+ clReleaseEvent (event);
19
+ auto Info = static_cast <T *>(pUserData);
20
+ delete Info;
21
+ }
14
22
15
23
namespace umf {
16
24
ur_result_t getProviderNativeError (const char *, int32_t ) {
@@ -312,32 +320,19 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
312
320
numEventsInWaitList, cl_adapter::cast<const cl_event *>(phEventWaitList),
313
321
&CopyEvent));
314
322
315
- struct DeleteCallbackInfo {
316
- DeleteCallbackInfo (clMemBlockingFreeINTEL_fn USMFree, cl_context CLContext,
317
- void *HostBuffer)
318
- : USMFree(USMFree), CLContext(CLContext), HostBuffer(HostBuffer) {
319
- clRetainContext (CLContext);
320
- }
321
- ~DeleteCallbackInfo () {
322
- USMFree (CLContext, HostBuffer);
323
- clReleaseContext (CLContext);
324
- }
325
- DeleteCallbackInfo (const DeleteCallbackInfo &) = delete ;
326
- DeleteCallbackInfo &operator =(const DeleteCallbackInfo &) = delete ;
327
-
328
- clMemBlockingFreeINTEL_fn USMFree;
329
- cl_context CLContext;
330
- void *HostBuffer;
331
- };
332
-
333
- auto Info = new DeleteCallbackInfo (USMFree, CLContext, HostBuffer);
323
+ if (phEvent) {
324
+ // Since we're releasing this in the callback above we need to retain it
325
+ // here to keep the user copy alive.
326
+ CL_RETURN_ON_FAILURE (clRetainEvent (CopyEvent));
327
+ *phEvent = cl_adapter::cast<ur_event_handle_t >(CopyEvent);
328
+ }
334
329
335
- auto DeleteCallback = [](cl_event, cl_int, void *pUserData) {
336
- auto Info = static_cast <DeleteCallbackInfo *>(pUserData);
337
- delete Info;
338
- };
330
+ // This self destructs taking the event and allocation with it.
331
+ auto Info = new AllocDeleterCallbackInfo (USMFree, CLContext, HostBuffer);
339
332
340
- ClErr = clSetEventCallback (CopyEvent, CL_COMPLETE, DeleteCallback, Info);
333
+ ClErr =
334
+ clSetEventCallback (CopyEvent, CL_COMPLETE,
335
+ AllocDeleterCallback<AllocDeleterCallbackInfo>, Info);
341
336
if (ClErr != CL_SUCCESS) {
342
337
// We can attempt to recover gracefully by attempting to wait for the copy
343
338
// to finish and deleting the info struct here.
@@ -346,11 +341,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
346
341
clReleaseEvent (CopyEvent);
347
342
CL_RETURN_ON_FAILURE (ClErr);
348
343
}
349
- if (phEvent) {
350
- *phEvent = cl_adapter::cast<ur_event_handle_t >(CopyEvent);
351
- } else {
352
- CL_RETURN_ON_FAILURE (clReleaseEvent (CopyEvent));
353
- }
354
344
355
345
return UR_RESULT_SUCCESS;
356
346
}
@@ -369,20 +359,131 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy(
369
359
return mapCLErrorToUR (CLErr);
370
360
}
371
361
372
- clEnqueueMemcpyINTEL_fn FuncPtr = nullptr ;
373
- ur_result_t RetVal = cl_ext::getExtFuncFromContext<clEnqueueMemcpyINTEL_fn>(
362
+ clGetMemAllocInfoINTEL_fn GetMemAllocInfo = nullptr ;
363
+ UR_RETURN_ON_FAILURE (cl_ext::getExtFuncFromContext<clGetMemAllocInfoINTEL_fn>(
364
+ CLContext, cl_ext::ExtFuncPtrCache->clGetMemAllocInfoINTELCache ,
365
+ cl_ext::GetMemAllocInfoName, &GetMemAllocInfo));
366
+
367
+ clEnqueueMemcpyINTEL_fn USMMemcpy = nullptr ;
368
+ UR_RETURN_ON_FAILURE (cl_ext::getExtFuncFromContext<clEnqueueMemcpyINTEL_fn>(
374
369
CLContext, cl_ext::ExtFuncPtrCache->clEnqueueMemcpyINTELCache ,
375
- cl_ext::EnqueueMemcpyName, &FuncPtr );
370
+ cl_ext::EnqueueMemcpyName, &USMMemcpy) );
376
371
377
- if (FuncPtr) {
378
- RetVal = mapCLErrorToUR (
379
- FuncPtr (cl_adapter::cast<cl_command_queue>(hQueue), blocking, pDst,
380
- pSrc, size, numEventsInWaitList,
381
- cl_adapter::cast<const cl_event *>(phEventWaitList),
382
- cl_adapter::cast<cl_event *>(phEvent)));
372
+ clMemBlockingFreeINTEL_fn USMFree = nullptr ;
373
+ UR_RETURN_ON_FAILURE (cl_ext::getExtFuncFromContext<clMemBlockingFreeINTEL_fn>(
374
+ CLContext, cl_ext::ExtFuncPtrCache->clMemBlockingFreeINTELCache ,
375
+ cl_ext::MemBlockingFreeName, &USMFree));
376
+
377
+ // Check if the two allocations are DEVICE allocations from different
378
+ // devices, if they are we need to do the copy indirectly via a host
379
+ // allocation.
380
+ cl_device_id SrcDevice = 0 , DstDevice = 0 ;
381
+ CL_RETURN_ON_FAILURE (
382
+ GetMemAllocInfo (CLContext, pSrc, CL_MEM_ALLOC_DEVICE_INTEL,
383
+ sizeof (cl_device_id), &SrcDevice, nullptr ));
384
+ CL_RETURN_ON_FAILURE (
385
+ GetMemAllocInfo (CLContext, pDst, CL_MEM_ALLOC_DEVICE_INTEL,
386
+ sizeof (cl_device_id), &DstDevice, nullptr ));
387
+
388
+ if ((SrcDevice && DstDevice) && SrcDevice != DstDevice) {
389
+ // We need a queue associated with each device, so first figure out which
390
+ // one we weren't given.
391
+ cl_device_id QueueDevice = nullptr ;
392
+ CL_RETURN_ON_FAILURE (clGetCommandQueueInfo (
393
+ cl_adapter::cast<cl_command_queue>(hQueue), CL_QUEUE_DEVICE,
394
+ sizeof (QueueDevice), &QueueDevice, nullptr ));
395
+
396
+ cl_command_queue MissingQueue = nullptr , SrcQueue = nullptr ,
397
+ DstQueue = nullptr ;
398
+ if (QueueDevice == SrcDevice) {
399
+ MissingQueue = clCreateCommandQueue (CLContext, DstDevice, 0 , &CLErr);
400
+ SrcQueue = cl_adapter::cast<cl_command_queue>(hQueue);
401
+ DstQueue = MissingQueue;
402
+ } else {
403
+ MissingQueue = clCreateCommandQueue (CLContext, SrcDevice, 0 , &CLErr);
404
+ DstQueue = cl_adapter::cast<cl_command_queue>(hQueue);
405
+ SrcQueue = MissingQueue;
406
+ }
407
+ CL_RETURN_ON_FAILURE (CLErr);
408
+
409
+ cl_event HostCopyEvent = nullptr , FinalCopyEvent = nullptr ;
410
+ clHostMemAllocINTEL_fn HostMemAlloc = nullptr ;
411
+ UR_RETURN_ON_FAILURE (cl_ext::getExtFuncFromContext<clHostMemAllocINTEL_fn>(
412
+ CLContext, cl_ext::ExtFuncPtrCache->clHostMemAllocINTELCache ,
413
+ cl_ext::HostMemAllocName, &HostMemAlloc));
414
+
415
+ auto HostAlloc = HostMemAlloc (CLContext, nullptr , size, 0 , &CLErr);
416
+ CL_RETURN_ON_FAILURE (CLErr);
417
+
418
+ // Now that we've successfully allocated we should try to clean it up if we
419
+ // hit an error somewhere.
420
+ auto checkCLErr = [&](cl_int CLErr) -> ur_result_t {
421
+ if (CLErr != CL_SUCCESS) {
422
+ if (HostCopyEvent) {
423
+ clReleaseEvent (HostCopyEvent);
424
+ }
425
+ if (FinalCopyEvent) {
426
+ clReleaseEvent (FinalCopyEvent);
427
+ }
428
+ USMFree (CLContext, HostAlloc);
429
+ CL_RETURN_ON_FAILURE (CLErr);
430
+ }
431
+ return UR_RESULT_SUCCESS;
432
+ };
433
+
434
+ UR_RETURN_ON_FAILURE (checkCLErr (USMMemcpy (
435
+ SrcQueue, blocking, HostAlloc, pSrc, size, numEventsInWaitList,
436
+ cl_adapter::cast<const cl_event *>(phEventWaitList), &HostCopyEvent)));
437
+
438
+ UR_RETURN_ON_FAILURE (
439
+ checkCLErr (USMMemcpy (DstQueue, blocking, pDst, HostAlloc, size, 1 ,
440
+ &HostCopyEvent, &FinalCopyEvent)));
441
+
442
+ // If this is a blocking operation we can do our cleanup immediately,
443
+ // otherwise we need to defer it to an event callback.
444
+ if (blocking) {
445
+ CL_RETURN_ON_FAILURE (USMFree (CLContext, HostAlloc));
446
+ CL_RETURN_ON_FAILURE (clReleaseEvent (HostCopyEvent));
447
+ CL_RETURN_ON_FAILURE (clReleaseCommandQueue (MissingQueue));
448
+ if (phEvent) {
449
+ *phEvent = cl_adapter::cast<ur_event_handle_t >(FinalCopyEvent);
450
+ } else {
451
+ CL_RETURN_ON_FAILURE (clReleaseEvent (FinalCopyEvent));
452
+ }
453
+ } else {
454
+ if (phEvent) {
455
+ *phEvent = cl_adapter::cast<ur_event_handle_t >(FinalCopyEvent);
456
+ // We are going to release this event in our callback so we need to
457
+ // retain if the user wants a copy.
458
+ CL_RETURN_ON_FAILURE (clRetainEvent (FinalCopyEvent));
459
+ }
460
+
461
+ // This self destructs taking the event and allocation with it.
462
+ auto DeleterInfo = new AllocDeleterCallbackInfoWithQueue (
463
+ USMFree, CLContext, HostAlloc, MissingQueue);
464
+
465
+ CLErr = clSetEventCallback (
466
+ HostCopyEvent, CL_COMPLETE,
467
+ AllocDeleterCallback<AllocDeleterCallbackInfoWithQueue>, DeleterInfo);
468
+
469
+ if (CLErr != CL_SUCCESS) {
470
+ // We can attempt to recover gracefully by attempting to wait for the
471
+ // copy to finish and deleting the info struct here.
472
+ clWaitForEvents (1 , &HostCopyEvent);
473
+ delete DeleterInfo;
474
+ clReleaseEvent (HostCopyEvent);
475
+ CL_RETURN_ON_FAILURE (CLErr);
476
+ }
477
+ }
478
+ } else {
479
+ CL_RETURN_ON_FAILURE (
480
+ USMMemcpy (cl_adapter::cast<cl_command_queue>(hQueue), blocking, pDst,
481
+ pSrc, size, numEventsInWaitList,
482
+ cl_adapter::cast<const cl_event *>(phEventWaitList),
483
+ cl_adapter::cast<cl_event *>(phEvent)));
383
484
}
384
485
385
- return RetVal ;
486
+ return UR_RESULT_SUCCESS ;
386
487
}
387
488
388
489
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch (
0 commit comments