Skip to content

Commit 09467b8

Browse files
authored
Merge pull request #1553 from hdelan/get-device-from-queue
[HIP] Get device from queue, not event
2 parents 1c1fa63 + 4cbf210 commit 09467b8

File tree

4 files changed

+32
-23
lines changed

4 files changed

+32
-23
lines changed

source/adapters/hip/enqueue.cpp

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
extern size_t imageElementByteSize(hipArray_Format ArrayFormat);
2222

23-
ur_result_t enqueueEventsWait(ur_queue_handle_t, hipStream_t Stream,
23+
ur_result_t enqueueEventsWait(ur_queue_handle_t Queue, hipStream_t Stream,
2424
uint32_t NumEventsInWaitList,
2525
const ur_event_handle_t *EventWaitList) {
2626
if (!EventWaitList) {
@@ -29,8 +29,8 @@ ur_result_t enqueueEventsWait(ur_queue_handle_t, hipStream_t Stream,
2929
try {
3030
auto Result = forLatestEvents(
3131
EventWaitList, NumEventsInWaitList,
32-
[Stream](ur_event_handle_t Event) -> ur_result_t {
33-
ScopedContext Active(Event->getDevice());
32+
[Stream, Queue](ur_event_handle_t Event) -> ur_result_t {
33+
ScopedContext Active(Queue->getDevice());
3434
if (Event->isCompleted() || Event->getStream() == Stream) {
3535
return UR_RESULT_SUCCESS;
3636
} else {
@@ -218,8 +218,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead(
218218
// last queue to write to the MemBuffer, meaning we must perform the copy
219219
// from a different device
220220
if (hBuffer->LastEventWritingToMemObj &&
221-
hBuffer->LastEventWritingToMemObj->getDevice() != hQueue->getDevice()) {
222-
Device = hBuffer->LastEventWritingToMemObj->getDevice();
221+
hBuffer->LastEventWritingToMemObj->getQueue()->getDevice() !=
222+
hQueue->getDevice()) {
223+
// This event is never created with interop so getQueue is never null
224+
hQueue = hBuffer->LastEventWritingToMemObj->getQueue();
225+
Device = hQueue->getDevice();
223226
ScopedContext Active(Device);
224227
HIPStream = hipStream_t{0}; // Default stream for different device
225228
// We may have to wait for an event on another queue if it is the last
@@ -584,8 +587,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferReadRect(
584587
// last queue to write to the MemBuffer, meaning we must perform the copy
585588
// from a different device
586589
if (hBuffer->LastEventWritingToMemObj &&
587-
hBuffer->LastEventWritingToMemObj->getDevice() != hQueue->getDevice()) {
588-
Device = hBuffer->LastEventWritingToMemObj->getDevice();
590+
hBuffer->LastEventWritingToMemObj->getQueue()->getDevice() !=
591+
hQueue->getDevice()) {
592+
// This event is never created with interop so getQueue is never null
593+
hQueue = hBuffer->LastEventWritingToMemObj->getQueue();
594+
Device = hQueue->getDevice();
589595
ScopedContext Active(Device);
590596
HIPStream = hipStream_t{0}; // Default stream for different device
591597
// We may have to wait for an event on another queue if it is the last
@@ -1017,8 +1023,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageRead(
10171023
// last queue to write to the MemBuffer, meaning we must perform the copy
10181024
// from a different device
10191025
if (hImage->LastEventWritingToMemObj &&
1020-
hImage->LastEventWritingToMemObj->getDevice() != hQueue->getDevice()) {
1021-
Device = hImage->LastEventWritingToMemObj->getDevice();
1026+
hImage->LastEventWritingToMemObj->getQueue()->getDevice() !=
1027+
hQueue->getDevice()) {
1028+
hQueue = hImage->LastEventWritingToMemObj->getQueue();
1029+
Device = hQueue->getDevice();
10221030
ScopedContext Active(Device);
10231031
HIPStream = hipStream_t{0}; // Default stream for different device
10241032
// We may have to wait for an event on another queue if it is the last

source/adapters/hip/event.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@ struct ur_event_handle_t_ {
2828

2929
ur_queue_handle_t getQueue() const noexcept { return Queue; }
3030

31-
ur_device_handle_t getDevice() const noexcept { return Queue->getDevice(); }
32-
3331
hipStream_t getStream() const noexcept { return Stream; }
3432

3533
uint32_t getComputeStreamToken() const noexcept { return StreamToken; }

source/adapters/hip/memory.cpp

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -525,11 +525,12 @@ inline ur_result_t migrateBufferToDevice(ur_mem_handle_t Mem,
525525
UR_CHECK_ERROR(
526526
hipMemcpyHtoD(Buffer.getPtr(hDevice), Buffer.HostPtr, Buffer.Size));
527527
}
528-
} else if (Mem->LastEventWritingToMemObj->getDevice() != hDevice) {
529-
UR_CHECK_ERROR(
530-
hipMemcpyDtoD(Buffer.getPtr(hDevice),
531-
Buffer.getPtr(Mem->LastEventWritingToMemObj->getDevice()),
532-
Buffer.Size));
528+
} else if (Mem->LastEventWritingToMemObj->getQueue()->getDevice() !=
529+
hDevice) {
530+
UR_CHECK_ERROR(hipMemcpyDtoD(
531+
Buffer.getPtr(hDevice),
532+
Buffer.getPtr(Mem->LastEventWritingToMemObj->getQueue()->getDevice()),
533+
Buffer.Size));
533534
}
534535
return UR_RESULT_SUCCESS;
535536
}
@@ -577,22 +578,24 @@ inline ur_result_t migrateImageToDevice(ur_mem_handle_t Mem,
577578
CpyDesc3D.srcHost = Image.HostPtr;
578579
UR_CHECK_ERROR(hipDrvMemcpy3D(&CpyDesc3D));
579580
}
580-
} else if (Mem->LastEventWritingToMemObj->getDevice() != hDevice) {
581+
} else if (Mem->LastEventWritingToMemObj->getQueue()->getDevice() !=
582+
hDevice) {
581583
if (Image.ImageDesc.type == UR_MEM_TYPE_IMAGE1D) {
582584
// FIXME: 1D memcpy from DtoD going through the host.
583585
UR_CHECK_ERROR(hipMemcpyAtoH(
584586
Image.HostPtr,
585-
Image.getArray(Mem->LastEventWritingToMemObj->getDevice()),
587+
Image.getArray(
588+
Mem->LastEventWritingToMemObj->getQueue()->getDevice()),
586589
0 /*srcOffset*/, ImageSizeBytes));
587590
UR_CHECK_ERROR(
588591
hipMemcpyHtoA(ImageArray, 0, Image.HostPtr, ImageSizeBytes));
589592
} else if (Image.ImageDesc.type == UR_MEM_TYPE_IMAGE2D) {
590-
CpyDesc2D.srcArray =
591-
Image.getArray(Mem->LastEventWritingToMemObj->getDevice());
593+
CpyDesc2D.srcArray = Image.getArray(
594+
Mem->LastEventWritingToMemObj->getQueue()->getDevice());
592595
UR_CHECK_ERROR(hipMemcpyParam2D(&CpyDesc2D));
593596
} else if (Image.ImageDesc.type == UR_MEM_TYPE_IMAGE3D) {
594-
CpyDesc3D.srcArray =
595-
Image.getArray(Mem->LastEventWritingToMemObj->getDevice());
597+
CpyDesc3D.srcArray = Image.getArray(
598+
Mem->LastEventWritingToMemObj->getQueue()->getDevice());
596599
UR_CHECK_ERROR(hipDrvMemcpy3D(&CpyDesc3D));
597600
}
598601
}

source/adapters/hip/memory.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,7 @@ struct ur_mem_handle_t_ {
498498
LastEventWritingToMemObj = NewEvent;
499499
for (const auto &Device : Context->getDevices()) {
500500
HaveMigratedToDeviceSinceLastWrite[Device->getIndex()] =
501-
Device == NewEvent->getDevice();
501+
Device == NewEvent->getQueue()->getDevice();
502502
}
503503
}
504504
};

0 commit comments

Comments
 (0)