Skip to content

Commit 68328c4

Browse files
author
Hugh Delaney
committed
Add ur_mem_handle_t * arg to entrypoint
In order to manage native memory migration across ur_mem_handle_ts, we need to know which ur_mem_handle_ts are wrapped up in the void * function data.
1 parent 86a2db9 commit 68328c4

File tree

15 files changed

+173
-37
lines changed

15 files changed

+173
-37
lines changed

include/ur_api.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9562,6 +9562,11 @@ urEnqueueNativeCommandExp(
95629562
ur_exp_enqueue_native_command_function_t pfnNativeEnqueue, ///< [in] function calling the native underlying API, to be executed
95639563
///< immediately.
95649564
void *data, ///< [in][optional] data used by pfnNativeEnqueue
9565+
uint32_t numMemsInMemList, ///< [in] size of the mem list
9566+
const ur_mem_handle_t *phMemList, ///< [in][optional][range(0, numMemsInMemList)] mems that are used within
9567+
///< pfnNativeEnqueue using ::urMemGetNativeHandle.
9568+
///< If nullptr, the numMemsInMemList must be 0, indicating that no mems
9569+
///< are accessed with ::urMemGetNativeHandle within pfnNativeEnqueue.
95659570
const ur_exp_enqueue_native_command_properties_t *pProperties, ///< [in][optional] pointer to the native enqueue properties
95669571
uint32_t numEventsInWaitList, ///< [in] size of the event wait list
95679572
const ur_event_handle_t *phEventWaitList, ///< [in][optional][range(0, numEventsInWaitList)] pointer to a list of
@@ -10998,6 +11003,8 @@ typedef struct ur_enqueue_native_command_exp_params_t {
1099811003
ur_queue_handle_t *phQueue;
1099911004
ur_exp_enqueue_native_command_function_t *ppfnNativeEnqueue;
1100011005
void **pdata;
11006+
uint32_t *pnumMemsInMemList;
11007+
const ur_mem_handle_t **pphMemList;
1100111008
const ur_exp_enqueue_native_command_properties_t **ppProperties;
1100211009
uint32_t *pnumEventsInWaitList;
1100311010
const ur_event_handle_t **pphEventWaitList;

include/ur_ddi.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1489,6 +1489,8 @@ typedef ur_result_t(UR_APICALL *ur_pfnEnqueueNativeCommandExp_t)(
14891489
ur_queue_handle_t,
14901490
ur_exp_enqueue_native_command_function_t,
14911491
void *,
1492+
uint32_t,
1493+
const ur_mem_handle_t *,
14921494
const ur_exp_enqueue_native_command_properties_t *,
14931495
uint32_t,
14941496
const ur_event_handle_t *,

include/ur_print.hpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14547,6 +14547,23 @@ inline std::ostream &operator<<(std::ostream &os, [[maybe_unused]] const struct
1454714547
ur::details::printPtr(os,
1454814548
*(params->pdata));
1454914549

14550+
os << ", ";
14551+
os << ".numMemsInMemList = ";
14552+
14553+
os << *(params->pnumMemsInMemList);
14554+
14555+
os << ", ";
14556+
os << ".phMemList = {";
14557+
for (size_t i = 0; *(params->pphMemList) != NULL && i < *params->pnumMemsInMemList; ++i) {
14558+
if (i != 0) {
14559+
os << ", ";
14560+
}
14561+
14562+
ur::details::printPtr(os,
14563+
(*(params->pphMemList))[i]);
14564+
}
14565+
os << "}";
14566+
1455014567
os << ", ";
1455114568
os << ".pProperties = ";
1455214569

scripts/core/EXP-NATIVE-ENQUEUE.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@ queue accessed through ${x}QueueGetNativeHandle. Use of a native queue that is
3434
not the native queue returned by ${x}QueueGetNativeHandle results in undefined
3535
behaviour.
3636

37+
Any args that are needed by the func must be passed through a void* and unpacked
38+
within the func. If ${x}_mem_handle_t arguments are to be used within
39+
pfnNativeEnqueue, they must be accessed using ${x}MemGetNativeHandle.
40+
${x}_mem_handle_t arguments must be packed in the void* argument that will be
41+
used in pfnNativeEnqueue, as well as ${x}EnqueueNativeCommandExp's phMemList
42+
argument.
43+
3744
API
3845
--------------------------------------------------------------------------------
3946

scripts/core/exp-native-enqueue.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,14 @@ params:
8989
- type: void*
9090
name: data
9191
desc: "[in][optional] data used by pfnNativeEnqueue"
92+
- type: uint32_t
93+
name: numMemsInMemList
94+
desc: "[in] size of the mem list"
95+
- type: const $x_mem_handle_t*
96+
name: phMemList
97+
desc: |
98+
[in][optional][range(0, numMemsInMemList)] mems that are used within pfnNativeEnqueue using $xMemGetNativeHandle.
99+
If nullptr, the numMemsInMemList must be 0, indicating that no mems are accessed with $xMemGetNativeHandle within pfnNativeEnqueue.
92100
- type: const $x_exp_enqueue_native_command_properties_t*
93101
name: pProperties
94102
desc: "[in][optional] pointer to the native enqueue properties"

source/adapters/cuda/enqueue_native.cpp

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,38 +12,79 @@
1212

1313
#include "context.hpp"
1414
#include "event.hpp"
15+
#include "memory.hpp"
1516
#include "queue.hpp"
1617

1718
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueNativeCommandExp(
1819
ur_queue_handle_t hQueue,
1920
ur_exp_enqueue_native_command_function_t pfnNativeEnqueue, void *data,
21+
uint32_t NumMemsInMemList, const ur_mem_handle_t *phMemList,
2022
const ur_exp_enqueue_native_command_properties_t *,
2123
uint32_t NumEventsInWaitList, const ur_event_handle_t *phEventWaitList,
2224
ur_event_handle_t *phEvent) {
23-
// TODO: how should mem migration work across a context here?
24-
// Perhaps we will need to add a phMemObjArgs so that we are able to make
25-
// sure memory migration happens across devices in the same context
25+
26+
std::vector<ur_event_handle_t> MemMigrationEvents;
27+
std::vector<std::pair<ur_mem_handle_t, ur_lock>> MemMigrationLocks;
28+
29+
// phEventWaitList only contains events that are handed to UR by the SYCL
30+
// runtime. However since UR handles memory dependencies within a context
31+
// we may need to add more events to our dependent events list if the UR
32+
// context contains multiple devices
33+
if (NumMemsInMemList > 0 && hQueue->getContext()->Devices.size() > 1) {
34+
for (auto i = 0u; i < NumMemsInMemList; ++i) {
35+
auto Mem = phMemList[i];
36+
if (auto MemDepEvent = Mem->LastEventWritingToMemObj;
37+
MemDepEvent &&
38+
std::find(MemMigrationEvents.begin(), MemMigrationEvents.end(),
39+
MemDepEvent) == MemMigrationEvents.end()) {
40+
MemMigrationEvents.push_back(MemDepEvent);
41+
MemMigrationLocks.emplace_back(
42+
std::pair{Mem, ur_lock{Mem->MemoryMigrationMutex}});
43+
}
44+
}
45+
}
2646

2747
try {
2848
ScopedContext ActiveContext(hQueue->getDevice());
2949
ScopedStream ActiveStream(hQueue, NumEventsInWaitList, phEventWaitList);
3050
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
3151

32-
if (phEvent) {
52+
if (phEvent || MemMigrationEvents.size()) {
3353
RetImplEvent =
3454
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
3555
UR_COMMAND_ENQUEUE_NATIVE_EXP, hQueue, ActiveStream.getStream()));
3656
UR_CHECK_ERROR(RetImplEvent->start());
3757
}
3858

59+
if (MemMigrationEvents.size()) {
60+
UR_CHECK_ERROR(
61+
urEnqueueEventsWaitWithBarrier(hQueue, MemMigrationEvents.size(),
62+
MemMigrationEvents.data(), nullptr));
63+
for (auto i = 0u; i < NumMemsInMemList; ++i) {
64+
auto Mem = phMemList[i];
65+
migrateMemoryToDeviceIfNeeded(Mem, hQueue->getDevice());
66+
Mem->setLastEventWritingToMemObj(RetImplEvent.get());
67+
}
68+
MemMigrationLocks.clear();
69+
}
70+
3971
pfnNativeEnqueue(hQueue, data); // This is using urQueueGetNativeHandle to
4072
// get the CUDA stream. It must be the
4173
// same stream as is used before and after
42-
if (phEvent) {
74+
75+
if (phEvent || MemMigrationEvents.size()) {
4376
UR_CHECK_ERROR(RetImplEvent->record());
44-
*phEvent = RetImplEvent.release();
77+
if (phEvent) {
78+
*phEvent = RetImplEvent.release();
79+
} else {
80+
// Give ownership of the event to the mem
81+
for (auto i = 0u; i < NumMemsInMemList; ++i) {
82+
auto Mem = phMemList[i];
83+
migrateMemoryToDeviceIfNeeded(Mem, hQueue->getDevice());
84+
Mem->setLastEventWritingToMemObj(RetImplEvent.release());
85+
}
86+
}
4587
}
46-
4788
} catch (ur_result_t Err) {
4889
return Err;
4990
} catch (CUresult CuErr) {

source/adapters/level_zero/enqueue_native.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
UR_APICALL UR_APIEXPORT ur_result_t urEnqueueNativeCommandExp(
1414
ur_queue_handle_t, ur_exp_enqueue_native_command_function_t, void *,
15+
uint32_t, const ur_mem_handle_t *,
1516
const ur_exp_enqueue_native_command_properties_t *, uint32_t,
1617
const ur_event_handle_t *, ur_event_handle_t *) {
1718
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;

source/adapters/null/ur_nullddi.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5867,7 +5867,13 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueNativeCommandExp(
58675867
ur_exp_enqueue_native_command_function_t
58685868
pfnNativeEnqueue, ///< [in] function calling the native underlying API, to be executed
58695869
///< immediately.
5870-
void *data, ///< [in][optional] data used by pfnNativeEnqueue
5870+
void *data, ///< [in][optional] data used by pfnNativeEnqueue
5871+
uint32_t numMemsInMemList, ///< [in] size of the mem list
5872+
const ur_mem_handle_t *
5873+
phMemList, ///< [in][optional][range(0, numMemsInMemList)] mems that are used within
5874+
///< pfnNativeEnqueue using ::urMemGetNativeHandle.
5875+
///< If nullptr, the numMemsInMemList must be 0, indicating that no mems
5876+
///< are accessed with ::urMemGetNativeHandle within pfnNativeEnqueue.
58715877
const ur_exp_enqueue_native_command_properties_t *
58725878
pProperties, ///< [in][optional] pointer to the native enqueue properties
58735879
uint32_t numEventsInWaitList, ///< [in] size of the event wait list
@@ -5885,9 +5891,9 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueNativeCommandExp(
58855891
auto pfnNativeCommandExp =
58865892
d_context.urDdiTable.EnqueueExp.pfnNativeCommandExp;
58875893
if (nullptr != pfnNativeCommandExp) {
5888-
result =
5889-
pfnNativeCommandExp(hQueue, pfnNativeEnqueue, data, pProperties,
5890-
numEventsInWaitList, phEventWaitList, phEvent);
5894+
result = pfnNativeCommandExp(
5895+
hQueue, pfnNativeEnqueue, data, numMemsInMemList, phMemList,
5896+
pProperties, numEventsInWaitList, phEventWaitList, phEvent);
58915897
} else {
58925898
// generic implementation
58935899
*phEvent = reinterpret_cast<ur_event_handle_t>(d_context.get());

source/adapters/opencl/enqueue_native.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
UR_APICALL UR_APIEXPORT ur_result_t urEnqueueNativeCommandExp(
1414
ur_queue_handle_t, ur_exp_enqueue_native_command_function_t, void *,
15+
uint32_t, const ur_mem_handle_t *,
1516
const ur_exp_enqueue_native_command_properties_t *, uint32_t,
1617
const ur_event_handle_t *, ur_event_handle_t *) {
1718
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;

source/loader/layers/tracing/ur_trcddi.cpp

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7851,7 +7851,13 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueNativeCommandExp(
78517851
ur_exp_enqueue_native_command_function_t
78527852
pfnNativeEnqueue, ///< [in] function calling the native underlying API, to be executed
78537853
///< immediately.
7854-
void *data, ///< [in][optional] data used by pfnNativeEnqueue
7854+
void *data, ///< [in][optional] data used by pfnNativeEnqueue
7855+
uint32_t numMemsInMemList, ///< [in] size of the mem list
7856+
const ur_mem_handle_t *
7857+
phMemList, ///< [in][optional][range(0, numMemsInMemList)] mems that are used within
7858+
///< pfnNativeEnqueue using ::urMemGetNativeHandle.
7859+
///< If nullptr, the numMemsInMemList must be 0, indicating that no mems
7860+
///< are accessed with ::urMemGetNativeHandle within pfnNativeEnqueue.
78557861
const ur_exp_enqueue_native_command_properties_t *
78567862
pProperties, ///< [in][optional] pointer to the native enqueue properties
78577863
uint32_t numEventsInWaitList, ///< [in] size of the event wait list
@@ -7870,19 +7876,24 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueNativeCommandExp(
78707876
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
78717877
}
78727878

7873-
ur_enqueue_native_command_exp_params_t params = {
7874-
&hQueue, &pfnNativeEnqueue, &data,
7875-
&pProperties, &numEventsInWaitList, &phEventWaitList,
7876-
&phEvent};
7879+
ur_enqueue_native_command_exp_params_t params = {&hQueue,
7880+
&pfnNativeEnqueue,
7881+
&data,
7882+
&numMemsInMemList,
7883+
&phMemList,
7884+
&pProperties,
7885+
&numEventsInWaitList,
7886+
&phEventWaitList,
7887+
&phEvent};
78777888
uint64_t instance =
78787889
context.notify_begin(UR_FUNCTION_ENQUEUE_NATIVE_COMMAND_EXP,
78797890
"urEnqueueNativeCommandExp", &params);
78807891

78817892
context.logger.info("---> urEnqueueNativeCommandExp");
78827893

7883-
ur_result_t result =
7884-
pfnNativeCommandExp(hQueue, pfnNativeEnqueue, data, pProperties,
7885-
numEventsInWaitList, phEventWaitList, phEvent);
7894+
ur_result_t result = pfnNativeCommandExp(
7895+
hQueue, pfnNativeEnqueue, data, numMemsInMemList, phMemList,
7896+
pProperties, numEventsInWaitList, phEventWaitList, phEvent);
78867897

78877898
context.notify_end(UR_FUNCTION_ENQUEUE_NATIVE_COMMAND_EXP,
78887899
"urEnqueueNativeCommandExp", &params, &result, instance);

0 commit comments

Comments
 (0)