Skip to content

Commit 0279879

Browse files
authored
[UR][CUDA] Cleanup USM allocations code (#18052)
The `Impl` functions were only being called once and doing something very similar to the regular code so we can fold them in the UR call. Also we can use the `isPowerOf2` utility function for alignment validation.
1 parent df34a2d commit 0279879

File tree

1 file changed

+25
-78
lines changed
  • unified-runtime/source/adapters/cuda

1 file changed

+25
-78
lines changed

unified-runtime/source/adapters/cuda/usm.cpp

Lines changed: 25 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -28,18 +28,17 @@ UR_APIEXPORT ur_result_t UR_APICALL
2828
urUSMHostAlloc(ur_context_handle_t hContext, const ur_usm_desc_t *pUSMDesc,
2929
ur_usm_pool_handle_t hPool, size_t size, void **ppMem) {
3030
auto alignment = pUSMDesc ? pUSMDesc->align : 0u;
31-
UR_ASSERT(!pUSMDesc ||
32-
(alignment == 0 || ((alignment & (alignment - 1)) == 0)),
33-
UR_RESULT_ERROR_INVALID_VALUE);
3431

35-
if (!hPool) {
36-
return USMHostAllocImpl(ppMem, hContext, /* flags */ 0, size, alignment);
32+
auto pool = hPool ? hPool->HostMemPool.get() : hContext->MemoryPoolHost;
33+
if (alignment) {
34+
UR_ASSERT(isPowerOf2(alignment), UR_RESULT_ERROR_INVALID_VALUE);
35+
*ppMem = umfPoolAlignedMalloc(pool, size, alignment);
36+
} else {
37+
*ppMem = umfPoolMalloc(pool, size);
3738
}
3839

39-
auto UMFPool = hPool->HostMemPool.get();
40-
*ppMem = umfPoolAlignedMalloc(UMFPool, size, alignment);
4140
if (*ppMem == nullptr) {
42-
auto umfErr = umfPoolGetLastAllocationError(UMFPool);
41+
auto umfErr = umfPoolGetLastAllocationError(pool);
4342
return umf::umf2urResult(umfErr);
4443
}
4544
return UR_RESULT_SUCCESS;
@@ -48,23 +47,22 @@ urUSMHostAlloc(ur_context_handle_t hContext, const ur_usm_desc_t *pUSMDesc,
4847
/// USM: Implements USM device allocations using a normal CUDA device pointer
4948
///
5049
UR_APIEXPORT ur_result_t UR_APICALL
51-
urUSMDeviceAlloc(ur_context_handle_t hContext, ur_device_handle_t hDevice,
50+
urUSMDeviceAlloc(ur_context_handle_t, ur_device_handle_t hDevice,
5251
const ur_usm_desc_t *pUSMDesc, ur_usm_pool_handle_t hPool,
5352
size_t size, void **ppMem) {
5453
auto alignment = pUSMDesc ? pUSMDesc->align : 0u;
55-
UR_ASSERT(!pUSMDesc ||
56-
(alignment == 0 || ((alignment & (alignment - 1)) == 0)),
57-
UR_RESULT_ERROR_INVALID_VALUE);
5854

59-
if (!hPool) {
60-
return USMDeviceAllocImpl(ppMem, hContext, hDevice, /* flags */ 0, size,
61-
alignment);
55+
ScopedContext SC(hDevice);
56+
auto pool = hPool ? hPool->DeviceMemPool.get() : hDevice->MemoryPoolDevice;
57+
if (alignment) {
58+
UR_ASSERT(isPowerOf2(alignment), UR_RESULT_ERROR_INVALID_VALUE);
59+
*ppMem = umfPoolAlignedMalloc(pool, size, alignment);
60+
} else {
61+
*ppMem = umfPoolMalloc(pool, size);
6262
}
6363

64-
auto UMFPool = hPool->DeviceMemPool.get();
65-
*ppMem = umfPoolAlignedMalloc(UMFPool, size, alignment);
6664
if (*ppMem == nullptr) {
67-
auto umfErr = umfPoolGetLastAllocationError(UMFPool);
65+
auto umfErr = umfPoolGetLastAllocationError(pool);
6866
return umf::umf2urResult(umfErr);
6967
}
7068
return UR_RESULT_SUCCESS;
@@ -73,23 +71,22 @@ urUSMDeviceAlloc(ur_context_handle_t hContext, ur_device_handle_t hDevice,
7371
/// USM: Implements USM Shared allocations using CUDA Managed Memory
7472
///
7573
UR_APIEXPORT ur_result_t UR_APICALL
76-
urUSMSharedAlloc(ur_context_handle_t hContext, ur_device_handle_t hDevice,
74+
urUSMSharedAlloc(ur_context_handle_t, ur_device_handle_t hDevice,
7775
const ur_usm_desc_t *pUSMDesc, ur_usm_pool_handle_t hPool,
7876
size_t size, void **ppMem) {
7977
auto alignment = pUSMDesc ? pUSMDesc->align : 0u;
80-
UR_ASSERT(!pUSMDesc ||
81-
(alignment == 0 || ((alignment & (alignment - 1)) == 0)),
82-
UR_RESULT_ERROR_INVALID_VALUE);
8378

84-
if (!hPool) {
85-
return USMSharedAllocImpl(ppMem, hContext, hDevice, /*host flags*/ 0,
86-
/*device flags*/ 0, size, alignment);
79+
ScopedContext SC(hDevice);
80+
auto pool = hPool ? hPool->SharedMemPool.get() : hDevice->MemoryPoolShared;
81+
if (alignment) {
82+
UR_ASSERT(isPowerOf2(alignment), UR_RESULT_ERROR_INVALID_VALUE);
83+
*ppMem = umfPoolAlignedMalloc(pool, size, alignment);
84+
} else {
85+
*ppMem = umfPoolMalloc(pool, size);
8786
}
8887

89-
auto UMFPool = hPool->SharedMemPool.get();
90-
*ppMem = umfPoolAlignedMalloc(UMFPool, size, alignment);
9188
if (*ppMem == nullptr) {
92-
auto umfErr = umfPoolGetLastAllocationError(UMFPool);
89+
auto umfErr = umfPoolGetLastAllocationError(pool);
9390
return umf::umf2urResult(umfErr);
9491
}
9592
return UR_RESULT_SUCCESS;
@@ -103,56 +100,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMFree(ur_context_handle_t hContext,
103100
return umf::umf2urResult(umfFree(pMem));
104101
}
105102

106-
ur_result_t USMDeviceAllocImpl(void **ResultPtr, ur_context_handle_t,
107-
ur_device_handle_t Device,
108-
ur_usm_device_mem_flags_t, size_t Size,
109-
[[maybe_unused]] uint32_t Alignment) {
110-
try {
111-
ScopedContext Active(Device);
112-
*ResultPtr = umfPoolMalloc(Device->MemoryPoolDevice, Size);
113-
UMF_CHECK_PTR(*ResultPtr);
114-
} catch (ur_result_t Err) {
115-
return Err;
116-
}
117-
118-
assert((Alignment == 0 ||
119-
reinterpret_cast<std::uintptr_t>(*ResultPtr) % Alignment == 0));
120-
return UR_RESULT_SUCCESS;
121-
}
122-
123-
ur_result_t USMSharedAllocImpl(void **ResultPtr, ur_context_handle_t,
124-
ur_device_handle_t Device,
125-
ur_usm_host_mem_flags_t,
126-
ur_usm_device_mem_flags_t, size_t Size,
127-
[[maybe_unused]] uint32_t Alignment) {
128-
try {
129-
ScopedContext Active(Device);
130-
*ResultPtr = umfPoolMalloc(Device->MemoryPoolShared, Size);
131-
UMF_CHECK_PTR(*ResultPtr);
132-
} catch (ur_result_t Err) {
133-
return Err;
134-
}
135-
136-
assert((Alignment == 0 ||
137-
reinterpret_cast<std::uintptr_t>(*ResultPtr) % Alignment == 0));
138-
return UR_RESULT_SUCCESS;
139-
}
140-
141-
ur_result_t USMHostAllocImpl(void **ResultPtr, ur_context_handle_t hContext,
142-
ur_usm_host_mem_flags_t, size_t Size,
143-
[[maybe_unused]] uint32_t Alignment) {
144-
try {
145-
*ResultPtr = umfPoolMalloc(hContext->MemoryPoolHost, Size);
146-
UMF_CHECK_PTR(*ResultPtr);
147-
} catch (ur_result_t Err) {
148-
return Err;
149-
}
150-
151-
assert((Alignment == 0 ||
152-
reinterpret_cast<std::uintptr_t>(*ResultPtr) % Alignment == 0));
153-
return UR_RESULT_SUCCESS;
154-
}
155-
156103
UR_APIEXPORT ur_result_t UR_APICALL
157104
urUSMGetMemAllocInfo(ur_context_handle_t hContext, const void *pMem,
158105
ur_usm_alloc_info_t propName, size_t propValueSize,

0 commit comments

Comments
 (0)