Skip to content

Commit 809e853

Browse files
authored
Merge pull request #2001 from aarongreig/aaron/supportCLInterDeviceCopy
Support device to device USM copies with the CL adapter.
2 parents 9d42e93 + 78c33ce commit 809e853

File tree

4 files changed

+299
-39
lines changed

4 files changed

+299
-39
lines changed

source/adapters/opencl/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ add_ur_adapter(${TARGET_NAME} SHARED
3737
${CMAKE_CURRENT_SOURCE_DIR}/program.cpp
3838
${CMAKE_CURRENT_SOURCE_DIR}/queue.cpp
3939
${CMAKE_CURRENT_SOURCE_DIR}/sampler.cpp
40+
${CMAKE_CURRENT_SOURCE_DIR}/usm.hpp
4041
${CMAKE_CURRENT_SOURCE_DIR}/usm.cpp
4142
${CMAKE_CURRENT_SOURCE_DIR}/usm_p2p.cpp
4243
${CMAKE_CURRENT_SOURCE_DIR}/virtual_mem.cpp

source/adapters/opencl/usm.cpp

Lines changed: 140 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,14 @@
1111
#include <ur/ur.hpp>
1212

1313
#include "common.hpp"
14+
#include "usm.hpp"
15+
16+
template <class T>
17+
void AllocDeleterCallback(cl_event event, cl_int, void *pUserData) {
18+
clReleaseEvent(event);
19+
auto Info = static_cast<T *>(pUserData);
20+
delete Info;
21+
}
1422

1523
namespace umf {
1624
ur_result_t getProviderNativeError(const char *, int32_t) {
@@ -312,32 +320,19 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
312320
numEventsInWaitList, cl_adapter::cast<const cl_event *>(phEventWaitList),
313321
&CopyEvent));
314322

315-
struct DeleteCallbackInfo {
316-
DeleteCallbackInfo(clMemBlockingFreeINTEL_fn USMFree, cl_context CLContext,
317-
void *HostBuffer)
318-
: USMFree(USMFree), CLContext(CLContext), HostBuffer(HostBuffer) {
319-
clRetainContext(CLContext);
320-
}
321-
~DeleteCallbackInfo() {
322-
USMFree(CLContext, HostBuffer);
323-
clReleaseContext(CLContext);
324-
}
325-
DeleteCallbackInfo(const DeleteCallbackInfo &) = delete;
326-
DeleteCallbackInfo &operator=(const DeleteCallbackInfo &) = delete;
327-
328-
clMemBlockingFreeINTEL_fn USMFree;
329-
cl_context CLContext;
330-
void *HostBuffer;
331-
};
332-
333-
auto Info = new DeleteCallbackInfo(USMFree, CLContext, HostBuffer);
323+
if (phEvent) {
324+
// Since we're releasing this in the callback above we need to retain it
325+
// here to keep the user copy alive.
326+
CL_RETURN_ON_FAILURE(clRetainEvent(CopyEvent));
327+
*phEvent = cl_adapter::cast<ur_event_handle_t>(CopyEvent);
328+
}
334329

335-
auto DeleteCallback = [](cl_event, cl_int, void *pUserData) {
336-
auto Info = static_cast<DeleteCallbackInfo *>(pUserData);
337-
delete Info;
338-
};
330+
// This self destructs taking the event and allocation with it.
331+
auto Info = new AllocDeleterCallbackInfo(USMFree, CLContext, HostBuffer);
339332

340-
ClErr = clSetEventCallback(CopyEvent, CL_COMPLETE, DeleteCallback, Info);
333+
ClErr =
334+
clSetEventCallback(CopyEvent, CL_COMPLETE,
335+
AllocDeleterCallback<AllocDeleterCallbackInfo>, Info);
341336
if (ClErr != CL_SUCCESS) {
342337
// We can attempt to recover gracefully by attempting to wait for the copy
343338
// to finish and deleting the info struct here.
@@ -346,11 +341,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
346341
clReleaseEvent(CopyEvent);
347342
CL_RETURN_ON_FAILURE(ClErr);
348343
}
349-
if (phEvent) {
350-
*phEvent = cl_adapter::cast<ur_event_handle_t>(CopyEvent);
351-
} else {
352-
CL_RETURN_ON_FAILURE(clReleaseEvent(CopyEvent));
353-
}
354344

355345
return UR_RESULT_SUCCESS;
356346
}
@@ -369,20 +359,131 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy(
369359
return mapCLErrorToUR(CLErr);
370360
}
371361

372-
clEnqueueMemcpyINTEL_fn FuncPtr = nullptr;
373-
ur_result_t RetVal = cl_ext::getExtFuncFromContext<clEnqueueMemcpyINTEL_fn>(
362+
clGetMemAllocInfoINTEL_fn GetMemAllocInfo = nullptr;
363+
UR_RETURN_ON_FAILURE(cl_ext::getExtFuncFromContext<clGetMemAllocInfoINTEL_fn>(
364+
CLContext, cl_ext::ExtFuncPtrCache->clGetMemAllocInfoINTELCache,
365+
cl_ext::GetMemAllocInfoName, &GetMemAllocInfo));
366+
367+
clEnqueueMemcpyINTEL_fn USMMemcpy = nullptr;
368+
UR_RETURN_ON_FAILURE(cl_ext::getExtFuncFromContext<clEnqueueMemcpyINTEL_fn>(
374369
CLContext, cl_ext::ExtFuncPtrCache->clEnqueueMemcpyINTELCache,
375-
cl_ext::EnqueueMemcpyName, &FuncPtr);
370+
cl_ext::EnqueueMemcpyName, &USMMemcpy));
376371

377-
if (FuncPtr) {
378-
RetVal = mapCLErrorToUR(
379-
FuncPtr(cl_adapter::cast<cl_command_queue>(hQueue), blocking, pDst,
380-
pSrc, size, numEventsInWaitList,
381-
cl_adapter::cast<const cl_event *>(phEventWaitList),
382-
cl_adapter::cast<cl_event *>(phEvent)));
372+
clMemBlockingFreeINTEL_fn USMFree = nullptr;
373+
UR_RETURN_ON_FAILURE(cl_ext::getExtFuncFromContext<clMemBlockingFreeINTEL_fn>(
374+
CLContext, cl_ext::ExtFuncPtrCache->clMemBlockingFreeINTELCache,
375+
cl_ext::MemBlockingFreeName, &USMFree));
376+
377+
// Check if the two allocations are DEVICE allocations from different
378+
// devices, if they are we need to do the copy indirectly via a host
379+
// allocation.
380+
cl_device_id SrcDevice = 0, DstDevice = 0;
381+
CL_RETURN_ON_FAILURE(
382+
GetMemAllocInfo(CLContext, pSrc, CL_MEM_ALLOC_DEVICE_INTEL,
383+
sizeof(cl_device_id), &SrcDevice, nullptr));
384+
CL_RETURN_ON_FAILURE(
385+
GetMemAllocInfo(CLContext, pDst, CL_MEM_ALLOC_DEVICE_INTEL,
386+
sizeof(cl_device_id), &DstDevice, nullptr));
387+
388+
if ((SrcDevice && DstDevice) && SrcDevice != DstDevice) {
389+
// We need a queue associated with each device, so first figure out which
390+
// one we weren't given.
391+
cl_device_id QueueDevice = nullptr;
392+
CL_RETURN_ON_FAILURE(clGetCommandQueueInfo(
393+
cl_adapter::cast<cl_command_queue>(hQueue), CL_QUEUE_DEVICE,
394+
sizeof(QueueDevice), &QueueDevice, nullptr));
395+
396+
cl_command_queue MissingQueue = nullptr, SrcQueue = nullptr,
397+
DstQueue = nullptr;
398+
if (QueueDevice == SrcDevice) {
399+
MissingQueue = clCreateCommandQueue(CLContext, DstDevice, 0, &CLErr);
400+
SrcQueue = cl_adapter::cast<cl_command_queue>(hQueue);
401+
DstQueue = MissingQueue;
402+
} else {
403+
MissingQueue = clCreateCommandQueue(CLContext, SrcDevice, 0, &CLErr);
404+
DstQueue = cl_adapter::cast<cl_command_queue>(hQueue);
405+
SrcQueue = MissingQueue;
406+
}
407+
CL_RETURN_ON_FAILURE(CLErr);
408+
409+
cl_event HostCopyEvent = nullptr, FinalCopyEvent = nullptr;
410+
clHostMemAllocINTEL_fn HostMemAlloc = nullptr;
411+
UR_RETURN_ON_FAILURE(cl_ext::getExtFuncFromContext<clHostMemAllocINTEL_fn>(
412+
CLContext, cl_ext::ExtFuncPtrCache->clHostMemAllocINTELCache,
413+
cl_ext::HostMemAllocName, &HostMemAlloc));
414+
415+
auto HostAlloc = HostMemAlloc(CLContext, nullptr, size, 0, &CLErr);
416+
CL_RETURN_ON_FAILURE(CLErr);
417+
418+
// Now that we've successfully allocated we should try to clean it up if we
419+
// hit an error somewhere.
420+
auto checkCLErr = [&](cl_int CLErr) -> ur_result_t {
421+
if (CLErr != CL_SUCCESS) {
422+
if (HostCopyEvent) {
423+
clReleaseEvent(HostCopyEvent);
424+
}
425+
if (FinalCopyEvent) {
426+
clReleaseEvent(FinalCopyEvent);
427+
}
428+
USMFree(CLContext, HostAlloc);
429+
CL_RETURN_ON_FAILURE(CLErr);
430+
}
431+
return UR_RESULT_SUCCESS;
432+
};
433+
434+
UR_RETURN_ON_FAILURE(checkCLErr(USMMemcpy(
435+
SrcQueue, blocking, HostAlloc, pSrc, size, numEventsInWaitList,
436+
cl_adapter::cast<const cl_event *>(phEventWaitList), &HostCopyEvent)));
437+
438+
UR_RETURN_ON_FAILURE(
439+
checkCLErr(USMMemcpy(DstQueue, blocking, pDst, HostAlloc, size, 1,
440+
&HostCopyEvent, &FinalCopyEvent)));
441+
442+
// If this is a blocking operation we can do our cleanup immediately,
443+
// otherwise we need to defer it to an event callback.
444+
if (blocking) {
445+
CL_RETURN_ON_FAILURE(USMFree(CLContext, HostAlloc));
446+
CL_RETURN_ON_FAILURE(clReleaseEvent(HostCopyEvent));
447+
CL_RETURN_ON_FAILURE(clReleaseCommandQueue(MissingQueue));
448+
if (phEvent) {
449+
*phEvent = cl_adapter::cast<ur_event_handle_t>(FinalCopyEvent);
450+
} else {
451+
CL_RETURN_ON_FAILURE(clReleaseEvent(FinalCopyEvent));
452+
}
453+
} else {
454+
if (phEvent) {
455+
*phEvent = cl_adapter::cast<ur_event_handle_t>(FinalCopyEvent);
456+
// We are going to release this event in our callback so we need to
457+
// retain if the user wants a copy.
458+
CL_RETURN_ON_FAILURE(clRetainEvent(FinalCopyEvent));
459+
}
460+
461+
// This self destructs taking the event and allocation with it.
462+
auto DeleterInfo = new AllocDeleterCallbackInfoWithQueue(
463+
USMFree, CLContext, HostAlloc, MissingQueue);
464+
465+
CLErr = clSetEventCallback(
466+
HostCopyEvent, CL_COMPLETE,
467+
AllocDeleterCallback<AllocDeleterCallbackInfoWithQueue>, DeleterInfo);
468+
469+
if (CLErr != CL_SUCCESS) {
470+
// We can attempt to recover gracefully by attempting to wait for the
471+
// copy to finish and deleting the info struct here.
472+
clWaitForEvents(1, &HostCopyEvent);
473+
delete DeleterInfo;
474+
clReleaseEvent(HostCopyEvent);
475+
CL_RETURN_ON_FAILURE(CLErr);
476+
}
477+
}
478+
} else {
479+
CL_RETURN_ON_FAILURE(
480+
USMMemcpy(cl_adapter::cast<cl_command_queue>(hQueue), blocking, pDst,
481+
pSrc, size, numEventsInWaitList,
482+
cl_adapter::cast<const cl_event *>(phEventWaitList),
483+
cl_adapter::cast<cl_event *>(phEvent)));
383484
}
384485

385-
return RetVal;
486+
return UR_RESULT_SUCCESS;
386487
}
387488

388489
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(

source/adapters/opencl/usm.hpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
//===--------------------- usm.hpp - OpenCL Adapter -----------------------===//
2+
//
3+
// Copyright (C) 2024 Intel Corporation
4+
//
5+
// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM
6+
// Exceptions. See LICENSE.TXT
7+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8+
//
9+
//===----------------------------------------------------------------------===//
10+
11+
#include "CL/cl_ext.h"
12+
#include <CL/cl.h>
13+
14+
// This struct is intended to be used in conjunction with the below callback via
15+
// clSetEventCallback to release temporary allocations created by the adapter to
16+
// implement certain USM operations.
17+
//
18+
// Example usage:
19+
//
20+
// auto Info = new AllocDeleterCallbackInfo(USMFreeFuncPtr, Context,
21+
// Allocation); clSetEventCallback(USMOpEvent, CL_COMPLETE,
22+
// AllocDeleterCallback, Info);
23+
struct AllocDeleterCallbackInfo {
24+
AllocDeleterCallbackInfo(clMemBlockingFreeINTEL_fn USMFree,
25+
cl_context CLContext, void *Allocation)
26+
: USMFree(USMFree), CLContext(CLContext), Allocation(Allocation) {
27+
clRetainContext(CLContext);
28+
}
29+
~AllocDeleterCallbackInfo() {
30+
USMFree(CLContext, Allocation);
31+
clReleaseContext(CLContext);
32+
}
33+
AllocDeleterCallbackInfo(const AllocDeleterCallbackInfo &) = delete;
34+
AllocDeleterCallbackInfo &
35+
operator=(const AllocDeleterCallbackInfo &) = delete;
36+
37+
clMemBlockingFreeINTEL_fn USMFree;
38+
cl_context CLContext;
39+
void *Allocation;
40+
};
41+
42+
struct AllocDeleterCallbackInfoWithQueue : AllocDeleterCallbackInfo {
43+
AllocDeleterCallbackInfoWithQueue(clMemBlockingFreeINTEL_fn USMFree,
44+
cl_context CLContext, void *Allocation,
45+
cl_command_queue CLQueue)
46+
: AllocDeleterCallbackInfo(USMFree, CLContext, Allocation),
47+
CLQueue(CLQueue) {
48+
clRetainContext(CLContext);
49+
}
50+
~AllocDeleterCallbackInfoWithQueue() { clReleaseCommandQueue(CLQueue); }
51+
AllocDeleterCallbackInfoWithQueue(const AllocDeleterCallbackInfoWithQueue &) =
52+
delete;
53+
AllocDeleterCallbackInfoWithQueue &
54+
operator=(const AllocDeleterCallbackInfoWithQueue &) = delete;
55+
56+
cl_command_queue CLQueue;
57+
};
58+
59+
template <class T>
60+
void AllocDeleterCallback(cl_event event, cl_int, void *pUserData);

test/conformance/enqueue/urEnqueueUSMMemcpy.cpp

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,101 @@ TEST_P(urEnqueueUSMMemcpyTest, InvalidNullPtrEventWaitList) {
167167
}
168168

169169
UUR_INSTANTIATE_DEVICE_TEST_SUITE_P(urEnqueueUSMMemcpyTest);
170+
171+
struct urEnqueueUSMMemcpyMultiDeviceTest : uur::urAllDevicesTest {
172+
void SetUp() override {
173+
uur::urAllDevicesTest::SetUp();
174+
for (auto &device : devices) {
175+
ur_device_usm_access_capability_flags_t device_usm = 0;
176+
ASSERT_SUCCESS(uur::GetDeviceUSMDeviceSupport(device, device_usm));
177+
if (device_usm) {
178+
usm_devices.push_back(device);
179+
if (usm_devices.size() == 2) {
180+
break;
181+
}
182+
}
183+
}
184+
185+
if (usm_devices.size() < 2) {
186+
GTEST_SKIP() << "Not enough devices in platform with USM support";
187+
}
188+
189+
ASSERT_SUCCESS(urContextCreate(usm_devices.size(), usm_devices.data(),
190+
nullptr, &context));
191+
ASSERT_SUCCESS(
192+
urQueueCreate(context, usm_devices[0], nullptr, &src_queue));
193+
ASSERT_SUCCESS(
194+
urQueueCreate(context, usm_devices[1], nullptr, &dst_queue));
195+
196+
ASSERT_SUCCESS(
197+
urUSMHostAlloc(context, nullptr, nullptr, alloc_size, &host_alloc));
198+
ASSERT_SUCCESS(urUSMDeviceAlloc(context, usm_devices[0], nullptr,
199+
nullptr, alloc_size, &src_alloc));
200+
ASSERT_SUCCESS(urUSMDeviceAlloc(context, usm_devices[1], nullptr,
201+
nullptr, alloc_size, &dst_alloc));
202+
203+
ASSERT_SUCCESS(urEnqueueUSMFill(src_queue, src_alloc,
204+
sizeof(fill_pattern), &fill_pattern,
205+
alloc_size, 0, nullptr, nullptr));
206+
ASSERT_SUCCESS(urQueueFinish(src_queue));
207+
}
208+
209+
void TearDown() override {
210+
if (src_alloc) {
211+
ASSERT_SUCCESS(urUSMFree(context, src_alloc));
212+
}
213+
if (dst_alloc) {
214+
ASSERT_SUCCESS(urUSMFree(context, dst_alloc));
215+
}
216+
if (host_alloc) {
217+
ASSERT_SUCCESS(urUSMFree(context, host_alloc));
218+
}
219+
if (src_queue) {
220+
ASSERT_SUCCESS(urQueueRelease(src_queue));
221+
}
222+
if (dst_queue) {
223+
ASSERT_SUCCESS(urQueueRelease(dst_queue));
224+
}
225+
if (context) {
226+
ASSERT_SUCCESS(urContextRelease(context));
227+
}
228+
uur::urAllDevicesTest::TearDown();
229+
}
230+
231+
void verifyData() {
232+
for (size_t i = 0; i < alloc_size; i++) {
233+
EXPECT_EQ(static_cast<uint8_t *>(host_alloc)[i], fill_pattern);
234+
}
235+
}
236+
237+
std::vector<ur_device_handle_t> usm_devices;
238+
ur_context_handle_t context = nullptr;
239+
ur_queue_handle_t src_queue = nullptr;
240+
ur_queue_handle_t dst_queue = nullptr;
241+
void *src_alloc = nullptr;
242+
void *dst_alloc = nullptr;
243+
void *host_alloc = nullptr;
244+
size_t alloc_size = 64;
245+
uint8_t fill_pattern = 42;
246+
};
247+
248+
TEST_F(urEnqueueUSMMemcpyMultiDeviceTest, DeviceToDeviceCopyBlocking) {
249+
ASSERT_SUCCESS(urEnqueueUSMMemcpy(src_queue, true, dst_alloc, src_alloc,
250+
alloc_size, 0, nullptr, nullptr));
251+
ASSERT_SUCCESS(urEnqueueUSMMemcpy(dst_queue, true, host_alloc, dst_alloc,
252+
alloc_size, 0, nullptr, nullptr));
253+
verifyData();
254+
}
255+
256+
TEST_F(urEnqueueUSMMemcpyMultiDeviceTest, DeviceToDeviceCopyNonBlocking) {
257+
ur_event_handle_t device_copy_event = nullptr;
258+
ASSERT_SUCCESS(urEnqueueUSMMemcpy(src_queue, false, dst_alloc, src_alloc,
259+
alloc_size, 0, nullptr,
260+
&device_copy_event));
261+
ASSERT_SUCCESS(urQueueFlush(src_queue));
262+
ASSERT_SUCCESS(urEventWait(1, &device_copy_event));
263+
ASSERT_SUCCESS(urEventRelease(device_copy_event));
264+
ASSERT_SUCCESS(urEnqueueUSMMemcpy(dst_queue, true, host_alloc, dst_alloc,
265+
alloc_size, 0, nullptr, nullptr));
266+
verifyData();
267+
}

0 commit comments

Comments
 (0)