Skip to content

Commit 0cbc511

Browse files
committed
Add a basic pool manager for memory pools
1 parent b3cc9ae commit 0cbc511

File tree

1 file changed

+97
-24
lines changed

1 file changed

+97
-24
lines changed

source/common/ur_pool_manager.hpp

Lines changed: 97 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,17 @@
1111
#ifndef USM_POOL_MANAGER_HPP
1212
#define USM_POOL_MANAGER_HPP 1
1313

14+
#include "logger/ur_logger.hpp"
15+
#include "umf_helpers.hpp"
16+
#include "umf_pools/disjoint_pool.hpp"
1417
#include "ur_api.h"
15-
#include "ur_pool_manager.hpp"
1618
#include "ur_util.hpp"
1719

20+
#include <umf/memory_pool.h>
21+
#include <umf/memory_provider.h>
22+
1823
#include <functional>
24+
#include <unordered_map>
1925
#include <vector>
2026

2127
namespace usm {
@@ -29,8 +35,9 @@ struct pool_descriptor {
2935
ur_usm_type_t type;
3036
bool deviceReadOnly;
3137

32-
static bool equal(const pool_descriptor &lhs, const pool_descriptor &rhs);
33-
static std::size_t hash(const pool_descriptor &desc);
38+
bool operator==(const pool_descriptor &other) const;
39+
friend std::ostream &operator<<(std::ostream &os,
40+
const pool_descriptor &desc);
3441
static std::pair<ur_result_t, std::vector<pool_descriptor>>
3542
create(ur_usm_pool_handle_t poolHandle, ur_context_handle_t hContext);
3643
};
@@ -75,10 +82,10 @@ urGetSubDevices(ur_device_handle_t hDevice) {
7582

7683
inline std::pair<ur_result_t, std::vector<ur_device_handle_t>>
7784
urGetAllDevicesAndSubDevices(ur_context_handle_t hContext) {
78-
size_t deviceCount;
85+
size_t deviceCount = 0;
7986
auto ret = urContextGetInfo(hContext, UR_CONTEXT_INFO_NUM_DEVICES,
8087
sizeof(deviceCount), &deviceCount, nullptr);
81-
if (ret != UR_RESULT_SUCCESS) {
88+
if (ret != UR_RESULT_SUCCESS || deviceCount == 0) {
8289
return {ret, {}};
8390
}
8491

@@ -122,22 +129,28 @@ isSharedAllocationReadOnlyOnDevice(const pool_descriptor &desc) {
122129
return desc.type == UR_USM_TYPE_SHARED && desc.deviceReadOnly;
123130
}
124131

125-
inline bool pool_descriptor::equal(const pool_descriptor &lhs,
126-
const pool_descriptor &rhs) {
127-
ur_native_handle_t lhsNative, rhsNative;
132+
inline bool pool_descriptor::operator==(const pool_descriptor &other) const {
133+
const pool_descriptor &lhs = *this;
134+
const pool_descriptor &rhs = other;
135+
ur_native_handle_t lhsNative = nullptr, rhsNative = nullptr;
128136

129137
// We want to share a memory pool for sub-devices and sub-sub devices.
130138
// Sub-devices and sub-sub-devices might be represented by different ur_device_handle_t but
131139
// they share the same native_handle_t (which is used by UMF provider).
132140
// Ref: https://github.com/intel/llvm/commit/86511c5dc84b5781dcfd828caadcb5cac157eae1
133141
// TODO: is this L0 specific?
134-
auto ret = urDeviceGetNativeHandle(lhs.hDevice, &lhsNative);
135-
if (ret != UR_RESULT_SUCCESS) {
136-
throw ret;
142+
if (lhs.hDevice) {
143+
auto ret = urDeviceGetNativeHandle(lhs.hDevice, &lhsNative);
144+
if (ret != UR_RESULT_SUCCESS) {
145+
throw ret;
146+
}
137147
}
138-
ret = urDeviceGetNativeHandle(rhs.hDevice, &rhsNative);
139-
if (ret != UR_RESULT_SUCCESS) {
140-
throw ret;
148+
149+
if (rhs.hDevice) {
150+
auto ret = urDeviceGetNativeHandle(rhs.hDevice, &rhsNative);
151+
if (ret != UR_RESULT_SUCCESS) {
152+
throw ret;
153+
}
141154
}
142155

143156
return lhsNative == rhsNative && lhs.type == rhs.type &&
@@ -146,16 +159,12 @@ inline bool pool_descriptor::equal(const pool_descriptor &lhs,
146159
lhs.poolHandle == rhs.poolHandle;
147160
}
148161

149-
inline std::size_t pool_descriptor::hash(const pool_descriptor &desc) {
150-
ur_native_handle_t native;
151-
auto ret = urDeviceGetNativeHandle(desc.hDevice, &native);
152-
if (ret != UR_RESULT_SUCCESS) {
153-
throw ret;
154-
}
155-
156-
return combine_hashes(0, desc.type, native,
157-
isSharedAllocationReadOnlyOnDevice(desc),
158-
desc.poolHandle);
162+
inline std::ostream &operator<<(std::ostream &os, const pool_descriptor &desc) {
163+
os << "pool handle: " << desc.poolHandle
164+
<< " context handle: " << desc.hContext
165+
<< " device handle: " << desc.hDevice << " memory type: " << desc.type
166+
<< " is read only: " << desc.deviceReadOnly;
167+
return os;
159168
}
160169

161170
inline std::pair<ur_result_t, std::vector<pool_descriptor>>
@@ -177,6 +186,7 @@ pool_descriptor::create(ur_usm_pool_handle_t poolHandle,
177186
pool_descriptor &desc = descriptors.emplace_back();
178187
desc.poolHandle = poolHandle;
179188
desc.hContext = hContext;
189+
desc.hDevice = device;
180190
desc.type = UR_USM_TYPE_DEVICE;
181191
}
182192
{
@@ -200,6 +210,69 @@ pool_descriptor::create(ur_usm_pool_handle_t poolHandle,
200210
return {ret, descriptors};
201211
}
202212

213+
template <typename D> struct pool_manager {
214+
private:
215+
using desc_to_pool_map_t = std::unordered_map<D, umf::pool_unique_handle_t>;
216+
217+
desc_to_pool_map_t descToPoolMap;
218+
219+
public:
220+
static std::pair<ur_result_t, pool_manager>
221+
create(desc_to_pool_map_t descToHandleMap = {}) {
222+
auto manager = pool_manager();
223+
224+
for (auto &[desc, hPool] : descToHandleMap) {
225+
auto ret = manager.addPool(desc, hPool);
226+
if (ret != UR_RESULT_SUCCESS) {
227+
return {ret, pool_manager()};
228+
}
229+
}
230+
231+
return {UR_RESULT_SUCCESS, std::move(manager)};
232+
}
233+
234+
ur_result_t addPool(const D &desc,
235+
umf::pool_unique_handle_t &hPool) noexcept {
236+
if (!descToPoolMap.try_emplace(desc, std::move(hPool)).second) {
237+
logger::error("Pool for pool descriptor: {}, already exists", desc);
238+
return UR_RESULT_ERROR_INVALID_ARGUMENT;
239+
}
240+
241+
return UR_RESULT_SUCCESS;
242+
}
243+
244+
std::optional<umf_memory_pool_handle_t> getPool(const D &desc) noexcept {
245+
auto it = descToPoolMap.find(desc);
246+
if (it == descToPoolMap.end()) {
247+
logger::error("Pool descriptor doesn't match any existing pool: {}",
248+
desc);
249+
return std::nullopt;
250+
}
251+
252+
return it->second.get();
253+
}
254+
};
255+
203256
} // namespace usm
204257

258+
namespace std {
259+
/// @brief hash specialization for usm::pool_descriptor
260+
template <> struct hash<usm::pool_descriptor> {
261+
inline size_t operator()(const usm::pool_descriptor &desc) const {
262+
ur_native_handle_t native = nullptr;
263+
if (desc.hDevice) {
264+
auto ret = urDeviceGetNativeHandle(desc.hDevice, &native);
265+
if (ret != UR_RESULT_SUCCESS) {
266+
throw ret;
267+
}
268+
}
269+
270+
return combine_hashes(0, desc.type, native,
271+
isSharedAllocationReadOnlyOnDevice(desc),
272+
desc.poolHandle);
273+
}
274+
};
275+
276+
} // namespace std
277+
205278
#endif /* USM_POOL_MANAGER_HPP */

0 commit comments

Comments
 (0)