@@ -127,16 +127,16 @@ makePool(usm::umf_disjoint_pool_config_t *poolParams,
127
127
128
128
if (!poolParams) {
129
129
auto [ret, poolHandle] = umf::poolMakeUniqueFromOps (
130
- umfProxyPoolOps (), std::move (provider), nullptr );
130
+ umfProxyPoolOps (), std::move (provider), nullptr , poolDescriptor );
131
131
if (ret != UMF_RESULT_SUCCESS)
132
132
throw umf::umf2urResult (ret);
133
133
return std::move (poolHandle);
134
134
} else {
135
135
auto umfParams = getUmfParamsHandle (*poolParams);
136
136
137
- auto [ret, poolHandle] =
138
- umf::poolMakeUniqueFromOps ( umfDisjointPoolOps (), std::move (provider),
139
- static_cast <void *>(umfParams.get ()));
137
+ auto [ret, poolHandle] = umf::poolMakeUniqueFromOps (
138
+ umfDisjointPoolOps (), std::move (provider),
139
+ static_cast <void *>(umfParams.get ()), poolDescriptor );
140
140
if (ret != UMF_RESULT_SUCCESS)
141
141
throw umf::umf2urResult (ret);
142
142
return std::move (poolHandle);
@@ -356,6 +356,19 @@ urUSMFree(ur_context_handle_t hContext, ///< [in] handle of the context object
356
356
return exceptionToResult (std::current_exception ());
357
357
}
358
358
359
+ static usm::pool_descriptor *getPoolDescriptor (const void *ptr) {
360
+ auto umfPool = umfPoolByPtr (ptr);
361
+ if (!umfPool) {
362
+ logger::error (" urUSMGetMemAllocInfo: no memory associated with given ptr" );
363
+ throw UR_RESULT_ERROR_INVALID_VALUE;
364
+ }
365
+
366
+ usm::pool_descriptor *poolDesc;
367
+ UMF_CALL_THROWS (umfPoolGetTag (umfPool, reinterpret_cast <void **>(&poolDesc)));
368
+
369
+ return poolDesc;
370
+ }
371
+
359
372
ur_result_t urUSMGetMemAllocInfo (
360
373
ur_context_handle_t hContext, // /< [in] handle of the context object
361
374
const void *ptr, // /< [in] pointer to USM memory object
@@ -367,48 +380,22 @@ ur_result_t urUSMGetMemAllocInfo(
367
380
size_t *pPropValueSizeRet // /< [out][optional] bytes returned in USM
368
381
// /< allocation property
369
382
) try {
370
- ze_device_handle_t zeDeviceHandle;
371
- ZeStruct<ze_memory_allocation_properties_t > zeMemoryAllocationProperties;
372
-
373
- // TODO: implement this using UMF once
374
- // https://github.com/oneapi-src/unified-memory-framework/issues/686
375
- // https://github.com/oneapi-src/unified-memory-framework/issues/687
376
- // are implemented
377
- ZE2UR_CALL (zeMemGetAllocProperties,
378
- (hContext->getZeHandle (), ptr, &zeMemoryAllocationProperties,
379
- &zeDeviceHandle));
380
383
381
384
UrReturnHelper ReturnValue (propValueSize, pPropValue, pPropValueSizeRet);
382
385
switch (propName) {
383
386
case UR_USM_ALLOC_INFO_TYPE: {
384
- ur_usm_type_t memAllocType;
385
- switch (zeMemoryAllocationProperties.type ) {
386
- case ZE_MEMORY_TYPE_UNKNOWN:
387
- memAllocType = UR_USM_TYPE_UNKNOWN;
388
- break ;
389
- case ZE_MEMORY_TYPE_HOST:
390
- memAllocType = UR_USM_TYPE_HOST;
391
- break ;
392
- case ZE_MEMORY_TYPE_DEVICE:
393
- memAllocType = UR_USM_TYPE_DEVICE;
394
- break ;
395
- case ZE_MEMORY_TYPE_SHARED:
396
- memAllocType = UR_USM_TYPE_SHARED;
397
- break ;
398
- default :
399
- logger::error (" urUSMGetMemAllocInfo: unexpected usm memory type" );
400
- return UR_RESULT_ERROR_INVALID_VALUE;
401
- }
402
- return ReturnValue (memAllocType);
387
+ auto poolDesc = getPoolDescriptor (ptr);
388
+
389
+ assert (poolDesc->type != UR_USM_TYPE_UNKNOWN);
390
+ return ReturnValue (poolDesc->type );
403
391
}
404
- case UR_USM_ALLOC_INFO_DEVICE:
405
- if (zeDeviceHandle) {
406
- auto Platform = hContext->getPlatform ();
407
- auto Device = Platform->getDeviceFromNativeHandle (zeDeviceHandle);
408
- return Device ? ReturnValue (Device) : UR_RESULT_ERROR_INVALID_VALUE;
409
- } else {
410
- return UR_RESULT_ERROR_INVALID_VALUE;
411
- }
392
+ case UR_USM_ALLOC_INFO_DEVICE: {
393
+ auto poolDesc = getPoolDescriptor (ptr);
394
+ return ReturnValue (poolDesc->hDevice );
395
+ }
396
+ // TODO: implement this using UMF once
397
+ // https://github.com/oneapi-src/unified-memory-framework/issues/686
398
+ // is implemented
412
399
case UR_USM_ALLOC_INFO_BASE_PTR: {
413
400
void *base;
414
401
ZE2UR_CALL (zeMemGetAddressRange,
@@ -422,9 +409,10 @@ ur_result_t urUSMGetMemAllocInfo(
422
409
return ReturnValue (size);
423
410
}
424
411
case UR_USM_ALLOC_INFO_POOL: {
425
- // TODO
426
- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
427
- default :
412
+ auto poolDesc = getPoolDescriptor (ptr);
413
+ return ReturnValue (poolDesc->poolHandle );
414
+ }
415
+ default : {
428
416
logger::error (" urUSMGetMemAllocInfo: unsupported ParamName" );
429
417
return UR_RESULT_ERROR_INVALID_VALUE;
430
418
}
0 commit comments