File tree Expand file tree Collapse file tree 4 files changed +19
-17
lines changed Expand file tree Collapse file tree 4 files changed +19
-17
lines changed Original file line number Diff line number Diff line change @@ -28,22 +28,21 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueNativeCommandExp(
28
28
ScopedStream ActiveStream (hQueue, NumEventsInWaitList, phEventWaitList);
29
29
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr };
30
30
31
+ if (hQueue->getContext ()->getDevices ().size () > 1 ) {
32
+ for (auto i = 0u ; i < NumMemsInMemList; ++i) {
33
+ enqueueMigrateMemoryToDeviceIfNeeded (phMemList[i], hQueue->getDevice (),
34
+ ActiveStream.getStream ());
35
+ phMemList[i]->setLastQueueWritingToMemObj (hQueue);
36
+ }
37
+ }
38
+
31
39
if (phEvent) {
32
40
RetImplEvent =
33
41
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative (
34
42
UR_COMMAND_ENQUEUE_NATIVE_EXP, hQueue, ActiveStream.getStream ()));
35
43
UR_CHECK_ERROR (RetImplEvent->start ());
36
44
}
37
45
38
- if (hQueue->getContext ()->getDevices ().size () > 1 ) {
39
- for (auto i = 0u ; i < NumMemsInMemList; ++i) {
40
- // FIXME: Update to enqueueMigrateMemory and also using
41
- // setLastQueueWritingToMemObj when #1711 has merged
42
- migrateMemoryToDeviceIfNeeded (phMemList[i], hQueue->getDevice ());
43
- phMemList[i]->setLastEventWritingToMemObj (RetImplEvent.get ());
44
- }
45
- }
46
-
47
46
pfnNativeEnqueue (hQueue, data); // This is using urQueueGetNativeHandle to
48
47
// get the CUDA stream. It must be the
49
48
// same stream as is used before and after
Original file line number Diff line number Diff line change 17
17
18
18
#include " common.hpp"
19
19
#include " context.hpp"
20
- #include " device.hpp"
21
- #include " event.hpp"
20
+ #include " queue.hpp"
22
21
23
22
ur_result_t allocateMemObjOnDeviceIfNeeded (ur_mem_handle_t ,
24
23
const ur_device_handle_t );
@@ -443,6 +442,3 @@ struct ur_mem_handle_t_ {
443
442
}
444
443
}
445
444
};
446
-
447
- ur_result_t migrateMemoryToDeviceIfNeeded (ur_mem_handle_t ,
448
- const ur_device_handle_t );
Original file line number Diff line number Diff line change 12
12
13
13
#include " context.hpp"
14
14
#include " event.hpp"
15
+ #include " memory.hpp"
15
16
#include " queue.hpp"
16
17
17
18
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueNativeCommandExp (
18
19
ur_queue_handle_t hQueue,
19
20
ur_exp_enqueue_native_command_function_t pfnNativeEnqueue, void *data,
21
+ uint32_t NumMemsInMemList, const ur_mem_handle_t *phMemList,
20
22
const ur_exp_enqueue_native_command_properties_t *,
21
23
uint32_t NumEventsInWaitList, const ur_event_handle_t *phEventWaitList,
22
24
ur_event_handle_t *phEvent) {
@@ -29,6 +31,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueNativeCommandExp(
29
31
ScopedStream ActiveStream (hQueue, NumEventsInWaitList, phEventWaitList);
30
32
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr };
31
33
34
+ if (hQueue->getContext ()->getDevices ().size () > 1 ) {
35
+ for (auto i = 0u ; i < NumMemsInMemList; ++i) {
36
+ enqueueMigrateMemoryToDeviceIfNeeded (phMemList[i], hQueue->getDevice (),
37
+ ActiveStream.getStream ());
38
+ phMemList[i]->setLastQueueWritingToMemObj (hQueue);
39
+ }
40
+ }
41
+
32
42
if (phEvent) {
33
43
RetImplEvent =
34
44
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative (
Original file line number Diff line number Diff line change @@ -437,6 +437,3 @@ struct ur_mem_handle_t_ {
437
437
}
438
438
}
439
439
};
440
-
441
- ur_result_t migrateMemoryToDeviceIfNeeded (ur_mem_handle_t ,
442
- const ur_device_handle_t );
You can’t perform that action at this time.
0 commit comments