11
11
#ifndef USM_POOL_MANAGER_HPP
12
12
#define USM_POOL_MANAGER_HPP 1
13
13
14
+ #include " logger/ur_logger.hpp"
15
+ #include " umf_helpers.hpp"
16
+ #include " umf_pools/disjoint_pool.hpp"
14
17
#include " ur_api.h"
15
- #include " ur_pool_manager.hpp"
16
18
#include " ur_util.hpp"
17
19
20
+ #include < umf/memory_pool.h>
21
+ #include < umf/memory_provider.h>
22
+
18
23
#include < functional>
24
+ #include < unordered_map>
19
25
#include < vector>
20
26
21
27
namespace usm {
@@ -29,8 +35,9 @@ struct pool_descriptor {
29
35
ur_usm_type_t type;
30
36
bool deviceReadOnly;
31
37
32
- static bool equal (const pool_descriptor &lhs, const pool_descriptor &rhs);
33
- static std::size_t hash (const pool_descriptor &desc);
38
+ bool operator ==(const pool_descriptor &other) const ;
39
+ friend std::ostream &operator <<(std::ostream &os,
40
+ const pool_descriptor &desc);
34
41
static std::pair<ur_result_t , std::vector<pool_descriptor>>
35
42
create (ur_usm_pool_handle_t poolHandle, ur_context_handle_t hContext);
36
43
};
@@ -75,10 +82,10 @@ urGetSubDevices(ur_device_handle_t hDevice) {
75
82
76
83
inline std::pair<ur_result_t , std::vector<ur_device_handle_t >>
77
84
urGetAllDevicesAndSubDevices (ur_context_handle_t hContext) {
78
- size_t deviceCount;
85
+ size_t deviceCount = 0 ;
79
86
auto ret = urContextGetInfo (hContext, UR_CONTEXT_INFO_NUM_DEVICES,
80
87
sizeof (deviceCount), &deviceCount, nullptr );
81
- if (ret != UR_RESULT_SUCCESS) {
88
+ if (ret != UR_RESULT_SUCCESS || deviceCount == 0 ) {
82
89
return {ret, {}};
83
90
}
84
91
@@ -122,22 +129,28 @@ isSharedAllocationReadOnlyOnDevice(const pool_descriptor &desc) {
122
129
return desc.type == UR_USM_TYPE_SHARED && desc.deviceReadOnly ;
123
130
}
124
131
125
- inline bool pool_descriptor::equal (const pool_descriptor &lhs,
126
- const pool_descriptor &rhs) {
127
- ur_native_handle_t lhsNative, rhsNative;
132
+ inline bool pool_descriptor::operator ==(const pool_descriptor &other) const {
133
+ const pool_descriptor &lhs = *this ;
134
+ const pool_descriptor &rhs = other;
135
+ ur_native_handle_t lhsNative = nullptr , rhsNative = nullptr ;
128
136
129
137
// We want to share a memory pool for sub-devices and sub-sub devices.
130
138
// Sub-devices and sub-sub-devices might be represented by different ur_device_handle_t but
131
139
// they share the same native_handle_t (which is used by UMF provider).
132
140
// Ref: https://github.com/intel/llvm/commit/86511c5dc84b5781dcfd828caadcb5cac157eae1
133
141
// TODO: is this L0 specific?
134
- auto ret = urDeviceGetNativeHandle (lhs.hDevice , &lhsNative);
135
- if (ret != UR_RESULT_SUCCESS) {
136
- throw ret;
142
+ if (lhs.hDevice ) {
143
+ auto ret = urDeviceGetNativeHandle (lhs.hDevice , &lhsNative);
144
+ if (ret != UR_RESULT_SUCCESS) {
145
+ throw ret;
146
+ }
137
147
}
138
- ret = urDeviceGetNativeHandle (rhs.hDevice , &rhsNative);
139
- if (ret != UR_RESULT_SUCCESS) {
140
- throw ret;
148
+
149
+ if (rhs.hDevice ) {
150
+ auto ret = urDeviceGetNativeHandle (rhs.hDevice , &rhsNative);
151
+ if (ret != UR_RESULT_SUCCESS) {
152
+ throw ret;
153
+ }
141
154
}
142
155
143
156
return lhsNative == rhsNative && lhs.type == rhs.type &&
@@ -146,16 +159,12 @@ inline bool pool_descriptor::equal(const pool_descriptor &lhs,
146
159
lhs.poolHandle == rhs.poolHandle ;
147
160
}
148
161
149
- inline std::size_t pool_descriptor::hash (const pool_descriptor &desc) {
150
- ur_native_handle_t native;
151
- auto ret = urDeviceGetNativeHandle (desc.hDevice , &native);
152
- if (ret != UR_RESULT_SUCCESS) {
153
- throw ret;
154
- }
155
-
156
- return combine_hashes (0 , desc.type , native,
157
- isSharedAllocationReadOnlyOnDevice (desc),
158
- desc.poolHandle );
162
+ inline std::ostream &operator <<(std::ostream &os, const pool_descriptor &desc) {
163
+ os << " pool handle: " << desc.poolHandle
164
+ << " context handle: " << desc.hContext
165
+ << " device handle: " << desc.hDevice << " memory type: " << desc.type
166
+ << " is read only: " << desc.deviceReadOnly ;
167
+ return os;
159
168
}
160
169
161
170
inline std::pair<ur_result_t , std::vector<pool_descriptor>>
@@ -177,6 +186,7 @@ pool_descriptor::create(ur_usm_pool_handle_t poolHandle,
177
186
pool_descriptor &desc = descriptors.emplace_back ();
178
187
desc.poolHandle = poolHandle;
179
188
desc.hContext = hContext;
189
+ desc.hDevice = device;
180
190
desc.type = UR_USM_TYPE_DEVICE;
181
191
}
182
192
{
@@ -200,6 +210,69 @@ pool_descriptor::create(ur_usm_pool_handle_t poolHandle,
200
210
return {ret, descriptors};
201
211
}
202
212
213
+ template <typename D> struct pool_manager {
214
+ private:
215
+ using desc_to_pool_map_t = std::unordered_map<D, umf::pool_unique_handle_t >;
216
+
217
+ desc_to_pool_map_t descToPoolMap;
218
+
219
+ public:
220
+ static std::pair<ur_result_t , pool_manager>
221
+ create (desc_to_pool_map_t descToHandleMap = {}) {
222
+ auto manager = pool_manager ();
223
+
224
+ for (auto &[desc, hPool] : descToHandleMap) {
225
+ auto ret = manager.addPool (desc, hPool);
226
+ if (ret != UR_RESULT_SUCCESS) {
227
+ return {ret, pool_manager ()};
228
+ }
229
+ }
230
+
231
+ return {UR_RESULT_SUCCESS, std::move (manager)};
232
+ }
233
+
234
+ ur_result_t addPool (const D &desc,
235
+ umf::pool_unique_handle_t &hPool) noexcept {
236
+ if (!descToPoolMap.try_emplace (desc, std::move (hPool)).second ) {
237
+ logger::error (" Pool for pool descriptor: {}, already exists" , desc);
238
+ return UR_RESULT_ERROR_INVALID_ARGUMENT;
239
+ }
240
+
241
+ return UR_RESULT_SUCCESS;
242
+ }
243
+
244
+ std::optional<umf_memory_pool_handle_t > getPool (const D &desc) noexcept {
245
+ auto it = descToPoolMap.find (desc);
246
+ if (it == descToPoolMap.end ()) {
247
+ logger::error (" Pool descriptor doesn't match any existing pool: {}" ,
248
+ desc);
249
+ return std::nullopt;
250
+ }
251
+
252
+ return it->second .get ();
253
+ }
254
+ };
255
+
203
256
} // namespace usm
204
257
258
+ namespace std {
259
+ // / @brief hash specialization for usm::pool_descriptor
260
+ template <> struct hash <usm::pool_descriptor> {
261
+ inline size_t operator ()(const usm::pool_descriptor &desc) const {
262
+ ur_native_handle_t native = nullptr ;
263
+ if (desc.hDevice ) {
264
+ auto ret = urDeviceGetNativeHandle (desc.hDevice , &native);
265
+ if (ret != UR_RESULT_SUCCESS) {
266
+ throw ret;
267
+ }
268
+ }
269
+
270
+ return combine_hashes (0 , desc.type , native,
271
+ isSharedAllocationReadOnlyOnDevice (desc),
272
+ desc.poolHandle );
273
+ }
274
+ };
275
+
276
+ } // namespace std
277
+
205
278
#endif /* USM_POOL_MANAGER_HPP */
0 commit comments