@@ -28,18 +28,17 @@ UR_APIEXPORT ur_result_t UR_APICALL
28
28
urUSMHostAlloc (ur_context_handle_t hContext, const ur_usm_desc_t *pUSMDesc,
29
29
ur_usm_pool_handle_t hPool, size_t size, void **ppMem) {
30
30
auto alignment = pUSMDesc ? pUSMDesc->align : 0u ;
31
- UR_ASSERT (!pUSMDesc ||
32
- (alignment == 0 || ((alignment & (alignment - 1 )) == 0 )),
33
- UR_RESULT_ERROR_INVALID_VALUE);
34
31
35
- if (!hPool) {
36
- return USMHostAllocImpl (ppMem, hContext, /* flags */ 0 , size, alignment);
32
+ auto pool = hPool ? hPool->HostMemPool .get () : hContext->MemoryPoolHost ;
33
+ if (alignment) {
34
+ UR_ASSERT (isPowerOf2 (alignment), UR_RESULT_ERROR_INVALID_VALUE);
35
+ *ppMem = umfPoolAlignedMalloc (pool, size, alignment);
36
+ } else {
37
+ *ppMem = umfPoolMalloc (pool, size);
37
38
}
38
39
39
- auto UMFPool = hPool->HostMemPool .get ();
40
- *ppMem = umfPoolAlignedMalloc (UMFPool, size, alignment);
41
40
if (*ppMem == nullptr ) {
42
- auto umfErr = umfPoolGetLastAllocationError (UMFPool );
41
+ auto umfErr = umfPoolGetLastAllocationError (pool );
43
42
return umf::umf2urResult (umfErr);
44
43
}
45
44
return UR_RESULT_SUCCESS;
@@ -48,23 +47,22 @@ urUSMHostAlloc(ur_context_handle_t hContext, const ur_usm_desc_t *pUSMDesc,
48
47
// / USM: Implements USM device allocations using a normal CUDA device pointer
49
48
// /
50
49
UR_APIEXPORT ur_result_t UR_APICALL
51
- urUSMDeviceAlloc (ur_context_handle_t hContext , ur_device_handle_t hDevice,
50
+ urUSMDeviceAlloc (ur_context_handle_t , ur_device_handle_t hDevice,
52
51
const ur_usm_desc_t *pUSMDesc, ur_usm_pool_handle_t hPool,
53
52
size_t size, void **ppMem) {
54
53
auto alignment = pUSMDesc ? pUSMDesc->align : 0u ;
55
- UR_ASSERT (!pUSMDesc ||
56
- (alignment == 0 || ((alignment & (alignment - 1 )) == 0 )),
57
- UR_RESULT_ERROR_INVALID_VALUE);
58
54
59
- if (!hPool) {
60
- return USMDeviceAllocImpl (ppMem, hContext, hDevice, /* flags */ 0 , size,
61
- alignment);
55
+ ScopedContext SC (hDevice);
56
+ auto pool = hPool ? hPool->DeviceMemPool .get () : hDevice->MemoryPoolDevice ;
57
+ if (alignment) {
58
+ UR_ASSERT (isPowerOf2 (alignment), UR_RESULT_ERROR_INVALID_VALUE);
59
+ *ppMem = umfPoolAlignedMalloc (pool, size, alignment);
60
+ } else {
61
+ *ppMem = umfPoolMalloc (pool, size);
62
62
}
63
63
64
- auto UMFPool = hPool->DeviceMemPool .get ();
65
- *ppMem = umfPoolAlignedMalloc (UMFPool, size, alignment);
66
64
if (*ppMem == nullptr ) {
67
- auto umfErr = umfPoolGetLastAllocationError (UMFPool );
65
+ auto umfErr = umfPoolGetLastAllocationError (pool );
68
66
return umf::umf2urResult (umfErr);
69
67
}
70
68
return UR_RESULT_SUCCESS;
@@ -73,23 +71,22 @@ urUSMDeviceAlloc(ur_context_handle_t hContext, ur_device_handle_t hDevice,
73
71
// / USM: Implements USM Shared allocations using CUDA Managed Memory
74
72
// /
75
73
UR_APIEXPORT ur_result_t UR_APICALL
76
- urUSMSharedAlloc (ur_context_handle_t hContext , ur_device_handle_t hDevice,
74
+ urUSMSharedAlloc (ur_context_handle_t , ur_device_handle_t hDevice,
77
75
const ur_usm_desc_t *pUSMDesc, ur_usm_pool_handle_t hPool,
78
76
size_t size, void **ppMem) {
79
77
auto alignment = pUSMDesc ? pUSMDesc->align : 0u ;
80
- UR_ASSERT (!pUSMDesc ||
81
- (alignment == 0 || ((alignment & (alignment - 1 )) == 0 )),
82
- UR_RESULT_ERROR_INVALID_VALUE);
83
78
84
- if (!hPool) {
85
- return USMSharedAllocImpl (ppMem, hContext, hDevice, /* host flags*/ 0 ,
86
- /* device flags*/ 0 , size, alignment);
79
+ ScopedContext SC (hDevice);
80
+ auto pool = hPool ? hPool->SharedMemPool .get () : hDevice->MemoryPoolShared ;
81
+ if (alignment) {
82
+ UR_ASSERT (isPowerOf2 (alignment), UR_RESULT_ERROR_INVALID_VALUE);
83
+ *ppMem = umfPoolAlignedMalloc (pool, size, alignment);
84
+ } else {
85
+ *ppMem = umfPoolMalloc (pool, size);
87
86
}
88
87
89
- auto UMFPool = hPool->SharedMemPool .get ();
90
- *ppMem = umfPoolAlignedMalloc (UMFPool, size, alignment);
91
88
if (*ppMem == nullptr ) {
92
- auto umfErr = umfPoolGetLastAllocationError (UMFPool );
89
+ auto umfErr = umfPoolGetLastAllocationError (pool );
93
90
return umf::umf2urResult (umfErr);
94
91
}
95
92
return UR_RESULT_SUCCESS;
@@ -103,56 +100,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMFree(ur_context_handle_t hContext,
103
100
return umf::umf2urResult (umfFree (pMem));
104
101
}
105
102
106
- ur_result_t USMDeviceAllocImpl (void **ResultPtr, ur_context_handle_t ,
107
- ur_device_handle_t Device,
108
- ur_usm_device_mem_flags_t , size_t Size,
109
- [[maybe_unused]] uint32_t Alignment) {
110
- try {
111
- ScopedContext Active (Device);
112
- *ResultPtr = umfPoolMalloc (Device->MemoryPoolDevice , Size);
113
- UMF_CHECK_PTR (*ResultPtr);
114
- } catch (ur_result_t Err) {
115
- return Err;
116
- }
117
-
118
- assert ((Alignment == 0 ||
119
- reinterpret_cast <std::uintptr_t >(*ResultPtr) % Alignment == 0 ));
120
- return UR_RESULT_SUCCESS;
121
- }
122
-
123
- ur_result_t USMSharedAllocImpl (void **ResultPtr, ur_context_handle_t ,
124
- ur_device_handle_t Device,
125
- ur_usm_host_mem_flags_t ,
126
- ur_usm_device_mem_flags_t , size_t Size,
127
- [[maybe_unused]] uint32_t Alignment) {
128
- try {
129
- ScopedContext Active (Device);
130
- *ResultPtr = umfPoolMalloc (Device->MemoryPoolShared , Size);
131
- UMF_CHECK_PTR (*ResultPtr);
132
- } catch (ur_result_t Err) {
133
- return Err;
134
- }
135
-
136
- assert ((Alignment == 0 ||
137
- reinterpret_cast <std::uintptr_t >(*ResultPtr) % Alignment == 0 ));
138
- return UR_RESULT_SUCCESS;
139
- }
140
-
141
- ur_result_t USMHostAllocImpl (void **ResultPtr, ur_context_handle_t hContext,
142
- ur_usm_host_mem_flags_t , size_t Size,
143
- [[maybe_unused]] uint32_t Alignment) {
144
- try {
145
- *ResultPtr = umfPoolMalloc (hContext->MemoryPoolHost , Size);
146
- UMF_CHECK_PTR (*ResultPtr);
147
- } catch (ur_result_t Err) {
148
- return Err;
149
- }
150
-
151
- assert ((Alignment == 0 ||
152
- reinterpret_cast <std::uintptr_t >(*ResultPtr) % Alignment == 0 ));
153
- return UR_RESULT_SUCCESS;
154
- }
155
-
156
103
UR_APIEXPORT ur_result_t UR_APICALL
157
104
urUSMGetMemAllocInfo (ur_context_handle_t hContext, const void *pMem,
158
105
ur_usm_alloc_info_t propName, size_t propValueSize,
0 commit comments