Skip to content

Commit 67e7da3

Browse files
authored
Merge pull request #999 from hdelan/hip-adapter-multi-dev-ctx
[HIP] Hip adapter multi dev ctx
2 parents 9b1fb4e + f0e0be2 commit 67e7da3

File tree

19 files changed

+959
-476
lines changed

19 files changed

+959
-476
lines changed

source/adapters/hip/context.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,13 @@ ur_context_handle_t_::getOwningURPool(umf_memory_pool_t *UMFPool) {
4040
UR_APIEXPORT ur_result_t UR_APICALL urContextCreate(
4141
uint32_t DeviceCount, const ur_device_handle_t *phDevices,
4242
const ur_context_properties_t *, ur_context_handle_t *phContext) {
43-
std::ignore = DeviceCount;
44-
assert(DeviceCount == 1);
4543
ur_result_t RetErr = UR_RESULT_SUCCESS;
4644

4745
std::unique_ptr<ur_context_handle_t_> ContextPtr{nullptr};
4846
try {
4947
// Create a scoped context.
5048
ContextPtr = std::unique_ptr<ur_context_handle_t_>(
51-
new ur_context_handle_t_{*phDevices});
49+
new ur_context_handle_t_{phDevices, DeviceCount});
5250

5351
static std::once_flag InitFlag;
5452
std::call_once(
@@ -78,9 +76,9 @@ urContextGetInfo(ur_context_handle_t hContext, ur_context_info_t propName,
7876

7977
switch (uint32_t{propName}) {
8078
case UR_CONTEXT_INFO_NUM_DEVICES:
81-
return ReturnValue(1);
79+
return ReturnValue(static_cast<uint32_t>(hContext->Devices.size()));
8280
case UR_CONTEXT_INFO_DEVICES:
83-
return ReturnValue(hContext->getDevice());
81+
return ReturnValue(hContext->getDevices());
8482
case UR_CONTEXT_INFO_REFERENCE_COUNT:
8583
return ReturnValue(hContext->getReferenceCount());
8684
case UR_CONTEXT_INFO_ATOMIC_MEMORY_ORDER_CAPABILITIES:
@@ -124,8 +122,10 @@ urContextRetain(ur_context_handle_t hContext) {
124122

125123
UR_APIEXPORT ur_result_t UR_APICALL urContextGetNativeHandle(
126124
ur_context_handle_t hContext, ur_native_handle_t *phNativeContext) {
125+
// FIXME: this entry point has been deprecated in the SYCL RT and should be
126+
// changed to unsupported once the deprecation period has elapsed
127127
*phNativeContext = reinterpret_cast<ur_native_handle_t>(
128-
hContext->getDevice()->getNativeContext());
128+
hContext->getDevices()[0]->getNativeContext());
129129
return UR_RESULT_SUCCESS;
130130
}
131131

source/adapters/hip/context.hpp

Lines changed: 40 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -28,26 +28,26 @@ typedef void (*ur_context_extended_deleter_t)(void *UserData);
2828
///
2929
/// One of the main differences between the UR API and the HIP driver API is
3030
/// that the second modifies the state of the threads by assigning
31-
/// `hipCtx_t` objects to threads. `hipCtx_t` objects store data associated
31+
/// \c hipCtx_t objects to threads. \c hipCtx_t objects store data associated
3232
/// with a given device and control access to said device from the user side.
3333
/// UR API context are objects that are passed to functions, and not bound
3434
/// to threads.
35-
/// The ur_context_handle_t_ object doesn't implement this behavior. It only
36-
/// holds the HIP context data. The RAII object \ref ScopedContext implements
37-
/// the active context behavior.
3835
///
39-
/// <b> Primary vs UserDefined context </b>
36+
/// Since the \c ur_context_handle_t can contain multiple devices, and a \c
37+
/// hipCtx_t refers to only a single device, the \c hipCtx_t is more tightly
38+
/// coupled to a \c ur_device_handle_t than a \c ur_context_handle_t. In order
39+
/// to remove some ambiguities about the different semantics of \c
40+
/// \c ur_context_handle_t and native \c hipCtx_t, we access the native \c
41+
/// hipCtx_t solely through the \c ur_device_handle_t class, by using the object
42+
/// \ref ScopedContext, which sets the active device (by setting the active
43+
/// native \c hipCtx_t).
4044
///
41-
/// HIP has two different types of context, the Primary context,
42-
/// which is usable by all threads on a given process for a given device, and
43-
/// the aforementioned custom contexts.
44-
/// The HIP documentation, and performance analysis, suggest using the Primary
45-
/// context whenever possible. The Primary context is also used by the HIP
46-
/// Runtime API. For UR applications to interop with HIP Runtime API, they have
47-
/// to use the primary context - and make that active in the thread. The
48-
/// `ur_context_handle_t_` object can be constructed with a `kind` parameter
49-
/// that allows to construct a Primary or `UserDefined` context, so that
50-
/// the UR object interface is always the same.
45+
/// <b> Primary vs User-defined \c hipCtx_t </b>
46+
///
47+
/// HIP has two different types of \c hipCtx_t, the Primary context, which is
48+
/// usable by all threads on a given process for a given device, and the
49+
/// aforementioned custom \c hipCtx_t s. The HIP documentation, confirmed with
50+
/// performance analysis, suggest using the Primary context whenever possible.
5151
///
5252
/// <b> Destructor callback </b>
5353
///
@@ -57,6 +57,16 @@ typedef void (*ur_context_extended_deleter_t)(void *UserData);
5757
/// See proposal for details.
5858
/// https://github.com/codeplaysoftware/standards-proposals/blob/master/extended-context-destruction/index.md
5959
///
60+
/// <b> Memory Management for Devices in a Context <\b>
61+
///
62+
/// A \c ur_mem_handle_t is associated with a \c ur_context_handle_t_, which
63+
/// may refer to multiple devices. Therefore the \c ur_mem_handle_t must
64+
/// handle a native allocation for each device in the context. UR is
65+
/// responsible for automatically handling event dependencies for kernels
66+
/// writing to or reading from the same \c ur_mem_handle_t and migrating memory
67+
/// between native allocations for devices in the same \c ur_context_handle_t_
68+
/// if necessary.
69+
///
6070
struct ur_context_handle_t_ {
6171

6272
struct deleter_data {
@@ -68,15 +78,22 @@ struct ur_context_handle_t_ {
6878

6979
using native_type = hipCtx_t;
7080

71-
ur_device_handle_t DeviceId;
81+
std::vector<ur_device_handle_t> Devices;
82+
7283
std::atomic_uint32_t RefCount;
7384

74-
ur_context_handle_t_(ur_device_handle_t DevId)
75-
: DeviceId{DevId}, RefCount{1} {
76-
urDeviceRetain(DeviceId);
85+
ur_context_handle_t_(const ur_device_handle_t *Devs, uint32_t NumDevices)
86+
: Devices{Devs, Devs + NumDevices}, RefCount{1} {
87+
for (auto &Dev : Devices) {
88+
urDeviceRetain(Dev);
89+
}
7790
};
7891

79-
~ur_context_handle_t_() { urDeviceRelease(DeviceId); }
92+
~ur_context_handle_t_() {
93+
for (auto &Dev : Devices) {
94+
urDeviceRelease(Dev);
95+
}
96+
}
8097

8198
void invokeExtendedDeleters() {
8299
std::lock_guard<std::mutex> Guard(Mutex);
@@ -91,7 +108,9 @@ struct ur_context_handle_t_ {
91108
ExtendedDeleters.emplace_back(deleter_data{Function, UserData});
92109
}
93110

94-
ur_device_handle_t getDevice() const noexcept { return DeviceId; }
111+
const std::vector<ur_device_handle_t> &getDevices() const noexcept {
112+
return Devices;
113+
}
95114

96115
uint32_t incrementReferenceCount() noexcept { return ++RefCount; }
97116

source/adapters/hip/device.hpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,13 @@ struct ur_device_handle_t_ {
2525
std::atomic_uint32_t RefCount;
2626
ur_platform_handle_t Platform;
2727
hipCtx_t HIPContext;
28+
uint32_t DeviceIndex;
2829

2930
public:
3031
ur_device_handle_t_(native_type HipDevice, hipCtx_t Context,
31-
ur_platform_handle_t Platform)
32+
ur_platform_handle_t Platform, uint32_t DeviceIndex)
3233
: HIPDevice(HipDevice), RefCount{1}, Platform(Platform),
33-
HIPContext(Context) {}
34+
HIPContext(Context), DeviceIndex(DeviceIndex) {}
3435

3536
~ur_device_handle_t_() {
3637
UR_CHECK_ERROR(hipDevicePrimaryCtxRelease(HIPDevice));
@@ -42,7 +43,11 @@ struct ur_device_handle_t_ {
4243

4344
ur_platform_handle_t getPlatform() const noexcept { return Platform; };
4445

45-
hipCtx_t getNativeContext() { return HIPContext; };
46+
hipCtx_t getNativeContext() const noexcept { return HIPContext; };
47+
48+
// Returns the index of the device relative to the other devices in the same
49+
// platform
50+
uint32_t getIndex() const noexcept { return DeviceIndex; };
4651
};
4752

4853
int getAttribute(ur_device_handle_t Device, hipDeviceAttribute_t Attribute);

0 commit comments

Comments
 (0)