Skip to content

Commit 4fbfbc4

Browse files
committed
Fix compilation
1 parent 7cc5b39 commit 4fbfbc4

File tree

2 files changed

+11
-58
lines changed

2 files changed

+11
-58
lines changed

source/adapters/cuda/usm.cpp

Lines changed: 6 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -379,9 +379,6 @@ ur_result_t USMHostMemoryProvider::allocateImpl(void **ResultPtr, size_t Size,
379379
ur_usm_pool_handle_t_::ur_usm_pool_handle_t_(ur_context_handle_t Context,
380380
ur_usm_pool_desc_t *PoolDesc)
381381
: Context{Context} {
382-
if (PoolDesc->flags & UR_USM_POOL_FLAG_USE_NATIVE_MEMORY_POOL_EXP) {
383-
// TODO: this should only use the host
384-
}
385382
const void *pNext = PoolDesc->pNext;
386383
while (pNext != nullptr) {
387384
const ur_base_desc_t *BaseDesc = static_cast<const ur_base_desc_t *>(pNext);
@@ -436,62 +433,13 @@ ur_usm_pool_handle_t_::ur_usm_pool_handle_t_(ur_context_handle_t Context,
436433
ur_usm_pool_handle_t_::ur_usm_pool_handle_t_(ur_context_handle_t Context,
437434
ur_device_handle_t Device,
438435
ur_usm_pool_desc_t *PoolDesc)
439-
: Context{Context} {
440-
if (PoolDesc->flags & UR_USM_POOL_FLAG_USE_NATIVE_MEMORY_POOL_EXP) {
441-
// TODO: this should only use the host
442-
}
443-
const void *pNext = PoolDesc->pNext;
444-
while (pNext != nullptr) {
445-
const ur_base_desc_t *BaseDesc = static_cast<const ur_base_desc_t *>(pNext);
446-
switch (BaseDesc->stype) {
447-
case UR_STRUCTURE_TYPE_USM_POOL_LIMITS_DESC: {
448-
const ur_usm_pool_limits_desc_t *Limits =
449-
reinterpret_cast<const ur_usm_pool_limits_desc_t *>(BaseDesc);
450-
for (auto &config : DisjointPoolConfigs.Configs) {
451-
config.MaxPoolableSize = Limits->maxPoolableSize;
452-
config.SlabMinSize = Limits->minDriverAllocSize;
453-
}
454-
break;
455-
}
456-
default: {
457-
throw UsmAllocationException(UR_RESULT_ERROR_INVALID_ARGUMENT);
458-
}
459-
}
460-
pNext = BaseDesc->pNext;
461-
}
462-
463-
auto MemProvider =
464-
umf::memoryProviderMakeUnique<USMHostMemoryProvider>(Context, nullptr)
465-
.second;
466-
467-
auto UmfHostParamsHandle = getUmfParamsHandle(
468-
DisjointPoolConfigs.Configs[usm::DisjointPoolMemType::Host]);
469-
HostMemPool =
470-
umf::poolMakeUniqueFromOps(umfDisjointPoolOps(), std::move(MemProvider),
471-
UmfHostParamsHandle.get())
472-
.second;
436+
: Context{Context}, Device{Device} {
437+
if (!(PoolDesc->flags & UR_USM_POOL_FLAG_USE_NATIVE_MEMORY_POOL_EXP))
438+
throw;
473439

474-
for (const auto &Device : Context->getDevices()) {
475-
MemProvider =
476-
umf::memoryProviderMakeUnique<USMDeviceMemoryProvider>(Context, Device)
477-
.second;
478-
auto UmfDeviceParamsHandle = getUmfParamsHandle(
479-
DisjointPoolConfigs.Configs[usm::DisjointPoolMemType::Device]);
480-
DeviceMemPool =
481-
umf::poolMakeUniqueFromOps(umfDisjointPoolOps(), std::move(MemProvider),
482-
UmfDeviceParamsHandle.get())
483-
.second;
484-
MemProvider =
485-
umf::memoryProviderMakeUnique<USMSharedMemoryProvider>(Context, Device)
486-
.second;
487-
auto UmfSharedParamsHandle = getUmfParamsHandle(
488-
DisjointPoolConfigs.Configs[usm::DisjointPoolMemType::Shared]);
489-
SharedMemPool =
490-
umf::poolMakeUniqueFromOps(umfDisjointPoolOps(), std::move(MemProvider),
491-
UmfSharedParamsHandle.get())
492-
.second;
493-
Context->addPool(this);
494-
}
440+
// TODO: what flags should be used here. Moreover what flags should have
441+
// UR counterparts?
442+
UR_CHECK_ERROR(cuMemPoolCreate(&CUmemPool, 0));
495443
}
496444

497445
bool ur_usm_pool_handle_t_::hasUMFPool(umf_memory_pool_t *umf_pool) {

source/adapters/cuda/usm.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ struct ur_usm_pool_handle_t_ {
2121
std::atomic_uint32_t RefCount = 1;
2222

2323
ur_context_handle_t Context = nullptr;
24+
ur_device_handle_t Device = nullptr;
2425

2526
usm::DisjointPoolAllConfigs DisjointPoolConfigs =
2627
usm::DisjointPoolAllConfigs();
@@ -34,6 +35,10 @@ struct ur_usm_pool_handle_t_ {
3435
ur_usm_pool_handle_t_(ur_context_handle_t Context,
3536
ur_usm_pool_desc_t *PoolDesc);
3637

38+
// TODO: do we need the context param?
39+
ur_usm_pool_handle_t_(ur_context_handle_t Context, ur_device_handle_t Device,
40+
ur_usm_pool_desc_t *PoolDesc);
41+
3742
uint32_t incrementReferenceCount() noexcept { return ++RefCount; }
3843

3944
uint32_t decrementReferenceCount() noexcept { return --RefCount; }

0 commit comments

Comments
 (0)