From 01b85acc521a0ae59ea72fdc43b9db332c9c8460 Mon Sep 17 00:00:00 2001 From: Martin Morrison-Grant Date: Thu, 3 Jul 2025 10:22:21 +0100 Subject: [PATCH] Refactor reference counting in UR CUDA adapter using new ur::RefCount class. --- unified-runtime/source/adapters/cuda/adapter.cpp | 8 ++++---- unified-runtime/source/adapters/cuda/adapter.hpp | 3 ++- .../source/adapters/cuda/command_buffer.cpp | 6 +++--- .../source/adapters/cuda/command_buffer.hpp | 10 +++------- unified-runtime/source/adapters/cuda/context.cpp | 8 ++++---- unified-runtime/source/adapters/cuda/context.hpp | 9 ++------- unified-runtime/source/adapters/cuda/device.cpp | 2 +- unified-runtime/source/adapters/cuda/device.hpp | 8 ++++---- unified-runtime/source/adapters/cuda/event.cpp | 8 ++++---- unified-runtime/source/adapters/cuda/event.hpp | 8 ++------ unified-runtime/source/adapters/cuda/kernel.cpp | 10 +++++----- unified-runtime/source/adapters/cuda/kernel.hpp | 9 ++------- unified-runtime/source/adapters/cuda/memory.cpp | 8 ++++---- unified-runtime/source/adapters/cuda/memory.hpp | 10 ++-------- .../source/adapters/cuda/physical_mem.cpp | 6 +++--- .../source/adapters/cuda/physical_mem.hpp | 13 ++++--------- unified-runtime/source/adapters/cuda/program.cpp | 10 +++++----- unified-runtime/source/adapters/cuda/program.hpp | 9 ++------- unified-runtime/source/adapters/cuda/queue.cpp | 8 ++++---- unified-runtime/source/adapters/cuda/sampler.cpp | 8 ++++---- unified-runtime/source/adapters/cuda/sampler.hpp | 9 ++------- unified-runtime/source/adapters/cuda/usm.cpp | 6 +++--- unified-runtime/source/adapters/cuda/usm.hpp | 9 ++------- .../source/common/cuda-hip/stream_queue.hpp | 10 ++++++---- 24 files changed, 77 insertions(+), 118 deletions(-) diff --git a/unified-runtime/source/adapters/cuda/adapter.cpp b/unified-runtime/source/adapters/cuda/adapter.cpp index dca627c87fc19..35944e6b03682 100644 --- a/unified-runtime/source/adapters/cuda/adapter.cpp +++ b/unified-runtime/source/adapters/cuda/adapter.cpp @@ -66,7 +66,7 @@ urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters, std::call_once(InitFlag, [=]() { ur::cuda::adapter = new ur_adapter_handle_t_; }); - ur::cuda::adapter->RefCount++; + ur::cuda::adapter->RefCount.retain(); *phAdapters = ur::cuda::adapter; } @@ -78,13 +78,13 @@ urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters, } UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) { - ur::cuda::adapter->RefCount++; + ur::cuda::adapter->RefCount.retain(); return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) { - if (--ur::cuda::adapter->RefCount == 0) { + if (ur::cuda::adapter->RefCount.release()) { delete ur::cuda::adapter; } return UR_RESULT_SUCCESS; @@ -108,7 +108,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t, case UR_ADAPTER_INFO_BACKEND: return ReturnValue(UR_BACKEND_CUDA); case UR_ADAPTER_INFO_REFERENCE_COUNT: - return ReturnValue(ur::cuda::adapter->RefCount.load()); + return ReturnValue(ur::cuda::adapter->RefCount.getCount()); case UR_ADAPTER_INFO_VERSION: return ReturnValue(uint32_t{1}); default: diff --git a/unified-runtime/source/adapters/cuda/adapter.hpp b/unified-runtime/source/adapters/cuda/adapter.hpp index 6ec9007bceacf..5208c3dea216c 100644 --- a/unified-runtime/source/adapters/cuda/adapter.hpp +++ b/unified-runtime/source/adapters/cuda/adapter.hpp @@ -11,6 +11,7 @@ #ifndef UR_CUDA_ADAPTER_HPP_INCLUDED #define UR_CUDA_ADAPTER_HPP_INCLUDED +#include "common/ur_ref_count.hpp" #include "logger/ur_logger.hpp" #include "platform.hpp" #include "tracing.hpp" @@ -20,7 +21,7 @@ #include struct ur_adapter_handle_t_ : ur::cuda::handle_base { - std::atomic RefCount = 0; + ur::RefCount RefCount; struct cuda_tracing_context_t_ *TracingCtx = nullptr; logger::Logger &logger; std::unique_ptr Platform; diff --git a/unified-runtime/source/adapters/cuda/command_buffer.cpp b/unified-runtime/source/adapters/cuda/command_buffer.cpp index b8567d02cd00f..52afeafd76e79 100644 --- a/unified-runtime/source/adapters/cuda/command_buffer.cpp +++ b/unified-runtime/source/adapters/cuda/command_buffer.cpp @@ -382,13 +382,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferCreateExp( UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferRetainExp(ur_exp_command_buffer_handle_t hCommandBuffer) { - hCommandBuffer->incrementReferenceCount(); + hCommandBuffer->RefCount.retain(); return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferReleaseExp(ur_exp_command_buffer_handle_t hCommandBuffer) { - if (hCommandBuffer->decrementReferenceCount() == 0) { + if (hCommandBuffer->RefCount.release()) { // Ref count has reached zero, release of created commands for (auto &Command : hCommandBuffer->CommandHandles) { commandHandleDestroy(Command); @@ -1478,7 +1478,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferGetInfoExp( switch (propName) { case UR_EXP_COMMAND_BUFFER_INFO_REFERENCE_COUNT: - return ReturnValue(hCommandBuffer->getReferenceCount()); + return ReturnValue(hCommandBuffer->RefCount.getCount()); case UR_EXP_COMMAND_BUFFER_INFO_DESCRIPTOR: { ur_exp_command_buffer_desc_t Descriptor{}; Descriptor.stype = UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_DESC; diff --git a/unified-runtime/source/adapters/cuda/command_buffer.hpp b/unified-runtime/source/adapters/cuda/command_buffer.hpp index e11b9ab74969a..49eb1eb48688a 100644 --- a/unified-runtime/source/adapters/cuda/command_buffer.hpp +++ b/unified-runtime/source/adapters/cuda/command_buffer.hpp @@ -12,6 +12,7 @@ #include #include +#include "common/ur_ref_count.hpp" #include "context.hpp" #include "logger/ur_logger.hpp" #include @@ -173,10 +174,6 @@ struct ur_exp_command_buffer_handle_t_ : ur::cuda::handle_base { return SyncPoint; } - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } - uint32_t decrementReferenceCount() noexcept { return --RefCount; } - uint32_t getReferenceCount() const noexcept { return RefCount; } - // UR context associated with this command-buffer ur_context_handle_t Context; // Device associated with this command-buffer @@ -189,9 +186,8 @@ struct ur_exp_command_buffer_handle_t_ : ur::cuda::handle_base { CUgraph CudaGraph; // Cuda Graph Exec handle CUgraphExec CudaGraphExec = nullptr; - // Atomic variable counting the number of reference to this command_buffer - // using std::atomic prevents data race when incrementing/decrementing. - std::atomic_uint32_t RefCount; + + ur::RefCount RefCount; // Ordered map of sync_points to ur_events, so that we can find the last // node added to an in-order command-buffer. diff --git a/unified-runtime/source/adapters/cuda/context.cpp b/unified-runtime/source/adapters/cuda/context.cpp index 074bbeb440b2e..e68adbe524c3a 100644 --- a/unified-runtime/source/adapters/cuda/context.cpp +++ b/unified-runtime/source/adapters/cuda/context.cpp @@ -66,7 +66,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextGetInfo( return ReturnValue(hContext->getDevices().data(), hContext->getDevices().size()); case UR_CONTEXT_INFO_REFERENCE_COUNT: - return ReturnValue(hContext->getReferenceCount()); + return ReturnValue(hContext->RefCount.getCount()); case UR_CONTEXT_INFO_USM_MEMCPY2D_SUPPORT: // 2D USM memcpy is supported. return ReturnValue(true); @@ -83,7 +83,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextGetInfo( UR_APIEXPORT ur_result_t UR_APICALL urContextRelease(ur_context_handle_t hContext) { - if (hContext->decrementReferenceCount() > 0) { + if (!hContext->RefCount.release()) { return UR_RESULT_SUCCESS; } hContext->invokeExtendedDeleters(); @@ -94,9 +94,9 @@ urContextRelease(ur_context_handle_t hContext) { UR_APIEXPORT ur_result_t UR_APICALL urContextRetain(ur_context_handle_t hContext) { - assert(hContext->getReferenceCount() > 0); + assert(hContext->RefCount.getCount() > 0); - hContext->incrementReferenceCount(); + hContext->RefCount.retain(); return UR_RESULT_SUCCESS; } diff --git a/unified-runtime/source/adapters/cuda/context.hpp b/unified-runtime/source/adapters/cuda/context.hpp index 1a24d163c9a50..fe1d7e5558a8c 100644 --- a/unified-runtime/source/adapters/cuda/context.hpp +++ b/unified-runtime/source/adapters/cuda/context.hpp @@ -20,6 +20,7 @@ #include "adapter.hpp" #include "common.hpp" +#include "common/ur_ref_count.hpp" #include "device.hpp" #include "umf_helpers.hpp" @@ -88,7 +89,7 @@ struct ur_context_handle_t_ : ur::cuda::handle_base { }; std::vector Devices; - std::atomic_uint32_t RefCount; + ur::RefCount RefCount; // UMF CUDA memory provider and pool for the host memory // (UMF_MEMORY_TYPE_HOST) @@ -140,12 +141,6 @@ struct ur_context_handle_t_ : ur::cuda::handle_base { return std::distance(Devices.begin(), It); } - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } - - uint32_t decrementReferenceCount() noexcept { return --RefCount; } - - uint32_t getReferenceCount() const noexcept { return RefCount; } - void addPool(ur_usm_pool_handle_t Pool); void removePool(ur_usm_pool_handle_t Pool); diff --git a/unified-runtime/source/adapters/cuda/device.cpp b/unified-runtime/source/adapters/cuda/device.cpp index 6f3f450877412..97b2430748d3f 100644 --- a/unified-runtime/source/adapters/cuda/device.cpp +++ b/unified-runtime/source/adapters/cuda/device.cpp @@ -593,7 +593,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice, return ReturnValue("CUDA"); } case UR_DEVICE_INFO_REFERENCE_COUNT: { - return ReturnValue(hDevice->getReferenceCount()); + return ReturnValue(hDevice->RefCount.getCount()); } case UR_DEVICE_INFO_VERSION: { std::stringstream SS; diff --git a/unified-runtime/source/adapters/cuda/device.hpp b/unified-runtime/source/adapters/cuda/device.hpp index 3a28d54b17d21..833c05e141b17 100644 --- a/unified-runtime/source/adapters/cuda/device.hpp +++ b/unified-runtime/source/adapters/cuda/device.hpp @@ -15,6 +15,7 @@ #include #include "common.hpp" +#include "common/ur_ref_count.hpp" struct ur_device_handle_t_ : ur::cuda::handle_base { private: @@ -23,7 +24,6 @@ struct ur_device_handle_t_ : ur::cuda::handle_base { native_type CuDevice; CUcontext CuContext; CUevent EvBase; // CUDA event used as base counter - std::atomic_uint32_t RefCount; ur_platform_handle_t Platform; uint32_t DeviceIndex; @@ -42,7 +42,7 @@ struct ur_device_handle_t_ : ur::cuda::handle_base { ur_device_handle_t_(native_type cuDevice, CUcontext cuContext, CUevent evBase, ur_platform_handle_t platform, uint32_t DevIndex) : handle_base(), CuDevice(cuDevice), CuContext(cuContext), EvBase(evBase), - RefCount{1}, Platform(platform), DeviceIndex{DevIndex} { + Platform(platform), DeviceIndex{DevIndex} { UR_CHECK_ERROR(cuDeviceGetAttribute( &MaxRegsPerBlock, CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK, cuDevice)); @@ -136,8 +136,6 @@ struct ur_device_handle_t_ : ur::cuda::handle_base { CUcontext getNativeContext() const noexcept { return CuContext; }; - uint32_t getReferenceCount() const noexcept { return RefCount; } - ur_platform_handle_t getPlatform() const noexcept { return Platform; }; // Returns the index of the device relative to the other devices in the same @@ -178,6 +176,8 @@ struct ur_device_handle_t_ : ur::cuda::handle_base { // (UMF_MEMORY_TYPE_SHARED) umf_memory_provider_handle_t MemoryProviderShared; umf_memory_pool_handle_t MemoryPoolShared; + + ur::RefCount RefCount; }; int getAttribute(ur_device_handle_t Device, CUdevice_attribute Attribute); diff --git a/unified-runtime/source/adapters/cuda/event.cpp b/unified-runtime/source/adapters/cuda/event.cpp index f9343a6b6f751..9ad85b99e6356 100644 --- a/unified-runtime/source/adapters/cuda/event.cpp +++ b/unified-runtime/source/adapters/cuda/event.cpp @@ -179,7 +179,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventGetInfo(ur_event_handle_t hEvent, case UR_EVENT_INFO_COMMAND_TYPE: return ReturnValue(hEvent->getCommandType()); case UR_EVENT_INFO_REFERENCE_COUNT: - return ReturnValue(hEvent->getReferenceCount()); + return ReturnValue(hEvent->RefCount.getCount()); case UR_EVENT_INFO_COMMAND_EXECUTION_STATUS: return ReturnValue(hEvent->getExecutionStatus()); case UR_EVENT_INFO_CONTEXT: @@ -248,7 +248,7 @@ urEventWait(uint32_t numEvents, const ur_event_handle_t *phEventWaitList) { } UR_APIEXPORT ur_result_t UR_APICALL urEventRetain(ur_event_handle_t hEvent) { - const auto RefCount = hEvent->incrementReferenceCount(); + const auto RefCount = hEvent->RefCount.retain(); if (RefCount == 0) { return UR_RESULT_ERROR_OUT_OF_RESOURCES; @@ -260,12 +260,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventRetain(ur_event_handle_t hEvent) { UR_APIEXPORT ur_result_t UR_APICALL urEventRelease(ur_event_handle_t hEvent) { // double delete or someone is messing with the ref count. // either way, cannot safely proceed. - if (hEvent->getReferenceCount() == 0) { + if (hEvent->RefCount.getCount() == 0) { return UR_RESULT_ERROR_INVALID_EVENT; } // decrement ref count. If it is 0, delete the event. - if (hEvent->decrementReferenceCount() == 0) { + if (hEvent->release()) { std::unique_ptr event_ptr{hEvent}; ur_result_t Result = UR_RESULT_ERROR_INVALID_EVENT; try { diff --git a/unified-runtime/source/adapters/cuda/event.hpp b/unified-runtime/source/adapters/cuda/event.hpp index 92f74349f9b3e..f63a90d471b25 100644 --- a/unified-runtime/source/adapters/cuda/event.hpp +++ b/unified-runtime/source/adapters/cuda/event.hpp @@ -13,6 +13,7 @@ #include #include "common.hpp" +#include "common/ur_ref_count.hpp" #include "queue.hpp" /// UR Event mapping to CUevent @@ -60,16 +61,11 @@ struct ur_event_handle_t_ : ur::cuda::handle_base { ur_context_handle_t getContext() const noexcept { return Context; }; uint32_t getEventID() const noexcept { return EventID; } - // Reference counting. - uint32_t getReferenceCount() const noexcept { return RefCount; } - uint32_t incrementReferenceCount() { return ++RefCount; } - uint32_t decrementReferenceCount() { return --RefCount; } + ur::RefCount RefCount; private: ur_command_t CommandType; // The type of command associated with event. - std::atomic_uint32_t RefCount{1}; // Event reference count. - bool HasOwnership{true}; // Signifies if event owns the native type. bool HasProfiling{false}; // Signifies if event has profiling information. diff --git a/unified-runtime/source/adapters/cuda/kernel.cpp b/unified-runtime/source/adapters/cuda/kernel.cpp index f296c74611462..78d2de1f90398 100644 --- a/unified-runtime/source/adapters/cuda/kernel.cpp +++ b/unified-runtime/source/adapters/cuda/kernel.cpp @@ -127,9 +127,9 @@ urKernelGetGroupInfo(ur_kernel_handle_t hKernel, ur_device_handle_t hDevice, } UR_APIEXPORT ur_result_t UR_APICALL urKernelRetain(ur_kernel_handle_t hKernel) { - UR_ASSERT(hKernel->getReferenceCount() > 0u, UR_RESULT_ERROR_INVALID_KERNEL); + UR_ASSERT(hKernel->RefCount.getCount() > 0u, UR_RESULT_ERROR_INVALID_KERNEL); - hKernel->incrementReferenceCount(); + hKernel->RefCount.retain(); return UR_RESULT_SUCCESS; } @@ -137,10 +137,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelRelease(ur_kernel_handle_t hKernel) { // double delete or someone is messing with the ref count. // either way, cannot safely proceed. - UR_ASSERT(hKernel->getReferenceCount() != 0, UR_RESULT_ERROR_INVALID_KERNEL); + UR_ASSERT(hKernel->RefCount.getCount() != 0, UR_RESULT_ERROR_INVALID_KERNEL); // decrement ref count. If it is 0, delete the program. - if (hKernel->decrementReferenceCount() == 0) { + if (hKernel->RefCount.release()) { // no internal cuda resources to clean up. Just delete it. delete hKernel; return UR_RESULT_SUCCESS; @@ -248,7 +248,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetInfo(ur_kernel_handle_t hKernel, case UR_KERNEL_INFO_NUM_ARGS: return ReturnValue(hKernel->getNumArgs()); case UR_KERNEL_INFO_REFERENCE_COUNT: - return ReturnValue(hKernel->getReferenceCount()); + return ReturnValue(hKernel->RefCount.getCount()); case UR_KERNEL_INFO_CONTEXT: return ReturnValue(hKernel->getContext()); case UR_KERNEL_INFO_PROGRAM: diff --git a/unified-runtime/source/adapters/cuda/kernel.hpp b/unified-runtime/source/adapters/cuda/kernel.hpp index 6898527e8df30..8035a5260604f 100644 --- a/unified-runtime/source/adapters/cuda/kernel.hpp +++ b/unified-runtime/source/adapters/cuda/kernel.hpp @@ -17,6 +17,7 @@ #include #include +#include "common/ur_ref_count.hpp" #include "program.hpp" /// Implementation of a UR Kernel for CUDA @@ -42,7 +43,7 @@ struct ur_kernel_handle_t_ : ur::cuda::handle_base { std::string Name; ur_context_handle_t Context; ur_program_handle_t Program; - std::atomic_uint32_t RefCount; + ur::RefCount RefCount; static constexpr uint32_t ReqdThreadsPerBlockDimensions = 3u; size_t ReqdThreadsPerBlock[ReqdThreadsPerBlockDimensions]; @@ -304,12 +305,6 @@ struct ur_kernel_handle_t_ : ur::cuda::handle_base { urContextRelease(Context); } - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } - - uint32_t decrementReferenceCount() noexcept { return --RefCount; } - - uint32_t getReferenceCount() const noexcept { return RefCount; } - native_type get() const noexcept { return Function; }; ur_program_handle_t getProgram() const noexcept { return Program; }; diff --git a/unified-runtime/source/adapters/cuda/memory.cpp b/unified-runtime/source/adapters/cuda/memory.cpp index d673ad06c09b9..519a6622f5ec7 100644 --- a/unified-runtime/source/adapters/cuda/memory.cpp +++ b/unified-runtime/source/adapters/cuda/memory.cpp @@ -78,8 +78,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferCreate( } UR_APIEXPORT ur_result_t UR_APICALL urMemRetain(ur_mem_handle_t hMem) { - UR_ASSERT(hMem->getReferenceCount() > 0, UR_RESULT_ERROR_INVALID_MEM_OBJECT); - hMem->incrementReferenceCount(); + UR_ASSERT(hMem->RefCount.getCount() > 0, UR_RESULT_ERROR_INVALID_MEM_OBJECT); + hMem->RefCount.retain(); return UR_RESULT_SUCCESS; } @@ -89,7 +89,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemRetain(ur_mem_handle_t hMem) { UR_APIEXPORT ur_result_t UR_APICALL urMemRelease(ur_mem_handle_t hMem) { try { // Do nothing if there are other references - if (hMem->decrementReferenceCount() > 0) { + if (!hMem->RefCount.release()) { return UR_RESULT_SUCCESS; } @@ -162,7 +162,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemGetInfo(ur_mem_handle_t hMemory, return ReturnValue(hMemory->getContext()); } case UR_MEM_INFO_REFERENCE_COUNT: { - return ReturnValue(hMemory->getReferenceCount()); + return ReturnValue(hMemory->RefCount.getCount()); } default: diff --git a/unified-runtime/source/adapters/cuda/memory.hpp b/unified-runtime/source/adapters/cuda/memory.hpp index 92aeb5878b952..1f427dd896bfa 100644 --- a/unified-runtime/source/adapters/cuda/memory.hpp +++ b/unified-runtime/source/adapters/cuda/memory.hpp @@ -16,6 +16,7 @@ #include #include "common.hpp" +#include "common/ur_ref_count.hpp" #include "context.hpp" #include "queue.hpp" @@ -314,8 +315,7 @@ struct ur_mem_handle_t_ : ur::cuda::handle_base { // Context where the memory object is accessible ur_context_handle_t Context; - /// Reference counting of the handler - std::atomic_uint32_t RefCount; + ur::RefCount RefCount; // Original mem flags passed ur_mem_flags_t MemFlags; @@ -424,12 +424,6 @@ struct ur_mem_handle_t_ : ur::cuda::handle_base { ur_context_handle_t getContext() const noexcept { return Context; } - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } - - uint32_t decrementReferenceCount() noexcept { return --RefCount; } - - uint32_t getReferenceCount() const noexcept { return RefCount; } - void setLastQueueWritingToMemObj(ur_queue_handle_t WritingQueue) { urQueueRetain(WritingQueue); if (LastQueueWritingToMemObj != nullptr) { diff --git a/unified-runtime/source/adapters/cuda/physical_mem.cpp b/unified-runtime/source/adapters/cuda/physical_mem.cpp index 71bf596acb09b..9b49dcfc9e399 100644 --- a/unified-runtime/source/adapters/cuda/physical_mem.cpp +++ b/unified-runtime/source/adapters/cuda/physical_mem.cpp @@ -46,13 +46,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urPhysicalMemCreate( UR_APIEXPORT ur_result_t UR_APICALL urPhysicalMemRetain(ur_physical_mem_handle_t hPhysicalMem) { - hPhysicalMem->incrementReferenceCount(); + hPhysicalMem->RefCount.retain(); return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urPhysicalMemRelease(ur_physical_mem_handle_t hPhysicalMem) { - if (hPhysicalMem->decrementReferenceCount() > 0) + if (!hPhysicalMem->RefCount.release()) return UR_RESULT_SUCCESS; try { @@ -88,7 +88,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urPhysicalMemGetInfo( return ReturnValue(hPhysicalMem->getProperties()); } case UR_PHYSICAL_MEM_INFO_REFERENCE_COUNT: { - return ReturnValue(hPhysicalMem->getReferenceCount()); + return ReturnValue(hPhysicalMem->RefCount.getCount()); } default: return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION; diff --git a/unified-runtime/source/adapters/cuda/physical_mem.hpp b/unified-runtime/source/adapters/cuda/physical_mem.hpp index d9abe587b1b5f..b4f2da82a192f 100644 --- a/unified-runtime/source/adapters/cuda/physical_mem.hpp +++ b/unified-runtime/source/adapters/cuda/physical_mem.hpp @@ -14,6 +14,7 @@ #include #include "adapter.hpp" +#include "common/ur_ref_count.hpp" #include "device.hpp" #include "platform.hpp" @@ -23,7 +24,7 @@ struct ur_physical_mem_handle_t_ : ur::cuda::handle_base { using native_type = CUmemGenericAllocationHandle; - std::atomic_uint32_t RefCount; + ur::RefCount RefCount; native_type PhysicalMem; ur_context_handle_t_ *Context; ur_device_handle_t Device; @@ -33,8 +34,8 @@ struct ur_physical_mem_handle_t_ : ur::cuda::handle_base { ur_physical_mem_handle_t_(native_type PhysMem, ur_context_handle_t_ *Ctx, ur_device_handle_t Device, size_t Size, ur_physical_mem_properties_t Properties) - : handle_base(), RefCount(1), PhysicalMem(PhysMem), Context(Ctx), - Device(Device), Size(Size), Properties(Properties) { + : handle_base(), PhysicalMem(PhysMem), Context(Ctx), Device(Device), + Size(Size), Properties(Properties) { urContextRetain(Context); } @@ -46,12 +47,6 @@ struct ur_physical_mem_handle_t_ : ur::cuda::handle_base { ur_device_handle_t_ *getDevice() const noexcept { return Device; } - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } - - uint32_t decrementReferenceCount() noexcept { return --RefCount; } - - uint32_t getReferenceCount() const noexcept { return RefCount; } - size_t getSize() const noexcept { return Size; } ur_physical_mem_properties_t getProperties() const noexcept { diff --git a/unified-runtime/source/adapters/cuda/program.cpp b/unified-runtime/source/adapters/cuda/program.cpp index 0600d371d5261..15afa125a80a2 100644 --- a/unified-runtime/source/adapters/cuda/program.cpp +++ b/unified-runtime/source/adapters/cuda/program.cpp @@ -350,7 +350,7 @@ urProgramGetInfo(ur_program_handle_t hProgram, ur_program_info_t propName, switch (propName) { case UR_PROGRAM_INFO_REFERENCE_COUNT: - return ReturnValue(hProgram->getReferenceCount()); + return ReturnValue(hProgram->RefCount.getCount()); case UR_PROGRAM_INFO_CONTEXT: return ReturnValue(hProgram->Context); case UR_PROGRAM_INFO_NUM_DEVICES: @@ -383,8 +383,8 @@ urProgramGetInfo(ur_program_handle_t hProgram, ur_program_info_t propName, UR_APIEXPORT ur_result_t UR_APICALL urProgramRetain(ur_program_handle_t hProgram) { - UR_ASSERT(hProgram->getReferenceCount() > 0, UR_RESULT_ERROR_INVALID_PROGRAM); - hProgram->incrementReferenceCount(); + UR_ASSERT(hProgram->RefCount.getCount() > 0, UR_RESULT_ERROR_INVALID_PROGRAM); + hProgram->RefCount.retain(); return UR_RESULT_SUCCESS; } @@ -395,11 +395,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramRelease(ur_program_handle_t hProgram) { // double delete or someone is messing with the ref count. // either way, cannot safely proceed. - UR_ASSERT(hProgram->getReferenceCount() != 0, + UR_ASSERT(hProgram->RefCount.getCount() != 0, UR_RESULT_ERROR_INVALID_PROGRAM); // decrement ref count. If it is 0, delete the program. - if (hProgram->decrementReferenceCount() == 0) { + if (hProgram->RefCount.release()) { std::unique_ptr ProgramPtr{hProgram}; try { ScopedContext Active(hProgram->getDevice()); diff --git a/unified-runtime/source/adapters/cuda/program.hpp b/unified-runtime/source/adapters/cuda/program.hpp index 7371283c274d1..ae7005bea6a58 100644 --- a/unified-runtime/source/adapters/cuda/program.hpp +++ b/unified-runtime/source/adapters/cuda/program.hpp @@ -15,6 +15,7 @@ #include #include +#include "common/ur_ref_count.hpp" #include "context.hpp" struct ur_program_handle_t_ : ur::cuda::handle_base { @@ -22,7 +23,7 @@ struct ur_program_handle_t_ : ur::cuda::handle_base { native_type Module; const char *Binary; size_t BinarySizeInBytes; - std::atomic_uint32_t RefCount; + ur::RefCount RefCount; ur_context_handle_t Context; ur_device_handle_t Device; @@ -71,12 +72,6 @@ struct ur_program_handle_t_ : ur::cuda::handle_base { native_type get() const noexcept { return Module; }; - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } - - uint32_t decrementReferenceCount() noexcept { return --RefCount; } - - uint32_t getReferenceCount() const noexcept { return RefCount; } - ur_result_t getGlobalVariablePointer(const char *name, CUdeviceptr *DeviceGlobal, size_t *DeviceGlobalSize); diff --git a/unified-runtime/source/adapters/cuda/queue.cpp b/unified-runtime/source/adapters/cuda/queue.cpp index 7c0f7b09f3a42..5797d79bb0d57 100644 --- a/unified-runtime/source/adapters/cuda/queue.cpp +++ b/unified-runtime/source/adapters/cuda/queue.cpp @@ -107,14 +107,14 @@ urQueueCreate(ur_context_handle_t hContext, ur_device_handle_t hDevice, } UR_APIEXPORT ur_result_t UR_APICALL urQueueRetain(ur_queue_handle_t hQueue) { - assert(hQueue->getReferenceCount() > 0); + assert(hQueue->RefCount.getCount() > 0); - hQueue->incrementReferenceCount(); + hQueue->RefCount.retain(); return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urQueueRelease(ur_queue_handle_t hQueue) { - if (hQueue->decrementReferenceCount() > 0) { + if (!hQueue->RefCount.release()) { return UR_RESULT_SUCCESS; } @@ -229,7 +229,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueGetInfo(ur_queue_handle_t hQueue, case UR_QUEUE_INFO_DEVICE: return ReturnValue(hQueue->Device); case UR_QUEUE_INFO_REFERENCE_COUNT: - return ReturnValue(hQueue->getReferenceCount()); + return ReturnValue(hQueue->RefCount.getCount()); case UR_QUEUE_INFO_FLAGS: return ReturnValue(hQueue->URFlags); case UR_QUEUE_INFO_EMPTY: { diff --git a/unified-runtime/source/adapters/cuda/sampler.cpp b/unified-runtime/source/adapters/cuda/sampler.cpp index f17c94cfa1b07..7b59d59b07f23 100644 --- a/unified-runtime/source/adapters/cuda/sampler.cpp +++ b/unified-runtime/source/adapters/cuda/sampler.cpp @@ -73,7 +73,7 @@ urSamplerGetInfo(ur_sampler_handle_t hSampler, ur_sampler_info_t propName, switch (propName) { case UR_SAMPLER_INFO_REFERENCE_COUNT: - return ReturnValue(hSampler->getReferenceCount()); + return ReturnValue(hSampler->RefCount.getCount()); case UR_SAMPLER_INFO_CONTEXT: return ReturnValue(hSampler->Context); case UR_SAMPLER_INFO_NORMALIZED_COORDS: { @@ -95,7 +95,7 @@ urSamplerGetInfo(ur_sampler_handle_t hSampler, ur_sampler_info_t propName, UR_APIEXPORT ur_result_t UR_APICALL urSamplerRetain(ur_sampler_handle_t hSampler) { - hSampler->incrementReferenceCount(); + hSampler->RefCount.retain(); return UR_RESULT_SUCCESS; } @@ -103,12 +103,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urSamplerRelease(ur_sampler_handle_t hSampler) { // double delete or someone is messing with the ref count. // either way, cannot safely proceed. - if (hSampler->getReferenceCount() == 0) { + if (hSampler->RefCount.getCount() == 0) { return UR_RESULT_ERROR_INVALID_SAMPLER; } // decrement ref count. If it is 0, delete the sampler. - if (hSampler->decrementReferenceCount() == 0) { + if (hSampler->RefCount.release()) { delete hSampler; } diff --git a/unified-runtime/source/adapters/cuda/sampler.hpp b/unified-runtime/source/adapters/cuda/sampler.hpp index e429439848e06..475a49690c941 100644 --- a/unified-runtime/source/adapters/cuda/sampler.hpp +++ b/unified-runtime/source/adapters/cuda/sampler.hpp @@ -9,6 +9,7 @@ //===----------------------------------------------------------------------===// #include "common.hpp" +#include "common/ur_ref_count.hpp" #include /// Implementation of samplers for CUDA @@ -25,7 +26,7 @@ /// | 1 | filter mode /// | 0 | normalize coords struct ur_sampler_handle_t_ : ur::cuda::handle_base { - std::atomic_uint32_t RefCount; + ur::RefCount RefCount; uint32_t Props; float MinMipmapLevelClamp; float MaxMipmapLevelClamp; @@ -36,12 +37,6 @@ struct ur_sampler_handle_t_ : ur::cuda::handle_base { : handle_base(), RefCount(1), Props(0), MinMipmapLevelClamp(0.0f), MaxMipmapLevelClamp(0.0f), MaxAnisotropy(0.0f), Context(Context) {} - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } - - uint32_t decrementReferenceCount() noexcept { return --RefCount; } - - uint32_t getReferenceCount() const noexcept { return RefCount; } - ur_bool_t isNormalizedCoords() const noexcept { return static_cast(Props & 0b1); } diff --git a/unified-runtime/source/adapters/cuda/usm.cpp b/unified-runtime/source/adapters/cuda/usm.cpp index a1d8b9455c2f6..5ba2db0cfd03b 100644 --- a/unified-runtime/source/adapters/cuda/usm.cpp +++ b/unified-runtime/source/adapters/cuda/usm.cpp @@ -290,14 +290,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMPoolCreate( UR_APIEXPORT ur_result_t UR_APICALL urUSMPoolRetain( /// [in] pointer to USM memory pool ur_usm_pool_handle_t Pool) { - Pool->incrementReferenceCount(); + Pool->RefCount.retain(); return UR_RESULT_SUCCESS; } UR_APIEXPORT ur_result_t UR_APICALL urUSMPoolRelease( /// [in] pointer to USM memory pool ur_usm_pool_handle_t Pool) { - if (Pool->decrementReferenceCount() > 0) { + if (!Pool->RefCount.release()) { return UR_RESULT_SUCCESS; } Pool->Context->removePool(Pool); @@ -320,7 +320,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMPoolGetInfo( switch (propName) { case UR_USM_POOL_INFO_REFERENCE_COUNT: { - return ReturnValue(hPool->getReferenceCount()); + return ReturnValue(hPool->RefCount.getCount()); } case UR_USM_POOL_INFO_CONTEXT: { return ReturnValue(hPool->Context); diff --git a/unified-runtime/source/adapters/cuda/usm.hpp b/unified-runtime/source/adapters/cuda/usm.hpp index 27e1beb6b606c..713c25045976f 100644 --- a/unified-runtime/source/adapters/cuda/usm.hpp +++ b/unified-runtime/source/adapters/cuda/usm.hpp @@ -9,6 +9,7 @@ //===----------------------------------------------------------------------===// #include "common.hpp" +#include "common/ur_ref_count.hpp" #include #include @@ -18,7 +19,7 @@ usm::DisjointPoolAllConfigs InitializeDisjointPoolConfig(); // A ur_usm_pool_handle_t can represent different types of memory pools. It may // sit on top of a UMF pool or a CUmemoryPool, but not both. struct ur_usm_pool_handle_t_ : ur::cuda::handle_base { - std::atomic_uint32_t RefCount = 1; + ur::RefCount RefCount; ur_context_handle_t Context = nullptr; ur_device_handle_t Device = nullptr; @@ -44,12 +45,6 @@ struct ur_usm_pool_handle_t_ : ur::cuda::handle_base { ur_usm_pool_handle_t_(ur_context_handle_t Context, ur_device_handle_t Device, CUmemoryPool CUmemPool); - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } - - uint32_t decrementReferenceCount() noexcept { return --RefCount; } - - uint32_t getReferenceCount() const noexcept { return RefCount; } - bool hasUMFPool(umf_memory_pool_t *umf_pool); // To be used if ur_usm_pool_handle_t represents a CUmemoryPool. diff --git a/unified-runtime/source/common/cuda-hip/stream_queue.hpp b/unified-runtime/source/common/cuda-hip/stream_queue.hpp index 0ead67e1d8729..44f6066a5c380 100644 --- a/unified-runtime/source/common/cuda-hip/stream_queue.hpp +++ b/unified-runtime/source/common/cuda-hip/stream_queue.hpp @@ -11,6 +11,7 @@ #pragma once +#include "common/ur_ref_count.hpp" #include #include #include @@ -44,7 +45,8 @@ struct stream_queue_t { std::vector TransferAppliedBarrier; ur_context_handle_t_ *Context; ur_device_handle_t_ *Device; - std::atomic_uint32_t RefCount{1}; + std::atomic_uint32_t RefCountOld{1}; + ur::RefCount RefCount; std::atomic_uint32_t EventCount{0}; std::atomic_uint32_t ComputeStreamIndex{0}; std::atomic_uint32_t TransferStreamIndex{0}; @@ -344,11 +346,11 @@ struct stream_queue_t { ur_context_handle_t_ *getContext() const { return Context; }; - uint32_t incrementReferenceCount() noexcept { return ++RefCount; } + uint32_t incrementReferenceCount() noexcept { return ++RefCountOld; } - uint32_t decrementReferenceCount() noexcept { return --RefCount; } + uint32_t decrementReferenceCount() noexcept { return --RefCountOld; } - uint32_t getReferenceCount() const noexcept { return RefCount; } + uint32_t getReferenceCount() const noexcept { return RefCountOld; } uint32_t getNextEventId() noexcept { return ++EventCount; }