Skip to content

Commit af586ec

Browse files
author
Hugh Delaney
committed
Remove some multi dev ctx shenanigans
As is done in #1711, we don't need to wait on events that are not given directly to UR by the user through the phEventWaitList param.
1 parent 68328c4 commit af586ec

File tree

1 file changed

+8
-41
lines changed

1 file changed

+8
-41
lines changed

source/adapters/cuda/enqueue_native.cpp

Lines changed: 8 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -23,67 +23,34 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueNativeCommandExp(
2323
uint32_t NumEventsInWaitList, const ur_event_handle_t *phEventWaitList,
2424
ur_event_handle_t *phEvent) {
2525

26-
std::vector<ur_event_handle_t> MemMigrationEvents;
27-
std::vector<std::pair<ur_mem_handle_t, ur_lock>> MemMigrationLocks;
28-
29-
// phEventWaitList only contains events that are handed to UR by the SYCL
30-
// runtime. However since UR handles memory dependencies within a context
31-
// we may need to add more events to our dependent events list if the UR
32-
// context contains multiple devices
33-
if (NumMemsInMemList > 0 && hQueue->getContext()->Devices.size() > 1) {
34-
for (auto i = 0u; i < NumMemsInMemList; ++i) {
35-
auto Mem = phMemList[i];
36-
if (auto MemDepEvent = Mem->LastEventWritingToMemObj;
37-
MemDepEvent &&
38-
std::find(MemMigrationEvents.begin(), MemMigrationEvents.end(),
39-
MemDepEvent) == MemMigrationEvents.end()) {
40-
MemMigrationEvents.push_back(MemDepEvent);
41-
MemMigrationLocks.emplace_back(
42-
std::pair{Mem, ur_lock{Mem->MemoryMigrationMutex}});
43-
}
44-
}
45-
}
46-
4726
try {
4827
ScopedContext ActiveContext(hQueue->getDevice());
4928
ScopedStream ActiveStream(hQueue, NumEventsInWaitList, phEventWaitList);
5029
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
5130

52-
if (phEvent || MemMigrationEvents.size()) {
31+
if (phEvent) {
5332
RetImplEvent =
5433
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
5534
UR_COMMAND_ENQUEUE_NATIVE_EXP, hQueue, ActiveStream.getStream()));
5635
UR_CHECK_ERROR(RetImplEvent->start());
5736
}
5837

59-
if (MemMigrationEvents.size()) {
60-
UR_CHECK_ERROR(
61-
urEnqueueEventsWaitWithBarrier(hQueue, MemMigrationEvents.size(),
62-
MemMigrationEvents.data(), nullptr));
38+
if (hQueue->getContext()->getDevices().size() > 1) {
6339
for (auto i = 0u; i < NumMemsInMemList; ++i) {
64-
auto Mem = phMemList[i];
65-
migrateMemoryToDeviceIfNeeded(Mem, hQueue->getDevice());
66-
Mem->setLastEventWritingToMemObj(RetImplEvent.get());
40+
// FIXME: Update to enqueueMigrateMemory and also using
41+
// setLastQueueWritingToMemObj when #1711 has merged
42+
migrateMemoryToDeviceIfNeeded(phMemList[i], hQueue->getDevice());
43+
phMemList[i]->setLastEventWritingToMemObj(RetImplEvent.get());
6744
}
68-
MemMigrationLocks.clear();
6945
}
7046

7147
pfnNativeEnqueue(hQueue, data); // This is using urQueueGetNativeHandle to
7248
// get the CUDA stream. It must be the
7349
// same stream as is used before and after
7450

75-
if (phEvent || MemMigrationEvents.size()) {
51+
if (phEvent) {
7652
UR_CHECK_ERROR(RetImplEvent->record());
77-
if (phEvent) {
78-
*phEvent = RetImplEvent.release();
79-
} else {
80-
// Give ownership of the event to the mem
81-
for (auto i = 0u; i < NumMemsInMemList; ++i) {
82-
auto Mem = phMemList[i];
83-
migrateMemoryToDeviceIfNeeded(Mem, hQueue->getDevice());
84-
Mem->setLastEventWritingToMemObj(RetImplEvent.release());
85-
}
86-
}
53+
*phEvent = RetImplEvent.release();
8754
}
8855
} catch (ur_result_t Err) {
8956
return Err;

0 commit comments

Comments
 (0)