Skip to content

Commit 2acb536

Browse files
author
Hugh Delaney
committed
Add basic CUDA impl using ScopedStream
Use ScopedStream to return the same stream during the lifetime of the RAII object. This allows us to create events outside a user submitted func, and submit work within the user submitted func, since the stream given to the user from urQueueGetNativeHandle is guaranteed to be the same stream that we record events on.
1 parent d7a18c1 commit 2acb536

File tree

3 files changed

+75
-6
lines changed

3 files changed

+75
-6
lines changed
Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===--------- native_enqueue.cpp - CUDA Adapter --------------------------===//
1+
//===--------- enqueue_native.cpp - CUDA Adapter --------------------------===//
22
//
33
// Copyright (C) 2024 Intel Corporation
44
//
@@ -10,9 +10,44 @@
1010

1111
#include <ur_api.h>
1212

13-
ur_result_t urNativeEnqueueExp(ur_queue_handle_t,
14-
ur_exp_enqueue_native_command_function_t, void *,
15-
uint32_t, const ur_event_handle_t *,
16-
ur_event_handle_t *) {
17-
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
13+
#include "context.hpp"
14+
#include "event.hpp"
15+
#include "queue.hpp"
16+
17+
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueNativeCommandExp(
18+
ur_queue_handle_t hQueue,
19+
ur_exp_enqueue_native_command_function_t pfnNativeEnqueue, void *data,
20+
const ur_exp_enqueue_native_command_properties_t *,
21+
uint32_t NumEventsInWaitList, const ur_event_handle_t *phEventWaitList,
22+
ur_event_handle_t *phEvent) {
23+
// TODO: how should mem migration work across a context here?
24+
// Perhaps we will need to add a phMemObjArgs so that we are able to make
25+
// sure memory migration happens across devices in the same context
26+
27+
try {
28+
ScopedContext ActiveContext(hQueue->getDevice());
29+
ScopedStream ActiveStream(hQueue, NumEventsInWaitList, phEventWaitList);
30+
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
31+
32+
if (phEvent) {
33+
RetImplEvent =
34+
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
35+
UR_COMMAND_ENQUEUE_NATIVE_EXP, hQueue, ActiveStream.getStream()));
36+
UR_CHECK_ERROR(RetImplEvent->start());
37+
}
38+
39+
pfnNativeEnqueue(hQueue, data); // This is using urQueueGetNativeHandle to
40+
// get the CUDA stream. It must be the
41+
// same stream as is used before and after
42+
if (phEvent) {
43+
UR_CHECK_ERROR(RetImplEvent->record());
44+
*phEvent = RetImplEvent.release();
45+
}
46+
47+
} catch (ur_result_t Err) {
48+
return Err;
49+
} catch (CUresult CuErr) {
50+
return mapErrorUR(CuErr);
51+
}
52+
return UR_RESULT_SUCCESS;
1853
}

source/adapters/cuda/queue.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ void ur_queue_handle_t_::transferStreamWaitForBarrierIfNeeded(
3333
}
3434

3535
CUstream ur_queue_handle_t_::getNextComputeStream(uint32_t *StreamToken) {
36+
if (getThreadLocalStream() != CUstream{0})
37+
return getThreadLocalStream();
3638
uint32_t StreamI;
3739
uint32_t Token;
3840
while (true) {
@@ -68,6 +70,8 @@ CUstream ur_queue_handle_t_::getNextComputeStream(uint32_t *StreamToken) {
6870
CUstream ur_queue_handle_t_::getNextComputeStream(
6971
uint32_t NumEventsInWaitList, const ur_event_handle_t *EventWaitList,
7072
ur_stream_guard_ &Guard, uint32_t *StreamToken) {
73+
if (getThreadLocalStream() != CUstream{0})
74+
return getThreadLocalStream();
7175
for (uint32_t i = 0; i < NumEventsInWaitList; i++) {
7276
uint32_t Token = EventWaitList[i]->getComputeStreamToken();
7377
if (reinterpret_cast<ur_queue_handle_t>(EventWaitList[i]->getQueue()) ==
@@ -94,6 +98,8 @@ CUstream ur_queue_handle_t_::getNextComputeStream(
9498
}
9599

96100
CUstream ur_queue_handle_t_::getNextTransferStream() {
101+
if (getThreadLocalStream() != CUstream{0})
102+
return getThreadLocalStream();
97103
if (TransferStreams.empty()) { // for example in in-order queue
98104
return getNextComputeStream();
99105
}

source/adapters/cuda/queue.hpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,13 @@ struct ur_queue_handle_t_ {
101101
const ur_event_handle_t *EventWaitList,
102102
ur_stream_guard_ &Guard,
103103
uint32_t *StreamToken = nullptr);
104+
105+
// Thread local stream will be used if ScopedStream is active
106+
static CUstream &getThreadLocalStream() {
107+
static thread_local CUstream stream{0};
108+
return stream;
109+
}
110+
104111
native_type getNextTransferStream();
105112
native_type get() { return getNextComputeStream(); };
106113
ur_device_handle_t getDevice() const noexcept { return Device; };
@@ -265,3 +272,24 @@ struct ur_queue_handle_t_ {
265272

266273
bool backendHasOwnership() const noexcept { return HasOwnership; }
267274
};
275+
276+
// RAII object to make hQueue stream getter methods all return the same stream
277+
// within the lifetime of this object.
278+
//
279+
// This is useful for urEnqueueNativeCommandExp where we want guarantees that
280+
// the user submitted native calls will be dispatched to a known stream, which
281+
// must be "got" within the user submitted fuction.
282+
class ScopedStream {
283+
ur_queue_handle_t hQueue;
284+
285+
public:
286+
ScopedStream(ur_queue_handle_t hQueue, uint32_t NumEventsInWaitList,
287+
const ur_event_handle_t *EventWaitList)
288+
: hQueue{hQueue} {
289+
ur_stream_guard_ Guard;
290+
hQueue->getThreadLocalStream() =
291+
hQueue->getNextComputeStream(NumEventsInWaitList, EventWaitList, Guard);
292+
}
293+
CUstream getStream() { return hQueue->getThreadLocalStream(); }
294+
~ScopedStream() { hQueue->getThreadLocalStream() = CUstream{0}; }
295+
};

0 commit comments

Comments
 (0)