Skip to content

Commit 2b7a24a

Browse files
authored
Merge pull request #630 from kswiecicki/usm-pool-manager
Add a basic pool manager for memory pools
2 parents 3a85197 + 0657f06 commit 2b7a24a

File tree

3 files changed

+180
-35
lines changed

3 files changed

+180
-35
lines changed

source/common/ur_pool_manager.hpp

Lines changed: 104 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,17 @@
1111
#ifndef USM_POOL_MANAGER_HPP
1212
#define USM_POOL_MANAGER_HPP 1
1313

14+
#include "logger/ur_logger.hpp"
15+
#include "umf_helpers.hpp"
16+
#include "umf_pools/disjoint_pool.hpp"
1417
#include "ur_api.h"
15-
#include "ur_pool_manager.hpp"
1618
#include "ur_util.hpp"
1719

20+
#include <umf/memory_pool.h>
21+
#include <umf/memory_provider.h>
22+
1823
#include <functional>
24+
#include <unordered_map>
1925
#include <vector>
2026

2127
namespace usm {
@@ -29,8 +35,9 @@ struct pool_descriptor {
2935
ur_usm_type_t type;
3036
bool deviceReadOnly;
3137

32-
static bool equal(const pool_descriptor &lhs, const pool_descriptor &rhs);
33-
static std::size_t hash(const pool_descriptor &desc);
38+
bool operator==(const pool_descriptor &other) const;
39+
friend std::ostream &operator<<(std::ostream &os,
40+
const pool_descriptor &desc);
3441
static std::pair<ur_result_t, std::vector<pool_descriptor>>
3542
create(ur_usm_pool_handle_t poolHandle, ur_context_handle_t hContext);
3643
};
@@ -45,8 +52,8 @@ urGetSubDevices(ur_device_handle_t hDevice) {
4552
}
4653

4754
ur_device_partition_property_t prop;
48-
prop.type = UR_DEVICE_PARTITION_EQUALLY;
49-
prop.value.equally = nComputeUnits;
55+
prop.type = UR_DEVICE_PARTITION_BY_CSLICE;
56+
prop.value.affinity_domain = 0;
5057

5158
ur_device_partition_properties_t properties{
5259
UR_STRUCTURE_TYPE_DEVICE_PARTITION_PROPERTIES,
@@ -75,10 +82,10 @@ urGetSubDevices(ur_device_handle_t hDevice) {
7582

7683
inline std::pair<ur_result_t, std::vector<ur_device_handle_t>>
7784
urGetAllDevicesAndSubDevices(ur_context_handle_t hContext) {
78-
size_t deviceCount;
85+
size_t deviceCount = 0;
7986
auto ret = urContextGetInfo(hContext, UR_CONTEXT_INFO_NUM_DEVICES,
8087
sizeof(deviceCount), &deviceCount, nullptr);
81-
if (ret != UR_RESULT_SUCCESS) {
88+
if (ret != UR_RESULT_SUCCESS || deviceCount == 0) {
8289
return {ret, {}};
8390
}
8491

@@ -110,6 +117,11 @@ urGetAllDevicesAndSubDevices(ur_context_handle_t hContext) {
110117
for (size_t i = 0; i < deviceCount; i++) {
111118
ret = addPoolsForDevicesRec(devices[i]);
112119
if (ret != UR_RESULT_SUCCESS) {
120+
if (ret == UR_RESULT_ERROR_UNSUPPORTED_FEATURE) {
121+
// Return main devices when sub-devices are unsupported.
122+
return {ret, std::move(devices)};
123+
}
124+
113125
return {ret, {}};
114126
}
115127
}
@@ -122,22 +134,28 @@ isSharedAllocationReadOnlyOnDevice(const pool_descriptor &desc) {
122134
return desc.type == UR_USM_TYPE_SHARED && desc.deviceReadOnly;
123135
}
124136

125-
inline bool pool_descriptor::equal(const pool_descriptor &lhs,
126-
const pool_descriptor &rhs) {
127-
ur_native_handle_t lhsNative, rhsNative;
137+
inline bool pool_descriptor::operator==(const pool_descriptor &other) const {
138+
const pool_descriptor &lhs = *this;
139+
const pool_descriptor &rhs = other;
140+
ur_native_handle_t lhsNative = nullptr, rhsNative = nullptr;
128141

129142
// We want to share a memory pool for sub-devices and sub-sub devices.
130143
// Sub-devices and sub-sub-devices might be represented by different ur_device_handle_t but
131144
// they share the same native_handle_t (which is used by UMF provider).
132145
// Ref: https://github.com/intel/llvm/commit/86511c5dc84b5781dcfd828caadcb5cac157eae1
133146
// TODO: is this L0 specific?
134-
auto ret = urDeviceGetNativeHandle(lhs.hDevice, &lhsNative);
135-
if (ret != UR_RESULT_SUCCESS) {
136-
throw ret;
147+
if (lhs.hDevice) {
148+
auto ret = urDeviceGetNativeHandle(lhs.hDevice, &lhsNative);
149+
if (ret != UR_RESULT_SUCCESS) {
150+
throw ret;
151+
}
137152
}
138-
ret = urDeviceGetNativeHandle(rhs.hDevice, &rhsNative);
139-
if (ret != UR_RESULT_SUCCESS) {
140-
throw ret;
153+
154+
if (rhs.hDevice) {
155+
auto ret = urDeviceGetNativeHandle(rhs.hDevice, &rhsNative);
156+
if (ret != UR_RESULT_SUCCESS) {
157+
throw ret;
158+
}
141159
}
142160

143161
return lhsNative == rhsNative && lhs.type == rhs.type &&
@@ -146,16 +164,12 @@ inline bool pool_descriptor::equal(const pool_descriptor &lhs,
146164
lhs.poolHandle == rhs.poolHandle;
147165
}
148166

149-
inline std::size_t pool_descriptor::hash(const pool_descriptor &desc) {
150-
ur_native_handle_t native;
151-
auto ret = urDeviceGetNativeHandle(desc.hDevice, &native);
152-
if (ret != UR_RESULT_SUCCESS) {
153-
throw ret;
154-
}
155-
156-
return combine_hashes(0, desc.type, native,
157-
isSharedAllocationReadOnlyOnDevice(desc),
158-
desc.poolHandle);
167+
inline std::ostream &operator<<(std::ostream &os, const pool_descriptor &desc) {
168+
os << "pool handle: " << desc.poolHandle
169+
<< " context handle: " << desc.hContext
170+
<< " device handle: " << desc.hDevice << " memory type: " << desc.type
171+
<< " is read only: " << desc.deviceReadOnly;
172+
return os;
159173
}
160174

161175
inline std::pair<ur_result_t, std::vector<pool_descriptor>>
@@ -177,6 +191,7 @@ pool_descriptor::create(ur_usm_pool_handle_t poolHandle,
177191
pool_descriptor &desc = descriptors.emplace_back();
178192
desc.poolHandle = poolHandle;
179193
desc.hContext = hContext;
194+
desc.hDevice = device;
180195
desc.type = UR_USM_TYPE_DEVICE;
181196
}
182197
{
@@ -200,6 +215,69 @@ pool_descriptor::create(ur_usm_pool_handle_t poolHandle,
200215
return {ret, descriptors};
201216
}
202217

218+
template <typename D> struct pool_manager {
219+
private:
220+
using desc_to_pool_map_t = std::unordered_map<D, umf::pool_unique_handle_t>;
221+
222+
desc_to_pool_map_t descToPoolMap;
223+
224+
public:
225+
static std::pair<ur_result_t, pool_manager>
226+
create(desc_to_pool_map_t descToHandleMap = {}) {
227+
auto manager = pool_manager();
228+
229+
for (auto &[desc, hPool] : descToHandleMap) {
230+
auto ret = manager.addPool(desc, hPool);
231+
if (ret != UR_RESULT_SUCCESS) {
232+
return {ret, pool_manager()};
233+
}
234+
}
235+
236+
return {UR_RESULT_SUCCESS, std::move(manager)};
237+
}
238+
239+
ur_result_t addPool(const D &desc,
240+
umf::pool_unique_handle_t &hPool) noexcept {
241+
if (!descToPoolMap.try_emplace(desc, std::move(hPool)).second) {
242+
logger::error("Pool for pool descriptor: {}, already exists", desc);
243+
return UR_RESULT_ERROR_INVALID_ARGUMENT;
244+
}
245+
246+
return UR_RESULT_SUCCESS;
247+
}
248+
249+
std::optional<umf_memory_pool_handle_t> getPool(const D &desc) noexcept {
250+
auto it = descToPoolMap.find(desc);
251+
if (it == descToPoolMap.end()) {
252+
logger::error("Pool descriptor doesn't match any existing pool: {}",
253+
desc);
254+
return std::nullopt;
255+
}
256+
257+
return it->second.get();
258+
}
259+
};
260+
203261
} // namespace usm
204262

263+
namespace std {
264+
/// @brief hash specialization for usm::pool_descriptor
265+
template <> struct hash<usm::pool_descriptor> {
266+
inline size_t operator()(const usm::pool_descriptor &desc) const {
267+
ur_native_handle_t native = nullptr;
268+
if (desc.hDevice) {
269+
auto ret = urDeviceGetNativeHandle(desc.hDevice, &native);
270+
if (ret != UR_RESULT_SUCCESS) {
271+
throw ret;
272+
}
273+
}
274+
275+
return combine_hashes(0, desc.type, native,
276+
isSharedAllocationReadOnlyOnDevice(desc),
277+
desc.poolHandle);
278+
}
279+
};
280+
281+
} // namespace std
282+
205283
#endif /* USM_POOL_MANAGER_HPP */

test/usm/CMakeLists.txt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,21 @@ function(add_usm_test name)
1010
add_ur_executable(${TEST_TARGET_NAME}
1111
${UR_USM_TEST_DIR}/../conformance/source/environment.cpp
1212
${UR_USM_TEST_DIR}/../conformance/source/main.cpp
13+
${UR_USM_TEST_DIR}/../unified_malloc_framework/common/provider.c
14+
${UR_USM_TEST_DIR}/../unified_malloc_framework/common/pool.c
1315
${ARGN})
1416
target_link_libraries(${TEST_TARGET_NAME}
1517
PRIVATE
1618
${PROJECT_NAME}::common
1719
${PROJECT_NAME}::loader
1820
ur_testing
1921
GTest::gtest_main)
20-
add_test(NAME usm-${name}
22+
add_test(NAME usm-${name}
2123
COMMAND ${TEST_TARGET_NAME}
2224
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
23-
set_tests_properties(usm-${name} PROPERTIES LABELS "usm")
25+
set_tests_properties(usm-${name} PROPERTIES
26+
LABELS "usm"
27+
ENVIRONMENT "UR_ADAPTERS_FORCE_LOAD=\"$<TARGET_FILE:ur_adapter_null>\"")
2428
target_compile_definitions("usm_test-${name}" PRIVATE DEVICES_ENVIRONMENT)
2529
endfunction()
2630

test/usm/usmPoolManager.cpp

Lines changed: 70 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,18 @@
33
// See LICENSE.TXT
44
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
55

6-
#include "../unified_malloc_framework/common/pool.hpp"
7-
#include "../unified_malloc_framework/common/provider.hpp"
86
#include "ur_pool_manager.hpp"
97

10-
#include <uur/fixtures.h>
8+
#include "../unified_malloc_framework/common/pool.h"
9+
#include "../unified_malloc_framework/common/provider.h"
1110

12-
#include <unordered_set>
11+
#include <uur/fixtures.h>
1312

14-
struct urUsmPoolManagerTest
13+
struct urUsmPoolDescriptorTest
1514
: public uur::urMultiDeviceContextTest,
1615
::testing::WithParamInterface<ur_usm_pool_handle_t> {};
1716

18-
TEST_P(urUsmPoolManagerTest, poolIsPerContextTypeAndDevice) {
17+
TEST_P(urUsmPoolDescriptorTest, poolIsPerContextTypeAndDevice) {
1918
auto &devices = uur::DevicesEnvironment::instance->devices;
2019
auto poolHandle = this->GetParam();
2120

@@ -49,7 +48,71 @@ TEST_P(urUsmPoolManagerTest, poolIsPerContextTypeAndDevice) {
4948
ASSERT_EQ(sharedPools, devices.size() * 2);
5049
}
5150

52-
INSTANTIATE_TEST_SUITE_P(urUsmPoolManagerTest, urUsmPoolManagerTest,
51+
INSTANTIATE_TEST_SUITE_P(urUsmPoolDescriptorTest, urUsmPoolDescriptorTest,
5352
::testing::Values(nullptr));
5453

5554
// TODO: add test with sub-devices
55+
56+
struct urUsmPoolManagerTest : public uur::urContextTest {
57+
void SetUp() override {
58+
UUR_RETURN_ON_FATAL_FAILURE(urContextTest::SetUp());
59+
auto [ret, descs] = usm::pool_descriptor::create(nullptr, context);
60+
ASSERT_EQ(ret, UR_RESULT_SUCCESS);
61+
poolDescriptors = descs;
62+
}
63+
64+
std::vector<usm::pool_descriptor> poolDescriptors;
65+
};
66+
67+
TEST_P(urUsmPoolManagerTest, poolManagerPopulate) {
68+
auto [ret, manager] = usm::pool_manager<usm::pool_descriptor>::create();
69+
ASSERT_EQ(ret, UR_RESULT_SUCCESS);
70+
71+
for (auto &desc : poolDescriptors) {
72+
// Populate the pool manager
73+
auto pool = nullPoolCreate();
74+
ASSERT_NE(pool, nullptr);
75+
auto poolUnique = umf::pool_unique_handle_t(pool, umfPoolDestroy);
76+
ASSERT_NE(poolUnique, nullptr);
77+
ret = manager.addPool(desc, poolUnique);
78+
ASSERT_EQ(ret, UR_RESULT_SUCCESS);
79+
}
80+
81+
for (auto &desc : poolDescriptors) {
82+
// Confirm that there is a pool for each descriptor
83+
auto hPoolOpt = manager.getPool(desc);
84+
ASSERT_TRUE(hPoolOpt.has_value());
85+
ASSERT_NE(hPoolOpt.value(), nullptr);
86+
}
87+
}
88+
89+
TEST_P(urUsmPoolManagerTest, poolManagerInsertExisting) {
90+
auto [ret, manager] = usm::pool_manager<usm::pool_descriptor>::create();
91+
ASSERT_EQ(ret, UR_RESULT_SUCCESS);
92+
93+
auto desc = poolDescriptors[0];
94+
95+
auto pool = nullPoolCreate();
96+
ASSERT_NE(pool, nullptr);
97+
auto poolUnique = umf::pool_unique_handle_t(pool, umfPoolDestroy);
98+
ASSERT_NE(poolUnique, nullptr);
99+
100+
ret = manager.addPool(desc, poolUnique);
101+
ASSERT_EQ(ret, UR_RESULT_SUCCESS);
102+
103+
// Inserting an existing key should return an error
104+
ret = manager.addPool(desc, poolUnique);
105+
ASSERT_EQ(ret, UR_RESULT_ERROR_INVALID_ARGUMENT);
106+
}
107+
108+
TEST_P(urUsmPoolManagerTest, poolManagerGetNonexistant) {
109+
auto [ret, manager] = usm::pool_manager<usm::pool_descriptor>::create();
110+
ASSERT_EQ(ret, UR_RESULT_SUCCESS);
111+
112+
for (auto &desc : poolDescriptors) {
113+
auto hPool = manager.getPool(desc);
114+
ASSERT_FALSE(hPool.has_value());
115+
}
116+
}
117+
118+
UUR_INSTANTIATE_DEVICE_TEST_SUITE_P(urUsmPoolManagerTest);

0 commit comments

Comments
 (0)