@@ -23,67 +23,34 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueNativeCommandExp(
23
23
uint32_t NumEventsInWaitList, const ur_event_handle_t *phEventWaitList,
24
24
ur_event_handle_t *phEvent) {
25
25
26
- std::vector<ur_event_handle_t > MemMigrationEvents;
27
- std::vector<std::pair<ur_mem_handle_t , ur_lock>> MemMigrationLocks;
28
-
29
- // phEventWaitList only contains events that are handed to UR by the SYCL
30
- // runtime. However since UR handles memory dependencies within a context
31
- // we may need to add more events to our dependent events list if the UR
32
- // context contains multiple devices
33
- if (NumMemsInMemList > 0 && hQueue->getContext ()->Devices .size () > 1 ) {
34
- for (auto i = 0u ; i < NumMemsInMemList; ++i) {
35
- auto Mem = phMemList[i];
36
- if (auto MemDepEvent = Mem->LastEventWritingToMemObj ;
37
- MemDepEvent &&
38
- std::find (MemMigrationEvents.begin (), MemMigrationEvents.end (),
39
- MemDepEvent) == MemMigrationEvents.end ()) {
40
- MemMigrationEvents.push_back (MemDepEvent);
41
- MemMigrationLocks.emplace_back (
42
- std::pair{Mem, ur_lock{Mem->MemoryMigrationMutex }});
43
- }
44
- }
45
- }
46
-
47
26
try {
48
27
ScopedContext ActiveContext (hQueue->getDevice ());
49
28
ScopedStream ActiveStream (hQueue, NumEventsInWaitList, phEventWaitList);
50
29
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr };
51
30
52
- if (phEvent || MemMigrationEvents. size () ) {
31
+ if (phEvent) {
53
32
RetImplEvent =
54
33
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative (
55
34
UR_COMMAND_ENQUEUE_NATIVE_EXP, hQueue, ActiveStream.getStream ()));
56
35
UR_CHECK_ERROR (RetImplEvent->start ());
57
36
}
58
37
59
- if (MemMigrationEvents.size ()) {
60
- UR_CHECK_ERROR (
61
- urEnqueueEventsWaitWithBarrier (hQueue, MemMigrationEvents.size (),
62
- MemMigrationEvents.data (), nullptr ));
38
+ if (hQueue->getContext ()->getDevices ().size () > 1 ) {
63
39
for (auto i = 0u ; i < NumMemsInMemList; ++i) {
64
- auto Mem = phMemList[i];
65
- migrateMemoryToDeviceIfNeeded (Mem, hQueue->getDevice ());
66
- Mem->setLastEventWritingToMemObj (RetImplEvent.get ());
40
+ // FIXME: Update to enqueueMigrateMemory and also using
41
+ // setLastQueueWritingToMemObj when #1711 has merged
42
+ migrateMemoryToDeviceIfNeeded (phMemList[i], hQueue->getDevice ());
43
+ phMemList[i]->setLastEventWritingToMemObj (RetImplEvent.get ());
67
44
}
68
- MemMigrationLocks.clear ();
69
45
}
70
46
71
47
pfnNativeEnqueue (hQueue, data); // This is using urQueueGetNativeHandle to
72
48
// get the CUDA stream. It must be the
73
49
// same stream as is used before and after
74
50
75
- if (phEvent || MemMigrationEvents. size () ) {
51
+ if (phEvent) {
76
52
UR_CHECK_ERROR (RetImplEvent->record ());
77
- if (phEvent) {
78
- *phEvent = RetImplEvent.release ();
79
- } else {
80
- // Give ownership of the event to the mem
81
- for (auto i = 0u ; i < NumMemsInMemList; ++i) {
82
- auto Mem = phMemList[i];
83
- migrateMemoryToDeviceIfNeeded (Mem, hQueue->getDevice ());
84
- Mem->setLastEventWritingToMemObj (RetImplEvent.release ());
85
- }
86
- }
53
+ *phEvent = RetImplEvent.release ();
87
54
}
88
55
} catch (ur_result_t Err) {
89
56
return Err;
0 commit comments