Skip to content

Commit a2bec63

Browse files
author
Hugh Delaney
committed
Add urEnqueueNativeCommand impl for HIP adapter
Add urEnqueueNativeCommand impl for HIP adapter. Also cahnge typo 'quard' to 'guard'.
1 parent af586ec commit a2bec63

File tree

5 files changed

+84
-14
lines changed

5 files changed

+84
-14
lines changed

source/adapters/hip/command_buffer.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -884,7 +884,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp(
884884
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
885885
ScopedContext Active(hQueue->getDevice());
886886
uint32_t StreamToken;
887-
ur_stream_quard Guard;
887+
ur_stream_guard Guard;
888888
hipStream_t HIPStream = hQueue->getNextComputeStream(
889889
numEventsInWaitList, phEventWaitList, Guard, &StreamToken);
890890

source/adapters/hip/enqueue.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
293293
ScopedContext Active(Dev);
294294

295295
uint32_t StreamToken;
296-
ur_stream_quard Guard;
296+
ur_stream_guard Guard;
297297
hipStream_t HIPStream = hQueue->getNextComputeStream(
298298
numEventsInWaitList, phEventWaitList, Guard, &StreamToken);
299299

@@ -380,7 +380,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueEventsWaitWithBarrier(
380380
try {
381381
ScopedContext Active(hQueue->getDevice());
382382
uint32_t StreamToken;
383-
ur_stream_quard Guard;
383+
ur_stream_guard Guard;
384384
hipStream_t HIPStream = hQueue->getNextComputeStream(
385385
numEventsInWaitList,
386386
reinterpret_cast<const ur_event_handle_t *>(phEventWaitList), Guard,
@@ -1243,7 +1243,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
12431243
try {
12441244
ScopedContext Active(hQueue->getDevice());
12451245
uint32_t StreamToken;
1246-
ur_stream_quard Guard;
1246+
ur_stream_guard Guard;
12471247
hipStream_t HIPStream = hQueue->getNextComputeStream(
12481248
numEventsInWaitList, phEventWaitList, Guard, &StreamToken);
12491249
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
@@ -1893,7 +1893,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueTimestampRecordingExp(
18931893
ScopedContext Active(hQueue->getDevice());
18941894

18951895
uint32_t StreamToken;
1896-
ur_stream_quard Guard;
1896+
ur_stream_guard Guard;
18971897
hipStream_t HIPStream = hQueue->getNextComputeStream(
18981898
numEventsInWaitList, phEventWaitList, Guard, &StreamToken);
18991899
UR_CHECK_ERROR(enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,

source/adapters/hip/enqueue_native.cpp

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,44 @@
1010

1111
#include <ur_api.h>
1212

13-
UR_APICALL UR_APIEXPORT ur_result_t urEnqueueNativeCommandExp(
14-
ur_queue_handle_t, ur_exp_enqueue_native_command_function_t, void *,
15-
const ur_exp_enqueue_native_command_properties_t *, uint32_t,
16-
const ur_event_handle_t *, ur_event_handle_t *) {
17-
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
13+
#include "context.hpp"
14+
#include "event.hpp"
15+
#include "queue.hpp"
16+
17+
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueNativeCommandExp(
18+
ur_queue_handle_t hQueue,
19+
ur_exp_enqueue_native_command_function_t pfnNativeEnqueue, void *data,
20+
const ur_exp_enqueue_native_command_properties_t *,
21+
uint32_t NumEventsInWaitList, const ur_event_handle_t *phEventWaitList,
22+
ur_event_handle_t *phEvent) {
23+
// TODO: how should mem migration work across a context here?
24+
// Perhaps we will need to add a phMemObjArgs so that we are able to make
25+
// sure memory migration happens across devices in the same context
26+
27+
try {
28+
ScopedContext ActiveContext(hQueue->getDevice());
29+
ScopedStream ActiveStream(hQueue, NumEventsInWaitList, phEventWaitList);
30+
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
31+
32+
if (phEvent) {
33+
RetImplEvent =
34+
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
35+
UR_COMMAND_ENQUEUE_NATIVE_EXP, hQueue, ActiveStream.getStream()));
36+
UR_CHECK_ERROR(RetImplEvent->start());
37+
}
38+
39+
pfnNativeEnqueue(hQueue, data); // This is using urQueueGetNativeHandle to
40+
// get the CUDA stream. It must be the
41+
// same stream as is used before and after
42+
if (phEvent) {
43+
UR_CHECK_ERROR(RetImplEvent->record());
44+
*phEvent = RetImplEvent.release();
45+
}
46+
47+
} catch (ur_result_t Err) {
48+
return Err;
49+
} catch (hipError_t hipErr) {
50+
return mapErrorUR(hipErr);
51+
}
52+
return UR_RESULT_SUCCESS;
1853
}

source/adapters/hip/queue.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ void ur_queue_handle_t_::transferStreamWaitForBarrierIfNeeded(
2929
}
3030

3131
hipStream_t ur_queue_handle_t_::getNextComputeStream(uint32_t *StreamToken) {
32+
if (getThreadLocalStream() != hipStream_t{0})
33+
return getThreadLocalStream();
3234
uint32_t Stream_i;
3335
uint32_t Token;
3436
while (true) {
@@ -63,7 +65,9 @@ hipStream_t ur_queue_handle_t_::getNextComputeStream(uint32_t *StreamToken) {
6365

6466
hipStream_t ur_queue_handle_t_::getNextComputeStream(
6567
uint32_t NumEventsInWaitList, const ur_event_handle_t *EventWaitList,
66-
ur_stream_quard &Guard, uint32_t *StreamToken) {
68+
ur_stream_guard &Guard, uint32_t *StreamToken) {
69+
if (getThreadLocalStream() != hipStream_t{0})
70+
return getThreadLocalStream();
6771
for (uint32_t i = 0; i < NumEventsInWaitList; i++) {
6872
uint32_t Token = EventWaitList[i]->getComputeStreamToken();
6973
if (EventWaitList[i]->getQueue() == this && canReuseStream(Token)) {
@@ -76,7 +80,7 @@ hipStream_t ur_queue_handle_t_::getNextComputeStream(
7680
if (StreamToken) {
7781
*StreamToken = Token;
7882
}
79-
Guard = ur_stream_quard{std::move(ComputeSyncGuard)};
83+
Guard = ur_stream_guard{std::move(ComputeSyncGuard)};
8084
hipStream_t Res = EventWaitList[i]->getStream();
8185
computeStreamWaitForBarrierIfNeeded(Res, Stream_i);
8286
return Res;
@@ -88,6 +92,8 @@ hipStream_t ur_queue_handle_t_::getNextComputeStream(
8892
}
8993

9094
hipStream_t ur_queue_handle_t_::getNextTransferStream() {
95+
if (getThreadLocalStream() != hipStream_t{0})
96+
return getThreadLocalStream();
9197
if (TransferStreams.empty()) { // for example in in-order queue
9298
return getNextComputeStream();
9399
}

source/adapters/hip/queue.hpp

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
#include <mutex>
1515
#include <vector>
1616

17-
using ur_stream_quard = std::unique_lock<std::mutex>;
17+
using ur_stream_guard = std::unique_lock<std::mutex>;
1818

1919
/// UR queue mapping on to hipStream_t objects.
2020
///
@@ -97,7 +97,7 @@ struct ur_queue_handle_t_ {
9797
// returns a lock that needs to remain locked as long as the stream is in use
9898
native_type getNextComputeStream(uint32_t NumEventsInWaitList,
9999
const ur_event_handle_t *EventWaitList,
100-
ur_stream_quard &Guard,
100+
ur_stream_guard &Guard,
101101
uint32_t *StreamToken = nullptr);
102102
native_type getNextTransferStream();
103103
native_type get() { return getNextComputeStream(); };
@@ -247,6 +247,12 @@ struct ur_queue_handle_t_ {
247247
}
248248
}
249249

250+
// Thread local stream will be used if ScopedStream is active
251+
static hipStream_t &getThreadLocalStream() {
252+
static thread_local hipStream_t stream{0};
253+
return stream;
254+
}
255+
250256
ur_context_handle_t getContext() const { return Context; };
251257

252258
ur_device_handle_t getDevice() const { return Device; };
@@ -261,3 +267,26 @@ struct ur_queue_handle_t_ {
261267

262268
bool backendHasOwnership() const noexcept { return HasOwnership; }
263269
};
270+
271+
// RAII object to make hQueue stream getter methods all return the same stream
272+
// within the lifetime of this object.
273+
//
274+
// This is useful for urEnqueueNativeCommandExp where we want guarantees that
275+
// the user submitted native calls will be dispatched to a known stream, which
276+
// must be "got" within the user submitted function.
277+
//
278+
// TODO: Add a test that this scoping works
279+
class ScopedStream {
280+
ur_queue_handle_t hQueue;
281+
282+
public:
283+
ScopedStream(ur_queue_handle_t hQueue, uint32_t NumEventsInWaitList,
284+
const ur_event_handle_t *EventWaitList)
285+
: hQueue{hQueue} {
286+
ur_stream_guard Guard;
287+
hQueue->getThreadLocalStream() =
288+
hQueue->getNextComputeStream(NumEventsInWaitList, EventWaitList, Guard);
289+
}
290+
hipStream_t getStream() { return hQueue->getThreadLocalStream(); }
291+
~ScopedStream() { hQueue->getThreadLocalStream() = hipStream_t{0}; }
292+
};

0 commit comments

Comments
 (0)