Skip to content

Commit f616029

Browse files
author
Hugh Delaney
committed
Change migrateMemoryToDevice to enqueue*
In line with changing from a sync op to async for memory migration across devices in a context.
1 parent a2bec63 commit f616029

File tree

4 files changed

+19
-17
lines changed

4 files changed

+19
-17
lines changed

source/adapters/cuda/enqueue_native.cpp

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,22 +28,21 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueNativeCommandExp(
2828
ScopedStream ActiveStream(hQueue, NumEventsInWaitList, phEventWaitList);
2929
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
3030

31+
if (hQueue->getContext()->getDevices().size() > 1) {
32+
for (auto i = 0u; i < NumMemsInMemList; ++i) {
33+
enqueueMigrateMemoryToDeviceIfNeeded(phMemList[i], hQueue->getDevice(),
34+
ActiveStream.getStream());
35+
phMemList[i]->setLastQueueWritingToMemObj(hQueue);
36+
}
37+
}
38+
3139
if (phEvent) {
3240
RetImplEvent =
3341
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
3442
UR_COMMAND_ENQUEUE_NATIVE_EXP, hQueue, ActiveStream.getStream()));
3543
UR_CHECK_ERROR(RetImplEvent->start());
3644
}
3745

38-
if (hQueue->getContext()->getDevices().size() > 1) {
39-
for (auto i = 0u; i < NumMemsInMemList; ++i) {
40-
// FIXME: Update to enqueueMigrateMemory and also using
41-
// setLastQueueWritingToMemObj when #1711 has merged
42-
migrateMemoryToDeviceIfNeeded(phMemList[i], hQueue->getDevice());
43-
phMemList[i]->setLastEventWritingToMemObj(RetImplEvent.get());
44-
}
45-
}
46-
4746
pfnNativeEnqueue(hQueue, data); // This is using urQueueGetNativeHandle to
4847
// get the CUDA stream. It must be the
4948
// same stream as is used before and after

source/adapters/cuda/memory.hpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717

1818
#include "common.hpp"
1919
#include "context.hpp"
20-
#include "device.hpp"
21-
#include "event.hpp"
20+
#include "queue.hpp"
2221

2322
ur_result_t allocateMemObjOnDeviceIfNeeded(ur_mem_handle_t,
2423
const ur_device_handle_t);
@@ -443,6 +442,3 @@ struct ur_mem_handle_t_ {
443442
}
444443
}
445444
};
446-
447-
ur_result_t migrateMemoryToDeviceIfNeeded(ur_mem_handle_t,
448-
const ur_device_handle_t);

source/adapters/hip/enqueue_native.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@
1212

1313
#include "context.hpp"
1414
#include "event.hpp"
15+
#include "memory.hpp"
1516
#include "queue.hpp"
1617

1718
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueNativeCommandExp(
1819
ur_queue_handle_t hQueue,
1920
ur_exp_enqueue_native_command_function_t pfnNativeEnqueue, void *data,
21+
uint32_t NumMemsInMemList, const ur_mem_handle_t *phMemList,
2022
const ur_exp_enqueue_native_command_properties_t *,
2123
uint32_t NumEventsInWaitList, const ur_event_handle_t *phEventWaitList,
2224
ur_event_handle_t *phEvent) {
@@ -29,6 +31,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueNativeCommandExp(
2931
ScopedStream ActiveStream(hQueue, NumEventsInWaitList, phEventWaitList);
3032
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
3133

34+
if (hQueue->getContext()->getDevices().size() > 1) {
35+
for (auto i = 0u; i < NumMemsInMemList; ++i) {
36+
enqueueMigrateMemoryToDeviceIfNeeded(phMemList[i], hQueue->getDevice(),
37+
ActiveStream.getStream());
38+
phMemList[i]->setLastQueueWritingToMemObj(hQueue);
39+
}
40+
}
41+
3242
if (phEvent) {
3343
RetImplEvent =
3444
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(

source/adapters/hip/memory.hpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,3 @@ struct ur_mem_handle_t_ {
437437
}
438438
}
439439
};
440-
441-
ur_result_t migrateMemoryToDeviceIfNeeded(ur_mem_handle_t,
442-
const ur_device_handle_t);

0 commit comments

Comments
 (0)