Skip to content

Commit e3910da

Browse files
authored
Merge pull request #2166 from igchor/memory_migration
[L0 v2] make device allocation resident and support multi-device buffers
2 parents c043566 + 28db1fd commit e3910da

File tree

22 files changed

+582
-361
lines changed

22 files changed

+582
-361
lines changed

.github/workflows/multi_device.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ jobs:
1717
strategy:
1818
matrix:
1919
adapter: [
20-
{name: L0}
20+
{name: L0},
21+
{name: L0_V2}
2122
]
2223
build_type: [Debug, Release]
2324
compiler: [{c: gcc, cxx: g++}] # TODO: investigate why memory-adapter-level_zero hangs with clang

source/adapters/cuda/usm.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@
2222

2323
#include <cuda.h>
2424

25+
namespace umf {
26+
ur_result_t getProviderNativeError(const char *, int32_t) {
27+
// TODO: implement when UMF supports CUDA
28+
return UR_RESULT_ERROR_UNKNOWN;
29+
}
30+
} // namespace umf
31+
2532
/// USM: Implements USM Host allocations using CUDA Pinned Memory
2633
/// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#page-locked-host-memory
2734
UR_APIEXPORT ur_result_t UR_APICALL

source/adapters/hip/usm.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@
1818
#include "ur_util.hpp"
1919
#include "usm.hpp"
2020

21+
namespace umf {
22+
ur_result_t getProviderNativeError(const char *, int32_t) {
23+
// TODO: implement when UMF supports HIP
24+
return UR_RESULT_ERROR_UNKNOWN;
25+
}
26+
} // namespace umf
27+
2128
/// USM: Implements USM Host allocations using HIP Pinned Memory
2229
UR_APIEXPORT ur_result_t UR_APICALL
2330
urUSMHostAlloc(ur_context_handle_t hContext, const ur_usm_desc_t *pUSMDesc,

source/adapters/level_zero/usm.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,17 @@
2323

2424
#include <umf_helpers.hpp>
2525

26+
namespace umf {
27+
ur_result_t getProviderNativeError(const char *providerName,
28+
int32_t nativeError) {
29+
if (strcmp(providerName, "Level Zero") == 0) {
30+
return ze2urResult(static_cast<ze_result_t>(nativeError));
31+
}
32+
33+
return UR_RESULT_ERROR_UNKNOWN;
34+
}
35+
} // namespace umf
36+
2637
usm::DisjointPoolAllConfigs DisjointPoolConfigInstance =
2738
InitializeDisjointPoolConfig();
2839

source/adapters/level_zero/v2/api.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,6 @@ ur_result_t urMemImageCreateWithNativeHandle(
6464
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
6565
}
6666

67-
ur_result_t urMemGetInfo(ur_mem_handle_t hMemory, ur_mem_info_t propName,
68-
size_t propSize, void *pPropValue,
69-
size_t *pPropSizeRet) {
70-
logger::error("{} function not implemented!", __FUNCTION__);
71-
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
72-
}
73-
7467
ur_result_t urMemImageGetInfo(ur_mem_handle_t hMemory, ur_image_info_t propName,
7568
size_t propSize, void *pPropValue,
7669
size_t *pPropSizeRet) {

source/adapters/level_zero/v2/context.cpp

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,41 @@
1313
#include "context.hpp"
1414
#include "event_provider_normal.hpp"
1515

16+
static std::vector<ur_device_handle_t>
17+
filterP2PDevices(ur_device_handle_t hSourceDevice,
18+
const std::vector<ur_device_handle_t> &devices) {
19+
std::vector<ur_device_handle_t> p2pDevices;
20+
for (auto &device : devices) {
21+
if (device == hSourceDevice) {
22+
continue;
23+
}
24+
25+
ze_bool_t p2p;
26+
ZE2UR_CALL_THROWS(zeDeviceCanAccessPeer,
27+
(hSourceDevice->ZeDevice, device->ZeDevice, &p2p));
28+
29+
if (p2p) {
30+
p2pDevices.push_back(device);
31+
}
32+
}
33+
return p2pDevices;
34+
}
35+
36+
static std::vector<std::vector<ur_device_handle_t>>
37+
populateP2PDevices(size_t maxDevices,
38+
const std::vector<ur_device_handle_t> &devices) {
39+
std::vector<std::vector<ur_device_handle_t>> p2pDevices(maxDevices);
40+
for (auto &device : devices) {
41+
p2pDevices[device->Id.value()] = filterP2PDevices(device, devices);
42+
}
43+
return p2pDevices;
44+
}
45+
1646
ur_context_handle_t_::ur_context_handle_t_(ze_context_handle_t hContext,
1747
uint32_t numDevices,
1848
const ur_device_handle_t *phDevices,
1949
bool ownZeContext)
20-
: hContext(hContext, ownZeContext),
21-
hDevices(phDevices, phDevices + numDevices), commandListCache(hContext),
50+
: commandListCache(hContext),
2251
eventPoolCache(phDevices[0]->Platform->getNumDevices(),
2352
[context = this,
2453
platform = phDevices[0]->Platform](DeviceId deviceId) {
@@ -28,6 +57,10 @@ ur_context_handle_t_::ur_context_handle_t_(ze_context_handle_t hContext,
2857
context, device, v2::EVENT_COUNTER,
2958
v2::QUEUE_IMMEDIATE);
3059
}),
60+
hContext(hContext, ownZeContext),
61+
hDevices(phDevices, phDevices + numDevices),
62+
p2pAccessDevices(populateP2PDevices(
63+
phDevices[0]->Platform->getNumDevices(), this->hDevices)),
3164
defaultUSMPool(this, nullptr) {}
3265

3366
ur_result_t ur_context_handle_t_::retain() {
@@ -65,6 +98,11 @@ ur_usm_pool_handle_t ur_context_handle_t_::getDefaultUSMPool() {
6598
return &defaultUSMPool;
6699
}
67100

101+
const std::vector<ur_device_handle_t> &
102+
ur_context_handle_t_::getP2PDevices(ur_device_handle_t hDevice) const {
103+
return p2pAccessDevices[hDevice->Id.value()];
104+
}
105+
68106
namespace ur::level_zero {
69107
ur_result_t urContextCreate(uint32_t deviceCount,
70108
const ur_device_handle_t *phDevices,

source/adapters/level_zero/v2/context.hpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,22 @@ struct ur_context_handle_t_ : _ur_object {
2828
ur_platform_handle_t getPlatform() const;
2929
const std::vector<ur_device_handle_t> &getDevices() const;
3030
ur_usm_pool_handle_t getDefaultUSMPool();
31+
const std::vector<ur_device_handle_t> &
32+
getP2PDevices(ur_device_handle_t hDevice) const;
3133

3234
// Checks if Device is covered by this context.
3335
// For that the Device or its root devices need to be in the context.
3436
bool isValidDevice(ur_device_handle_t Device) const;
3537

36-
const v2::raii::ze_context_handle_t hContext;
37-
const std::vector<ur_device_handle_t> hDevices;
3838
v2::command_list_cache_t commandListCache;
3939
v2::event_pool_cache eventPoolCache;
40+
41+
private:
42+
const v2::raii::ze_context_handle_t hContext;
43+
const std::vector<ur_device_handle_t> hDevices;
44+
45+
// P2P devices for each device in the context, indexed by device id.
46+
const std::vector<std::vector<ur_device_handle_t>> p2pAccessDevices;
47+
4048
ur_usm_pool_handle_t_ defaultUSMPool;
4149
};

source/adapters/level_zero/v2/kernel.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -316,11 +316,16 @@ urKernelSetArgMemObj(ur_kernel_handle_t hKernel, uint32_t argIndex,
316316
auto zePtr = hArgValue->getPtr(kernelDevices.front());
317317
return hKernel->setArgPointer(argIndex, nullptr, zePtr);
318318
} else {
319-
// TODO: Implement this for multi-device kernels.
320-
// Do this the same way as in legacy (keep a pending Args vector and
321-
// do actual allocation on kernel submission) or allocate the memory
322-
// immediately (only for small allocations?)
323-
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
319+
// TODO: if devices do not have p2p capabilities, we need to have allocation
320+
// on each device. Do this the same way as in legacy (keep a pending Args
321+
// vector and do actual allocation on kernel submission) or allocate the
322+
// memory immediately (only for small allocations?).
323+
324+
// Get memory that is accessible by the first device.
325+
// If kernel is submitted to a different device the memory
326+
// will be accessed trough the link or migrated in enqueueKernelLaunch.
327+
auto zePtr = hArgValue->getPtr(kernelDevices.front());
328+
return hKernel->setArgPointer(argIndex, nullptr, zePtr);
324329
}
325330
}
326331

source/adapters/level_zero/v2/memory.cpp

Lines changed: 84 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,8 @@ ur_host_mem_handle_t::ur_host_mem_handle_t(ur_context_handle_t hContext,
2828
}
2929

3030
if (!hostPtrImported) {
31-
// TODO: use UMF
32-
ZeStruct<ze_host_mem_alloc_desc_t> hostDesc;
33-
ZE2UR_CALL_THROWS(zeMemAllocHost, (hContext->getZeHandle(), &hostDesc, size,
34-
0, &this->ptr));
31+
UR_CALL_THROWS(hContext->getDefaultUSMPool()->allocate(
32+
hContext, nullptr, nullptr, UR_USM_TYPE_HOST, size, &this->ptr));
3533

3634
if (hostPtr) {
3735
std::memcpy(this->ptr, hostPtr, size);
@@ -40,9 +38,11 @@ ur_host_mem_handle_t::ur_host_mem_handle_t(ur_context_handle_t hContext,
4038
}
4139

4240
ur_host_mem_handle_t::~ur_host_mem_handle_t() {
43-
// TODO: use UMF API here
4441
if (ptr) {
45-
ZE_CALL_NOCHECK(zeMemFree, (hContext->getZeHandle(), ptr));
42+
auto ret = hContext->getDefaultUSMPool()->free(ptr);
43+
if (ret != UR_RESULT_SUCCESS) {
44+
logger::error("Failed to free host memory: {}", ret);
45+
}
4646
}
4747
}
4848

@@ -51,55 +51,80 @@ void *ur_host_mem_handle_t::getPtr(ur_device_handle_t hDevice) {
5151
return ptr;
5252
}
5353

54+
ur_result_t ur_device_mem_handle_t::migrateBufferTo(ur_device_handle_t hDevice,
55+
void *src, size_t size) {
56+
auto Id = hDevice->Id.value();
57+
58+
if (!deviceAllocations[Id]) {
59+
UR_CALL(hContext->getDefaultUSMPool()->allocate(hContext, hDevice, nullptr,
60+
UR_USM_TYPE_DEVICE, size,
61+
&deviceAllocations[Id]));
62+
}
63+
64+
auto commandList = hContext->commandListCache.getImmediateCommandList(
65+
hDevice->ZeDevice, true,
66+
hDevice
67+
->QueueGroup[ur_device_handle_t_::queue_group_info_t::type::Compute]
68+
.ZeOrdinal,
69+
ZE_COMMAND_QUEUE_MODE_SYNCHRONOUS, ZE_COMMAND_QUEUE_PRIORITY_NORMAL,
70+
std::nullopt);
71+
72+
ZE2UR_CALL(zeCommandListAppendMemoryCopy,
73+
(commandList.get(), deviceAllocations[Id], src, size, nullptr, 0,
74+
nullptr));
75+
76+
activeAllocationDevice = hDevice;
77+
78+
return UR_RESULT_SUCCESS;
79+
}
80+
5481
ur_device_mem_handle_t::ur_device_mem_handle_t(ur_context_handle_t hContext,
5582
void *hostPtr, size_t size)
5683
: ur_mem_handle_t_(hContext, size),
57-
deviceAllocations(hContext->getPlatform()->getNumDevices()) {
58-
// Legacy adapter allocated the memory directly on a device (first on the
59-
// contxt) and if the buffer is used on another device, memory is migrated
60-
// (depending on an env var setting).
61-
//
62-
// TODO: port this behavior or figure out if it makes sense to keep the memory
63-
// in a host buffer (e.g. for smaller sizes).
84+
deviceAllocations(hContext->getPlatform()->getNumDevices()),
85+
activeAllocationDevice(nullptr) {
6486
if (hostPtr) {
65-
buffer.assign(reinterpret_cast<char *>(hostPtr),
66-
reinterpret_cast<char *>(hostPtr) + size);
87+
auto initialDevice = hContext->getDevices()[0];
88+
UR_CALL_THROWS(migrateBufferTo(initialDevice, hostPtr, size));
6789
}
6890
}
6991

7092
ur_device_mem_handle_t::~ur_device_mem_handle_t() {
71-
// TODO: use UMF API here
7293
for (auto &ptr : deviceAllocations) {
7394
if (ptr) {
74-
ZE_CALL_NOCHECK(zeMemFree, (hContext->getZeHandle(), ptr));
95+
auto ret = hContext->getDefaultUSMPool()->free(ptr);
96+
if (ret != UR_RESULT_SUCCESS) {
97+
logger::error("Failed to free device memory: {}", ret);
98+
}
7599
}
76100
}
77101
}
78102

79103
void *ur_device_mem_handle_t::getPtr(ur_device_handle_t hDevice) {
80104
std::lock_guard lock(this->Mutex);
81105

82-
auto &ptr = deviceAllocations[hDevice->Id.value()];
83-
if (!ptr) {
84-
ZeStruct<ze_device_mem_alloc_desc_t> deviceDesc;
85-
ZE2UR_CALL_THROWS(zeMemAllocDevice, (hContext->getZeHandle(), &deviceDesc,
86-
size, 0, hDevice->ZeDevice, &ptr));
87-
88-
if (!buffer.empty()) {
89-
auto commandList = hContext->commandListCache.getImmediateCommandList(
90-
hDevice->ZeDevice, true,
91-
hDevice
92-
->QueueGroup
93-
[ur_device_handle_t_::queue_group_info_t::type::Compute]
94-
.ZeOrdinal,
95-
ZE_COMMAND_QUEUE_MODE_SYNCHRONOUS, ZE_COMMAND_QUEUE_PRIORITY_NORMAL,
96-
std::nullopt);
97-
ZE2UR_CALL_THROWS(
98-
zeCommandListAppendMemoryCopy,
99-
(commandList.get(), ptr, buffer.data(), size, nullptr, 0, nullptr));
100-
}
106+
if (!activeAllocationDevice) {
107+
UR_CALL_THROWS(hContext->getDefaultUSMPool()->allocate(
108+
hContext, hDevice, nullptr, UR_USM_TYPE_DEVICE, getSize(),
109+
&deviceAllocations[hDevice->Id.value()]));
110+
activeAllocationDevice = hDevice;
101111
}
102-
return ptr;
112+
113+
if (activeAllocationDevice == hDevice) {
114+
return deviceAllocations[hDevice->Id.value()];
115+
}
116+
117+
auto &p2pDevices = hContext->getP2PDevices(hDevice);
118+
auto p2pAccessible = std::find(p2pDevices.begin(), p2pDevices.end(),
119+
activeAllocationDevice) != p2pDevices.end();
120+
121+
if (!p2pAccessible) {
122+
// TODO: migrate buffer through the host
123+
throw UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
124+
}
125+
126+
// TODO: see if it's better to migrate the memory to the specified device
127+
return deviceAllocations[activeAllocationDevice->Id.value()];
103128
}
104129

105130
namespace ur::level_zero {
@@ -166,6 +191,28 @@ ur_result_t urMemBufferCreateWithNativeHandle(
166191
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
167192
}
168193

194+
ur_result_t urMemGetInfo(ur_mem_handle_t hMemory, ur_mem_info_t propName,
195+
size_t propSize, void *pPropValue,
196+
size_t *pPropSizeRet) {
197+
std::shared_lock<ur_shared_mutex> Lock(hMemory->Mutex);
198+
UrReturnHelper returnValue(propSize, pPropValue, pPropSizeRet);
199+
200+
switch (propName) {
201+
case UR_MEM_INFO_CONTEXT: {
202+
return returnValue(hMemory->getContext());
203+
}
204+
case UR_MEM_INFO_SIZE: {
205+
// Get size of the allocation
206+
return returnValue(size_t{hMemory->getSize()});
207+
}
208+
default: {
209+
return UR_RESULT_ERROR_INVALID_ENUMERATION;
210+
}
211+
}
212+
213+
return UR_RESULT_SUCCESS;
214+
}
215+
169216
ur_result_t urMemRetain(ur_mem_handle_t hMem) {
170217
hMem->RefCount.increment();
171218
return UR_RESULT_SUCCESS;

source/adapters/level_zero/v2/memory.hpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include <ur_api.h>
1414

15+
#include "../device.hpp"
1516
#include "common.hpp"
1617

1718
struct ur_mem_handle_t_ : _ur_object {
@@ -21,6 +22,7 @@ struct ur_mem_handle_t_ : _ur_object {
2122
virtual void *getPtr(ur_device_handle_t) = 0;
2223

2324
inline size_t getSize() { return size; }
25+
inline ur_context_handle_t getContext() { return hContext; }
2426

2527
protected:
2628
const ur_context_handle_t hContext;
@@ -48,8 +50,13 @@ struct ur_device_mem_handle_t : public ur_mem_handle_t_ {
4850
void *getPtr(ur_device_handle_t) override;
4951

5052
private:
51-
std::vector<char> buffer;
52-
5353
// Vector of per-device allocations indexed by device->Id
5454
std::vector<void *> deviceAllocations;
55+
56+
// Specifies device on which the latest allocation resides.
57+
// If null, there is no allocation.
58+
ur_device_handle_t activeAllocationDevice;
59+
60+
ur_result_t migrateBufferTo(ur_device_handle_t hDevice, void *src,
61+
size_t size);
5562
};

0 commit comments

Comments
 (0)