Skip to content

Commit ecbe035

Browse files
committed
[L0 v2] use UMF for memory info queries
This avoid going to the driver and speeds up UR_USM_ALLOC_INFO_DEVICE query ~2 times in a microbenchmark.
1 parent 230d19e commit ecbe035

File tree

3 files changed

+72
-44
lines changed

3 files changed

+72
-44
lines changed

source/adapters/level_zero/v2/usm.cpp

Lines changed: 33 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -127,16 +127,16 @@ makePool(usm::umf_disjoint_pool_config_t *poolParams,
127127

128128
if (!poolParams) {
129129
auto [ret, poolHandle] = umf::poolMakeUniqueFromOps(
130-
umfProxyPoolOps(), std::move(provider), nullptr);
130+
umfProxyPoolOps(), std::move(provider), nullptr, poolDescriptor);
131131
if (ret != UMF_RESULT_SUCCESS)
132132
throw umf::umf2urResult(ret);
133133
return std::move(poolHandle);
134134
} else {
135135
auto umfParams = getUmfParamsHandle(*poolParams);
136136

137-
auto [ret, poolHandle] =
138-
umf::poolMakeUniqueFromOps(umfDisjointPoolOps(), std::move(provider),
139-
static_cast<void *>(umfParams.get()));
137+
auto [ret, poolHandle] = umf::poolMakeUniqueFromOps(
138+
umfDisjointPoolOps(), std::move(provider),
139+
static_cast<void *>(umfParams.get()), poolDescriptor);
140140
if (ret != UMF_RESULT_SUCCESS)
141141
throw umf::umf2urResult(ret);
142142
return std::move(poolHandle);
@@ -356,6 +356,19 @@ urUSMFree(ur_context_handle_t hContext, ///< [in] handle of the context object
356356
return exceptionToResult(std::current_exception());
357357
}
358358

359+
static usm::pool_descriptor *getPoolDescriptor(const void *ptr) {
360+
auto umfPool = umfPoolByPtr(ptr);
361+
if (!umfPool) {
362+
logger::error("urUSMGetMemAllocInfo: no memory associated with given ptr");
363+
throw UR_RESULT_ERROR_INVALID_VALUE;
364+
}
365+
366+
usm::pool_descriptor *poolDesc;
367+
UMF_CALL_THROWS(umfPoolGetTag(umfPool, reinterpret_cast<void **>(&poolDesc)));
368+
369+
return poolDesc;
370+
}
371+
359372
ur_result_t urUSMGetMemAllocInfo(
360373
ur_context_handle_t hContext, ///< [in] handle of the context object
361374
const void *ptr, ///< [in] pointer to USM memory object
@@ -367,48 +380,24 @@ ur_result_t urUSMGetMemAllocInfo(
367380
size_t *pPropValueSizeRet ///< [out][optional] bytes returned in USM
368381
///< allocation property
369382
) try {
370-
ze_device_handle_t zeDeviceHandle;
371-
ZeStruct<ze_memory_allocation_properties_t> zeMemoryAllocationProperties;
372-
373-
// TODO: implement this using UMF once
374-
// https://github.com/oneapi-src/unified-memory-framework/issues/686
375-
// https://github.com/oneapi-src/unified-memory-framework/issues/687
376-
// are implemented
377-
ZE2UR_CALL(zeMemGetAllocProperties,
378-
(hContext->getZeHandle(), ptr, &zeMemoryAllocationProperties,
379-
&zeDeviceHandle));
380383

381384
UrReturnHelper ReturnValue(propValueSize, pPropValue, pPropValueSizeRet);
382385
switch (propName) {
383386
case UR_USM_ALLOC_INFO_TYPE: {
384-
ur_usm_type_t memAllocType;
385-
switch (zeMemoryAllocationProperties.type) {
386-
case ZE_MEMORY_TYPE_UNKNOWN:
387-
memAllocType = UR_USM_TYPE_UNKNOWN;
388-
break;
389-
case ZE_MEMORY_TYPE_HOST:
390-
memAllocType = UR_USM_TYPE_HOST;
391-
break;
392-
case ZE_MEMORY_TYPE_DEVICE:
393-
memAllocType = UR_USM_TYPE_DEVICE;
394-
break;
395-
case ZE_MEMORY_TYPE_SHARED:
396-
memAllocType = UR_USM_TYPE_SHARED;
397-
break;
398-
default:
399-
logger::error("urUSMGetMemAllocInfo: unexpected usm memory type");
400-
return UR_RESULT_ERROR_INVALID_VALUE;
387+
try {
388+
auto poolDesc = getPoolDescriptor(ptr);
389+
return ReturnValue(poolDesc->type);
390+
} catch (...) {
391+
return ReturnValue(UR_USM_TYPE_UNKNOWN);
401392
}
402-
return ReturnValue(memAllocType);
403393
}
404-
case UR_USM_ALLOC_INFO_DEVICE:
405-
if (zeDeviceHandle) {
406-
auto Platform = hContext->getPlatform();
407-
auto Device = Platform->getDeviceFromNativeHandle(zeDeviceHandle);
408-
return Device ? ReturnValue(Device) : UR_RESULT_ERROR_INVALID_VALUE;
409-
} else {
410-
return UR_RESULT_ERROR_INVALID_VALUE;
411-
}
394+
case UR_USM_ALLOC_INFO_DEVICE: {
395+
auto poolDesc = getPoolDescriptor(ptr);
396+
return ReturnValue(poolDesc->hDevice);
397+
}
398+
// TODO: implement this using UMF once
399+
// https://github.com/oneapi-src/unified-memory-framework/issues/686
400+
// is implemented
412401
case UR_USM_ALLOC_INFO_BASE_PTR: {
413402
void *base;
414403
ZE2UR_CALL(zeMemGetAddressRange,
@@ -422,9 +411,10 @@ ur_result_t urUSMGetMemAllocInfo(
422411
return ReturnValue(size);
423412
}
424413
case UR_USM_ALLOC_INFO_POOL: {
425-
// TODO
426-
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
427-
default:
414+
auto poolDesc = getPoolDescriptor(ptr);
415+
return ReturnValue(poolDesc->poolHandle);
416+
}
417+
default: {
428418
logger::error("urUSMGetMemAllocInfo: unsupported ParamName");
429419
return UR_RESULT_ERROR_INVALID_VALUE;
430420
}

source/common/umf_helpers.hpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,45 @@ static inline auto poolMakeUniqueFromOps(umf_memory_pool_ops_t *ops,
222222
UMF_RESULT_SUCCESS, pool_unique_handle_t(hPool, umfPoolDestroy)};
223223
}
224224

225+
template <typename Tag>
226+
static inline auto poolMakeUniqueFromOps(umf_memory_pool_ops_t *ops,
227+
provider_unique_handle_t provider,
228+
void *params, const Tag &tag) {
229+
auto poolTag = new Tag(tag);
230+
231+
umf_memory_pool_handle_t hPool;
232+
auto ret = umfPoolCreate(ops, provider.get(), params,
233+
UMF_POOL_CREATE_FLAG_OWN_PROVIDER, &hPool);
234+
if (ret != UMF_RESULT_SUCCESS) {
235+
return std::pair<umf_result_t, pool_unique_handle_t>{
236+
ret, pool_unique_handle_t(nullptr, nullptr)};
237+
}
238+
239+
ret = umfPoolSetTag(hPool, poolTag, nullptr);
240+
if (ret != UMF_RESULT_SUCCESS) {
241+
umfPoolDestroy(hPool);
242+
return std::pair<umf_result_t, pool_unique_handle_t>{
243+
ret, pool_unique_handle_t(nullptr, nullptr)};
244+
}
245+
246+
provider.release(); // pool now owns the provider
247+
248+
return std::pair<umf_result_t, pool_unique_handle_t>{
249+
UMF_RESULT_SUCCESS,
250+
pool_unique_handle_t(hPool, [](umf_memory_pool_handle_t hPool) {
251+
Tag *tag = nullptr;
252+
umfPoolGetTag(hPool, reinterpret_cast<void **>(&tag));
253+
254+
if (tag) {
255+
delete tag;
256+
} else {
257+
logger::error("Failed to get tag from pool");
258+
}
259+
260+
umfPoolDestroy(hPool);
261+
})};
262+
}
263+
225264
static inline auto providerMakeUniqueFromOps(umf_memory_provider_ops_t *ops,
226265
void *params) {
227266
umf_memory_provider_handle_t hProvider;

test/conformance/usm/usm_adapter_level_zero_v2.match

Lines changed: 0 additions & 1 deletion
This file was deleted.

0 commit comments

Comments
 (0)