Skip to content

[UR][CUDA] Refactor reference counting in UR CUDA adapter using new ur::RefCount class #19287

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: sycl
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions unified-runtime/source/adapters/cuda/adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters,
std::call_once(InitFlag,
[=]() { ur::cuda::adapter = new ur_adapter_handle_t_; });

ur::cuda::adapter->RefCount++;
ur::cuda::adapter->RefCount.retain();
*phAdapters = ur::cuda::adapter;
}

Expand All @@ -78,13 +78,13 @@ urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters,
}

UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) {
ur::cuda::adapter->RefCount++;
ur::cuda::adapter->RefCount.retain();

return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) {
if (--ur::cuda::adapter->RefCount == 0) {
if (ur::cuda::adapter->RefCount.release()) {
delete ur::cuda::adapter;
}
return UR_RESULT_SUCCESS;
Expand All @@ -108,7 +108,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t,
case UR_ADAPTER_INFO_BACKEND:
return ReturnValue(UR_BACKEND_CUDA);
case UR_ADAPTER_INFO_REFERENCE_COUNT:
return ReturnValue(ur::cuda::adapter->RefCount.load());
return ReturnValue(ur::cuda::adapter->RefCount.getCount());
case UR_ADAPTER_INFO_VERSION:
return ReturnValue(uint32_t{1});
default:
Expand Down
3 changes: 2 additions & 1 deletion unified-runtime/source/adapters/cuda/adapter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#ifndef UR_CUDA_ADAPTER_HPP_INCLUDED
#define UR_CUDA_ADAPTER_HPP_INCLUDED

#include "common/ur_ref_count.hpp"
#include "logger/ur_logger.hpp"
#include "platform.hpp"
#include "tracing.hpp"
Expand All @@ -20,7 +21,7 @@
#include <memory>

struct ur_adapter_handle_t_ : ur::cuda::handle_base {
std::atomic<uint32_t> RefCount = 0;
ur::RefCount RefCount;
struct cuda_tracing_context_t_ *TracingCtx = nullptr;
logger::Logger &logger;
std::unique_ptr<ur_platform_handle_t_> Platform;
Expand Down
6 changes: 3 additions & 3 deletions unified-runtime/source/adapters/cuda/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,13 +382,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferCreateExp(

UR_APIEXPORT ur_result_t UR_APICALL
urCommandBufferRetainExp(ur_exp_command_buffer_handle_t hCommandBuffer) {
hCommandBuffer->incrementReferenceCount();
hCommandBuffer->RefCount.retain();
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL
urCommandBufferReleaseExp(ur_exp_command_buffer_handle_t hCommandBuffer) {
if (hCommandBuffer->decrementReferenceCount() == 0) {
if (hCommandBuffer->RefCount.release()) {
// Ref count has reached zero, release of created commands
for (auto &Command : hCommandBuffer->CommandHandles) {
commandHandleDestroy(Command);
Expand Down Expand Up @@ -1478,7 +1478,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferGetInfoExp(

switch (propName) {
case UR_EXP_COMMAND_BUFFER_INFO_REFERENCE_COUNT:
return ReturnValue(hCommandBuffer->getReferenceCount());
return ReturnValue(hCommandBuffer->RefCount.getCount());
case UR_EXP_COMMAND_BUFFER_INFO_DESCRIPTOR: {
ur_exp_command_buffer_desc_t Descriptor{};
Descriptor.stype = UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_DESC;
Expand Down
10 changes: 3 additions & 7 deletions unified-runtime/source/adapters/cuda/command_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <ur_api.h>
#include <ur_print.hpp>

#include "common/ur_ref_count.hpp"
#include "context.hpp"
#include "logger/ur_logger.hpp"
#include <cuda.h>
Expand Down Expand Up @@ -173,10 +174,6 @@ struct ur_exp_command_buffer_handle_t_ : ur::cuda::handle_base {
return SyncPoint;
}

uint32_t incrementReferenceCount() noexcept { return ++RefCount; }
uint32_t decrementReferenceCount() noexcept { return --RefCount; }
uint32_t getReferenceCount() const noexcept { return RefCount; }

// UR context associated with this command-buffer
ur_context_handle_t Context;
// Device associated with this command-buffer
Expand All @@ -189,9 +186,8 @@ struct ur_exp_command_buffer_handle_t_ : ur::cuda::handle_base {
CUgraph CudaGraph;
// Cuda Graph Exec handle
CUgraphExec CudaGraphExec = nullptr;
// Atomic variable counting the number of reference to this command_buffer
// using std::atomic prevents data race when incrementing/decrementing.
std::atomic_uint32_t RefCount;

ur::RefCount RefCount;

// Ordered map of sync_points to ur_events, so that we can find the last
// node added to an in-order command-buffer.
Expand Down
8 changes: 4 additions & 4 deletions unified-runtime/source/adapters/cuda/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextGetInfo(
return ReturnValue(hContext->getDevices().data(),
hContext->getDevices().size());
case UR_CONTEXT_INFO_REFERENCE_COUNT:
return ReturnValue(hContext->getReferenceCount());
return ReturnValue(hContext->RefCount.getCount());
case UR_CONTEXT_INFO_USM_MEMCPY2D_SUPPORT:
// 2D USM memcpy is supported.
return ReturnValue(true);
Expand All @@ -83,7 +83,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextGetInfo(

UR_APIEXPORT ur_result_t UR_APICALL
urContextRelease(ur_context_handle_t hContext) {
if (hContext->decrementReferenceCount() > 0) {
if (!hContext->RefCount.release()) {
return UR_RESULT_SUCCESS;
}
hContext->invokeExtendedDeleters();
Expand All @@ -94,9 +94,9 @@ urContextRelease(ur_context_handle_t hContext) {

UR_APIEXPORT ur_result_t UR_APICALL
urContextRetain(ur_context_handle_t hContext) {
assert(hContext->getReferenceCount() > 0);
assert(hContext->RefCount.getCount() > 0);

hContext->incrementReferenceCount();
hContext->RefCount.retain();
return UR_RESULT_SUCCESS;
}

Expand Down
9 changes: 2 additions & 7 deletions unified-runtime/source/adapters/cuda/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "adapter.hpp"
#include "common.hpp"
#include "common/ur_ref_count.hpp"
#include "device.hpp"
#include "umf_helpers.hpp"

Expand Down Expand Up @@ -88,7 +89,7 @@ struct ur_context_handle_t_ : ur::cuda::handle_base {
};

std::vector<ur_device_handle_t> Devices;
std::atomic_uint32_t RefCount;
ur::RefCount RefCount;

// UMF CUDA memory provider and pool for the host memory
// (UMF_MEMORY_TYPE_HOST)
Expand Down Expand Up @@ -140,12 +141,6 @@ struct ur_context_handle_t_ : ur::cuda::handle_base {
return std::distance(Devices.begin(), It);
}

uint32_t incrementReferenceCount() noexcept { return ++RefCount; }

uint32_t decrementReferenceCount() noexcept { return --RefCount; }

uint32_t getReferenceCount() const noexcept { return RefCount; }

void addPool(ur_usm_pool_handle_t Pool);

void removePool(ur_usm_pool_handle_t Pool);
Expand Down
2 changes: 1 addition & 1 deletion unified-runtime/source/adapters/cuda/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
return ReturnValue("CUDA");
}
case UR_DEVICE_INFO_REFERENCE_COUNT: {
return ReturnValue(hDevice->getReferenceCount());
return ReturnValue(hDevice->RefCount.getCount());
}
case UR_DEVICE_INFO_VERSION: {
std::stringstream SS;
Expand Down
8 changes: 4 additions & 4 deletions unified-runtime/source/adapters/cuda/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <umf/memory_provider.h>

#include "common.hpp"
#include "common/ur_ref_count.hpp"

struct ur_device_handle_t_ : ur::cuda::handle_base {
private:
Expand All @@ -23,7 +24,6 @@ struct ur_device_handle_t_ : ur::cuda::handle_base {
native_type CuDevice;
CUcontext CuContext;
CUevent EvBase; // CUDA event used as base counter
std::atomic_uint32_t RefCount;
ur_platform_handle_t Platform;
uint32_t DeviceIndex;

Expand All @@ -42,7 +42,7 @@ struct ur_device_handle_t_ : ur::cuda::handle_base {
ur_device_handle_t_(native_type cuDevice, CUcontext cuContext, CUevent evBase,
ur_platform_handle_t platform, uint32_t DevIndex)
: handle_base(), CuDevice(cuDevice), CuContext(cuContext), EvBase(evBase),
RefCount{1}, Platform(platform), DeviceIndex{DevIndex} {
Platform(platform), DeviceIndex{DevIndex} {
UR_CHECK_ERROR(cuDeviceGetAttribute(
&MaxRegsPerBlock, CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK,
cuDevice));
Expand Down Expand Up @@ -136,8 +136,6 @@ struct ur_device_handle_t_ : ur::cuda::handle_base {

CUcontext getNativeContext() const noexcept { return CuContext; };

uint32_t getReferenceCount() const noexcept { return RefCount; }

ur_platform_handle_t getPlatform() const noexcept { return Platform; };

// Returns the index of the device relative to the other devices in the same
Expand Down Expand Up @@ -178,6 +176,8 @@ struct ur_device_handle_t_ : ur::cuda::handle_base {
// (UMF_MEMORY_TYPE_SHARED)
umf_memory_provider_handle_t MemoryProviderShared;
umf_memory_pool_handle_t MemoryPoolShared;

ur::RefCount RefCount;
};

int getAttribute(ur_device_handle_t Device, CUdevice_attribute Attribute);
8 changes: 4 additions & 4 deletions unified-runtime/source/adapters/cuda/event.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventGetInfo(ur_event_handle_t hEvent,
case UR_EVENT_INFO_COMMAND_TYPE:
return ReturnValue(hEvent->getCommandType());
case UR_EVENT_INFO_REFERENCE_COUNT:
return ReturnValue(hEvent->getReferenceCount());
return ReturnValue(hEvent->RefCount.getCount());
case UR_EVENT_INFO_COMMAND_EXECUTION_STATUS:
return ReturnValue(hEvent->getExecutionStatus());
case UR_EVENT_INFO_CONTEXT:
Expand Down Expand Up @@ -248,7 +248,7 @@ urEventWait(uint32_t numEvents, const ur_event_handle_t *phEventWaitList) {
}

UR_APIEXPORT ur_result_t UR_APICALL urEventRetain(ur_event_handle_t hEvent) {
const auto RefCount = hEvent->incrementReferenceCount();
const auto RefCount = hEvent->RefCount.retain();

if (RefCount == 0) {
return UR_RESULT_ERROR_OUT_OF_RESOURCES;
Expand All @@ -260,12 +260,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventRetain(ur_event_handle_t hEvent) {
UR_APIEXPORT ur_result_t UR_APICALL urEventRelease(ur_event_handle_t hEvent) {
// double delete or someone is messing with the ref count.
// either way, cannot safely proceed.
if (hEvent->getReferenceCount() == 0) {
if (hEvent->RefCount.getCount() == 0) {
return UR_RESULT_ERROR_INVALID_EVENT;
}

// decrement ref count. If it is 0, delete the event.
if (hEvent->decrementReferenceCount() == 0) {
if (hEvent->release()) {
std::unique_ptr<ur_event_handle_t_> event_ptr{hEvent};
ur_result_t Result = UR_RESULT_ERROR_INVALID_EVENT;
try {
Expand Down
8 changes: 2 additions & 6 deletions unified-runtime/source/adapters/cuda/event.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <ur/ur.hpp>

#include "common.hpp"
#include "common/ur_ref_count.hpp"
#include "queue.hpp"

/// UR Event mapping to CUevent
Expand Down Expand Up @@ -60,16 +61,11 @@ struct ur_event_handle_t_ : ur::cuda::handle_base {
ur_context_handle_t getContext() const noexcept { return Context; };
uint32_t getEventID() const noexcept { return EventID; }

// Reference counting.
uint32_t getReferenceCount() const noexcept { return RefCount; }
uint32_t incrementReferenceCount() { return ++RefCount; }
uint32_t decrementReferenceCount() { return --RefCount; }
ur::RefCount RefCount;

private:
ur_command_t CommandType; // The type of command associated with event.

std::atomic_uint32_t RefCount{1}; // Event reference count.

bool HasOwnership{true}; // Signifies if event owns the native type.
bool HasProfiling{false}; // Signifies if event has profiling information.

Expand Down
10 changes: 5 additions & 5 deletions unified-runtime/source/adapters/cuda/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,20 +127,20 @@ urKernelGetGroupInfo(ur_kernel_handle_t hKernel, ur_device_handle_t hDevice,
}

UR_APIEXPORT ur_result_t UR_APICALL urKernelRetain(ur_kernel_handle_t hKernel) {
UR_ASSERT(hKernel->getReferenceCount() > 0u, UR_RESULT_ERROR_INVALID_KERNEL);
UR_ASSERT(hKernel->RefCount.getCount() > 0u, UR_RESULT_ERROR_INVALID_KERNEL);

hKernel->incrementReferenceCount();
hKernel->RefCount.retain();
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL
urKernelRelease(ur_kernel_handle_t hKernel) {
// double delete or someone is messing with the ref count.
// either way, cannot safely proceed.
UR_ASSERT(hKernel->getReferenceCount() != 0, UR_RESULT_ERROR_INVALID_KERNEL);
UR_ASSERT(hKernel->RefCount.getCount() != 0, UR_RESULT_ERROR_INVALID_KERNEL);

// decrement ref count. If it is 0, delete the program.
if (hKernel->decrementReferenceCount() == 0) {
if (hKernel->RefCount.release()) {
// no internal cuda resources to clean up. Just delete it.
delete hKernel;
return UR_RESULT_SUCCESS;
Expand Down Expand Up @@ -248,7 +248,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetInfo(ur_kernel_handle_t hKernel,
case UR_KERNEL_INFO_NUM_ARGS:
return ReturnValue(hKernel->getNumArgs());
case UR_KERNEL_INFO_REFERENCE_COUNT:
return ReturnValue(hKernel->getReferenceCount());
return ReturnValue(hKernel->RefCount.getCount());
case UR_KERNEL_INFO_CONTEXT:
return ReturnValue(hKernel->getContext());
case UR_KERNEL_INFO_PROGRAM:
Expand Down
9 changes: 2 additions & 7 deletions unified-runtime/source/adapters/cuda/kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <cassert>
#include <numeric>

#include "common/ur_ref_count.hpp"
#include "program.hpp"

/// Implementation of a UR Kernel for CUDA
Expand All @@ -42,7 +43,7 @@ struct ur_kernel_handle_t_ : ur::cuda::handle_base {
std::string Name;
ur_context_handle_t Context;
ur_program_handle_t Program;
std::atomic_uint32_t RefCount;
ur::RefCount RefCount;

static constexpr uint32_t ReqdThreadsPerBlockDimensions = 3u;
size_t ReqdThreadsPerBlock[ReqdThreadsPerBlockDimensions];
Expand Down Expand Up @@ -304,12 +305,6 @@ struct ur_kernel_handle_t_ : ur::cuda::handle_base {
urContextRelease(Context);
}

uint32_t incrementReferenceCount() noexcept { return ++RefCount; }

uint32_t decrementReferenceCount() noexcept { return --RefCount; }

uint32_t getReferenceCount() const noexcept { return RefCount; }

native_type get() const noexcept { return Function; };

ur_program_handle_t getProgram() const noexcept { return Program; };
Expand Down
8 changes: 4 additions & 4 deletions unified-runtime/source/adapters/cuda/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferCreate(
}

UR_APIEXPORT ur_result_t UR_APICALL urMemRetain(ur_mem_handle_t hMem) {
UR_ASSERT(hMem->getReferenceCount() > 0, UR_RESULT_ERROR_INVALID_MEM_OBJECT);
hMem->incrementReferenceCount();
UR_ASSERT(hMem->RefCount.getCount() > 0, UR_RESULT_ERROR_INVALID_MEM_OBJECT);
hMem->RefCount.retain();
return UR_RESULT_SUCCESS;
}

Expand All @@ -89,7 +89,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemRetain(ur_mem_handle_t hMem) {
UR_APIEXPORT ur_result_t UR_APICALL urMemRelease(ur_mem_handle_t hMem) {
try {
// Do nothing if there are other references
if (hMem->decrementReferenceCount() > 0) {
if (!hMem->RefCount.release()) {
return UR_RESULT_SUCCESS;
}

Expand Down Expand Up @@ -162,7 +162,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemGetInfo(ur_mem_handle_t hMemory,
return ReturnValue(hMemory->getContext());
}
case UR_MEM_INFO_REFERENCE_COUNT: {
return ReturnValue(hMemory->getReferenceCount());
return ReturnValue(hMemory->RefCount.getCount());
}

default:
Expand Down
Loading
Loading