Skip to content

Commit 1da316c

Browse files
committed
[HIP] Refactor error handling in enqueue.cpp
Mostly by taking the existing try/catch/UR_CHECK_ERROR based approach and making sure it's used consistently so as not to drop any errors.
1 parent 7aba70b commit 1da316c

File tree

1 file changed

+48
-58
lines changed

1 file changed

+48
-58
lines changed

source/adapters/hip/enqueue.cpp

Lines changed: 48 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ ur_result_t enqueueEventsWait(ur_queue_handle_t Queue, hipStream_t Stream,
2828
return UR_RESULT_SUCCESS;
2929
}
3030
try {
31-
auto Result = forLatestEvents(
31+
UR_CHECK_ERROR(forLatestEvents(
3232
EventWaitList, NumEventsInWaitList,
3333
[Stream, Queue](ur_event_handle_t Event) -> ur_result_t {
3434
ScopedDevice Active(Queue->getDevice());
@@ -38,17 +38,13 @@ ur_result_t enqueueEventsWait(ur_queue_handle_t Queue, hipStream_t Stream,
3838
UR_CHECK_ERROR(hipStreamWaitEvent(Stream, Event->get(), 0));
3939
return UR_RESULT_SUCCESS;
4040
}
41-
});
42-
43-
if (Result != UR_RESULT_SUCCESS) {
44-
return Result;
45-
}
46-
return UR_RESULT_SUCCESS;
41+
}));
4742
} catch (ur_result_t Err) {
4843
return Err;
4944
} catch (...) {
5045
return UR_RESULT_ERROR_UNKNOWN;
5146
}
47+
return UR_RESULT_SUCCESS;
5248
}
5349

5450
// Determine local work sizes that result in uniform work groups.
@@ -629,13 +625,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopy(
629625
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
630626

631627
try {
632-
ScopedDevice Active(hQueue->getDevice());
633-
ur_result_t Result = UR_RESULT_SUCCESS;
628+
ScopedContext Active(hQueue->getDevice());
634629
auto Stream = hQueue->getNextTransferStream();
635630

636631
if (phEventWaitList) {
637-
Result = enqueueEventsWait(hQueue, Stream, numEventsInWaitList,
638-
phEventWaitList);
632+
UR_CHECK_ERROR(enqueueEventsWait(hQueue, Stream, numEventsInWaitList,
633+
phEventWaitList));
639634
}
640635

641636
if (phEvent) {
@@ -657,12 +652,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopy(
657652
*phEvent = RetImplEvent.release();
658653
}
659654

660-
return Result;
661655
} catch (ur_result_t Err) {
662656
return Err;
663657
} catch (...) {
664658
return UR_RESULT_ERROR_UNKNOWN;
665659
}
660+
return UR_RESULT_SUCCESS;
666661
}
667662

668663
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect(
@@ -672,7 +667,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect(
672667
size_t srcSlicePitch, size_t dstRowPitch, size_t dstSlicePitch,
673668
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
674669
ur_event_handle_t *phEvent) {
675-
ur_result_t Result = UR_RESULT_SUCCESS;
676670
void *SrcPtr =
677671
std::get<BufferMem>(hBufferSrc->Mem).getVoid(hQueue->getDevice());
678672
void *DstPtr =
@@ -682,8 +676,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect(
682676
try {
683677
ScopedDevice Active(hQueue->getDevice());
684678
hipStream_t HIPStream = hQueue->getNextTransferStream();
685-
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
686-
phEventWaitList);
679+
UR_CHECK_ERROR(enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
680+
phEventWaitList));
687681

688682
if (phEvent) {
689683
RetImplEvent =
@@ -692,20 +686,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect(
692686
UR_CHECK_ERROR(RetImplEvent->start());
693687
}
694688

695-
Result = commonEnqueueMemBufferCopyRect(
689+
UR_CHECK_ERROR(commonEnqueueMemBufferCopyRect(
696690
HIPStream, region, &SrcPtr, hipMemoryTypeDevice, srcOrigin, srcRowPitch,
697691
srcSlicePitch, &DstPtr, hipMemoryTypeDevice, dstOrigin, dstRowPitch,
698-
dstSlicePitch);
692+
dstSlicePitch));
699693

700694
if (phEvent) {
701695
UR_CHECK_ERROR(RetImplEvent->record());
702696
*phEvent = RetImplEvent.release();
703697
}
704698

705699
} catch (ur_result_t Err) {
706-
Result = Err;
700+
return Err;
707701
}
708-
return Result;
702+
return UR_RESULT_SUCCESS;
709703
}
710704

711705
static inline void memsetRemainPattern(hipStream_t Stream, uint32_t PatternSize,
@@ -1063,14 +1057,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageCopy(
10631057
std::get<SurfaceMem>(hImageDst->Mem).getImageType(),
10641058
UR_RESULT_ERROR_INVALID_MEM_OBJECT);
10651059

1066-
ur_result_t Result = UR_RESULT_SUCCESS;
1067-
10681060
try {
10691061
ScopedDevice Active(hQueue->getDevice());
10701062
hipStream_t HIPStream = hQueue->getNextTransferStream();
10711063
if (phEventWaitList) {
1072-
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
1073-
phEventWaitList);
1064+
UR_CHECK_ERROR(enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
1065+
phEventWaitList));
10741066
}
10751067

10761068
hipArray *SrcArray =
@@ -1110,13 +1102,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageCopy(
11101102
UR_CHECK_ERROR(RetImplEvent->start());
11111103
}
11121104

1113-
Result = commonEnqueueMemImageNDCopy(
1105+
UR_CHECK_ERROR(commonEnqueueMemImageNDCopy(
11141106
HIPStream, ImgType, AdjustedRegion, SrcArray, hipMemoryTypeArray,
1115-
SrcOffset, DstArray, hipMemoryTypeArray, DstOffset);
1116-
1117-
if (Result != UR_RESULT_SUCCESS) {
1118-
return Result;
1119-
}
1107+
SrcOffset, DstArray, hipMemoryTypeArray, DstOffset));
11201108

11211109
if (phEvent) {
11221110
UR_CHECK_ERROR(RetImplEvent->record());
@@ -1237,7 +1225,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
12371225
ur_queue_handle_t hQueue, void *ptr, size_t patternSize,
12381226
const void *pPattern, size_t size, uint32_t numEventsInWaitList,
12391227
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
1240-
ur_result_t Result = UR_RESULT_SUCCESS;
12411228
std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr};
12421229

12431230
try {
@@ -1246,8 +1233,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
12461233
ur_stream_guard Guard;
12471234
hipStream_t HIPStream = hQueue->getNextComputeStream(
12481235
numEventsInWaitList, phEventWaitList, Guard, &StreamToken);
1249-
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
1250-
phEventWaitList);
1236+
UR_CHECK_ERROR(enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
1237+
phEventWaitList));
12511238
if (phEvent) {
12521239
EventPtr =
12531240
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
@@ -1274,8 +1261,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
12741261
break;
12751262

12761263
default:
1277-
Result = commonMemSetLargePattern(HIPStream, patternSize, size, pPattern,
1278-
reinterpret_cast<hipDeviceptr_t>(ptr));
1264+
UR_CHECK_ERROR(
1265+
commonMemSetLargePattern(HIPStream, patternSize, size, pPattern,
1266+
reinterpret_cast<hipDeviceptr_t>(ptr)));
12791267
break;
12801268
}
12811269

@@ -1284,25 +1272,23 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
12841272
*phEvent = EventPtr.release();
12851273
}
12861274
} catch (ur_result_t Err) {
1287-
Result = Err;
1275+
return Err;
12881276
}
12891277

1290-
return Result;
1278+
return UR_RESULT_SUCCESS;
12911279
}
12921280

12931281
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy(
12941282
ur_queue_handle_t hQueue, bool blocking, void *pDst, const void *pSrc,
12951283
size_t size, uint32_t numEventsInWaitList,
12961284
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
1297-
ur_result_t Result = UR_RESULT_SUCCESS;
1298-
12991285
std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr};
13001286

13011287
try {
13021288
ScopedDevice Active(hQueue->getDevice());
13031289
hipStream_t HIPStream = hQueue->getNextTransferStream();
1304-
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
1305-
phEventWaitList);
1290+
UR_CHECK_ERROR(enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
1291+
phEventWaitList));
13061292
if (phEvent) {
13071293
EventPtr =
13081294
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
@@ -1321,9 +1307,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy(
13211307
*phEvent = EventPtr.release();
13221308
}
13231309
} catch (ur_result_t Err) {
1324-
Result = Err;
1310+
return Err;
13251311
}
1326-
return Result;
1312+
return UR_RESULT_SUCCESS;
13271313
}
13281314

13291315
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
@@ -1345,13 +1331,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
13451331
UR_ASSERT(size <= PointerRangeSize, UR_RESULT_ERROR_INVALID_SIZE);
13461332
#endif
13471333

1348-
ur_result_t Result = UR_RESULT_SUCCESS;
1349-
13501334
try {
13511335
ScopedDevice Active(hQueue->getDevice());
13521336
hipStream_t HIPStream = hQueue->getNextTransferStream();
1353-
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
1354-
phEventWaitList);
1337+
UR_CHECK_ERROR(enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
1338+
phEventWaitList));
13551339

13561340
std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr};
13571341

@@ -1399,10 +1383,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
13991383
hipMemPrefetchAsync(pMem, size, hQueue->getDevice()->get(), HIPStream));
14001384
releaseEvent();
14011385
} catch (ur_result_t Err) {
1402-
Result = Err;
1386+
return Err;
14031387
}
14041388

1405-
return Result;
1389+
return UR_RESULT_SUCCESS;
14061390
}
14071391

14081392
/// USM: memadvise API to govern behavior of automatic migration mechanisms
@@ -1516,10 +1500,18 @@ urEnqueueUSMAdvise(ur_queue_handle_t hQueue, const void *pMem, size_t size,
15161500
// the runtime.
15171501
if (Result == UR_RESULT_ERROR_INVALID_ENUMERATION) {
15181502
releaseEvent();
1519-
setErrorMessage("mem_advise is ignored as the advice argument is not "
1520-
"supported by this device",
1521-
UR_RESULT_SUCCESS);
1522-
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
1503+
// UR_RESULT_ERROR_INVALID_ENUMERATION is returned when using a valid
1504+
// but currently unmapped advice arguments as not supported by this
1505+
// platform. Therefore, warn the user instead of throwing and aborting
1506+
// the runtime.
1507+
if (Result == UR_RESULT_ERROR_INVALID_ENUMERATION) {
1508+
setErrorMessage("mem_advise is ignored as the advice argument is not "
1509+
"supported by this device",
1510+
UR_RESULT_SUCCESS);
1511+
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
1512+
} else {
1513+
throw Result;
1514+
}
15231515
}
15241516
}
15251517

@@ -1558,13 +1550,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
15581550
const void *pSrc, size_t srcPitch, size_t width, size_t height,
15591551
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
15601552
ur_event_handle_t *phEvent) {
1561-
ur_result_t Result = UR_RESULT_SUCCESS;
1562-
15631553
try {
15641554
ScopedDevice Active(hQueue->getDevice());
15651555
hipStream_t HIPStream = hQueue->getNextTransferStream();
1566-
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
1567-
phEventWaitList);
1556+
UR_CHECK_ERROR(enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
1557+
phEventWaitList));
15681558

15691559
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
15701560
if (phEvent) {
@@ -1668,10 +1658,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
16681658
UR_CHECK_ERROR(hipStreamSynchronize(HIPStream));
16691659
}
16701660
} catch (ur_result_t Err) {
1671-
Result = Err;
1661+
return Err;
16721662
}
16731663

1674-
return Result;
1664+
return UR_RESULT_SUCCESS;
16751665
}
16761666

16771667
namespace {

0 commit comments

Comments
 (0)