Skip to content

Commit 920a968

Browse files
committed
Merge branch 'main' into sanitizer-pr-cpu-local
2 parents 255f7fa + 9b97a5f commit 920a968

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+719
-215
lines changed

scripts/generate_docs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def generate_html(dstpath):
262262
if result.returncode != 0:
263263
print("sphinx-build returned non-zero error code.")
264264
print("--- output ---")
265-
print(result.stderr.read().decode())
265+
print(result.stderr.decode())
266266
raise Exception("Failed to generate html documentation.")
267267

268268
"""
@@ -277,7 +277,7 @@ def generate_pdf(dstpath):
277277
if result.returncode != 0:
278278
print("sphinx-build returned non-zero error code.")
279279
print("--- output ---")
280-
print(result.stderr.read().decode())
280+
print(result.stderr.decode())
281281
raise Exception("Failed to generate pdf documentation.")
282282

283283
"""

source/adapters/cuda/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ else()
7676
message(WARNING "CUDA adapter USM pools are disabled, set UMF_ENABLE_POOL_TRACKING to enable them")
7777
endif()
7878

79+
if (CUDA_cupti_LIBRARY)
80+
target_compile_definitions("ur_adapter_cuda" PRIVATE CUPTI_LIB_PATH="${CUDA_cupti_LIBRARY}")
81+
endif()
82+
7983
target_link_libraries(${TARGET_NAME} PRIVATE
8084
${PROJECT_NAME}::headers
8185
${PROJECT_NAME}::common

source/adapters/cuda/adapter.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,12 @@
1111
#include <ur_api.h>
1212

1313
#include "common.hpp"
14-
15-
void enableCUDATracing();
16-
void disableCUDATracing();
14+
#include "tracing.hpp"
1715

1816
struct ur_adapter_handle_t_ {
1917
std::atomic<uint32_t> RefCount = 0;
2018
std::mutex Mutex;
19+
struct cuda_tracing_context_t_ *TracingCtx = nullptr;
2120
};
2221

2322
ur_adapter_handle_t_ adapter{};
@@ -28,7 +27,8 @@ urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters,
2827
if (NumEntries > 0 && phAdapters) {
2928
std::lock_guard<std::mutex> Lock{adapter.Mutex};
3029
if (adapter.RefCount++ == 0) {
31-
enableCUDATracing();
30+
adapter.TracingCtx = createCUDATracingContext();
31+
enableCUDATracing(adapter.TracingCtx);
3232
}
3333

3434
*phAdapters = &adapter;
@@ -50,7 +50,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) {
5050
UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) {
5151
std::lock_guard<std::mutex> Lock{adapter.Mutex};
5252
if (--adapter.RefCount == 0) {
53-
disableCUDATracing();
53+
disableCUDATracing(adapter.TracingCtx);
54+
freeCUDATracingContext(adapter.TracingCtx);
55+
adapter.TracingCtx = nullptr;
5456
}
5557
return UR_RESULT_SUCCESS;
5658
}

source/adapters/cuda/tracing.cpp

Lines changed: 150 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,27 +18,77 @@
1818
#include <cupti.h>
1919
#endif // XPTI_ENABLE_INSTRUMENTATION
2020

21+
#include "tracing.hpp"
22+
#include "ur_lib_loader.hpp"
2123
#include <exception>
2224
#include <iostream>
2325

26+
#ifdef XPTI_ENABLE_INSTRUMENTATION
27+
using tracing_event_t = xpti_td *;
28+
using subscriber_handle_t = CUpti_SubscriberHandle;
29+
30+
using cuptiSubscribe_fn = CUPTIAPI
31+
CUptiResult (*)(CUpti_SubscriberHandle *subscriber, CUpti_CallbackFunc callback,
32+
void *userdata);
33+
34+
using cuptiUnsubscribe_fn = CUPTIAPI
35+
CUptiResult (*)(CUpti_SubscriberHandle subscriber);
36+
37+
using cuptiEnableDomain_fn = CUPTIAPI
38+
CUptiResult (*)(uint32_t enable, CUpti_SubscriberHandle subscriber,
39+
CUpti_CallbackDomain domain);
40+
41+
using cuptiEnableCallback_fn = CUPTIAPI
42+
CUptiResult (*)(uint32_t enable, CUpti_SubscriberHandle subscriber,
43+
CUpti_CallbackDomain domain, CUpti_CallbackId cbid);
44+
45+
#define LOAD_CUPTI_SYM(p, lib, x) \
46+
p.x = (cupti##x##_fn)ur_loader::LibLoader::getFunctionPtr(lib.get(), \
47+
"cupti" #x);
48+
49+
#else
50+
using tracing_event_t = void *;
51+
using subscriber_handle_t = void *;
52+
using cuptiSubscribe_fn = void *;
53+
using cuptiUnsubscribe_fn = void *;
54+
using cuptiEnableDomain_fn = void *;
55+
using cuptiEnableCallback_fn = void *;
56+
#endif // XPTI_ENABLE_INSTRUMENTATION
57+
58+
struct cupti_table_t_ {
59+
cuptiSubscribe_fn Subscribe = nullptr;
60+
cuptiUnsubscribe_fn Unsubscribe = nullptr;
61+
cuptiEnableDomain_fn EnableDomain = nullptr;
62+
cuptiEnableCallback_fn EnableCallback = nullptr;
63+
64+
bool isInitialized() const;
65+
};
66+
67+
struct cuda_tracing_context_t_ {
68+
tracing_event_t CallEvent = nullptr;
69+
tracing_event_t DebugEvent = nullptr;
70+
subscriber_handle_t Subscriber = nullptr;
71+
ur_loader::LibLoader::Lib Library;
72+
cupti_table_t_ Cupti;
73+
};
74+
2475
#ifdef XPTI_ENABLE_INSTRUMENTATION
2576
constexpr auto CUDA_CALL_STREAM_NAME = "sycl.experimental.cuda.call";
2677
constexpr auto CUDA_DEBUG_STREAM_NAME = "sycl.experimental.cuda.debug";
2778

2879
thread_local uint64_t CallCorrelationID = 0;
2980
thread_local uint64_t DebugCorrelationID = 0;
3081

31-
static xpti_td *GCallEvent = nullptr;
32-
static xpti_td *GDebugEvent = nullptr;
33-
3482
constexpr auto GVerStr = "0.1";
3583
constexpr int GMajVer = 0;
3684
constexpr int GMinVer = 1;
3785

38-
static void cuptiCallback(void *, CUpti_CallbackDomain, CUpti_CallbackId CBID,
39-
const void *CBData) {
86+
static void cuptiCallback(void *UserData, CUpti_CallbackDomain,
87+
CUpti_CallbackId CBID, const void *CBData) {
4088
if (xptiTraceEnabled()) {
4189
const auto *CBInfo = static_cast<const CUpti_CallbackData *>(CBData);
90+
cuda_tracing_context_t_ *Ctx =
91+
static_cast<cuda_tracing_context_t_ *>(UserData);
4292

4393
if (CBInfo->callbackSite == CUPTI_API_ENTER) {
4494
CallCorrelationID = xptiGetUniqueId();
@@ -57,22 +107,95 @@ static void cuptiCallback(void *, CUpti_CallbackDomain, CUpti_CallbackId CBID,
57107
uint8_t CallStreamID = xptiRegisterStream(CUDA_CALL_STREAM_NAME);
58108
uint8_t DebugStreamID = xptiRegisterStream(CUDA_DEBUG_STREAM_NAME);
59109

60-
xptiNotifySubscribers(CallStreamID, TraceType, GCallEvent, nullptr,
110+
xptiNotifySubscribers(CallStreamID, TraceType, Ctx->CallEvent, nullptr,
61111
CallCorrelationID, FuncName);
62112

63113
xpti::function_with_args_t Payload{
64114
FuncID, FuncName, const_cast<void *>(CBInfo->functionParams),
65115
CBInfo->functionReturnValue, CBInfo->context};
66-
xptiNotifySubscribers(DebugStreamID, TraceTypeArgs, GDebugEvent, nullptr,
67-
DebugCorrelationID, &Payload);
116+
xptiNotifySubscribers(DebugStreamID, TraceTypeArgs, Ctx->DebugEvent,
117+
nullptr, DebugCorrelationID, &Payload);
68118
}
69119
}
70120
#endif
71121

122+
cuda_tracing_context_t_ *createCUDATracingContext() {
123+
#ifdef XPTI_ENABLE_INSTRUMENTATION
124+
if (!xptiTraceEnabled())
125+
return nullptr;
126+
return new cuda_tracing_context_t_;
127+
#else
128+
return nullptr;
129+
#endif // XPTI_ENABLE_INSTRUMENTATION
130+
}
131+
132+
void freeCUDATracingContext(cuda_tracing_context_t_ *Ctx) {
133+
#ifdef XPTI_ENABLE_INSTRUMENTATION
134+
unloadCUDATracingLibrary(Ctx);
135+
delete Ctx;
136+
#else
137+
(void)Ctx;
138+
#endif // XPTI_ENABLE_INSTRUMENTATION
139+
}
140+
141+
bool cupti_table_t_::isInitialized() const {
142+
return Subscribe && Unsubscribe && EnableDomain && EnableCallback;
143+
}
144+
145+
bool loadCUDATracingLibrary(cuda_tracing_context_t_ *Ctx) {
146+
#if defined(XPTI_ENABLE_INSTRUMENTATION) && defined(CUPTI_LIB_PATH)
147+
if (!Ctx)
148+
return false;
149+
if (Ctx->Library)
150+
return true;
151+
auto Lib{ur_loader::LibLoader::loadAdapterLibrary(CUPTI_LIB_PATH)};
152+
if (!Lib)
153+
return false;
154+
cupti_table_t_ Table;
155+
LOAD_CUPTI_SYM(Table, Lib, Subscribe)
156+
LOAD_CUPTI_SYM(Table, Lib, Unsubscribe)
157+
LOAD_CUPTI_SYM(Table, Lib, EnableDomain)
158+
LOAD_CUPTI_SYM(Table, Lib, EnableCallback)
159+
if (!Table.isInitialized()) {
160+
return false;
161+
}
162+
Ctx->Library = std::move(Lib);
163+
Ctx->Cupti = Table;
164+
return true;
165+
#else
166+
(void)Ctx;
167+
return false;
168+
#endif // XPTI_ENABLE_INSTRUMENTATION && CUPTI_LIB_PATH
169+
}
170+
171+
void unloadCUDATracingLibrary(cuda_tracing_context_t_ *Ctx) {
172+
#ifdef XPTI_ENABLE_INSTRUMENTATION
173+
if (!Ctx)
174+
return;
175+
Ctx->Library.reset();
176+
Ctx->Cupti = cupti_table_t_();
177+
#else
178+
(void)Ctx;
179+
#endif // XPTI_ENABLE_INSTRUMENTATION
180+
}
181+
72182
void enableCUDATracing() {
73183
#ifdef XPTI_ENABLE_INSTRUMENTATION
74184
if (!xptiTraceEnabled())
75185
return;
186+
static cuda_tracing_context_t_ *Ctx = nullptr;
187+
if (!Ctx)
188+
Ctx = createCUDATracingContext();
189+
enableCUDATracing(Ctx);
190+
#endif
191+
}
192+
193+
void enableCUDATracing(cuda_tracing_context_t_ *Ctx) {
194+
#ifdef XPTI_ENABLE_INSTRUMENTATION
195+
if (!Ctx || !xptiTraceEnabled())
196+
return;
197+
else if (!loadCUDATracingLibrary(Ctx))
198+
return;
76199

77200
xptiRegisterStream(CUDA_CALL_STREAM_NAME);
78201
xptiInitialize(CUDA_CALL_STREAM_NAME, GMajVer, GMinVer, GVerStr);
@@ -81,31 +204,39 @@ void enableCUDATracing() {
81204

82205
uint64_t Dummy;
83206
xpti::payload_t CUDAPayload("CUDA Plugin Layer");
84-
GCallEvent =
207+
Ctx->CallEvent =
85208
xptiMakeEvent("CUDA Plugin Layer", &CUDAPayload,
86209
xpti::trace_algorithm_event, xpti_at::active, &Dummy);
87210

88211
xpti::payload_t CUDADebugPayload("CUDA Plugin Debug Layer");
89-
GDebugEvent =
212+
Ctx->DebugEvent =
90213
xptiMakeEvent("CUDA Plugin Debug Layer", &CUDADebugPayload,
91214
xpti::trace_algorithm_event, xpti_at::active, &Dummy);
92215

93-
CUpti_SubscriberHandle Subscriber;
94-
cuptiSubscribe(&Subscriber, cuptiCallback, nullptr);
95-
cuptiEnableDomain(1, Subscriber, CUPTI_CB_DOMAIN_DRIVER_API);
96-
cuptiEnableCallback(0, Subscriber, CUPTI_CB_DOMAIN_DRIVER_API,
97-
CUPTI_DRIVER_TRACE_CBID_cuGetErrorString);
98-
cuptiEnableCallback(0, Subscriber, CUPTI_CB_DOMAIN_DRIVER_API,
99-
CUPTI_DRIVER_TRACE_CBID_cuGetErrorName);
216+
Ctx->Cupti.Subscribe(&Ctx->Subscriber, cuptiCallback, Ctx);
217+
Ctx->Cupti.EnableDomain(1, Ctx->Subscriber, CUPTI_CB_DOMAIN_DRIVER_API);
218+
Ctx->Cupti.EnableCallback(0, Ctx->Subscriber, CUPTI_CB_DOMAIN_DRIVER_API,
219+
CUPTI_DRIVER_TRACE_CBID_cuGetErrorString);
220+
Ctx->Cupti.EnableCallback(0, Ctx->Subscriber, CUPTI_CB_DOMAIN_DRIVER_API,
221+
CUPTI_DRIVER_TRACE_CBID_cuGetErrorName);
222+
#else
223+
(void)Ctx;
100224
#endif
101225
}
102226

103-
void disableCUDATracing() {
227+
void disableCUDATracing(cuda_tracing_context_t_ *Ctx) {
104228
#ifdef XPTI_ENABLE_INSTRUMENTATION
105-
if (!xptiTraceEnabled())
229+
if (!Ctx || !xptiTraceEnabled())
106230
return;
107231

232+
if (Ctx->Subscriber && Ctx->Cupti.isInitialized()) {
233+
Ctx->Cupti.Unsubscribe(Ctx->Subscriber);
234+
Ctx->Subscriber = nullptr;
235+
}
236+
108237
xptiFinalize(CUDA_CALL_STREAM_NAME);
109238
xptiFinalize(CUDA_DEBUG_STREAM_NAME);
239+
#else
240+
(void)Ctx;
110241
#endif // XPTI_ENABLE_INSTRUMENTATION
111242
}

source/adapters/cuda/tracing.hpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
//===--------- tracing.hpp - CUDA Host API Tracing -------------------------==//
2+
//
3+
// Copyright (C) 2023 Intel Corporation
4+
//
5+
// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM
6+
// Exceptions. See LICENSE.TXT
7+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8+
//
9+
//===----------------------------------------------------------------------===//
10+
11+
struct cuda_tracing_context_t_;
12+
13+
cuda_tracing_context_t_ *createCUDATracingContext();
14+
void freeCUDATracingContext(cuda_tracing_context_t_ *Ctx);
15+
16+
bool loadCUDATracingLibrary(cuda_tracing_context_t_ *Ctx);
17+
void unloadCUDATracingLibrary(cuda_tracing_context_t_ *Ctx);
18+
19+
void enableCUDATracing(cuda_tracing_context_t_ *Ctx);
20+
void disableCUDATracing(cuda_tracing_context_t_ *Ctx);
21+
22+
// Deprecated. Will be removed once pi_cuda has been updated to use the variant
23+
// that takes a context pointer.
24+
void enableCUDATracing();

0 commit comments

Comments
 (0)