Skip to content

Commit 01b85ac

Browse files
committed
Refactor reference counting in UR CUDA adapter using new ur::RefCount class.
1 parent b42de17 commit 01b85ac

24 files changed

+77
-118
lines changed

unified-runtime/source/adapters/cuda/adapter.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters,
6666
std::call_once(InitFlag,
6767
[=]() { ur::cuda::adapter = new ur_adapter_handle_t_; });
6868

69-
ur::cuda::adapter->RefCount++;
69+
ur::cuda::adapter->RefCount.retain();
7070
*phAdapters = ur::cuda::adapter;
7171
}
7272

@@ -78,13 +78,13 @@ urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters,
7878
}
7979

8080
UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) {
81-
ur::cuda::adapter->RefCount++;
81+
ur::cuda::adapter->RefCount.retain();
8282

8383
return UR_RESULT_SUCCESS;
8484
}
8585

8686
UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) {
87-
if (--ur::cuda::adapter->RefCount == 0) {
87+
if (ur::cuda::adapter->RefCount.release()) {
8888
delete ur::cuda::adapter;
8989
}
9090
return UR_RESULT_SUCCESS;
@@ -108,7 +108,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t,
108108
case UR_ADAPTER_INFO_BACKEND:
109109
return ReturnValue(UR_BACKEND_CUDA);
110110
case UR_ADAPTER_INFO_REFERENCE_COUNT:
111-
return ReturnValue(ur::cuda::adapter->RefCount.load());
111+
return ReturnValue(ur::cuda::adapter->RefCount.getCount());
112112
case UR_ADAPTER_INFO_VERSION:
113113
return ReturnValue(uint32_t{1});
114114
default:

unified-runtime/source/adapters/cuda/adapter.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#ifndef UR_CUDA_ADAPTER_HPP_INCLUDED
1212
#define UR_CUDA_ADAPTER_HPP_INCLUDED
1313

14+
#include "common/ur_ref_count.hpp"
1415
#include "logger/ur_logger.hpp"
1516
#include "platform.hpp"
1617
#include "tracing.hpp"
@@ -20,7 +21,7 @@
2021
#include <memory>
2122

2223
struct ur_adapter_handle_t_ : ur::cuda::handle_base {
23-
std::atomic<uint32_t> RefCount = 0;
24+
ur::RefCount RefCount;
2425
struct cuda_tracing_context_t_ *TracingCtx = nullptr;
2526
logger::Logger &logger;
2627
std::unique_ptr<ur_platform_handle_t_> Platform;

unified-runtime/source/adapters/cuda/command_buffer.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -382,13 +382,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferCreateExp(
382382

383383
UR_APIEXPORT ur_result_t UR_APICALL
384384
urCommandBufferRetainExp(ur_exp_command_buffer_handle_t hCommandBuffer) {
385-
hCommandBuffer->incrementReferenceCount();
385+
hCommandBuffer->RefCount.retain();
386386
return UR_RESULT_SUCCESS;
387387
}
388388

389389
UR_APIEXPORT ur_result_t UR_APICALL
390390
urCommandBufferReleaseExp(ur_exp_command_buffer_handle_t hCommandBuffer) {
391-
if (hCommandBuffer->decrementReferenceCount() == 0) {
391+
if (hCommandBuffer->RefCount.release()) {
392392
// Ref count has reached zero, release of created commands
393393
for (auto &Command : hCommandBuffer->CommandHandles) {
394394
commandHandleDestroy(Command);
@@ -1478,7 +1478,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferGetInfoExp(
14781478

14791479
switch (propName) {
14801480
case UR_EXP_COMMAND_BUFFER_INFO_REFERENCE_COUNT:
1481-
return ReturnValue(hCommandBuffer->getReferenceCount());
1481+
return ReturnValue(hCommandBuffer->RefCount.getCount());
14821482
case UR_EXP_COMMAND_BUFFER_INFO_DESCRIPTOR: {
14831483
ur_exp_command_buffer_desc_t Descriptor{};
14841484
Descriptor.stype = UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_DESC;

unified-runtime/source/adapters/cuda/command_buffer.hpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <ur_api.h>
1313
#include <ur_print.hpp>
1414

15+
#include "common/ur_ref_count.hpp"
1516
#include "context.hpp"
1617
#include "logger/ur_logger.hpp"
1718
#include <cuda.h>
@@ -173,10 +174,6 @@ struct ur_exp_command_buffer_handle_t_ : ur::cuda::handle_base {
173174
return SyncPoint;
174175
}
175176

176-
uint32_t incrementReferenceCount() noexcept { return ++RefCount; }
177-
uint32_t decrementReferenceCount() noexcept { return --RefCount; }
178-
uint32_t getReferenceCount() const noexcept { return RefCount; }
179-
180177
// UR context associated with this command-buffer
181178
ur_context_handle_t Context;
182179
// Device associated with this command-buffer
@@ -189,9 +186,8 @@ struct ur_exp_command_buffer_handle_t_ : ur::cuda::handle_base {
189186
CUgraph CudaGraph;
190187
// Cuda Graph Exec handle
191188
CUgraphExec CudaGraphExec = nullptr;
192-
// Atomic variable counting the number of reference to this command_buffer
193-
// using std::atomic prevents data race when incrementing/decrementing.
194-
std::atomic_uint32_t RefCount;
189+
190+
ur::RefCount RefCount;
195191

196192
// Ordered map of sync_points to ur_events, so that we can find the last
197193
// node added to an in-order command-buffer.

unified-runtime/source/adapters/cuda/context.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextGetInfo(
6666
return ReturnValue(hContext->getDevices().data(),
6767
hContext->getDevices().size());
6868
case UR_CONTEXT_INFO_REFERENCE_COUNT:
69-
return ReturnValue(hContext->getReferenceCount());
69+
return ReturnValue(hContext->RefCount.getCount());
7070
case UR_CONTEXT_INFO_USM_MEMCPY2D_SUPPORT:
7171
// 2D USM memcpy is supported.
7272
return ReturnValue(true);
@@ -83,7 +83,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextGetInfo(
8383

8484
UR_APIEXPORT ur_result_t UR_APICALL
8585
urContextRelease(ur_context_handle_t hContext) {
86-
if (hContext->decrementReferenceCount() > 0) {
86+
if (!hContext->RefCount.release()) {
8787
return UR_RESULT_SUCCESS;
8888
}
8989
hContext->invokeExtendedDeleters();
@@ -94,9 +94,9 @@ urContextRelease(ur_context_handle_t hContext) {
9494

9595
UR_APIEXPORT ur_result_t UR_APICALL
9696
urContextRetain(ur_context_handle_t hContext) {
97-
assert(hContext->getReferenceCount() > 0);
97+
assert(hContext->RefCount.getCount() > 0);
9898

99-
hContext->incrementReferenceCount();
99+
hContext->RefCount.retain();
100100
return UR_RESULT_SUCCESS;
101101
}
102102

unified-runtime/source/adapters/cuda/context.hpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
#include "adapter.hpp"
2222
#include "common.hpp"
23+
#include "common/ur_ref_count.hpp"
2324
#include "device.hpp"
2425
#include "umf_helpers.hpp"
2526

@@ -88,7 +89,7 @@ struct ur_context_handle_t_ : ur::cuda::handle_base {
8889
};
8990

9091
std::vector<ur_device_handle_t> Devices;
91-
std::atomic_uint32_t RefCount;
92+
ur::RefCount RefCount;
9293

9394
// UMF CUDA memory provider and pool for the host memory
9495
// (UMF_MEMORY_TYPE_HOST)
@@ -140,12 +141,6 @@ struct ur_context_handle_t_ : ur::cuda::handle_base {
140141
return std::distance(Devices.begin(), It);
141142
}
142143

143-
uint32_t incrementReferenceCount() noexcept { return ++RefCount; }
144-
145-
uint32_t decrementReferenceCount() noexcept { return --RefCount; }
146-
147-
uint32_t getReferenceCount() const noexcept { return RefCount; }
148-
149144
void addPool(ur_usm_pool_handle_t Pool);
150145

151146
void removePool(ur_usm_pool_handle_t Pool);

unified-runtime/source/adapters/cuda/device.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
593593
return ReturnValue("CUDA");
594594
}
595595
case UR_DEVICE_INFO_REFERENCE_COUNT: {
596-
return ReturnValue(hDevice->getReferenceCount());
596+
return ReturnValue(hDevice->RefCount.getCount());
597597
}
598598
case UR_DEVICE_INFO_VERSION: {
599599
std::stringstream SS;

unified-runtime/source/adapters/cuda/device.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <umf/memory_provider.h>
1616

1717
#include "common.hpp"
18+
#include "common/ur_ref_count.hpp"
1819

1920
struct ur_device_handle_t_ : ur::cuda::handle_base {
2021
private:
@@ -23,7 +24,6 @@ struct ur_device_handle_t_ : ur::cuda::handle_base {
2324
native_type CuDevice;
2425
CUcontext CuContext;
2526
CUevent EvBase; // CUDA event used as base counter
26-
std::atomic_uint32_t RefCount;
2727
ur_platform_handle_t Platform;
2828
uint32_t DeviceIndex;
2929

@@ -42,7 +42,7 @@ struct ur_device_handle_t_ : ur::cuda::handle_base {
4242
ur_device_handle_t_(native_type cuDevice, CUcontext cuContext, CUevent evBase,
4343
ur_platform_handle_t platform, uint32_t DevIndex)
4444
: handle_base(), CuDevice(cuDevice), CuContext(cuContext), EvBase(evBase),
45-
RefCount{1}, Platform(platform), DeviceIndex{DevIndex} {
45+
Platform(platform), DeviceIndex{DevIndex} {
4646
UR_CHECK_ERROR(cuDeviceGetAttribute(
4747
&MaxRegsPerBlock, CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK,
4848
cuDevice));
@@ -136,8 +136,6 @@ struct ur_device_handle_t_ : ur::cuda::handle_base {
136136

137137
CUcontext getNativeContext() const noexcept { return CuContext; };
138138

139-
uint32_t getReferenceCount() const noexcept { return RefCount; }
140-
141139
ur_platform_handle_t getPlatform() const noexcept { return Platform; };
142140

143141
// Returns the index of the device relative to the other devices in the same
@@ -178,6 +176,8 @@ struct ur_device_handle_t_ : ur::cuda::handle_base {
178176
// (UMF_MEMORY_TYPE_SHARED)
179177
umf_memory_provider_handle_t MemoryProviderShared;
180178
umf_memory_pool_handle_t MemoryPoolShared;
179+
180+
ur::RefCount RefCount;
181181
};
182182

183183
int getAttribute(ur_device_handle_t Device, CUdevice_attribute Attribute);

unified-runtime/source/adapters/cuda/event.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventGetInfo(ur_event_handle_t hEvent,
179179
case UR_EVENT_INFO_COMMAND_TYPE:
180180
return ReturnValue(hEvent->getCommandType());
181181
case UR_EVENT_INFO_REFERENCE_COUNT:
182-
return ReturnValue(hEvent->getReferenceCount());
182+
return ReturnValue(hEvent->RefCount.getCount());
183183
case UR_EVENT_INFO_COMMAND_EXECUTION_STATUS:
184184
return ReturnValue(hEvent->getExecutionStatus());
185185
case UR_EVENT_INFO_CONTEXT:
@@ -248,7 +248,7 @@ urEventWait(uint32_t numEvents, const ur_event_handle_t *phEventWaitList) {
248248
}
249249

250250
UR_APIEXPORT ur_result_t UR_APICALL urEventRetain(ur_event_handle_t hEvent) {
251-
const auto RefCount = hEvent->incrementReferenceCount();
251+
const auto RefCount = hEvent->RefCount.retain();
252252

253253
if (RefCount == 0) {
254254
return UR_RESULT_ERROR_OUT_OF_RESOURCES;
@@ -260,12 +260,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventRetain(ur_event_handle_t hEvent) {
260260
UR_APIEXPORT ur_result_t UR_APICALL urEventRelease(ur_event_handle_t hEvent) {
261261
// double delete or someone is messing with the ref count.
262262
// either way, cannot safely proceed.
263-
if (hEvent->getReferenceCount() == 0) {
263+
if (hEvent->RefCount.getCount() == 0) {
264264
return UR_RESULT_ERROR_INVALID_EVENT;
265265
}
266266

267267
// decrement ref count. If it is 0, delete the event.
268-
if (hEvent->decrementReferenceCount() == 0) {
268+
if (hEvent->release()) {
269269
std::unique_ptr<ur_event_handle_t_> event_ptr{hEvent};
270270
ur_result_t Result = UR_RESULT_ERROR_INVALID_EVENT;
271271
try {

unified-runtime/source/adapters/cuda/event.hpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <ur/ur.hpp>
1414

1515
#include "common.hpp"
16+
#include "common/ur_ref_count.hpp"
1617
#include "queue.hpp"
1718

1819
/// UR Event mapping to CUevent
@@ -60,16 +61,11 @@ struct ur_event_handle_t_ : ur::cuda::handle_base {
6061
ur_context_handle_t getContext() const noexcept { return Context; };
6162
uint32_t getEventID() const noexcept { return EventID; }
6263

63-
// Reference counting.
64-
uint32_t getReferenceCount() const noexcept { return RefCount; }
65-
uint32_t incrementReferenceCount() { return ++RefCount; }
66-
uint32_t decrementReferenceCount() { return --RefCount; }
64+
ur::RefCount RefCount;
6765

6866
private:
6967
ur_command_t CommandType; // The type of command associated with event.
7068

71-
std::atomic_uint32_t RefCount{1}; // Event reference count.
72-
7369
bool HasOwnership{true}; // Signifies if event owns the native type.
7470
bool HasProfiling{false}; // Signifies if event has profiling information.
7571

unified-runtime/source/adapters/cuda/kernel.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,20 +127,20 @@ urKernelGetGroupInfo(ur_kernel_handle_t hKernel, ur_device_handle_t hDevice,
127127
}
128128

129129
UR_APIEXPORT ur_result_t UR_APICALL urKernelRetain(ur_kernel_handle_t hKernel) {
130-
UR_ASSERT(hKernel->getReferenceCount() > 0u, UR_RESULT_ERROR_INVALID_KERNEL);
130+
UR_ASSERT(hKernel->RefCount.getCount() > 0u, UR_RESULT_ERROR_INVALID_KERNEL);
131131

132-
hKernel->incrementReferenceCount();
132+
hKernel->RefCount.retain();
133133
return UR_RESULT_SUCCESS;
134134
}
135135

136136
UR_APIEXPORT ur_result_t UR_APICALL
137137
urKernelRelease(ur_kernel_handle_t hKernel) {
138138
// double delete or someone is messing with the ref count.
139139
// either way, cannot safely proceed.
140-
UR_ASSERT(hKernel->getReferenceCount() != 0, UR_RESULT_ERROR_INVALID_KERNEL);
140+
UR_ASSERT(hKernel->RefCount.getCount() != 0, UR_RESULT_ERROR_INVALID_KERNEL);
141141

142142
// decrement ref count. If it is 0, delete the program.
143-
if (hKernel->decrementReferenceCount() == 0) {
143+
if (hKernel->RefCount.release()) {
144144
// no internal cuda resources to clean up. Just delete it.
145145
delete hKernel;
146146
return UR_RESULT_SUCCESS;
@@ -248,7 +248,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetInfo(ur_kernel_handle_t hKernel,
248248
case UR_KERNEL_INFO_NUM_ARGS:
249249
return ReturnValue(hKernel->getNumArgs());
250250
case UR_KERNEL_INFO_REFERENCE_COUNT:
251-
return ReturnValue(hKernel->getReferenceCount());
251+
return ReturnValue(hKernel->RefCount.getCount());
252252
case UR_KERNEL_INFO_CONTEXT:
253253
return ReturnValue(hKernel->getContext());
254254
case UR_KERNEL_INFO_PROGRAM:

unified-runtime/source/adapters/cuda/kernel.hpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <cassert>
1818
#include <numeric>
1919

20+
#include "common/ur_ref_count.hpp"
2021
#include "program.hpp"
2122

2223
/// Implementation of a UR Kernel for CUDA
@@ -42,7 +43,7 @@ struct ur_kernel_handle_t_ : ur::cuda::handle_base {
4243
std::string Name;
4344
ur_context_handle_t Context;
4445
ur_program_handle_t Program;
45-
std::atomic_uint32_t RefCount;
46+
ur::RefCount RefCount;
4647

4748
static constexpr uint32_t ReqdThreadsPerBlockDimensions = 3u;
4849
size_t ReqdThreadsPerBlock[ReqdThreadsPerBlockDimensions];
@@ -304,12 +305,6 @@ struct ur_kernel_handle_t_ : ur::cuda::handle_base {
304305
urContextRelease(Context);
305306
}
306307

307-
uint32_t incrementReferenceCount() noexcept { return ++RefCount; }
308-
309-
uint32_t decrementReferenceCount() noexcept { return --RefCount; }
310-
311-
uint32_t getReferenceCount() const noexcept { return RefCount; }
312-
313308
native_type get() const noexcept { return Function; };
314309

315310
ur_program_handle_t getProgram() const noexcept { return Program; };

unified-runtime/source/adapters/cuda/memory.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferCreate(
7878
}
7979

8080
UR_APIEXPORT ur_result_t UR_APICALL urMemRetain(ur_mem_handle_t hMem) {
81-
UR_ASSERT(hMem->getReferenceCount() > 0, UR_RESULT_ERROR_INVALID_MEM_OBJECT);
82-
hMem->incrementReferenceCount();
81+
UR_ASSERT(hMem->RefCount.getCount() > 0, UR_RESULT_ERROR_INVALID_MEM_OBJECT);
82+
hMem->RefCount.retain();
8383
return UR_RESULT_SUCCESS;
8484
}
8585

@@ -89,7 +89,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemRetain(ur_mem_handle_t hMem) {
8989
UR_APIEXPORT ur_result_t UR_APICALL urMemRelease(ur_mem_handle_t hMem) {
9090
try {
9191
// Do nothing if there are other references
92-
if (hMem->decrementReferenceCount() > 0) {
92+
if (!hMem->RefCount.release()) {
9393
return UR_RESULT_SUCCESS;
9494
}
9595

@@ -162,7 +162,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemGetInfo(ur_mem_handle_t hMemory,
162162
return ReturnValue(hMemory->getContext());
163163
}
164164
case UR_MEM_INFO_REFERENCE_COUNT: {
165-
return ReturnValue(hMemory->getReferenceCount());
165+
return ReturnValue(hMemory->RefCount.getCount());
166166
}
167167

168168
default:

0 commit comments

Comments
 (0)