Skip to content

Commit 18e36c5

Browse files
authored
[UR][CUDA][HIP] Allow adapter objects to block full adapter teardown. (#17571)
In the cuda adapter the adapter struct itself is currently an extern global defined in adapter.cpp. This means fully tearing down the adapter is subject to the same destructor ordering as all other static and global variables, it's first in last out. This presents a problem because an application can declare a static sycl object like a buffer right up top before doing anything else, which results in the sycl object being destroyed after the cuda adapter struct. The UR spec doesn't put the onus on users to keep their parent object lifetimes in order, i.e. there is no statement about "the context you use to create a ur_mem_handle_t must not be released until after the mem_handle". It's assumed (by omission rather than explicitly) that adapters will have their objects keep a reference to any parent objects alive for the duration of their own lifetime. This change moves the cuda adapter structs ownership into a global shared_ptr, which allows child objects of the adapter to keep their own references to it alive past the point where its initial definition goes out of scope. Also adjusts how some other objects track parent object references so that the destructors correctly cascade back to the top: mem handle releases its context, which releases its adapter, which releases the platform + devices, etc. All of this also applies to the hip adapter, although it seems something in hip itself prevents this change from fixing the static_buffer_dtor test - see #17571 (comment) Fixes #17450
1 parent 78a557d commit 18e36c5

File tree

16 files changed

+137
-85
lines changed

16 files changed

+137
-85
lines changed

sycl/test-e2e/Regression/static-buffer-dtor.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,6 @@
2121
// UNSUPPORTED: windows && arch-intel_gpu_bmg_g21
2222
// UNSUPPORTED-TRACKER: https://github.com/intel/llvm/issues/17255
2323

24-
// UNSUPPORTED: cuda
25-
// UNSUPPORTED-TRACKER: https://github.com/intel/llvm/issues/17450
26-
2724
#include <sycl/detail/core.hpp>
2825

2926
int main() {

unified-runtime/source/adapters/cuda/adapter.cpp

Lines changed: 28 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,16 @@
88
//
99
//===----------------------------------------------------------------------===//
1010

11-
#include <ur_api.h>
12-
11+
#include "adapter.hpp"
1312
#include "common.hpp"
14-
#include "logger/ur_logger.hpp"
13+
#include "platform.hpp"
1514
#include "tracing.hpp"
1615

17-
struct ur_adapter_handle_t_ {
18-
std::atomic<uint32_t> RefCount = 0;
19-
std::mutex Mutex;
20-
struct cuda_tracing_context_t_ *TracingCtx = nullptr;
21-
logger::Logger &logger;
22-
ur_adapter_handle_t_();
23-
};
16+
#include <memory>
17+
18+
namespace ur::cuda {
19+
ur_adapter_handle_t adapter;
20+
} // namespace ur::cuda
2421

2522
class ur_legacy_sink : public logger::Sink {
2623
public:
@@ -43,28 +40,33 @@ class ur_legacy_sink : public logger::Sink {
4340
ur_adapter_handle_t_::ur_adapter_handle_t_()
4441
: logger(logger::get_logger("cuda",
4542
/*default_log_level*/ logger::Level::ERR)) {
43+
Platform = std::make_unique<ur_platform_handle_t_>();
4644

47-
if (std::getenv("UR_LOG_CUDA") != nullptr)
48-
return;
49-
50-
if (std::getenv("SYCL_PI_SUPPRESS_ERROR_MESSAGE") != nullptr ||
51-
std::getenv("UR_SUPPRESS_ERROR_MESSAGE") != nullptr) {
45+
if (std::getenv("UR_LOG_CUDA") == nullptr &&
46+
(std::getenv("SYCL_PI_SUPPRESS_ERROR_MESSAGE") != nullptr ||
47+
std::getenv("UR_SUPPRESS_ERROR_MESSAGE") != nullptr)) {
5248
logger.setLegacySink(std::make_unique<ur_legacy_sink>());
5349
}
50+
51+
TracingCtx = createCUDATracingContext();
52+
enableCUDATracing(TracingCtx);
53+
}
54+
55+
ur_adapter_handle_t_::~ur_adapter_handle_t_() {
56+
disableCUDATracing(TracingCtx);
57+
freeCUDATracingContext(TracingCtx);
5458
}
55-
ur_adapter_handle_t_ adapter{};
5659

5760
UR_APIEXPORT ur_result_t UR_APICALL
5861
urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters,
5962
uint32_t *pNumAdapters) {
6063
if (NumEntries > 0 && phAdapters) {
61-
std::lock_guard<std::mutex> Lock{adapter.Mutex};
62-
if (adapter.RefCount++ == 0) {
63-
adapter.TracingCtx = createCUDATracingContext();
64-
enableCUDATracing(adapter.TracingCtx);
65-
}
64+
static std::once_flag InitFlag;
65+
std::call_once(InitFlag,
66+
[=]() { ur::cuda::adapter = new ur_adapter_handle_t_; });
6667

67-
*phAdapters = &adapter;
68+
ur::cuda::adapter->RefCount++;
69+
*phAdapters = ur::cuda::adapter;
6870
}
6971

7072
if (pNumAdapters) {
@@ -75,17 +77,14 @@ urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters,
7577
}
7678

7779
UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) {
78-
adapter.RefCount++;
80+
ur::cuda::adapter->RefCount++;
7981

8082
return UR_RESULT_SUCCESS;
8183
}
8284

8385
UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) {
84-
std::lock_guard<std::mutex> Lock{adapter.Mutex};
85-
if (--adapter.RefCount == 0) {
86-
disableCUDATracing(adapter.TracingCtx);
87-
freeCUDATracingContext(adapter.TracingCtx);
88-
adapter.TracingCtx = nullptr;
86+
if (--ur::cuda::adapter->RefCount == 0) {
87+
delete ur::cuda::adapter;
8988
}
9089
return UR_RESULT_SUCCESS;
9190
}
@@ -108,7 +107,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t,
108107
case UR_ADAPTER_INFO_BACKEND:
109108
return ReturnValue(UR_ADAPTER_BACKEND_CUDA);
110109
case UR_ADAPTER_INFO_REFERENCE_COUNT:
111-
return ReturnValue(adapter.RefCount.load());
110+
return ReturnValue(ur::cuda::adapter->RefCount.load());
112111
case UR_ADAPTER_INFO_VERSION:
113112
return ReturnValue(uint32_t{1});
114113
default:

unified-runtime/source/adapters/cuda/adapter.hpp

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,30 @@
88
//
99
//===----------------------------------------------------------------------===//
1010

11-
struct ur_adapter_handle_t_;
11+
#ifndef UR_CUDA_ADAPTER_HPP_INCLUDED
12+
#define UR_CUDA_ADAPTER_HPP_INCLUDED
1213

13-
extern ur_adapter_handle_t_ adapter;
14+
#include "logger/ur_logger.hpp"
15+
#include "platform.hpp"
16+
#include "tracing.hpp"
17+
#include <ur_api.h>
18+
19+
#include <atomic>
20+
#include <memory>
21+
22+
struct ur_adapter_handle_t_ {
23+
std::atomic<uint32_t> RefCount = 0;
24+
struct cuda_tracing_context_t_ *TracingCtx = nullptr;
25+
logger::Logger &logger;
26+
std::unique_ptr<ur_platform_handle_t_> Platform;
27+
ur_adapter_handle_t_();
28+
~ur_adapter_handle_t_();
29+
ur_adapter_handle_t_(const ur_adapter_handle_t_ &) = delete;
30+
};
31+
32+
// Keep the global namespace'd
33+
namespace ur::cuda {
34+
extern ur_adapter_handle_t adapter;
35+
} // namespace ur::cuda
36+
37+
#endif // UR_CUDA_ADAPTER_HPP_INCLUDED

unified-runtime/source/adapters/cuda/context.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
//===----------------------------------------------------------------------===//
1010

1111
#include "context.hpp"
12+
#include "platform.hpp"
1213
#include "usm.hpp"
1314

1415
#include <cassert>

unified-runtime/source/adapters/cuda/context.hpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@
1010
#pragma once
1111

1212
#include <cuda.h>
13+
#include <memory>
1314
#include <ur_api.h>
1415

1516
#include <atomic>
1617
#include <mutex>
1718
#include <set>
1819
#include <vector>
1920

21+
#include "adapter.hpp"
2022
#include "common.hpp"
2123
#include "device.hpp"
2224
#include "umf_helpers.hpp"
@@ -127,15 +129,12 @@ struct ur_context_handle_t_ {
127129

128130
ur_context_handle_t_(const ur_device_handle_t *Devs, uint32_t NumDevices)
129131
: Devices{Devs, Devs + NumDevices}, RefCount{1} {
130-
for (auto &Dev : Devices) {
131-
urDeviceRetain(Dev);
132-
}
133-
134132
// Create UMF CUDA memory provider for the host memory
135133
// (UMF_MEMORY_TYPE_HOST) from any device (Devices[0] is used here, because
136134
// it is guaranteed to exist).
137135
UR_CHECK_ERROR(CreateHostMemoryProviderPool(Devices[0], &MemoryProviderHost,
138136
&MemoryPoolHost));
137+
UR_CHECK_ERROR(urAdapterRetain(ur::cuda::adapter));
139138
};
140139

141140
~ur_context_handle_t_() {
@@ -145,9 +144,7 @@ struct ur_context_handle_t_ {
145144
if (MemoryProviderHost) {
146145
umfMemoryProviderDestroy(MemoryProviderHost);
147146
}
148-
for (auto &Dev : Devices) {
149-
urDeviceRelease(Dev);
150-
}
147+
UR_CHECK_ERROR(urAdapterRelease(ur::cuda::adapter));
151148
}
152149

153150
void invokeExtendedDeleters() {

unified-runtime/source/adapters/cuda/device.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1284,7 +1284,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(
12841284

12851285
// Get list of platforms
12861286
uint32_t NumPlatforms = 0;
1287-
ur_adapter_handle_t AdapterHandle = &adapter;
1287+
ur_adapter_handle_t AdapterHandle = ur::cuda::adapter;
12881288
ur_result_t Result = urPlatformGet(AdapterHandle, 0, nullptr, &NumPlatforms);
12891289
if (Result != UR_RESULT_SUCCESS)
12901290
return Result;

unified-runtime/source/adapters/cuda/memory.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,9 @@ struct ur_mem_handle_t_ {
393393
urMemRelease(std::get<BufferMem>(Mem).Parent);
394394
return;
395395
}
396+
if (LastQueueWritingToMemObj != nullptr) {
397+
urQueueRelease(LastQueueWritingToMemObj);
398+
}
396399
urContextRelease(Context);
397400
}
398401

unified-runtime/source/adapters/cuda/platform.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformGetInfo(
100100
return ReturnValue(UR_PLATFORM_BACKEND_CUDA);
101101
}
102102
case UR_PLATFORM_INFO_ADAPTER: {
103-
return ReturnValue(&adapter);
103+
return ReturnValue(ur::cuda::adapter);
104104
}
105105
default:
106106
return UR_RESULT_ERROR_INVALID_ENUMERATION;
@@ -116,11 +116,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformGetInfo(
116116
UR_APIEXPORT ur_result_t UR_APICALL
117117
urPlatformGet(ur_adapter_handle_t, uint32_t NumEntries,
118118
ur_platform_handle_t *phPlatforms, uint32_t *pNumPlatforms) {
119-
120119
try {
121120
static std::once_flag InitFlag;
122121
static uint32_t NumPlatforms = 1;
123-
static ur_platform_handle_t_ Platform;
124122

125123
UR_ASSERT(phPlatforms || pNumPlatforms, UR_RESULT_ERROR_INVALID_VALUE);
126124
UR_ASSERT(!phPlatforms || NumEntries > 0, UR_RESULT_ERROR_INVALID_SIZE);
@@ -151,22 +149,24 @@ urPlatformGet(ur_adapter_handle_t, uint32_t NumEntries,
151149
// Use default stream to record base event counter
152150
UR_CHECK_ERROR(cuEventRecord(EvBase, 0));
153151

154-
Platform.Devices.emplace_back(
155-
new ur_device_handle_t_{Device, Context, EvBase, &Platform,
152+
ur::cuda::adapter->Platform->Devices.emplace_back(
153+
new ur_device_handle_t_{Device, Context, EvBase,
154+
ur::cuda::adapter->Platform.get(),
156155
static_cast<uint32_t>(i)});
157156
}
158157

159-
UR_CHECK_ERROR(CreateDeviceMemoryProvidersPools(&Platform));
158+
UR_CHECK_ERROR(CreateDeviceMemoryProvidersPools(
159+
ur::cuda::adapter->Platform.get()));
160160
} catch (const std::bad_alloc &) {
161161
// Signal out-of-memory situation
162162
for (int i = 0; i < NumDevices; ++i) {
163-
Platform.Devices.clear();
163+
ur::cuda::adapter->Platform->Devices.clear();
164164
}
165165
Result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
166166
} catch (ur_result_t Err) {
167167
// Clear and rethrow to allow retry
168168
for (int i = 0; i < NumDevices; ++i) {
169-
Platform.Devices.clear();
169+
ur::cuda::adapter->Platform->Devices.clear();
170170
}
171171
Result = Err;
172172
throw Err;
@@ -182,7 +182,7 @@ urPlatformGet(ur_adapter_handle_t, uint32_t NumEntries,
182182
}
183183

184184
if (phPlatforms != nullptr) {
185-
*phPlatforms = &Platform;
185+
*phPlatforms = ur::cuda::adapter->Platform.get();
186186
}
187187

188188
return Result;

unified-runtime/source/adapters/cuda/platform.hpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,18 @@
77
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
88
//
99
//===----------------------------------------------------------------------===//
10-
#pragma once
1110

11+
#ifndef UR_CUDA_PLATFORM_HPP_INCLUDED
12+
#define UR_CUDA_PLATFORM_HPP_INCLUDED
13+
14+
#include "device.hpp"
1215
#include <ur/ur.hpp>
16+
17+
#include <memory>
1318
#include <vector>
1419

1520
struct ur_platform_handle_t_ {
1621
std::vector<std::unique_ptr<ur_device_handle_t_>> Devices;
1722
};
23+
24+
#endif // UR_CUDA_PLATFORM_HPP_INCLUDED

unified-runtime/source/adapters/cuda/usm.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ urUSMGetMemAllocInfo(ur_context_handle_t hContext, const void *pMem,
242242

243243
// cuda backend has only one platform containing all devices
244244
ur_platform_handle_t platform;
245-
ur_adapter_handle_t AdapterHandle = &adapter;
245+
ur_adapter_handle_t AdapterHandle = ur::cuda::adapter;
246246
Result = urPlatformGet(AdapterHandle, 1, &platform, nullptr);
247247

248248
// get the device from the platform

0 commit comments

Comments
 (0)