Skip to content

Commit 42c0b02

Browse files
authored
Merge pull request #1634 from konradkusiak97/RefactorHIPBaseEvent
[HIP][CUDA] Refactor using profiling events
2 parents 8d86f5b + d082057 commit 42c0b02

File tree

13 files changed

+114
-86
lines changed

13 files changed

+114
-86
lines changed

source/adapters/cuda/event.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,7 @@ ur_result_t ur_event_handle_t_::start() {
5555

5656
try {
5757
if (Queue->URFlags & UR_QUEUE_FLAG_PROFILING_ENABLE || isTimestampEvent()) {
58-
// NOTE: This relies on the default stream to be unused.
59-
UR_CHECK_ERROR(cuEventRecord(EvQueued, 0));
58+
UR_CHECK_ERROR(cuEventRecord(EvQueued, Queue->getHostSubmitTimeStream()));
6059
UR_CHECK_ERROR(cuEventRecord(EvStart, Stream));
6160
}
6261
} catch (ur_result_t Err) {

source/adapters/cuda/event.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ struct ur_event_handle_t_ {
9090
const bool RequiresTimings =
9191
Queue->URFlags & UR_QUEUE_FLAG_PROFILING_ENABLE ||
9292
Type == UR_COMMAND_TIMESTAMP_RECORDING_EXP;
93+
if (RequiresTimings) {
94+
Queue->createHostSubmitTimeStream();
95+
}
9396
native_type EvEnd = nullptr, EvQueued = nullptr, EvStart = nullptr;
9497
UR_CHECK_ERROR(cuEventCreate(
9598
&EvEnd, RequiresTimings ? CU_EVENT_DEFAULT : CU_EVENT_DISABLE_TIMING));

source/adapters/cuda/queue.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueRelease(ur_queue_handle_t hQueue) {
201201
UR_CHECK_ERROR(cuStreamDestroy(S));
202202
});
203203

204+
if (hQueue->getHostSubmitTimeStream() != CUstream{0}) {
205+
UR_CHECK_ERROR(cuStreamSynchronize(hQueue->getHostSubmitTimeStream()));
206+
UR_CHECK_ERROR(cuStreamDestroy(hQueue->getHostSubmitTimeStream()));
207+
}
208+
204209
return UR_RESULT_SUCCESS;
205210
} catch (ur_result_t Err) {
206211
return Err;

source/adapters/cuda/queue.hpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99
//===----------------------------------------------------------------------===//
1010
#pragma once
1111

12+
#include "common.hpp"
1213
#include <ur/ur.hpp>
1314

1415
#include <algorithm>
1516
#include <cuda.h>
17+
#include <mutex>
1618
#include <vector>
1719

1820
using ur_stream_guard_ = std::unique_lock<std::mutex>;
@@ -27,6 +29,10 @@ struct ur_queue_handle_t_ {
2729

2830
std::vector<native_type> ComputeStreams;
2931
std::vector<native_type> TransferStreams;
32+
// Stream used for recording EvQueue, which holds information about when the
33+
// command in question is enqueued on host, as opposed to started. It is
34+
// created only if profiling is enabled - either for queue or per event.
35+
native_type HostSubmitTimeStream{0};
3036
// delay_compute_ keeps track of which streams have been recently reused and
3137
// their next use should be delayed. If a stream has been recently reused it
3238
// will be skipped the next time it would be selected round-robin style. When
@@ -99,6 +105,18 @@ struct ur_queue_handle_t_ {
99105
native_type get() { return getNextComputeStream(); };
100106
ur_device_handle_t getDevice() const noexcept { return Device; };
101107

108+
// Function which creates the profiling stream. Called only from makeNative
109+
// event when profiling is required.
110+
void createHostSubmitTimeStream() {
111+
static std::once_flag HostSubmitTimeStreamFlag;
112+
std::call_once(HostSubmitTimeStreamFlag, [&]() {
113+
UR_CHECK_ERROR(cuStreamCreateWithPriority(&HostSubmitTimeStream,
114+
CU_STREAM_NON_BLOCKING, 0));
115+
});
116+
}
117+
118+
native_type getHostSubmitTimeStream() { return HostSubmitTimeStream; }
119+
102120
bool hasBeenSynchronized(uint32_t StreamToken) {
103121
// stream token not associated with one of the compute streams
104122
if (StreamToken == std::numeric_limits<uint32_t>::max()) {

source/adapters/hip/context.cpp

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,18 +47,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextCreate(
4747
// Create a scoped context.
4848
ContextPtr = std::unique_ptr<ur_context_handle_t_>(
4949
new ur_context_handle_t_{phDevices, DeviceCount});
50-
51-
static std::once_flag InitFlag;
52-
std::call_once(
53-
InitFlag,
54-
[](ur_result_t &) {
55-
// Use default stream to record base event counter
56-
UR_CHECK_ERROR(hipEventCreateWithFlags(&ur_platform_handle_t_::EvBase,
57-
hipEventDefault));
58-
UR_CHECK_ERROR(hipEventRecord(ur_platform_handle_t_::EvBase, 0));
59-
},
60-
RetErr);
61-
6250
*phContext = ContextPtr.release();
6351
} catch (ur_result_t Err) {
6452
RetErr = Err;

source/adapters/hip/device.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,18 @@ int getAttribute(ur_device_handle_t Device, hipDeviceAttribute_t Attribute) {
2121
return Value;
2222
}
2323

24+
uint64_t ur_device_handle_t_::getElapsedTime(hipEvent_t ev) const {
25+
float Milliseconds = 0.0f;
26+
27+
// hipEventSynchronize waits till the event is ready for call to
28+
// hipEventElapsedTime.
29+
UR_CHECK_ERROR(hipEventSynchronize(EvBase));
30+
UR_CHECK_ERROR(hipEventSynchronize(ev));
31+
UR_CHECK_ERROR(hipEventElapsedTime(&Milliseconds, EvBase, ev));
32+
33+
return static_cast<uint64_t>(Milliseconds * 1.0e6);
34+
}
35+
2436
UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
2537
ur_device_info_t propName,
2638
size_t propSize,
@@ -1049,11 +1061,7 @@ ur_result_t UR_APICALL urDeviceGetGlobalTimestamps(ur_device_handle_t hDevice,
10491061
if (pDeviceTimestamp) {
10501062
UR_CHECK_ERROR(hipEventCreateWithFlags(&Event, hipEventDefault));
10511063
UR_CHECK_ERROR(hipEventRecord(Event));
1052-
UR_CHECK_ERROR(hipEventSynchronize(Event));
1053-
float ElapsedTime = 0.0f;
1054-
UR_CHECK_ERROR(hipEventElapsedTime(&ElapsedTime,
1055-
ur_platform_handle_t_::EvBase, Event));
1056-
*pDeviceTimestamp = (uint64_t)(ElapsedTime * (double)1e6);
1064+
*pDeviceTimestamp = hDevice->getElapsedTime(Event);
10571065
}
10581066

10591067
if (pHostTimestamp) {

source/adapters/hip/device.hpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ struct ur_device_handle_t_ {
2525
std::atomic_uint32_t RefCount;
2626
ur_platform_handle_t Platform;
2727
hipCtx_t HIPContext;
28+
hipEvent_t EvBase; // HIP event used as base counter
2829
uint32_t DeviceIndex;
30+
2931
int MaxWorkGroupSize{0};
3032
int MaxBlockDimX{0};
3133
int MaxBlockDimY{0};
@@ -36,9 +38,10 @@ struct ur_device_handle_t_ {
3638

3739
public:
3840
ur_device_handle_t_(native_type HipDevice, hipCtx_t Context,
39-
ur_platform_handle_t Platform, uint32_t DeviceIndex)
41+
hipEvent_t EvBase, ur_platform_handle_t Platform,
42+
uint32_t DeviceIndex)
4043
: HIPDevice(HipDevice), RefCount{1}, Platform(Platform),
41-
HIPContext(Context), DeviceIndex(DeviceIndex) {
44+
HIPContext(Context), EvBase(EvBase), DeviceIndex(DeviceIndex) {
4245

4346
UR_CHECK_ERROR(hipDeviceGetAttribute(
4447
&MaxWorkGroupSize, hipDeviceAttributeMaxThreadsPerBlock, HIPDevice));
@@ -68,6 +71,8 @@ struct ur_device_handle_t_ {
6871

6972
ur_platform_handle_t getPlatform() const noexcept { return Platform; };
7073

74+
uint64_t getElapsedTime(hipEvent_t) const;
75+
7176
hipCtx_t getNativeContext() const noexcept { return HIPContext; };
7277

7378
// Returns the index of the device relative to the other devices in the same

source/adapters/hip/event.cpp

Lines changed: 11 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,13 @@
1616
ur_event_handle_t_::ur_event_handle_t_(ur_command_t Type,
1717
ur_context_handle_t Context,
1818
ur_queue_handle_t Queue,
19-
hipStream_t Stream, uint32_t StreamToken)
19+
hipEvent_t EvEnd, hipEvent_t EvQueued,
20+
hipEvent_t EvStart, hipStream_t Stream,
21+
uint32_t StreamToken)
2022
: CommandType{Type}, RefCount{1}, HasOwnership{true},
2123
HasBeenWaitedOn{false}, IsRecorded{false}, IsStarted{false},
22-
StreamToken{StreamToken}, EventId{0}, EvEnd{nullptr}, EvStart{nullptr},
23-
EvQueued{nullptr}, Queue{Queue}, Stream{Stream}, Context{Context} {
24-
25-
bool ProfilingEnabled =
26-
Queue->URFlags & UR_QUEUE_FLAG_PROFILING_ENABLE || isTimestampEvent();
27-
28-
UR_CHECK_ERROR(hipEventCreateWithFlags(
29-
&EvEnd, ProfilingEnabled ? hipEventDefault : hipEventDisableTiming));
30-
31-
if (ProfilingEnabled) {
32-
UR_CHECK_ERROR(hipEventCreateWithFlags(&EvQueued, hipEventDefault));
33-
UR_CHECK_ERROR(hipEventCreateWithFlags(&EvStart, hipEventDefault));
34-
}
35-
24+
StreamToken{StreamToken}, EventId{0}, EvEnd{EvEnd}, EvStart{EvStart},
25+
EvQueued{EvQueued}, Queue{Queue}, Stream{Stream}, Context{Context} {
3626
urQueueRetain(Queue);
3727
urContextRetain(Context);
3828
}
@@ -60,9 +50,9 @@ ur_result_t ur_event_handle_t_::start() {
6050

6151
try {
6252
if (Queue->URFlags & UR_QUEUE_FLAG_PROFILING_ENABLE || isTimestampEvent()) {
63-
// NOTE: This relies on the default stream to be unused.
64-
UR_CHECK_ERROR(hipEventRecord(EvQueued, 0));
65-
UR_CHECK_ERROR(hipEventRecord(EvStart, Queue->get()));
53+
UR_CHECK_ERROR(
54+
hipEventRecord(EvQueued, Queue->getHostSubmitTimeStream()));
55+
UR_CHECK_ERROR(hipEventRecord(EvStart, Stream));
6656
}
6757
} catch (ur_result_t Error) {
6858
Result = Error;
@@ -90,44 +80,18 @@ bool ur_event_handle_t_::isCompleted() const {
9080
}
9181

9282
uint64_t ur_event_handle_t_::getQueuedTime() const {
93-
float MilliSeconds = 0.0f;
9483
assert(isStarted());
95-
96-
// hipEventSynchronize waits till the event is ready for call to
97-
// hipEventElapsedTime.
98-
UR_CHECK_ERROR(hipEventSynchronize(EvStart));
99-
UR_CHECK_ERROR(hipEventSynchronize(EvEnd));
100-
101-
UR_CHECK_ERROR(hipEventElapsedTime(&MilliSeconds, EvStart, EvEnd));
102-
return static_cast<uint64_t>(MilliSeconds * 1.0e6);
84+
return Queue->getDevice()->getElapsedTime(EvQueued);
10385
}
10486

10587
uint64_t ur_event_handle_t_::getStartTime() const {
106-
float MiliSeconds = 0.0f;
10788
assert(isStarted());
108-
109-
// hipEventSynchronize waits till the event is ready for call to
110-
// hipEventElapsedTime.
111-
UR_CHECK_ERROR(hipEventSynchronize(ur_platform_handle_t_::EvBase));
112-
UR_CHECK_ERROR(hipEventSynchronize(EvStart));
113-
114-
UR_CHECK_ERROR(hipEventElapsedTime(&MiliSeconds,
115-
ur_platform_handle_t_::EvBase, EvStart));
116-
return static_cast<uint64_t>(MiliSeconds * 1.0e6);
89+
return Queue->getDevice()->getElapsedTime(EvStart);
11790
}
11891

11992
uint64_t ur_event_handle_t_::getEndTime() const {
120-
float MiliSeconds = 0.0f;
12193
assert(isStarted() && isRecorded());
122-
123-
// hipEventSynchronize waits till the event is ready for call to
124-
// hipEventElapsedTime.
125-
UR_CHECK_ERROR(hipEventSynchronize(ur_platform_handle_t_::EvBase));
126-
UR_CHECK_ERROR(hipEventSynchronize(EvEnd));
127-
128-
UR_CHECK_ERROR(
129-
hipEventElapsedTime(&MiliSeconds, ur_platform_handle_t_::EvBase, EvEnd));
130-
return static_cast<uint64_t>(MiliSeconds * 1.0e6);
94+
return Queue->getDevice()->getElapsedTime(EvEnd);
13195
}
13296

13397
ur_result_t ur_event_handle_t_::record() {

source/adapters/hip/event.hpp

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,23 @@ struct ur_event_handle_t_ {
8080
static ur_event_handle_t
8181
makeNative(ur_command_t Type, ur_queue_handle_t Queue, hipStream_t Stream,
8282
uint32_t StreamToken = std::numeric_limits<uint32_t>::max()) {
83-
return new ur_event_handle_t_(Type, Queue->getContext(), Queue, Stream,
84-
StreamToken);
83+
const bool RequiresTimings =
84+
Queue->URFlags & UR_QUEUE_FLAG_PROFILING_ENABLE ||
85+
Type == UR_COMMAND_TIMESTAMP_RECORDING_EXP;
86+
if (RequiresTimings) {
87+
Queue->createHostSubmitTimeStream();
88+
}
89+
native_type EvEnd{nullptr}, EvQueued{nullptr}, EvStart{nullptr};
90+
UR_CHECK_ERROR(hipEventCreateWithFlags(
91+
&EvEnd, RequiresTimings ? hipEventDefault : hipEventDisableTiming));
92+
93+
if (RequiresTimings) {
94+
UR_CHECK_ERROR(hipEventCreateWithFlags(&EvQueued, hipEventDefault));
95+
UR_CHECK_ERROR(hipEventCreateWithFlags(&EvStart, hipEventDefault));
96+
}
97+
98+
return new ur_event_handle_t_(Type, Queue->getContext(), Queue, EvEnd,
99+
EvQueued, EvStart, Stream, StreamToken);
85100
}
86101

87102
static ur_event_handle_t makeWithNative(ur_context_handle_t context,
@@ -97,8 +112,9 @@ struct ur_event_handle_t_ {
97112
// This constructor is private to force programmers to use the makeNative /
98113
// make_user static members in order to create a ur_event_handle_t for HIP.
99114
ur_event_handle_t_(ur_command_t Type, ur_context_handle_t Context,
100-
ur_queue_handle_t Queue, hipStream_t Stream,
101-
uint32_t StreamToken);
115+
ur_queue_handle_t Queue, native_type EvEnd,
116+
native_type EvQueued, native_type EvStart,
117+
hipStream_t Stream, uint32_t StreamToken);
102118

103119
// This constructor is private to force programmers to use the
104120
// makeWithNative for event interop

source/adapters/hip/platform.cpp

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
#include "platform.hpp"
1212
#include "context.hpp"
1313

14-
hipEvent_t ur_platform_handle_t_::EvBase{nullptr};
15-
1614
UR_APIEXPORT ur_result_t UR_APICALL
1715
urPlatformGetInfo(ur_platform_handle_t, ur_platform_info_t propName,
1816
size_t propSize, void *pPropValue, size_t *pSizeRet) {
@@ -81,18 +79,15 @@ urPlatformGet(ur_adapter_handle_t *, uint32_t, uint32_t NumEntries,
8179
UR_CHECK_ERROR(hipDeviceGet(&Device, i));
8280
hipCtx_t Context;
8381
UR_CHECK_ERROR(hipDevicePrimaryCtxRetain(&Context, Device));
84-
Platform.Devices.emplace_back(
85-
new ur_device_handle_t_{Device, Context, &Platform, i});
86-
}
87-
88-
// Setup EvBase
89-
{
90-
ScopedContext Active(Platform.Devices.front().get());
9182
hipEvent_t EvBase;
9283
UR_CHECK_ERROR(hipEventCreate(&EvBase));
84+
85+
// Use the default stream to record base event counter
9386
UR_CHECK_ERROR(hipEventRecord(EvBase, 0));
87+
Platform.Devices.emplace_back(new ur_device_handle_t_{
88+
Device, Context, EvBase, &Platform, i});
9489

95-
ur_platform_handle_t_::EvBase = EvBase;
90+
ScopedContext Active(Platform.Devices.front().get());
9691
}
9792
} catch (const std::bad_alloc &) {
9893
// Signal out-of-memory situation

0 commit comments

Comments
 (0)