11
11
#ifndef USM_POOL_MANAGER_HPP
12
12
#define USM_POOL_MANAGER_HPP 1
13
13
14
+ #include " logger/ur_logger.hpp"
15
+ #include " umf_helpers.hpp"
16
+ #include " umf_pools/disjoint_pool.hpp"
14
17
#include " ur_api.h"
15
- #include " ur_pool_manager.hpp"
16
18
#include " ur_util.hpp"
17
19
20
+ #include < umf/memory_pool.h>
21
+ #include < umf/memory_provider.h>
22
+
18
23
#include < functional>
24
+ #include < unordered_map>
19
25
#include < vector>
20
26
21
27
namespace usm {
@@ -29,8 +35,9 @@ struct pool_descriptor {
29
35
ur_usm_type_t type;
30
36
bool deviceReadOnly;
31
37
32
- static bool equal (const pool_descriptor &lhs, const pool_descriptor &rhs);
33
- static std::size_t hash (const pool_descriptor &desc);
38
+ bool operator ==(const pool_descriptor &other) const ;
39
+ friend std::ostream &operator <<(std::ostream &os,
40
+ const pool_descriptor &desc);
34
41
static std::pair<ur_result_t , std::vector<pool_descriptor>>
35
42
create (ur_usm_pool_handle_t poolHandle, ur_context_handle_t hContext);
36
43
};
@@ -45,8 +52,8 @@ urGetSubDevices(ur_device_handle_t hDevice) {
45
52
}
46
53
47
54
ur_device_partition_property_t prop;
48
- prop.type = UR_DEVICE_PARTITION_EQUALLY ;
49
- prop.value .equally = nComputeUnits ;
55
+ prop.type = UR_DEVICE_PARTITION_BY_CSLICE ;
56
+ prop.value .affinity_domain = 0 ;
50
57
51
58
ur_device_partition_properties_t properties{
52
59
UR_STRUCTURE_TYPE_DEVICE_PARTITION_PROPERTIES,
@@ -75,10 +82,10 @@ urGetSubDevices(ur_device_handle_t hDevice) {
75
82
76
83
inline std::pair<ur_result_t , std::vector<ur_device_handle_t >>
77
84
urGetAllDevicesAndSubDevices (ur_context_handle_t hContext) {
78
- size_t deviceCount;
85
+ size_t deviceCount = 0 ;
79
86
auto ret = urContextGetInfo (hContext, UR_CONTEXT_INFO_NUM_DEVICES,
80
87
sizeof (deviceCount), &deviceCount, nullptr );
81
- if (ret != UR_RESULT_SUCCESS) {
88
+ if (ret != UR_RESULT_SUCCESS || deviceCount == 0 ) {
82
89
return {ret, {}};
83
90
}
84
91
@@ -110,6 +117,11 @@ urGetAllDevicesAndSubDevices(ur_context_handle_t hContext) {
110
117
for (size_t i = 0 ; i < deviceCount; i++) {
111
118
ret = addPoolsForDevicesRec (devices[i]);
112
119
if (ret != UR_RESULT_SUCCESS) {
120
+ if (ret == UR_RESULT_ERROR_UNSUPPORTED_FEATURE) {
121
+ // Return main devices when sub-devices are unsupported.
122
+ return {ret, std::move (devices)};
123
+ }
124
+
113
125
return {ret, {}};
114
126
}
115
127
}
@@ -122,22 +134,28 @@ isSharedAllocationReadOnlyOnDevice(const pool_descriptor &desc) {
122
134
return desc.type == UR_USM_TYPE_SHARED && desc.deviceReadOnly ;
123
135
}
124
136
125
- inline bool pool_descriptor::equal (const pool_descriptor &lhs,
126
- const pool_descriptor &rhs) {
127
- ur_native_handle_t lhsNative, rhsNative;
137
+ inline bool pool_descriptor::operator ==(const pool_descriptor &other) const {
138
+ const pool_descriptor &lhs = *this ;
139
+ const pool_descriptor &rhs = other;
140
+ ur_native_handle_t lhsNative = nullptr , rhsNative = nullptr ;
128
141
129
142
// We want to share a memory pool for sub-devices and sub-sub devices.
130
143
// Sub-devices and sub-sub-devices might be represented by different ur_device_handle_t but
131
144
// they share the same native_handle_t (which is used by UMF provider).
132
145
// Ref: https://github.com/intel/llvm/commit/86511c5dc84b5781dcfd828caadcb5cac157eae1
133
146
// TODO: is this L0 specific?
134
- auto ret = urDeviceGetNativeHandle (lhs.hDevice , &lhsNative);
135
- if (ret != UR_RESULT_SUCCESS) {
136
- throw ret;
147
+ if (lhs.hDevice ) {
148
+ auto ret = urDeviceGetNativeHandle (lhs.hDevice , &lhsNative);
149
+ if (ret != UR_RESULT_SUCCESS) {
150
+ throw ret;
151
+ }
137
152
}
138
- ret = urDeviceGetNativeHandle (rhs.hDevice , &rhsNative);
139
- if (ret != UR_RESULT_SUCCESS) {
140
- throw ret;
153
+
154
+ if (rhs.hDevice ) {
155
+ auto ret = urDeviceGetNativeHandle (rhs.hDevice , &rhsNative);
156
+ if (ret != UR_RESULT_SUCCESS) {
157
+ throw ret;
158
+ }
141
159
}
142
160
143
161
return lhsNative == rhsNative && lhs.type == rhs.type &&
@@ -146,16 +164,12 @@ inline bool pool_descriptor::equal(const pool_descriptor &lhs,
146
164
lhs.poolHandle == rhs.poolHandle ;
147
165
}
148
166
149
- inline std::size_t pool_descriptor::hash (const pool_descriptor &desc) {
150
- ur_native_handle_t native;
151
- auto ret = urDeviceGetNativeHandle (desc.hDevice , &native);
152
- if (ret != UR_RESULT_SUCCESS) {
153
- throw ret;
154
- }
155
-
156
- return combine_hashes (0 , desc.type , native,
157
- isSharedAllocationReadOnlyOnDevice (desc),
158
- desc.poolHandle );
167
+ inline std::ostream &operator <<(std::ostream &os, const pool_descriptor &desc) {
168
+ os << " pool handle: " << desc.poolHandle
169
+ << " context handle: " << desc.hContext
170
+ << " device handle: " << desc.hDevice << " memory type: " << desc.type
171
+ << " is read only: " << desc.deviceReadOnly ;
172
+ return os;
159
173
}
160
174
161
175
inline std::pair<ur_result_t , std::vector<pool_descriptor>>
@@ -177,6 +191,7 @@ pool_descriptor::create(ur_usm_pool_handle_t poolHandle,
177
191
pool_descriptor &desc = descriptors.emplace_back ();
178
192
desc.poolHandle = poolHandle;
179
193
desc.hContext = hContext;
194
+ desc.hDevice = device;
180
195
desc.type = UR_USM_TYPE_DEVICE;
181
196
}
182
197
{
@@ -200,6 +215,69 @@ pool_descriptor::create(ur_usm_pool_handle_t poolHandle,
200
215
return {ret, descriptors};
201
216
}
202
217
218
+ template <typename D> struct pool_manager {
219
+ private:
220
+ using desc_to_pool_map_t = std::unordered_map<D, umf::pool_unique_handle_t >;
221
+
222
+ desc_to_pool_map_t descToPoolMap;
223
+
224
+ public:
225
+ static std::pair<ur_result_t , pool_manager>
226
+ create (desc_to_pool_map_t descToHandleMap = {}) {
227
+ auto manager = pool_manager ();
228
+
229
+ for (auto &[desc, hPool] : descToHandleMap) {
230
+ auto ret = manager.addPool (desc, hPool);
231
+ if (ret != UR_RESULT_SUCCESS) {
232
+ return {ret, pool_manager ()};
233
+ }
234
+ }
235
+
236
+ return {UR_RESULT_SUCCESS, std::move (manager)};
237
+ }
238
+
239
+ ur_result_t addPool (const D &desc,
240
+ umf::pool_unique_handle_t &hPool) noexcept {
241
+ if (!descToPoolMap.try_emplace (desc, std::move (hPool)).second ) {
242
+ logger::error (" Pool for pool descriptor: {}, already exists" , desc);
243
+ return UR_RESULT_ERROR_INVALID_ARGUMENT;
244
+ }
245
+
246
+ return UR_RESULT_SUCCESS;
247
+ }
248
+
249
+ std::optional<umf_memory_pool_handle_t > getPool (const D &desc) noexcept {
250
+ auto it = descToPoolMap.find (desc);
251
+ if (it == descToPoolMap.end ()) {
252
+ logger::error (" Pool descriptor doesn't match any existing pool: {}" ,
253
+ desc);
254
+ return std::nullopt;
255
+ }
256
+
257
+ return it->second .get ();
258
+ }
259
+ };
260
+
203
261
} // namespace usm
204
262
263
+ namespace std {
264
+ // / @brief hash specialization for usm::pool_descriptor
265
+ template <> struct hash <usm::pool_descriptor> {
266
+ inline size_t operator ()(const usm::pool_descriptor &desc) const {
267
+ ur_native_handle_t native = nullptr ;
268
+ if (desc.hDevice ) {
269
+ auto ret = urDeviceGetNativeHandle (desc.hDevice , &native);
270
+ if (ret != UR_RESULT_SUCCESS) {
271
+ throw ret;
272
+ }
273
+ }
274
+
275
+ return combine_hashes (0 , desc.type , native,
276
+ isSharedAllocationReadOnlyOnDevice (desc),
277
+ desc.poolHandle );
278
+ }
279
+ };
280
+
281
+ } // namespace std
282
+
205
283
#endif /* USM_POOL_MANAGER_HPP */
0 commit comments