Skip to content

Commit 26208a5

Browse files
authored
[UR] Use templates rather than virtual for stream_queue_t (#17858)
In a future change, UR is going to restrict handle types from having vtables. This patch changes the queue base type shared by CUDA and HIP to a template where each backend specialises its own backend-specific queue type.
1 parent 507001c commit 26208a5

File tree

5 files changed

+83
-95
lines changed

5 files changed

+83
-95
lines changed

unified-runtime/source/adapters/cuda/queue.cpp

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,32 +16,37 @@
1616
#include <cassert>
1717
#include <cuda.h>
1818

19-
void ur_queue_handle_t_::computeStreamWaitForBarrierIfNeeded(CUstream Stream,
20-
uint32_t StreamI) {
19+
template <>
20+
void cuda_stream_queue::computeStreamWaitForBarrierIfNeeded(CUstream Stream,
21+
uint32_t StreamI) {
2122
if (BarrierEvent && !ComputeAppliedBarrier[StreamI]) {
2223
UR_CHECK_ERROR(cuStreamWaitEvent(Stream, BarrierEvent, 0));
2324
ComputeAppliedBarrier[StreamI] = true;
2425
}
2526
}
2627

27-
void ur_queue_handle_t_::transferStreamWaitForBarrierIfNeeded(
28-
CUstream Stream, uint32_t StreamI) {
28+
template <>
29+
void cuda_stream_queue::transferStreamWaitForBarrierIfNeeded(CUstream Stream,
30+
uint32_t StreamI) {
2931
if (BarrierEvent && !TransferAppliedBarrier[StreamI]) {
3032
UR_CHECK_ERROR(cuStreamWaitEvent(Stream, BarrierEvent, 0));
3133
TransferAppliedBarrier[StreamI] = true;
3234
}
3335
}
3436

35-
ur_queue_handle_t ur_queue_handle_t_::getEventQueue(const ur_event_handle_t e) {
37+
template <>
38+
ur_queue_handle_t cuda_stream_queue::getEventQueue(const ur_event_handle_t e) {
3639
return e->getQueue();
3740
}
3841

42+
template <>
3943
uint32_t
40-
ur_queue_handle_t_::getEventComputeStreamToken(const ur_event_handle_t e) {
44+
cuda_stream_queue::getEventComputeStreamToken(const ur_event_handle_t e) {
4145
return e->getComputeStreamToken();
4246
}
4347

44-
CUstream ur_queue_handle_t_::getEventStream(const ur_event_handle_t e) {
48+
template <>
49+
CUstream cuda_stream_queue::getEventStream(const ur_event_handle_t e) {
4550
return e->getStream();
4651
}
4752

@@ -87,7 +92,7 @@ urQueueCreate(ur_context_handle_t hContext, ur_device_handle_t hDevice,
8792
}
8893

8994
Queue = std::unique_ptr<ur_queue_handle_t_>(new ur_queue_handle_t_{
90-
IsOutOfOrder, hContext, hDevice, Flags, URFlags, Priority});
95+
{IsOutOfOrder, hContext, hDevice, Flags, URFlags, Priority}});
9196

9297
*phQueue = Queue.release();
9398

@@ -203,8 +208,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueCreateWithNativeHandle(
203208
pProperties ? pProperties->isNativeHandleOwned : false;
204209

205210
// Create queue from a native stream
206-
*phQueue = new ur_queue_handle_t_{CuStream, hContext, hDevice,
207-
CuFlags, Flags, isNativeHandleOwned};
211+
*phQueue = new ur_queue_handle_t_{
212+
{CuStream, hContext, hDevice, CuFlags, Flags, isNativeHandleOwned}};
208213

209214
return UR_RESULT_SUCCESS;
210215
}

unified-runtime/source/adapters/cuda/queue.hpp

Lines changed: 19 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -19,38 +19,25 @@
1919

2020
#include <common/cuda-hip/stream_queue.hpp>
2121

22-
/// UR queue mapping on to CUstream objects.
23-
///
24-
struct ur_queue_handle_t_ : stream_queue_t<CUstream, 128, 64> {
25-
using stream_queue_t<CUstream, DefaultNumComputeStreams,
26-
DefaultNumTransferStreams>::stream_queue_t;
27-
28-
CUevent BarrierEvent = nullptr;
29-
CUevent BarrierTmpEvent = nullptr;
30-
31-
void computeStreamWaitForBarrierIfNeeded(CUstream Strean,
32-
uint32_t StreamI) override;
33-
void transferStreamWaitForBarrierIfNeeded(CUstream Stream,
34-
uint32_t StreamI) override;
35-
ur_queue_handle_t getEventQueue(const ur_event_handle_t) override;
36-
uint32_t getEventComputeStreamToken(const ur_event_handle_t) override;
37-
CUstream getEventStream(const ur_event_handle_t) override;
38-
39-
// Function which creates the profiling stream. Called only from makeNative
40-
// event when profiling is required.
41-
void createHostSubmitTimeStream() {
42-
static std::once_flag HostSubmitTimeStreamFlag;
43-
std::call_once(HostSubmitTimeStreamFlag, [&]() {
44-
UR_CHECK_ERROR(cuStreamCreateWithPriority(&HostSubmitTimeStream,
45-
CU_STREAM_NON_BLOCKING, 0));
46-
});
47-
}
48-
49-
void createStreamWithPriority(CUstream *Stream, unsigned int Flags,
50-
int Priority) override {
51-
UR_CHECK_ERROR(cuStreamCreateWithPriority(Stream, Flags, Priority));
52-
}
53-
};
22+
using cuda_stream_queue = stream_queue_t<CUstream, 128, 64, CUevent>;
23+
struct ur_queue_handle_t_ : public cuda_stream_queue {};
24+
25+
// Function which creates the profiling stream. Called only from makeNative
26+
// event when profiling is required.
27+
template <> inline void cuda_stream_queue::createHostSubmitTimeStream() {
28+
static std::once_flag HostSubmitTimeStreamFlag;
29+
std::call_once(HostSubmitTimeStreamFlag, [&]() {
30+
UR_CHECK_ERROR(cuStreamCreateWithPriority(&HostSubmitTimeStream,
31+
CU_STREAM_NON_BLOCKING, 0));
32+
});
33+
}
34+
35+
template <>
36+
inline void cuda_stream_queue::createStreamWithPriority(CUstream *Stream,
37+
unsigned int Flags,
38+
int Priority) {
39+
UR_CHECK_ERROR(cuStreamCreateWithPriority(Stream, Flags, Priority));
40+
}
5441

5542
// RAII object to make hQueue stream getter methods all return the same stream
5643
// within the lifetime of this object.

unified-runtime/source/adapters/hip/queue.cpp

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,32 +12,37 @@
1212
#include "context.hpp"
1313
#include "event.hpp"
1414

15-
void ur_queue_handle_t_::computeStreamWaitForBarrierIfNeeded(
16-
hipStream_t Stream, uint32_t Stream_i) {
15+
template <>
16+
void hip_stream_queue::computeStreamWaitForBarrierIfNeeded(hipStream_t Stream,
17+
uint32_t Stream_i) {
1718
if (BarrierEvent && !ComputeAppliedBarrier[Stream_i]) {
1819
UR_CHECK_ERROR(hipStreamWaitEvent(Stream, BarrierEvent, 0));
1920
ComputeAppliedBarrier[Stream_i] = true;
2021
}
2122
}
2223

23-
void ur_queue_handle_t_::transferStreamWaitForBarrierIfNeeded(
24-
hipStream_t Stream, uint32_t Stream_i) {
24+
template <>
25+
void hip_stream_queue::transferStreamWaitForBarrierIfNeeded(hipStream_t Stream,
26+
uint32_t Stream_i) {
2527
if (BarrierEvent && !TransferAppliedBarrier[Stream_i]) {
2628
UR_CHECK_ERROR(hipStreamWaitEvent(Stream, BarrierEvent, 0));
2729
TransferAppliedBarrier[Stream_i] = true;
2830
}
2931
}
3032

31-
ur_queue_handle_t ur_queue_handle_t_::getEventQueue(const ur_event_handle_t e) {
33+
template <>
34+
ur_queue_handle_t hip_stream_queue::getEventQueue(const ur_event_handle_t e) {
3235
return e->getQueue();
3336
}
3437

38+
template <>
3539
uint32_t
36-
ur_queue_handle_t_::getEventComputeStreamToken(const ur_event_handle_t e) {
40+
hip_stream_queue::getEventComputeStreamToken(const ur_event_handle_t e) {
3741
return e->getComputeStreamToken();
3842
}
3943

40-
hipStream_t ur_queue_handle_t_::getEventStream(const ur_event_handle_t e) {
44+
template <>
45+
hipStream_t hip_stream_queue::getEventStream(const ur_event_handle_t e) {
4146
return e->getStream();
4247
}
4348

@@ -76,7 +81,7 @@ urQueueCreate(ur_context_handle_t hContext, ur_device_handle_t hDevice,
7681
: false;
7782

7883
QueueImpl = std::unique_ptr<ur_queue_handle_t_>(new ur_queue_handle_t_{
79-
IsOutOfOrder, hContext, hDevice, Flags, URFlags, Priority});
84+
{IsOutOfOrder, hContext, hDevice, Flags, URFlags, Priority}});
8085

8186
*phQueue = QueueImpl.release();
8287

@@ -238,8 +243,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueCreateWithNativeHandle(
238243

239244
// Create queue and set num_compute_streams to 1, as computeHIPStreams has
240245
// valid stream
241-
*phQueue = new ur_queue_handle_t_{HIPStream, hContext, hDevice,
242-
HIPFlags, Flags, isNativeHandleOwned};
246+
*phQueue = new ur_queue_handle_t_{
247+
{HIPStream, hContext, hDevice, HIPFlags, Flags, isNativeHandleOwned}};
243248

244249
return UR_RESULT_SUCCESS;
245250
}

unified-runtime/source/adapters/hip/queue.hpp

Lines changed: 19 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -16,38 +16,25 @@
1616

1717
#include <common/cuda-hip/stream_queue.hpp>
1818

19-
/// UR queue mapping on to hipStream_t objects.
20-
///
21-
struct ur_queue_handle_t_ : stream_queue_t<hipStream_t, 64, 16> {
22-
using stream_queue_t<hipStream_t, DefaultNumComputeStreams,
23-
DefaultNumTransferStreams>::stream_queue_t;
24-
25-
hipEvent_t BarrierEvent = nullptr;
26-
hipEvent_t BarrierTmpEvent = nullptr;
27-
28-
void computeStreamWaitForBarrierIfNeeded(hipStream_t Strean,
29-
uint32_t StreamI) override;
30-
void transferStreamWaitForBarrierIfNeeded(hipStream_t Stream,
31-
uint32_t StreamI) override;
32-
ur_queue_handle_t getEventQueue(const ur_event_handle_t) override;
33-
uint32_t getEventComputeStreamToken(const ur_event_handle_t) override;
34-
hipStream_t getEventStream(const ur_event_handle_t) override;
35-
36-
// Function which creates the profiling stream. Called only from makeNative
37-
// event when profiling is required.
38-
void createHostSubmitTimeStream() {
39-
static std::once_flag HostSubmitTimeStreamFlag;
40-
std::call_once(HostSubmitTimeStreamFlag, [&]() {
41-
UR_CHECK_ERROR(hipStreamCreateWithFlags(&HostSubmitTimeStream,
42-
hipStreamNonBlocking));
43-
});
44-
}
45-
46-
void createStreamWithPriority(hipStream_t *Stream, unsigned int Flags,
47-
int Priority) override {
48-
UR_CHECK_ERROR(hipStreamCreateWithPriority(Stream, Flags, Priority));
49-
}
50-
};
19+
using hip_stream_queue = stream_queue_t<hipStream_t, 64, 16, hipEvent_t>;
20+
struct ur_queue_handle_t_ : public hip_stream_queue {};
21+
22+
template <>
23+
inline void hip_stream_queue::createStreamWithPriority(hipStream_t *Stream,
24+
unsigned int Flags,
25+
int Priority) {
26+
UR_CHECK_ERROR(hipStreamCreateWithPriority(Stream, Flags, Priority));
27+
}
28+
29+
// Function which creates the profiling stream. Called only from makeNative
30+
// event when profiling is required.
31+
template <> inline void hip_stream_queue::createHostSubmitTimeStream() {
32+
static std::once_flag HostSubmitTimeStreamFlag;
33+
std::call_once(HostSubmitTimeStreamFlag, [&]() {
34+
UR_CHECK_ERROR(
35+
hipStreamCreateWithFlags(&HostSubmitTimeStream, hipStreamNonBlocking));
36+
});
37+
}
5138

5239
// RAII object to make hQueue stream getter methods all return the same stream
5340
// within the lifetime of this object.

unified-runtime/source/common/cuda-hip/stream_queue.hpp

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ using ur_stream_guard = std::unique_lock<std::mutex>;
2121
/// backend 'stream' objects.
2222
///
2323
/// This class is specifically designed for the CUDA and HIP adapters.
24-
template <typename ST, int CS, int TS> struct stream_queue_t {
24+
template <typename ST, int CS, int TS, typename BarrierEventT>
25+
struct stream_queue_t {
2526
using native_type = ST;
2627
static constexpr int DefaultNumComputeStreams = CS;
2728
static constexpr int DefaultNumTransferStreams = TS;
@@ -61,6 +62,8 @@ template <typename ST, int CS, int TS> struct stream_queue_t {
6162
std::mutex TransferStreamMutex;
6263
std::mutex BarrierMutex;
6364
bool HasOwnership;
65+
BarrierEventT BarrierEvent = nullptr;
66+
BarrierEventT BarrierTmpEvent = nullptr;
6467

6568
stream_queue_t(bool IsOutOfOrder, ur_context_handle_t_ *Context,
6669
ur_device_handle_t_ *Device, unsigned int Flags,
@@ -88,17 +91,18 @@ template <typename ST, int CS, int TS> struct stream_queue_t {
8891
urContextRetain(Context);
8992
}
9093

91-
virtual ~stream_queue_t() { urContextRelease(Context); }
94+
~stream_queue_t() { urContextRelease(Context); }
9295

93-
virtual void computeStreamWaitForBarrierIfNeeded(native_type Strean,
94-
uint32_t StreamI) = 0;
95-
virtual void transferStreamWaitForBarrierIfNeeded(native_type Stream,
96-
uint32_t StreamI) = 0;
97-
virtual void createStreamWithPriority(native_type *Stream, unsigned int Flags,
98-
int Priority) = 0;
99-
virtual ur_queue_handle_t getEventQueue(const ur_event_handle_t) = 0;
100-
virtual uint32_t getEventComputeStreamToken(const ur_event_handle_t) = 0;
101-
virtual native_type getEventStream(const ur_event_handle_t) = 0;
96+
void computeStreamWaitForBarrierIfNeeded(native_type Strean,
97+
uint32_t StreamI);
98+
void transferStreamWaitForBarrierIfNeeded(native_type Stream,
99+
uint32_t StreamI);
100+
void createStreamWithPriority(native_type *Stream, unsigned int Flags,
101+
int Priority);
102+
ur_queue_handle_t getEventQueue(const ur_event_handle_t);
103+
uint32_t getEventComputeStreamToken(const ur_event_handle_t);
104+
native_type getEventStream(const ur_event_handle_t);
105+
void createHostSubmitTimeStream();
102106

103107
// get_next_compute/transfer_stream() functions return streams from
104108
// appropriate pools in round-robin fashion

0 commit comments

Comments
 (0)