Skip to content

Commit f5c907a

Browse files
authored
Merge pull request #1830 from JackAKirk/hip-set-device
[hip] Remove deprecated hip APIs, simplify urContext
2 parents 9ca3ec7 + be38e56 commit f5c907a

21 files changed

+99
-141
lines changed

source/adapters/hip/command_buffer.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -789,7 +789,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferEnqueueExp(
789789
ur_event_handle_t *phEvent) {
790790
try {
791791
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
792-
ScopedContext Active(hQueue->getDevice());
792+
ScopedDevice Active(hQueue->getDevice());
793793
uint32_t StreamToken;
794794
ur_stream_guard Guard;
795795
hipStream_t HIPStream = hQueue->getNextComputeStream(

source/adapters/hip/context.cpp

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,7 @@ ur_context_handle_t_::getOwningURPool(umf_memory_pool_t *UMFPool) {
3232
return nullptr;
3333
}
3434

35-
/// Create a UR HIP context.
36-
///
37-
/// By default creates a scoped context and keeps the last active HIP context
38-
/// on top of the HIP context stack.
35+
/// Create a UR context.
3936
///
4037
UR_APIEXPORT ur_result_t UR_APICALL urContextCreate(
4138
uint32_t DeviceCount, const ur_device_handle_t *phDevices,
@@ -44,7 +41,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urContextCreate(
4441

4542
std::unique_ptr<ur_context_handle_t_> ContextPtr{nullptr};
4643
try {
47-
// Create a scoped context.
44+
// Create a context.
4845
ContextPtr = std::unique_ptr<ur_context_handle_t_>(
4946
new ur_context_handle_t_{phDevices, DeviceCount});
5047
*phContext = ContextPtr.release();
@@ -111,13 +108,15 @@ urContextRetain(ur_context_handle_t hContext) {
111108
return UR_RESULT_SUCCESS;
112109
}
113110

114-
UR_APIEXPORT ur_result_t UR_APICALL urContextGetNativeHandle(
115-
ur_context_handle_t hContext, ur_native_handle_t *phNativeContext) {
116-
// FIXME: this entry point has been deprecated in the SYCL RT and should be
117-
// changed to unsupported once the deprecation period has elapsed
118-
*phNativeContext = reinterpret_cast<ur_native_handle_t>(
119-
hContext->getDevices()[0]->getNativeContext());
120-
return UR_RESULT_SUCCESS;
111+
// urContextGetNativeHandle should not be implemented in the HIP backend.
112+
// hipCtx_t is not natively supported by amd devices, and more importantly does
113+
// not map to ur_context_handle_t in any way.
114+
UR_APIEXPORT ur_result_t UR_APICALL
115+
urContextGetNativeHandle([[maybe_unused]] ur_context_handle_t hContext,
116+
[[maybe_unused]] ur_native_handle_t *phNativeContext) {
117+
std::ignore = hContext;
118+
std::ignore = phNativeContext;
119+
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
121120
}
122121

123122
UR_APIEXPORT ur_result_t UR_APICALL urContextCreateWithNativeHandle(

source/adapters/hip/context.hpp

Lines changed: 10 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,15 @@ typedef void (*ur_context_extended_deleter_t)(void *UserData);
5757
/// See proposal for details.
5858
/// https://github.com/codeplaysoftware/standards-proposals/blob/master/extended-context-destruction/index.md
5959
///
60+
///
61+
/// <b> Destructor callback </b>
62+
///
63+
/// Required to implement CP023, SYCL Extended Context Destruction,
64+
/// the UR Context can store a number of callback functions that will be
65+
/// called upon destruction of the UR Context.
66+
/// See proposal for details.
67+
/// https://github.com/codeplaysoftware/standards-proposals/blob/master/extended-context-destruction/index.md
68+
///
6069
/// <b> Memory Management for Devices in a Context <\b>
6170
///
6271
/// A \c ur_mem_handle_t is associated with a \c ur_context_handle_t_, which
@@ -76,8 +85,6 @@ struct ur_context_handle_t_ {
7685
void operator()() { Function(UserData); }
7786
};
7887

79-
using native_type = hipCtx_t;
80-
8188
std::vector<ur_device_handle_t> Devices;
8289

8390
std::atomic_uint32_t RefCount;
@@ -89,11 +96,7 @@ struct ur_context_handle_t_ {
8996
}
9097
};
9198

92-
~ur_context_handle_t_() {
93-
for (auto &Dev : Devices) {
94-
urDeviceRelease(Dev);
95-
}
96-
}
99+
~ur_context_handle_t_() {}
97100

98101
void invokeExtendedDeleters() {
99102
std::lock_guard<std::mutex> Guard(Mutex);
@@ -136,28 +139,3 @@ struct ur_context_handle_t_ {
136139
std::vector<deleter_data> ExtendedDeleters;
137140
std::set<ur_usm_pool_handle_t> PoolHandles;
138141
};
139-
140-
namespace {
141-
/// Scoped context is used across all UR HIP plugin implementation to activate
142-
/// the native Context on the current thread. The ScopedContext does not
143-
/// reinstate the previous context as all operations in the hip adapter that
144-
/// require an active context, set the active context and don't rely on context
145-
/// reinstation
146-
class ScopedContext {
147-
public:
148-
ScopedContext(ur_device_handle_t hDevice) {
149-
hipCtx_t Original{};
150-
151-
if (!hDevice) {
152-
throw UR_RESULT_ERROR_INVALID_DEVICE;
153-
}
154-
155-
hipCtx_t Desired = hDevice->getNativeContext();
156-
UR_CHECK_ERROR(hipCtxGetCurrent(&Original));
157-
if (Original != Desired) {
158-
// Sets the desired context as the active one for the thread
159-
UR_CHECK_ERROR(hipCtxSetCurrent(Desired));
160-
}
161-
}
162-
};
163-
} // namespace

source/adapters/hip/device.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1068,7 +1068,7 @@ ur_result_t UR_APICALL urDeviceGetGlobalTimestamps(ur_device_handle_t hDevice,
10681068
return UR_RESULT_SUCCESS;
10691069

10701070
ur_event_handle_t_::native_type Event;
1071-
ScopedContext Active(hDevice);
1071+
ScopedDevice Active(hDevice);
10721072

10731073
if (pDeviceTimestamp) {
10741074
UR_CHECK_ERROR(hipEventCreateWithFlags(&Event, hipEventDefault));

source/adapters/hip/device.hpp

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ struct ur_device_handle_t_ {
2424
native_type HIPDevice;
2525
std::atomic_uint32_t RefCount;
2626
ur_platform_handle_t Platform;
27-
hipCtx_t HIPContext;
2827
hipEvent_t EvBase; // HIP event used as base counter
2928
uint32_t DeviceIndex;
3029

@@ -37,11 +36,10 @@ struct ur_device_handle_t_ {
3736
int ConcurrentManagedAccess{0};
3837

3938
public:
40-
ur_device_handle_t_(native_type HipDevice, hipCtx_t Context,
41-
hipEvent_t EvBase, ur_platform_handle_t Platform,
42-
uint32_t DeviceIndex)
43-
: HIPDevice(HipDevice), RefCount{1}, Platform(Platform),
44-
HIPContext(Context), EvBase(EvBase), DeviceIndex(DeviceIndex) {
39+
ur_device_handle_t_(native_type HipDevice, hipEvent_t EvBase,
40+
ur_platform_handle_t Platform, uint32_t DeviceIndex)
41+
: HIPDevice(HipDevice), RefCount{1}, Platform(Platform), EvBase(EvBase),
42+
DeviceIndex(DeviceIndex) {
4543

4644
UR_CHECK_ERROR(hipDeviceGetAttribute(
4745
&MaxWorkGroupSize, hipDeviceAttributeMaxThreadsPerBlock, HIPDevice));
@@ -61,9 +59,7 @@ struct ur_device_handle_t_ {
6159
HIPDevice));
6260
}
6361

64-
~ur_device_handle_t_() noexcept(false) {
65-
UR_CHECK_ERROR(hipDevicePrimaryCtxRelease(HIPDevice));
66-
}
62+
~ur_device_handle_t_() noexcept(false) {}
6763

6864
native_type get() const noexcept { return HIPDevice; };
6965

@@ -73,8 +69,6 @@ struct ur_device_handle_t_ {
7369

7470
uint64_t getElapsedTime(hipEvent_t) const;
7571

76-
hipCtx_t getNativeContext() const noexcept { return HIPContext; };
77-
7872
// Returns the index of the device relative to the other devices in the same
7973
// platform
8074
uint32_t getIndex() const noexcept { return DeviceIndex; };
@@ -97,3 +91,20 @@ struct ur_device_handle_t_ {
9791
};
9892

9993
int getAttribute(ur_device_handle_t Device, hipDeviceAttribute_t Attribute);
94+
95+
namespace {
96+
/// Scoped Device is used across all UR HIP plugin implementation to activate
97+
/// the native Device on the current thread. The ScopedDevice does not
98+
/// reinstate the previous device as all operations in the HIP adapter that
99+
/// require an active device, set the active device and don't rely on device
100+
/// reinstation
101+
class ScopedDevice {
102+
public:
103+
ScopedDevice(ur_device_handle_t hDevice) {
104+
if (!hDevice) {
105+
throw UR_RESULT_ERROR_INVALID_DEVICE;
106+
}
107+
UR_CHECK_ERROR(hipSetDevice(hDevice->getIndex()));
108+
}
109+
};
110+
} // namespace

source/adapters/hip/enqueue.cpp

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ ur_result_t enqueueEventsWait(ur_queue_handle_t Queue, hipStream_t Stream,
3131
auto Result = forLatestEvents(
3232
EventWaitList, NumEventsInWaitList,
3333
[Stream, Queue](ur_event_handle_t Event) -> ur_result_t {
34-
ScopedContext Active(Queue->getDevice());
34+
ScopedDevice Active(Queue->getDevice());
3535
if (Event->isCompleted() || Event->getStream() == Stream) {
3636
return UR_RESULT_SUCCESS;
3737
} else {
@@ -164,7 +164,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite(
164164
hBuffer->setLastQueueWritingToMemObj(hQueue);
165165

166166
try {
167-
ScopedContext Active(hQueue->getDevice());
167+
ScopedDevice Active(hQueue->getDevice());
168168
hipStream_t HIPStream = hQueue->getNextTransferStream();
169169
UR_CHECK_ERROR(enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
170170
phEventWaitList));
@@ -220,7 +220,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead(
220220
}
221221

222222
auto Device = hQueue->getDevice();
223-
ScopedContext Active(Device);
223+
ScopedDevice Active(Device);
224224
hipStream_t HIPStream = hQueue->getNextTransferStream();
225225

226226
// Use the default stream if copying from another device
@@ -290,7 +290,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
290290
pGlobalWorkSize, pLocalWorkSize, hKernel,
291291
HIPFunc, ThreadsPerBlock, BlocksPerGrid));
292292

293-
ScopedContext Active(Dev);
293+
ScopedDevice Active(Dev);
294294

295295
uint32_t StreamToken;
296296
ur_stream_guard Guard;
@@ -378,7 +378,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueEventsWaitWithBarrier(
378378
UR_RESULT_ERROR_INVALID_EVENT_WAIT_LIST)
379379

380380
try {
381-
ScopedContext Active(hQueue->getDevice());
381+
ScopedDevice Active(hQueue->getDevice());
382382
uint32_t StreamToken;
383383
ur_stream_guard Guard;
384384
hipStream_t HIPStream = hQueue->getNextComputeStream(
@@ -533,7 +533,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferReadRect(
533533
}
534534

535535
auto Device = hQueue->getDevice();
536-
ScopedContext Active(Device);
536+
ScopedDevice Active(Device);
537537
hipStream_t HIPStream = hQueue->getNextTransferStream();
538538

539539
UR_CHECK_ERROR(enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
@@ -582,7 +582,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect(
582582
hBuffer->setLastQueueWritingToMemObj(hQueue);
583583

584584
try {
585-
ScopedContext Active(hQueue->getDevice());
585+
ScopedDevice Active(hQueue->getDevice());
586586
hipStream_t HIPStream = hQueue->getNextTransferStream();
587587
UR_CHECK_ERROR(enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
588588
phEventWaitList));
@@ -629,7 +629,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopy(
629629
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
630630

631631
try {
632-
ScopedContext Active(hQueue->getDevice());
632+
ScopedDevice Active(hQueue->getDevice());
633633
ur_result_t Result = UR_RESULT_SUCCESS;
634634
auto Stream = hQueue->getNextTransferStream();
635635

@@ -680,7 +680,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect(
680680
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
681681

682682
try {
683-
ScopedContext Active(hQueue->getDevice());
683+
ScopedDevice Active(hQueue->getDevice());
684684
hipStream_t HIPStream = hQueue->getNextTransferStream();
685685
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
686686
phEventWaitList);
@@ -794,7 +794,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
794794
hBuffer->setLastQueueWritingToMemObj(hQueue);
795795

796796
try {
797-
ScopedContext Active(hQueue->getDevice());
797+
ScopedDevice Active(hQueue->getDevice());
798798

799799
auto Stream = hQueue->getNextTransferStream();
800800
if (phEventWaitList) {
@@ -941,7 +941,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageRead(
941941
}
942942

943943
auto Device = hQueue->getDevice();
944-
ScopedContext Active(Device);
944+
ScopedDevice Active(Device);
945945
hipStream_t HIPStream = hQueue->getNextTransferStream();
946946

947947
if (phEventWaitList) {
@@ -1001,7 +1001,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageWrite(
10011001
UR_ASSERT(hImage->isImage(), UR_RESULT_ERROR_INVALID_MEM_OBJECT);
10021002

10031003
try {
1004-
ScopedContext Active(hQueue->getDevice());
1004+
ScopedDevice Active(hQueue->getDevice());
10051005
hipStream_t HIPStream = hQueue->getNextTransferStream();
10061006

10071007
if (phEventWaitList) {
@@ -1066,7 +1066,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageCopy(
10661066
ur_result_t Result = UR_RESULT_SUCCESS;
10671067

10681068
try {
1069-
ScopedContext Active(hQueue->getDevice());
1069+
ScopedDevice Active(hQueue->getDevice());
10701070
hipStream_t HIPStream = hQueue->getNextTransferStream();
10711071
if (phEventWaitList) {
10721072
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
@@ -1161,7 +1161,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferMap(
11611161
hQueue, hBuffer, blockingMap, offset, size, MapPtr,
11621162
numEventsInWaitList, phEventWaitList, phEvent));
11631163
} else {
1164-
ScopedContext Active(hQueue->getDevice());
1164+
ScopedDevice Active(hQueue->getDevice());
11651165

11661166
if (IsPinned) {
11671167
UR_CHECK_ERROR(urEnqueueEventsWait(hQueue, numEventsInWaitList,
@@ -1211,7 +1211,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemUnmap(
12111211
hQueue, hMem, true, Map->getMapOffset(), Map->getMapSize(),
12121212
pMappedPtr, numEventsInWaitList, phEventWaitList, phEvent));
12131213
} else {
1214-
ScopedContext Active(hQueue->getDevice());
1214+
ScopedDevice Active(hQueue->getDevice());
12151215

12161216
if (IsPinned) {
12171217
UR_CHECK_ERROR(urEnqueueEventsWait(hQueue, numEventsInWaitList,
@@ -1241,7 +1241,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
12411241
std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr};
12421242

12431243
try {
1244-
ScopedContext Active(hQueue->getDevice());
1244+
ScopedDevice Active(hQueue->getDevice());
12451245
uint32_t StreamToken;
12461246
ur_stream_guard Guard;
12471247
hipStream_t HIPStream = hQueue->getNextComputeStream(
@@ -1299,7 +1299,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy(
12991299
std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr};
13001300

13011301
try {
1302-
ScopedContext Active(hQueue->getDevice());
1302+
ScopedDevice Active(hQueue->getDevice());
13031303
hipStream_t HIPStream = hQueue->getNextTransferStream();
13041304
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
13051305
phEventWaitList);
@@ -1348,7 +1348,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
13481348
ur_result_t Result = UR_RESULT_SUCCESS;
13491349

13501350
try {
1351-
ScopedContext Active(hQueue->getDevice());
1351+
ScopedDevice Active(hQueue->getDevice());
13521352
hipStream_t HIPStream = hQueue->getNextTransferStream();
13531353
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
13541354
phEventWaitList);
@@ -1425,7 +1425,7 @@ urEnqueueUSMAdvise(ur_queue_handle_t hQueue, const void *pMem, size_t size,
14251425
#endif
14261426

14271427
try {
1428-
ScopedContext Active(Device);
1428+
ScopedDevice Active(Device);
14291429
std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr};
14301430

14311431
if (phEvent) {
@@ -1561,7 +1561,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
15611561
ur_result_t Result = UR_RESULT_SUCCESS;
15621562

15631563
try {
1564-
ScopedContext Active(hQueue->getDevice());
1564+
ScopedDevice Active(hQueue->getDevice());
15651565
hipStream_t HIPStream = hQueue->getNextTransferStream();
15661566
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
15671567
phEventWaitList);
@@ -1762,7 +1762,7 @@ setKernelParams(const ur_device_handle_t Device, const uint32_t WorkDim,
17621762
size_t MaxWorkGroupSize = 0;
17631763
ur_result_t Result = UR_RESULT_SUCCESS;
17641764
try {
1765-
ScopedContext Active(Device);
1765+
ScopedDevice Active(Device);
17661766
{
17671767
size_t MaxThreadsPerBlock[3] = {
17681768
static_cast<size_t>(Device->getMaxBlockDimX()),
@@ -1906,7 +1906,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueTimestampRecordingExp(
19061906
ur_result_t Result = UR_RESULT_SUCCESS;
19071907
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
19081908
try {
1909-
ScopedContext Active(hQueue->getDevice());
1909+
ScopedDevice Active(hQueue->getDevice());
19101910

19111911
uint32_t StreamToken;
19121912
ur_stream_guard Guard;

0 commit comments

Comments
 (0)