Skip to content

Commit d216eb4

Browse files
authored
Merge pull request #1226 from hdelan/get-native-mem-on-device2
[UR] Add extra param to urMemGetNativeHandle
2 parents 40517d2 + fc1f306 commit d216eb4

File tree

19 files changed

+87
-30
lines changed

19 files changed

+87
-30
lines changed

include/ur_api.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2723,13 +2723,15 @@ urMemBufferPartition(
27232723
/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC
27242724
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
27252725
/// + `NULL == hMem`
2726+
/// + `NULL == hDevice`
27262727
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
27272728
/// + `NULL == phNativeMem`
27282729
/// - ::UR_RESULT_ERROR_UNSUPPORTED_FEATURE
27292730
/// + If the adapter has no underlying equivalent handle.
27302731
UR_APIEXPORT ur_result_t UR_APICALL
27312732
urMemGetNativeHandle(
27322733
ur_mem_handle_t hMem, ///< [in] handle of the mem.
2734+
ur_device_handle_t hDevice, ///< [in] handle of the device that the native handle will be resident on.
27332735
ur_native_handle_t *phNativeMem ///< [out] a pointer to the native handle of the mem.
27342736
);
27352737

@@ -9488,6 +9490,7 @@ typedef struct ur_mem_buffer_partition_params_t {
94889490
/// allowing the callback the ability to modify the parameter's value
94899491
typedef struct ur_mem_get_native_handle_params_t {
94909492
ur_mem_handle_t *phMem;
9493+
ur_device_handle_t *phDevice;
94919494
ur_native_handle_t **pphNativeMem;
94929495
} ur_mem_get_native_handle_params_t;
94939496

include/ur_ddi.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -770,6 +770,7 @@ typedef ur_result_t(UR_APICALL *ur_pfnMemBufferPartition_t)(
770770
/// @brief Function-pointer for urMemGetNativeHandle
771771
typedef ur_result_t(UR_APICALL *ur_pfnMemGetNativeHandle_t)(
772772
ur_mem_handle_t,
773+
ur_device_handle_t,
773774
ur_native_handle_t *);
774775

775776
///////////////////////////////////////////////////////////////////////////////

include/ur_print.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11174,6 +11174,12 @@ inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct
1117411174
ur::details::printPtr(os,
1117511175
*(params->phMem));
1117611176

11177+
os << ", ";
11178+
os << ".hDevice = ";
11179+
11180+
ur::details::printPtr(os,
11181+
*(params->phDevice));
11182+
1117711183
os << ", ";
1117811184
os << ".phNativeMem = ";
1117911185

scripts/core/memory.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,10 @@ params:
432432
name: hMem
433433
desc: |
434434
[in] handle of the mem.
435+
- type: $x_device_handle_t
436+
name: hDevice
437+
desc: |
438+
[in] handle of the device that the native handle will be resident on.
435439
- type: $x_native_handle_t*
436440
name: phNativeMem
437441
desc: |

source/adapters/cuda/memory.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemRelease(ur_mem_handle_t hMem) {
161161
/// \param[out] phNativeMem Set to the native handle of the UR mem object.
162162
///
163163
/// \return UR_RESULT_SUCCESS
164-
UR_APIEXPORT ur_result_t UR_APICALL
165-
urMemGetNativeHandle(ur_mem_handle_t hMem, ur_native_handle_t *phNativeMem) {
164+
UR_APIEXPORT ur_result_t UR_APICALL urMemGetNativeHandle(
165+
ur_mem_handle_t hMem, ur_device_handle_t, ur_native_handle_t *phNativeMem) {
166166
*phNativeMem = reinterpret_cast<ur_native_handle_t>(
167167
std::get<BufferMem>(hMem->Mem).get());
168168
return UR_RESULT_SUCCESS;

source/adapters/hip/memory.cpp

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -279,16 +279,32 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemGetInfo(ur_mem_handle_t hMemory,
279279
/// \param[out] phNativeMem Set to the native handle of the UR mem object.
280280
///
281281
/// \return UR_RESULT_SUCCESS
282-
UR_APIEXPORT ur_result_t UR_APICALL urMemGetNativeHandle(ur_mem_handle_t,
283-
ur_native_handle_t *) {
284-
// FIXME: there is no good way of doing this with a multi device context.
285-
// If we return a single pointer, how would we know which device's allocation
286-
// it should be?
287-
// If we return a vector of pointers, this is OK for read only access but if
288-
// we write to a buffer, how would we know which one had been written to?
289-
// Should unused allocations be updated afterwards? We have no way of knowing
290-
// any of these things in the current API design.
291-
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
282+
UR_APIEXPORT ur_result_t UR_APICALL
283+
urMemGetNativeHandle(ur_mem_handle_t hMem, ur_device_handle_t Device,
284+
ur_native_handle_t *phNativeMem) {
285+
#if defined(__HIP_PLATFORM_NVIDIA__)
286+
if (sizeof(BufferMem::native_type) > sizeof(ur_native_handle_t)) {
287+
// Check that all the upper bits that cannot be represented by
288+
// ur_native_handle_t are empty.
289+
// NOTE: The following shift might trigger a warning, but the check in the
290+
// if above makes sure that this does not underflow.
291+
BufferMem::native_type UpperBits =
292+
std::get<BufferMem>(hMem->Mem).getPtr(Device) >>
293+
(sizeof(ur_native_handle_t) * CHAR_BIT);
294+
if (UpperBits) {
295+
// Return an error if any of the remaining bits is non-zero.
296+
return UR_RESULT_ERROR_INVALID_MEM_OBJECT;
297+
}
298+
}
299+
*phNativeMem = reinterpret_cast<ur_native_handle_t>(
300+
std::get<BufferMem>(hMem->Mem).getPtr(Device));
301+
#elif defined(__HIP_PLATFORM_AMD__)
302+
*phNativeMem = reinterpret_cast<ur_native_handle_t>(
303+
std::get<BufferMem>(hMem->Mem).getPtr(Device));
304+
#else
305+
#error("Must define exactly one of __HIP_PLATFORM_AMD__ or __HIP_PLATFORM_NVIDIA__");
306+
#endif
307+
return UR_RESULT_SUCCESS;
292308
}
293309

294310
UR_APIEXPORT ur_result_t UR_APICALL urMemBufferCreateWithNativeHandle(

source/adapters/level_zero/memory.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1856,6 +1856,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferPartition(
18561856

18571857
UR_APIEXPORT ur_result_t UR_APICALL urMemGetNativeHandle(
18581858
ur_mem_handle_t Mem, ///< [in] handle of the mem.
1859+
ur_device_handle_t, ///< [in] handle of the device.
18591860
ur_native_handle_t
18601861
*NativeMem ///< [out] a pointer to the native handle of the mem.
18611862
) {

source/adapters/native_cpu/memory.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferPartition(
111111
}
112112

113113
UR_APIEXPORT ur_result_t UR_APICALL
114-
urMemGetNativeHandle(ur_mem_handle_t hMem, ur_native_handle_t *phNativeMem) {
114+
urMemGetNativeHandle(ur_mem_handle_t hMem, ur_device_handle_t hDevice,
115+
ur_native_handle_t *phNativeMem) {
115116
std::ignore = hMem;
117+
std::ignore = hDevice;
116118
std::ignore = phNativeMem;
117119

118120
DIE_NO_IMPLEMENTATION

source/adapters/null/ur_nullddi.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -916,6 +916,8 @@ __urdlllocal ur_result_t UR_APICALL urMemBufferPartition(
916916
/// @brief Intercept function for urMemGetNativeHandle
917917
__urdlllocal ur_result_t UR_APICALL urMemGetNativeHandle(
918918
ur_mem_handle_t hMem, ///< [in] handle of the mem.
919+
ur_device_handle_t
920+
hDevice, ///< [in] handle of the device that the native handle will be resident on.
919921
ur_native_handle_t
920922
*phNativeMem ///< [out] a pointer to the native handle of the mem.
921923
) try {
@@ -924,7 +926,7 @@ __urdlllocal ur_result_t UR_APICALL urMemGetNativeHandle(
924926
// if the driver has created a custom function, then call it instead of using the generic path
925927
auto pfnGetNativeHandle = d_context.urDdiTable.Mem.pfnGetNativeHandle;
926928
if (nullptr != pfnGetNativeHandle) {
927-
result = pfnGetNativeHandle(hMem, phNativeMem);
929+
result = pfnGetNativeHandle(hMem, hDevice, phNativeMem);
928930
} else {
929931
// generic implementation
930932
*phNativeMem = reinterpret_cast<ur_native_handle_t>(d_context.get());

source/adapters/opencl/memory.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,8 +331,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferPartition(
331331
return mapCLErrorToUR(RetErr);
332332
}
333333

334-
UR_APIEXPORT ur_result_t UR_APICALL
335-
urMemGetNativeHandle(ur_mem_handle_t hMem, ur_native_handle_t *phNativeMem) {
334+
UR_APIEXPORT ur_result_t UR_APICALL urMemGetNativeHandle(
335+
ur_mem_handle_t hMem, ur_device_handle_t, ur_native_handle_t *phNativeMem) {
336336
return getNativeHandle(hMem, phNativeMem);
337337
}
338338

0 commit comments

Comments
 (0)