Skip to content

Commit 3439c78

Browse files
authored
[UR][CUDA][HIP] Cleanup and share queue interop code (#17927)
This patch should essentially be NFC. The way it's currently setup is still very hacky and could use further refactoring, but this patch just isolates the hack a little bit and makes it more obvious what's going on, as well as removing some code duplication.
1 parent 81b09d0 commit 3439c78

File tree

8 files changed

+40
-62
lines changed

8 files changed

+40
-62
lines changed

unified-runtime/source/adapters/cuda/enqueue_native.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueNativeCommandExp(
2525

2626
try {
2727
ScopedContext ActiveContext(hQueue->getDevice());
28-
ScopedStream ActiveStream(hQueue, NumEventsInWaitList, phEventWaitList);
28+
InteropGuard ActiveStream(hQueue, NumEventsInWaitList, phEventWaitList);
2929
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
3030

3131
if (hQueue->getContext()->getDevices().size() > 1) {

unified-runtime/source/adapters/cuda/queue.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ urQueueGetNativeHandle(ur_queue_handle_t hQueue, ur_queue_native_desc_t *pDesc,
180180

181181
ScopedContext Active(hQueue->getDevice());
182182
*phNativeQueue =
183-
reinterpret_cast<ur_native_handle_t>(hQueue->getNextComputeStream());
183+
reinterpret_cast<ur_native_handle_t>(hQueue->getInteropStream());
184184
return UR_RESULT_SUCCESS;
185185
}
186186

unified-runtime/source/adapters/cuda/queue.hpp

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
using cuda_stream_queue = stream_queue_t<CUstream, 128, 64, CUevent>;
2323
struct ur_queue_handle_t_ : public cuda_stream_queue {};
2424

25+
using InteropGuard = cuda_stream_queue::interop_guard;
26+
2527
// Function which creates the profiling stream. Called only from makeNative
2628
// event when profiling is required.
2729
template <> inline void cuda_stream_queue::createHostSubmitTimeStream() {
@@ -38,24 +40,3 @@ inline void cuda_stream_queue::createStreamWithPriority(CUstream *Stream,
3840
int Priority) {
3941
UR_CHECK_ERROR(cuStreamCreateWithPriority(Stream, Flags, Priority));
4042
}
41-
42-
// RAII object to make hQueue stream getter methods all return the same stream
43-
// within the lifetime of this object.
44-
//
45-
// This is useful for urEnqueueNativeCommandExp where we want guarantees that
46-
// the user submitted native calls will be dispatched to a known stream, which
47-
// must be "got" within the user submitted fuction.
48-
class ScopedStream {
49-
ur_queue_handle_t hQueue;
50-
51-
public:
52-
ScopedStream(ur_queue_handle_t hQueue, uint32_t NumEventsInWaitList,
53-
const ur_event_handle_t *EventWaitList)
54-
: hQueue{hQueue} {
55-
ur_stream_guard Guard;
56-
hQueue->getThreadLocalStream() =
57-
hQueue->getNextComputeStream(NumEventsInWaitList, EventWaitList, Guard);
58-
}
59-
CUstream getStream() { return hQueue->getThreadLocalStream(); }
60-
~ScopedStream() { hQueue->getThreadLocalStream() = CUstream{0}; }
61-
};

unified-runtime/source/adapters/hip/enqueue_native.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueNativeCommandExp(
2828

2929
try {
3030
ScopedDevice ActiveDevice(hQueue->getDevice());
31-
ScopedStream ActiveStream(hQueue, NumEventsInWaitList, phEventWaitList);
31+
InteropGuard ActiveStream(hQueue, NumEventsInWaitList, phEventWaitList);
3232
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
3333

3434
if (hQueue->getContext()->getDevices().size() > 1) {

unified-runtime/source/adapters/hip/queue.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ urQueueGetNativeHandle(ur_queue_handle_t hQueue, ur_queue_native_desc_t *,
208208
ur_native_handle_t *phNativeQueue) {
209209
ScopedDevice Active(hQueue->getDevice());
210210
*phNativeQueue =
211-
reinterpret_cast<ur_native_handle_t>(hQueue->getNextComputeStream());
211+
reinterpret_cast<ur_native_handle_t>(hQueue->getInteropStream());
212212
return UR_RESULT_SUCCESS;
213213
}
214214

unified-runtime/source/adapters/hip/queue.hpp

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
using hip_stream_queue = stream_queue_t<hipStream_t, 64, 16, hipEvent_t>;
2020
struct ur_queue_handle_t_ : public hip_stream_queue {};
2121

22+
using InteropGuard = hip_stream_queue::interop_guard;
23+
2224
template <>
2325
inline void hip_stream_queue::createStreamWithPriority(hipStream_t *Stream,
2426
unsigned int Flags,
@@ -35,26 +37,3 @@ template <> inline void hip_stream_queue::createHostSubmitTimeStream() {
3537
hipStreamCreateWithFlags(&HostSubmitTimeStream, hipStreamNonBlocking));
3638
});
3739
}
38-
39-
// RAII object to make hQueue stream getter methods all return the same stream
40-
// within the lifetime of this object.
41-
//
42-
// This is useful for urEnqueueNativeCommandExp where we want guarantees that
43-
// the user submitted native calls will be dispatched to a known stream, which
44-
// must be "got" within the user submitted function.
45-
//
46-
// TODO: Add a test that this scoping works
47-
class ScopedStream {
48-
ur_queue_handle_t hQueue;
49-
50-
public:
51-
ScopedStream(ur_queue_handle_t hQueue, uint32_t NumEventsInWaitList,
52-
const ur_event_handle_t *EventWaitList)
53-
: hQueue{hQueue} {
54-
ur_stream_guard Guard;
55-
hQueue->getThreadLocalStream() =
56-
hQueue->getNextComputeStream(NumEventsInWaitList, EventWaitList, Guard);
57-
}
58-
hipStream_t getStream() { return hQueue->getThreadLocalStream(); }
59-
~ScopedStream() { hQueue->getThreadLocalStream() = hipStream_t{0}; }
60-
};

unified-runtime/source/common/cuda-hip/stream_queue.hpp

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,6 @@ struct stream_queue_t {
107107
// get_next_compute/transfer_stream() functions return streams from
108108
// appropriate pools in round-robin fashion
109109
native_type getNextComputeStream(uint32_t *StreamToken = nullptr) {
110-
if (getThreadLocalStream() != native_type{0})
111-
return getThreadLocalStream();
112110
uint32_t StreamI;
113111
uint32_t Token;
114112
while (true) {
@@ -150,8 +148,6 @@ struct stream_queue_t {
150148
const ur_event_handle_t *EventWaitList,
151149
ur_stream_guard &Guard,
152150
uint32_t *StreamToken = nullptr) {
153-
if (getThreadLocalStream() != native_type{0})
154-
return getThreadLocalStream();
155151
for (uint32_t i = 0; i < NumEventsInWaitList; i++) {
156152
uint32_t Token = getEventComputeStreamToken(EventWaitList[i]);
157153
if (getEventQueue(EventWaitList[i]) == this && canReuseStream(Token)) {
@@ -175,15 +171,7 @@ struct stream_queue_t {
175171
return getNextComputeStream(StreamToken);
176172
}
177173

178-
// Thread local stream will be used if ScopedStream is active
179-
static native_type &getThreadLocalStream() {
180-
static thread_local native_type stream{0};
181-
return stream;
182-
}
183-
184174
native_type getNextTransferStream() {
185-
if (getThreadLocalStream() != native_type{0})
186-
return getThreadLocalStream();
187175
if (TransferStreams.empty()) { // for example in in-order queue
188176
return getNextComputeStream();
189177
}
@@ -354,4 +342,34 @@ struct stream_queue_t {
354342
uint32_t getNextEventId() noexcept { return ++EventCount; }
355343

356344
bool backendHasOwnership() const noexcept { return HasOwnership; }
345+
346+
// Interop handling, for regular interop we return the next compute stream,
347+
// for native commands we use the interop_guard and return a thread local
348+
// stream. Native commands require to only have one in-order stream to work.
349+
native_type getInteropStream() {
350+
if (getThreadLocalStream() != native_type{0})
351+
return getThreadLocalStream();
352+
353+
return getNextComputeStream();
354+
}
355+
356+
static native_type &getThreadLocalStream() {
357+
static thread_local native_type stream{0};
358+
return stream;
359+
}
360+
361+
class interop_guard {
362+
stream_queue_t *q;
363+
364+
public:
365+
interop_guard(stream_queue_t *q, uint32_t NumEventsInWaitList,
366+
const ur_event_handle_t *EventWaitList)
367+
: q{q} {
368+
ur_stream_guard Guard;
369+
q->getThreadLocalStream() =
370+
q->getNextComputeStream(NumEventsInWaitList, EventWaitList, Guard);
371+
}
372+
native_type getStream() { return q->getThreadLocalStream(); }
373+
~interop_guard() { q->getThreadLocalStream() = native_type{0}; }
374+
};
357375
};

unified-runtime/test/adapters/cuda/urQueueGetNativeHandle.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ TEST_P(urCudaQueueGetNativeHandleTest, OutOfOrder) {
3030
ASSERT_SUCCESS_CUDA(cuStreamSynchronize(Stream));
3131
}
3232

33-
TEST_P(urCudaQueueGetNativeHandleTest, ScopedStream) {
33+
TEST_P(urCudaQueueGetNativeHandleTest, InteropGuard) {
3434
CUstream Stream1, Stream2;
3535
ur_queue_properties_t props = {
3636
/*.stype =*/UR_STRUCTURE_TYPE_QUEUE_PROPERTIES,
@@ -50,7 +50,7 @@ TEST_P(urCudaQueueGetNativeHandleTest, ScopedStream) {
5050
ASSERT_NE(Stream1, Stream2);
5151

5252
{
53-
ScopedStream ActiveStream(OutOfOrderQueue, 0, nullptr);
53+
InteropGuard ActiveStream(OutOfOrderQueue, 0, nullptr);
5454

5555
ASSERT_SUCCESS(urQueueGetNativeHandle(OutOfOrderQueue, nullptr,
5656
(ur_native_handle_t *)&Stream1));

0 commit comments

Comments
 (0)