@@ -187,8 +187,15 @@ static ur_result_t USMDeviceAllocImpl(void **ResultPtr,
187
187
ZeDesc.pNext = &RelaxedDesc;
188
188
}
189
189
190
- ZE2UR_CALL (zeMemAllocDevice, (Context->ZeContext , &ZeDesc, Size, Alignment,
191
- Device->ZeDevice , ResultPtr));
190
+ ze_result_t ZeResult =
191
+ zeMemAllocDevice (Context->ZeContext , &ZeDesc, Size, Alignment,
192
+ Device->ZeDevice , ResultPtr);
193
+ if (ZeResult != ZE_RESULT_SUCCESS) {
194
+ if (ZeResult == ZE_RESULT_ERROR_UNSUPPORTED_SIZE) {
195
+ return UR_RESULT_ERROR_INVALID_USM_SIZE;
196
+ }
197
+ return ze2urResult (ZeResult);
198
+ }
192
199
193
200
UR_ASSERT (Alignment == 0 ||
194
201
reinterpret_cast <std::uintptr_t >(*ResultPtr) % Alignment == 0 ,
@@ -226,8 +233,15 @@ static ur_result_t USMSharedAllocImpl(void **ResultPtr,
226
233
ZeDevDesc.pNext = &RelaxedDesc;
227
234
}
228
235
229
- ZE2UR_CALL (zeMemAllocShared, (Context->ZeContext , &ZeDevDesc, &ZeHostDesc,
230
- Size, Alignment, Device->ZeDevice , ResultPtr));
236
+ ze_result_t ZeResult =
237
+ zeMemAllocShared (Context->ZeContext , &ZeDevDesc, &ZeHostDesc, Size,
238
+ Alignment, Device->ZeDevice , ResultPtr);
239
+ if (ZeResult != ZE_RESULT_SUCCESS) {
240
+ if (ZeResult == ZE_RESULT_ERROR_UNSUPPORTED_SIZE) {
241
+ return UR_RESULT_ERROR_INVALID_USM_SIZE;
242
+ }
243
+ return ze2urResult (ZeResult);
244
+ }
231
245
232
246
UR_ASSERT (Alignment == 0 ||
233
247
reinterpret_cast <std::uintptr_t >(*ResultPtr) % Alignment == 0 ,
@@ -254,8 +268,14 @@ static ur_result_t USMHostAllocImpl(void **ResultPtr,
254
268
// TODO: translate PI properties to Level Zero flags
255
269
ZeStruct<ze_host_mem_alloc_desc_t > ZeHostDesc;
256
270
ZeHostDesc.flags = 0 ;
257
- ZE2UR_CALL (zeMemAllocHost,
258
- (Context->ZeContext , &ZeHostDesc, Size, Alignment, ResultPtr));
271
+ ze_result_t ZeResult = zeMemAllocHost (Context->ZeContext , &ZeHostDesc, Size,
272
+ Alignment, ResultPtr);
273
+ if (ZeResult != ZE_RESULT_SUCCESS) {
274
+ if (ZeResult == ZE_RESULT_ERROR_UNSUPPORTED_SIZE) {
275
+ return UR_RESULT_ERROR_INVALID_USM_SIZE;
276
+ }
277
+ return ze2urResult (ZeResult);
278
+ }
259
279
260
280
UR_ASSERT (Alignment == 0 ||
261
281
reinterpret_cast <std::uintptr_t >(*ResultPtr) % Alignment == 0 ,
@@ -599,6 +619,40 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMGetMemAllocInfo(
599
619
ZE2UR_CALL (zeMemGetAddressRange, (Context->ZeContext , Ptr, nullptr , &Size));
600
620
return ReturnValue (Size);
601
621
}
622
+ case UR_USM_ALLOC_INFO_POOL: {
623
+ auto UMFPool = umfPoolByPtr (Ptr);
624
+ if (!UMFPool) {
625
+ return UR_RESULT_ERROR_INVALID_VALUE;
626
+ }
627
+
628
+ std::shared_lock<ur_shared_mutex> ContextLock (Context->Mutex );
629
+
630
+ auto SearchMatchingPool =
631
+ [](std::unordered_map<ur_device_handle_t , umf::pool_unique_handle_t >
632
+ &PoolMap,
633
+ umf_memory_pool_handle_t UMFPool) {
634
+ for (auto &PoolPair : PoolMap) {
635
+ if (PoolPair.second .get () == UMFPool) {
636
+ return true ;
637
+ }
638
+ }
639
+ return false ;
640
+ };
641
+
642
+ for (auto &Pool : Context->UsmPoolHandles ) {
643
+ if (SearchMatchingPool (Pool->DeviceMemPools , UMFPool)) {
644
+ return ReturnValue (Pool);
645
+ }
646
+ if (SearchMatchingPool (Pool->SharedMemPools , UMFPool)) {
647
+ return ReturnValue (Pool);
648
+ }
649
+ if (Pool->HostMemPool .get () == UMFPool) {
650
+ return ReturnValue (Pool);
651
+ }
652
+ }
653
+
654
+ return UR_RESULT_ERROR_INVALID_VALUE;
655
+ }
602
656
default :
603
657
urPrint (" urUSMGetMemAllocInfo: unsupported ParamName\n " );
604
658
return UR_RESULT_ERROR_INVALID_VALUE;
@@ -748,6 +802,7 @@ ur_result_t L0HostMemoryProvider::allocateImpl(void **ResultPtr, size_t Size,
748
802
ur_usm_pool_handle_t_::ur_usm_pool_handle_t_ (ur_context_handle_t Context,
749
803
ur_usm_pool_desc_t *PoolDesc) {
750
804
805
+ this ->Context = Context;
751
806
zeroInit = static_cast <uint32_t >(PoolDesc->flags &
752
807
UR_USM_POOL_FLAG_ZERO_INITIALIZE_BLOCK);
753
808
@@ -831,6 +886,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMPoolCreate(
831
886
try {
832
887
*Pool = reinterpret_cast <ur_usm_pool_handle_t >(
833
888
new ur_usm_pool_handle_t_ (Context, PoolDesc));
889
+
890
+ std::shared_lock<ur_shared_mutex> ContextLock (Context->Mutex );
891
+ Context->UsmPoolHandles .insert (Context->UsmPoolHandles .cend (), *Pool);
892
+
834
893
} catch (const UsmAllocationException &Ex) {
835
894
return Ex.getError ();
836
895
}
@@ -848,6 +907,8 @@ ur_result_t
848
907
urUSMPoolRelease (ur_usm_pool_handle_t Pool // /< [in] pointer to USM memory pool
849
908
) {
850
909
if (Pool->RefCount .decrementAndTest ()) {
910
+ std::shared_lock<ur_shared_mutex> ContextLock (Pool->Context ->Mutex );
911
+ Pool->Context ->UsmPoolHandles .remove (Pool);
851
912
delete Pool;
852
913
}
853
914
return UR_RESULT_SUCCESS;
@@ -861,13 +922,19 @@ ur_result_t urUSMPoolGetInfo(
861
922
// /< property
862
923
size_t *PropSizeRet // /< [out] size in bytes returned in pool property value
863
924
) {
864
- std::ignore = Pool;
865
- std::ignore = PropName;
866
- std::ignore = PropSize;
867
- std::ignore = PropValue;
868
- std::ignore = PropSizeRet;
869
- urPrint (" [UR][L0] %s function not implemented!\n " , __FUNCTION__);
870
- return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
925
+ UrReturnHelper ReturnValue (PropSize, PropValue, PropSizeRet);
926
+
927
+ switch (PropName) {
928
+ case UR_USM_POOL_INFO_REFERENCE_COUNT: {
929
+ return ReturnValue (Pool->RefCount .load ());
930
+ }
931
+ case UR_USM_POOL_INFO_CONTEXT: {
932
+ return ReturnValue (Pool->Context );
933
+ }
934
+ default : {
935
+ return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION;
936
+ }
937
+ }
871
938
}
872
939
873
940
// If indirect access tracking is not enabled then this functions just performs
0 commit comments