@@ -379,9 +379,6 @@ ur_result_t USMHostMemoryProvider::allocateImpl(void **ResultPtr, size_t Size,
379
379
ur_usm_pool_handle_t_::ur_usm_pool_handle_t_ (ur_context_handle_t Context,
380
380
ur_usm_pool_desc_t *PoolDesc)
381
381
: Context{Context} {
382
- if (PoolDesc->flags & UR_USM_POOL_FLAG_USE_NATIVE_MEMORY_POOL_EXP) {
383
- // TODO: this should only use the host
384
- }
385
382
const void *pNext = PoolDesc->pNext ;
386
383
while (pNext != nullptr ) {
387
384
const ur_base_desc_t *BaseDesc = static_cast <const ur_base_desc_t *>(pNext);
@@ -436,62 +433,13 @@ ur_usm_pool_handle_t_::ur_usm_pool_handle_t_(ur_context_handle_t Context,
436
433
ur_usm_pool_handle_t_::ur_usm_pool_handle_t_ (ur_context_handle_t Context,
437
434
ur_device_handle_t Device,
438
435
ur_usm_pool_desc_t *PoolDesc)
439
- : Context{Context} {
440
- if (PoolDesc->flags & UR_USM_POOL_FLAG_USE_NATIVE_MEMORY_POOL_EXP) {
441
- // TODO: this should only use the host
442
- }
443
- const void *pNext = PoolDesc->pNext ;
444
- while (pNext != nullptr ) {
445
- const ur_base_desc_t *BaseDesc = static_cast <const ur_base_desc_t *>(pNext);
446
- switch (BaseDesc->stype ) {
447
- case UR_STRUCTURE_TYPE_USM_POOL_LIMITS_DESC: {
448
- const ur_usm_pool_limits_desc_t *Limits =
449
- reinterpret_cast <const ur_usm_pool_limits_desc_t *>(BaseDesc);
450
- for (auto &config : DisjointPoolConfigs.Configs ) {
451
- config.MaxPoolableSize = Limits->maxPoolableSize ;
452
- config.SlabMinSize = Limits->minDriverAllocSize ;
453
- }
454
- break ;
455
- }
456
- default : {
457
- throw UsmAllocationException (UR_RESULT_ERROR_INVALID_ARGUMENT);
458
- }
459
- }
460
- pNext = BaseDesc->pNext ;
461
- }
462
-
463
- auto MemProvider =
464
- umf::memoryProviderMakeUnique<USMHostMemoryProvider>(Context, nullptr )
465
- .second ;
466
-
467
- auto UmfHostParamsHandle = getUmfParamsHandle (
468
- DisjointPoolConfigs.Configs [usm::DisjointPoolMemType::Host]);
469
- HostMemPool =
470
- umf::poolMakeUniqueFromOps (umfDisjointPoolOps (), std::move (MemProvider),
471
- UmfHostParamsHandle.get ())
472
- .second ;
436
+ : Context{Context}, Device{Device} {
437
+ if (!(PoolDesc->flags & UR_USM_POOL_FLAG_USE_NATIVE_MEMORY_POOL_EXP))
438
+ throw ;
473
439
474
- for (const auto &Device : Context->getDevices ()) {
475
- MemProvider =
476
- umf::memoryProviderMakeUnique<USMDeviceMemoryProvider>(Context, Device)
477
- .second ;
478
- auto UmfDeviceParamsHandle = getUmfParamsHandle (
479
- DisjointPoolConfigs.Configs [usm::DisjointPoolMemType::Device]);
480
- DeviceMemPool =
481
- umf::poolMakeUniqueFromOps (umfDisjointPoolOps (), std::move (MemProvider),
482
- UmfDeviceParamsHandle.get ())
483
- .second ;
484
- MemProvider =
485
- umf::memoryProviderMakeUnique<USMSharedMemoryProvider>(Context, Device)
486
- .second ;
487
- auto UmfSharedParamsHandle = getUmfParamsHandle (
488
- DisjointPoolConfigs.Configs [usm::DisjointPoolMemType::Shared]);
489
- SharedMemPool =
490
- umf::poolMakeUniqueFromOps (umfDisjointPoolOps (), std::move (MemProvider),
491
- UmfSharedParamsHandle.get ())
492
- .second ;
493
- Context->addPool (this );
494
- }
440
+ // TODO: what flags should be used here. Moreover what flags should have
441
+ // UR counterparts?
442
+ UR_CHECK_ERROR (cuMemPoolCreate (&CUmemPool, 0 ));
495
443
}
496
444
497
445
bool ur_usm_pool_handle_t_::hasUMFPool (umf_memory_pool_t *umf_pool) {
0 commit comments