Skip to content

Commit 3d99145

Browse files
committed
[L0 v2] make device allocation resident
on all devices in the context to match the default behavior of legacy adapter.
1 parent d9facf2 commit 3d99145

File tree

6 files changed

+135
-6
lines changed

6 files changed

+135
-6
lines changed

source/adapters/level_zero/v2/context.cpp

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,41 @@
1313
#include "context.hpp"
1414
#include "event_provider_normal.hpp"
1515

16+
static std::vector<ur_device_handle_t>
17+
filterP2PDevices(ur_device_handle_t hSourceDevice,
18+
const std::vector<ur_device_handle_t> &devices) {
19+
std::vector<ur_device_handle_t> p2pDevices;
20+
for (auto &device : devices) {
21+
if (device == hSourceDevice) {
22+
continue;
23+
}
24+
25+
ze_bool_t p2p;
26+
ZE2UR_CALL_THROWS(zeDeviceCanAccessPeer,
27+
(hSourceDevice->ZeDevice, device->ZeDevice, &p2p));
28+
29+
if (p2p) {
30+
p2pDevices.push_back(device);
31+
}
32+
}
33+
return p2pDevices;
34+
}
35+
36+
static std::vector<std::vector<ur_device_handle_t>>
37+
populateP2PDevices(size_t maxDevices,
38+
const std::vector<ur_device_handle_t> &devices) {
39+
std::vector<std::vector<ur_device_handle_t>> p2pDevices(maxDevices);
40+
for (auto &device : devices) {
41+
p2pDevices[device->Id.value()] = filterP2PDevices(device, devices);
42+
}
43+
return p2pDevices;
44+
}
45+
1646
ur_context_handle_t_::ur_context_handle_t_(ze_context_handle_t hContext,
1747
uint32_t numDevices,
1848
const ur_device_handle_t *phDevices,
1949
bool ownZeContext)
20-
: hContext(hContext, ownZeContext),
21-
hDevices(phDevices, phDevices + numDevices), commandListCache(hContext),
50+
: commandListCache(hContext),
2251
eventPoolCache(phDevices[0]->Platform->getNumDevices(),
2352
[context = this,
2453
platform = phDevices[0]->Platform](DeviceId deviceId) {
@@ -28,6 +57,10 @@ ur_context_handle_t_::ur_context_handle_t_(ze_context_handle_t hContext,
2857
context, device, v2::EVENT_COUNTER,
2958
v2::QUEUE_IMMEDIATE);
3059
}),
60+
hContext(hContext, ownZeContext),
61+
hDevices(phDevices, phDevices + numDevices),
62+
p2pAccessDevices(populateP2PDevices(
63+
phDevices[0]->Platform->getNumDevices(), this->hDevices)),
3164
defaultUSMPool(this, nullptr) {}
3265

3366
ur_result_t ur_context_handle_t_::retain() {
@@ -65,6 +98,11 @@ ur_usm_pool_handle_t ur_context_handle_t_::getDefaultUSMPool() {
6598
return &defaultUSMPool;
6699
}
67100

101+
const std::vector<ur_device_handle_t> &
102+
ur_context_handle_t_::getP2PDevices(ur_device_handle_t hDevice) const {
103+
return p2pAccessDevices[hDevice->Id.value()];
104+
}
105+
68106
namespace ur::level_zero {
69107
ur_result_t urContextCreate(uint32_t deviceCount,
70108
const ur_device_handle_t *phDevices,

source/adapters/level_zero/v2/context.hpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,22 @@ struct ur_context_handle_t_ : _ur_object {
2828
ur_platform_handle_t getPlatform() const;
2929
const std::vector<ur_device_handle_t> &getDevices() const;
3030
ur_usm_pool_handle_t getDefaultUSMPool();
31+
const std::vector<ur_device_handle_t> &
32+
getP2PDevices(ur_device_handle_t hDevice) const;
3133

3234
// Checks if Device is covered by this context.
3335
// For that the Device or its root devices need to be in the context.
3436
bool isValidDevice(ur_device_handle_t Device) const;
3537

36-
const v2::raii::ze_context_handle_t hContext;
37-
const std::vector<ur_device_handle_t> hDevices;
3838
v2::command_list_cache_t commandListCache;
3939
v2::event_pool_cache eventPoolCache;
40+
41+
private:
42+
const v2::raii::ze_context_handle_t hContext;
43+
const std::vector<ur_device_handle_t> hDevices;
44+
45+
// P2P devices for each device in the context, indexed by device id.
46+
const std::vector<std::vector<ur_device_handle_t>> p2pAccessDevices;
47+
4048
ur_usm_pool_handle_t_ defaultUSMPool;
4149
};

source/adapters/level_zero/v2/usm.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,22 @@ makePool(umf_disjoint_pool_params_t *poolParams,
8585
params.level_zero_device_handle =
8686
poolDescriptor.hDevice ? poolDescriptor.hDevice->ZeDevice : nullptr;
8787
params.memory_type = urToUmfMemoryType(poolDescriptor.type);
88-
// TODO: handle memory residency:
89-
// set resident_device_handles and resident_device_count
88+
89+
std::vector<ze_device_handle_t> residentZeHandles;
90+
91+
if (poolDescriptor.type == UR_USM_TYPE_DEVICE) {
92+
assert(params.level_zero_device_handle);
93+
auto residentHandles =
94+
poolDescriptor.hContext->getP2PDevices(poolDescriptor.hDevice);
95+
residentZeHandles.push_back(params.level_zero_device_handle);
96+
for (auto &device : residentHandles) {
97+
residentZeHandles.push_back(device->ZeDevice);
98+
}
99+
100+
params.resident_device_handles = residentZeHandles.data();
101+
params.resident_device_count = residentZeHandles.size();
102+
}
103+
90104
auto [ret, provider] =
91105
umf::providerMakeUniqueFromOps(umfLevelZeroMemoryProviderOps(), &params);
92106
if (ret != UMF_RESULT_SUCCESS) {

test/adapters/level_zero/v2/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,12 @@ add_unittest(level_zero_event_pool
4545
${PROJECT_SOURCE_DIR}/source/adapters/level_zero/v2/event_provider_counter.cpp
4646
${PROJECT_SOURCE_DIR}/source/adapters/level_zero/v2/event.cpp
4747
)
48+
49+
add_adapter_test(level_zero_memory_residency
50+
FIXTURE DEVICES
51+
SOURCES
52+
memory_residency.cpp
53+
ENVIRONMENT
54+
"UR_ADAPTERS_FORCE_LOAD=\"$<TARGET_FILE:ur_adapter_level_zero_v2>\""
55+
"ZES_ENABLE_SYSMAN=1"
56+
)
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// Copyright (C) 2024 Intel Corporation
2+
// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions.
3+
// See LICENSE.TXT
4+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
5+
6+
#include "ur_print.hpp"
7+
#include "uur/fixtures.h"
8+
#include "uur/raii.h"
9+
#include "uur/utils.h"
10+
11+
#include <map>
12+
#include <string>
13+
14+
using urMemoryResidencyTest = uur::urMultiDeviceContextTestTemplate<1>;
15+
16+
TEST_F(urMemoryResidencyTest, allocatingDeviceMemoryWillResultInOOM) {
17+
static constexpr size_t allocSize = 1024 * 1024;
18+
19+
if (!uur::isPVC(uur::DevicesEnvironment::instance->devices[0])) {
20+
GTEST_SKIP() << "Test requires a PVC device";
21+
}
22+
23+
size_t initialMemFree = 0;
24+
ASSERT_SUCCESS(
25+
urDeviceGetInfo(uur::DevicesEnvironment::instance->devices[0],
26+
UR_DEVICE_INFO_GLOBAL_MEM_FREE, sizeof(size_t),
27+
&initialMemFree, nullptr));
28+
29+
if (initialMemFree < allocSize) {
30+
GTEST_SKIP() << "Not enough device memory available";
31+
}
32+
33+
void *ptr = nullptr;
34+
ASSERT_SUCCESS(
35+
urUSMDeviceAlloc(context, uur::DevicesEnvironment::instance->devices[0],
36+
nullptr, nullptr, allocSize, &ptr));
37+
38+
size_t currentMemFree = 0;
39+
ASSERT_SUCCESS(
40+
urDeviceGetInfo(uur::DevicesEnvironment::instance->devices[0],
41+
UR_DEVICE_INFO_GLOBAL_MEM_FREE, sizeof(size_t),
42+
&currentMemFree, nullptr));
43+
44+
// amount of free memory should decrease after making a memory allocation resident
45+
ASSERT_LE(currentMemFree, initialMemFree);
46+
47+
ASSERT_SUCCESS(urUSMFree(context, ptr));
48+
}

test/conformance/testing/include/uur/utils.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,18 @@ getDriverVersion(ur_device_handle_t hDevice) {
483483
} \
484484
} while (0)
485485

486+
// Is this a Data Center GPU Max series (aka PVC)?
487+
// TODO: change to use
488+
// https://spec.oneapi.io/level-zero/latest/core/api.html#ze-device-ip-version-ext-t
489+
// when that is stable.
490+
static inline bool isPVC(ur_device_handle_t hDevice) {
491+
uint32_t deviceId;
492+
EXPECT_EQ(urDeviceGetInfo(hDevice, UR_DEVICE_INFO_DEVICE_ID,
493+
sizeof(uint32_t), &deviceId, nullptr),
494+
UR_RESULT_SUCCESS);
495+
return (deviceId & 0xff0) == 0xbd0 || (deviceId & 0xff0) == 0xb60;
496+
}
497+
486498
} // namespace uur
487499

488500
#endif // UR_CONFORMANCE_INCLUDE_UTILS_H_INCLUDED

0 commit comments

Comments
 (0)