Skip to content

Commit 7236161

Browse files
committed
[L0 v2] support multi-device memory buffers
1 parent 3d99145 commit 7236161

File tree

6 files changed

+87
-44
lines changed

6 files changed

+87
-44
lines changed

source/adapters/level_zero/v2/kernel.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -316,11 +316,16 @@ urKernelSetArgMemObj(ur_kernel_handle_t hKernel, uint32_t argIndex,
316316
auto zePtr = hArgValue->getPtr(kernelDevices.front());
317317
return hKernel->setArgPointer(argIndex, nullptr, zePtr);
318318
} else {
319-
// TODO: Implement this for multi-device kernels.
320-
// Do this the same way as in legacy (keep a pending Args vector and
321-
// do actual allocation on kernel submission) or allocate the memory
322-
// immediately (only for small allocations?)
323-
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
319+
// TODO: if devices do not have p2p capabilities, we need to have allocation
320+
// on each device. Do this the same way as in legacy (keep a pending Args
321+
// vector and do actual allocation on kernel submission) or allocate the
322+
// memory immediately (only for small allocations?).
323+
324+
// Get memory that is accessible by the first device.
325+
// If kernel is submitted to a different device the memory
326+
// will be accessed trough the link or migrated in enqueueKernelLaunch.
327+
auto zePtr = hArgValue->getPtr(kernelDevices.front());
328+
return hKernel->setArgPointer(argIndex, nullptr, zePtr);
324329
}
325330
}
326331

source/adapters/level_zero/v2/memory.cpp

Lines changed: 62 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,8 @@ ur_host_mem_handle_t::ur_host_mem_handle_t(ur_context_handle_t hContext,
2828
}
2929

3030
if (!hostPtrImported) {
31-
// TODO: use UMF
32-
ZeStruct<ze_host_mem_alloc_desc_t> hostDesc;
33-
ZE2UR_CALL_THROWS(zeMemAllocHost, (hContext->getZeHandle(), &hostDesc, size,
34-
0, &this->ptr));
31+
UR_CALL_THROWS(hContext->getDefaultUSMPool()->allocate(
32+
hContext, nullptr, nullptr, UR_USM_TYPE_HOST, size, &this->ptr));
3533

3634
if (hostPtr) {
3735
std::memcpy(this->ptr, hostPtr, size);
@@ -40,9 +38,11 @@ ur_host_mem_handle_t::ur_host_mem_handle_t(ur_context_handle_t hContext,
4038
}
4139

4240
ur_host_mem_handle_t::~ur_host_mem_handle_t() {
43-
// TODO: use UMF API here
4441
if (ptr) {
45-
ZE_CALL_NOCHECK(zeMemFree, (hContext->getZeHandle(), ptr));
42+
auto ret = hContext->getDefaultUSMPool()->free(ptr);
43+
if (ret != UR_RESULT_SUCCESS) {
44+
logger::error("Failed to free host memory: {}", ret);
45+
}
4646
}
4747
}
4848

@@ -51,55 +51,80 @@ void *ur_host_mem_handle_t::getPtr(ur_device_handle_t hDevice) {
5151
return ptr;
5252
}
5353

54+
ur_result_t ur_device_mem_handle_t::migrateBufferTo(ur_device_handle_t hDevice,
55+
void *src, size_t size) {
56+
auto Id = hDevice->Id.value();
57+
58+
if (!deviceAllocations[Id]) {
59+
UR_CALL(hContext->getDefaultUSMPool()->allocate(hContext, hDevice, nullptr,
60+
UR_USM_TYPE_DEVICE, size,
61+
&deviceAllocations[Id]));
62+
}
63+
64+
auto commandList = hContext->commandListCache.getImmediateCommandList(
65+
hDevice->ZeDevice, true,
66+
hDevice
67+
->QueueGroup[ur_device_handle_t_::queue_group_info_t::type::Compute]
68+
.ZeOrdinal,
69+
ZE_COMMAND_QUEUE_MODE_SYNCHRONOUS, ZE_COMMAND_QUEUE_PRIORITY_NORMAL,
70+
std::nullopt);
71+
72+
ZE2UR_CALL(zeCommandListAppendMemoryCopy,
73+
(commandList.get(), deviceAllocations[Id], src, size, nullptr, 0,
74+
nullptr));
75+
76+
activeAllocationDevice = hDevice;
77+
78+
return UR_RESULT_SUCCESS;
79+
}
80+
5481
ur_device_mem_handle_t::ur_device_mem_handle_t(ur_context_handle_t hContext,
5582
void *hostPtr, size_t size)
5683
: ur_mem_handle_t_(hContext, size),
57-
deviceAllocations(hContext->getPlatform()->getNumDevices()) {
58-
// Legacy adapter allocated the memory directly on a device (first on the
59-
// contxt) and if the buffer is used on another device, memory is migrated
60-
// (depending on an env var setting).
61-
//
62-
// TODO: port this behavior or figure out if it makes sense to keep the memory
63-
// in a host buffer (e.g. for smaller sizes).
84+
deviceAllocations(hContext->getPlatform()->getNumDevices()),
85+
activeAllocationDevice(nullptr) {
6486
if (hostPtr) {
65-
buffer.assign(reinterpret_cast<char *>(hostPtr),
66-
reinterpret_cast<char *>(hostPtr) + size);
87+
auto initialDevice = hContext->getDevices()[0];
88+
UR_CALL_THROWS(migrateBufferTo(initialDevice, hostPtr, size));
6789
}
6890
}
6991

7092
ur_device_mem_handle_t::~ur_device_mem_handle_t() {
71-
// TODO: use UMF API here
7293
for (auto &ptr : deviceAllocations) {
7394
if (ptr) {
74-
ZE_CALL_NOCHECK(zeMemFree, (hContext->getZeHandle(), ptr));
95+
auto ret = hContext->getDefaultUSMPool()->free(ptr);
96+
if (ret != UR_RESULT_SUCCESS) {
97+
logger::error("Failed to free device memory: {}", ret);
98+
}
7599
}
76100
}
77101
}
78102

79103
void *ur_device_mem_handle_t::getPtr(ur_device_handle_t hDevice) {
80104
std::lock_guard lock(this->Mutex);
81105

82-
auto &ptr = deviceAllocations[hDevice->Id.value()];
83-
if (!ptr) {
84-
ZeStruct<ze_device_mem_alloc_desc_t> deviceDesc;
85-
ZE2UR_CALL_THROWS(zeMemAllocDevice, (hContext->getZeHandle(), &deviceDesc,
86-
size, 0, hDevice->ZeDevice, &ptr));
87-
88-
if (!buffer.empty()) {
89-
auto commandList = hContext->commandListCache.getImmediateCommandList(
90-
hDevice->ZeDevice, true,
91-
hDevice
92-
->QueueGroup
93-
[ur_device_handle_t_::queue_group_info_t::type::Compute]
94-
.ZeOrdinal,
95-
ZE_COMMAND_QUEUE_MODE_SYNCHRONOUS, ZE_COMMAND_QUEUE_PRIORITY_NORMAL,
96-
std::nullopt);
97-
ZE2UR_CALL_THROWS(
98-
zeCommandListAppendMemoryCopy,
99-
(commandList.get(), ptr, buffer.data(), size, nullptr, 0, nullptr));
100-
}
106+
if (!activeAllocationDevice) {
107+
UR_CALL_THROWS(hContext->getDefaultUSMPool()->allocate(
108+
hContext, hDevice, nullptr, UR_USM_TYPE_DEVICE, getSize(),
109+
&deviceAllocations[hDevice->Id.value()]));
110+
activeAllocationDevice = hDevice;
101111
}
102-
return ptr;
112+
113+
if (activeAllocationDevice == hDevice) {
114+
return deviceAllocations[hDevice->Id.value()];
115+
}
116+
117+
auto &p2pDevices = hContext->getP2PDevices(hDevice);
118+
auto p2pAccessible = std::find(p2pDevices.begin(), p2pDevices.end(),
119+
activeAllocationDevice) != p2pDevices.end();
120+
121+
if (!p2pAccessible) {
122+
// TODO: migrate buffer through the host
123+
throw UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
124+
}
125+
126+
// TODO: see if it's better to migrate the memory to the specified device
127+
return deviceAllocations[activeAllocationDevice->Id.value()];
103128
}
104129

105130
namespace ur::level_zero {

source/adapters/level_zero/v2/memory.hpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include <ur_api.h>
1414

15+
#include "../device.hpp"
1516
#include "common.hpp"
1617

1718
struct ur_mem_handle_t_ : _ur_object {
@@ -48,8 +49,13 @@ struct ur_device_mem_handle_t : public ur_mem_handle_t_ {
4849
void *getPtr(ur_device_handle_t) override;
4950

5051
private:
51-
std::vector<char> buffer;
52-
5352
// Vector of per-device allocations indexed by device->Id
5453
std::vector<void *> deviceAllocations;
54+
55+
// Specifies device on which the latest allocation resides.
56+
// If null, there is no allocation.
57+
ur_device_handle_t activeAllocationDevice;
58+
59+
ur_result_t migrateBufferTo(ur_device_handle_t hDevice, void *src,
60+
size_t size);
5561
};

source/adapters/level_zero/v2/queue_immediate_in_order.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,8 @@ ur_result_t ur_queue_immediate_in_order_t::enqueueKernelLaunch(
239239
auto [pWaitEvents, numWaitEvents] =
240240
getWaitListView(phEventWaitList, numEventsInWaitList, handler);
241241

242+
// TODO: consider migrating memory to the device if memory buffers are used
243+
242244
TRACK_SCOPE_LATENCY(
243245
"ur_queue_immediate_in_order_t::zeCommandListAppendLaunchKernel");
244246
ZE2UR_CALL(zeCommandListAppendLaunchKernel,

source/adapters/level_zero/v2/usm.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,10 @@ ur_result_t ur_usm_pool_handle_t_::allocate(
184184
return UR_RESULT_SUCCESS;
185185
}
186186

187+
ur_result_t ur_usm_pool_handle_t_::free(void *ptr) {
188+
return umf::umf2urResult(umfFree(ptr));
189+
}
190+
187191
namespace ur::level_zero {
188192
ur_result_t urUSMPoolCreate(
189193
ur_context_handle_t hContext, ///< [in] handle of the context object

source/adapters/level_zero/v2/usm.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ struct ur_usm_pool_handle_t_ : _ur_object {
2424
ur_result_t allocate(ur_context_handle_t hContext, ur_device_handle_t hDevice,
2525
const ur_usm_desc_t *pUSMDesc, ur_usm_type_t type,
2626
size_t size, void **ppRetMem);
27+
ur_result_t free(void *ptr);
2728

2829
private:
2930
ur_context_handle_t hContext;

0 commit comments

Comments
 (0)