Skip to content

Commit c84a81a

Browse files
committed
Added the profiling stream for CUDA backend
1 parent a1e9b22 commit c84a81a

File tree

8 files changed

+47
-19
lines changed

8 files changed

+47
-19
lines changed

source/adapters/cuda/event.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,7 @@ ur_result_t ur_event_handle_t_::start() {
5555

5656
try {
5757
if (Queue->URFlags & UR_QUEUE_FLAG_PROFILING_ENABLE || isTimestampEvent()) {
58-
// NOTE: This relies on the default stream to be unused.
59-
UR_CHECK_ERROR(cuEventRecord(EvQueued, 0));
58+
UR_CHECK_ERROR(cuEventRecord(EvQueued, Queue->getProfilingStream()));
6059
UR_CHECK_ERROR(cuEventRecord(EvStart, Stream));
6160
}
6261
} catch (ur_result_t Err) {

source/adapters/cuda/event.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ struct ur_event_handle_t_ {
9090
const bool RequiresTimings =
9191
Queue->URFlags & UR_QUEUE_FLAG_PROFILING_ENABLE ||
9292
Type == UR_COMMAND_TIMESTAMP_RECORDING_EXP;
93+
if (RequiresTimings) {
94+
Queue->createProfilingStream();
95+
}
9396
native_type EvEnd = nullptr, EvQueued = nullptr, EvStart = nullptr;
9497
UR_CHECK_ERROR(cuEventCreate(
9598
&EvEnd, RequiresTimings ? CU_EVENT_DEFAULT : CU_EVENT_DISABLE_TIMING));

source/adapters/cuda/queue.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueRelease(ur_queue_handle_t hQueue) {
201201
UR_CHECK_ERROR(cuStreamDestroy(S));
202202
});
203203

204+
if (hQueue->IsProfStreamCreated) {
205+
UR_CHECK_ERROR(cuStreamSynchronize(hQueue->getProfilingStream()));
206+
UR_CHECK_ERROR(cuStreamDestroy(hQueue->getProfilingStream()));
207+
}
208+
204209
return UR_RESULT_SUCCESS;
205210
} catch (ur_result_t Err) {
206211
return Err;

source/adapters/cuda/queue.hpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99
//===----------------------------------------------------------------------===//
1010
#pragma once
1111

12+
#include "common.hpp"
1213
#include <ur/ur.hpp>
1314

1415
#include <algorithm>
1516
#include <cuda.h>
17+
#include <mutex>
1618
#include <vector>
1719

1820
using ur_stream_guard_ = std::unique_lock<std::mutex>;
@@ -27,6 +29,9 @@ struct ur_queue_handle_t_ {
2729

2830
std::vector<native_type> ComputeStreams;
2931
std::vector<native_type> TransferStreams;
32+
// Stream used solely when profiling is enabled
33+
native_type ProfStream;
34+
bool IsProfStreamCreated{false};
3035
// delay_compute_ keeps track of which streams have been recently reused and
3136
// their next use should be delayed. If a stream has been recently reused it
3237
// will be skipped the next time it would be selected round-robin style. When
@@ -64,8 +69,8 @@ struct ur_queue_handle_t_ {
6469
ur_context_handle_t_ *Context, ur_device_handle_t_ *Device,
6570
unsigned int Flags, ur_queue_flags_t URFlags, int Priority,
6671
bool BackendOwns = true)
67-
: ComputeStreams{std::move(ComputeStreams)}, TransferStreams{std::move(
68-
TransferStreams)},
72+
: ComputeStreams{std::move(ComputeStreams)},
73+
TransferStreams{std::move(TransferStreams)},
6974
DelayCompute(this->ComputeStreams.size(), false),
7075
ComputeAppliedBarrier(this->ComputeStreams.size()),
7176
TransferAppliedBarrier(this->TransferStreams.size()), Context{Context},
@@ -99,6 +104,19 @@ struct ur_queue_handle_t_ {
99104
native_type get() { return getNextComputeStream(); };
100105
ur_device_handle_t getDevice() const noexcept { return Device; };
101106

107+
// Function which creates the profiling stream. Called only if profiling is
108+
// enabled.
109+
void createProfilingStream() {
110+
static std::once_flag ProfStreamFlag;
111+
std::call_once(ProfStreamFlag, [&]() {
112+
UR_CHECK_ERROR(
113+
cuStreamCreateWithPriority(&ProfStream, CU_STREAM_NON_BLOCKING, 0));
114+
IsProfStreamCreated = true;
115+
});
116+
}
117+
118+
native_type getProfilingStream() { return ProfStream; }
119+
102120
bool hasBeenSynchronized(uint32_t StreamToken) {
103121
// stream token not associated with one of the compute streams
104122
if (StreamToken == std::numeric_limits<uint32_t>::max()) {

source/adapters/hip/event.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ ur_event_handle_t_::ur_event_handle_t_(ur_command_t Type,
2121
uint32_t StreamToken)
2222
: CommandType{Type}, RefCount{1}, HasOwnership{true},
2323
HasBeenWaitedOn{false}, IsRecorded{false}, IsStarted{false},
24-
StreamToken{StreamToken}, EventId{0}, EvEnd{EvEnd}, EvQueued{EvQueued},
25-
EvStart{EvStart}, Queue{Queue}, Stream{Stream}, Context{Context} {
24+
StreamToken{StreamToken}, EventId{0}, EvEnd{EvEnd}, EvStart{EvStart},
25+
EvQueued{EvQueued}, Queue{Queue}, Stream{Stream}, Context{Context} {
2626
urQueueRetain(Queue);
2727
urContextRetain(Context);
2828
}
@@ -50,7 +50,6 @@ ur_result_t ur_event_handle_t_::start() {
5050

5151
try {
5252
if (Queue->URFlags & UR_QUEUE_FLAG_PROFILING_ENABLE || isTimestampEvent()) {
53-
// NOTE: This relies on the default stream to be unused.
5453
UR_CHECK_ERROR(hipEventRecord(EvQueued, Queue->getProfilingStream()));
5554
UR_CHECK_ERROR(hipEventRecord(EvStart, Stream));
5655
}

source/adapters/hip/event.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,9 @@ struct ur_event_handle_t_ {
112112
// This constructor is private to force programmers to use the makeNative /
113113
// make_user static members in order to create a ur_event_handle_t for HIP.
114114
ur_event_handle_t_(ur_command_t Type, ur_context_handle_t Context,
115-
ur_queue_handle_t Queue, hipStream_t Stream,
116-
uint32_t StreamToken);
115+
ur_queue_handle_t Queue, native_type EvEnd,
116+
native_type EvQueued, native_type EvStart,
117+
hipStream_t Stream, uint32_t StreamToken);
117118

118119
// This constructor is private to force programmers to use the
119120
// makeWithNative for event interop

source/adapters/hip/queue.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -222,9 +222,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueRelease(ur_queue_handle_t hQueue) {
222222
UR_CHECK_ERROR(hipStreamDestroy(S));
223223
});
224224

225-
if (hQueue->ProfStreamCreated) {
226-
UR_CHECK_ERROR(hipStreamSynchronize(ProfStream));
227-
UR_CHECK_ERROR(hipStreamDestroy(ProfStream));
225+
if (hQueue->IsProfStreamCreated) {
226+
UR_CHECK_ERROR(hipStreamSynchronize(hQueue->getProfilingStream()));
227+
UR_CHECK_ERROR(hipStreamDestroy(hQueue->getProfilingStream()));
228228
}
229229

230230
return UR_RESULT_SUCCESS;

source/adapters/hip/queue.hpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
#pragma once
1111

1212
#include "common.hpp"
13+
#include <hip/hip_runtime.h>
14+
#include <mutex>
15+
#include <vector>
1316

1417
using ur_stream_quard = std::unique_lock<std::mutex>;
1518

@@ -24,8 +27,7 @@ struct ur_queue_handle_t_ {
2427
std::vector<native_type> TransferStreams;
2528
// Stream used solely when profiling is enabled
2629
native_type ProfStream;
27-
static std::once_flag ProfStreamFlag;
28-
bool ProfStreamCreated;
30+
bool IsProfStreamCreated{false};
2931
// DelayCompute keeps track of which streams have been recently reused and
3032
// their next use should be delayed. If a stream has been recently reused it
3133
// will be skipped the next time it would be selected round-robin style. When
@@ -99,13 +101,14 @@ struct ur_queue_handle_t_ {
99101
native_type getNextTransferStream();
100102
native_type get() { return getNextComputeStream(); };
101103

102-
// Function which creates the profiling stream. Called only from makeNative
103-
// in event handle, if the profiling is enabled.
104+
// Function which creates the profiling stream. Called only if profiling is
105+
// enabled.
104106
void createProfilingStream() {
105-
std::call_once(ProfStreamFlag, []() {
107+
static std::once_flag ProfStreamFlag;
108+
std::call_once(ProfStreamFlag, [&]() {
106109
UR_CHECK_ERROR(
107-
hipEventCreateWithFlags(&ProfStream, hipStreamNonBlocking));
108-
ProfStreamCreated = true;
110+
hipStreamCreateWithFlags(&ProfStream, hipStreamNonBlocking));
111+
IsProfStreamCreated = true;
109112
});
110113
}
111114
native_type getProfilingStream() { return ProfStream; }

0 commit comments

Comments
 (0)