Skip to content

Commit a9645bd

Browse files
authored
Merge pull request #1190 from aarongreig/aaron/refactorHipEnqueueErrors
[HIP] Refactor error handling in enqueue.cpp
2 parents cf90cb1 + 96468e2 commit a9645bd

File tree

1 file changed

+36
-53
lines changed

1 file changed

+36
-53
lines changed

source/adapters/hip/enqueue.cpp

Lines changed: 36 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ ur_result_t enqueueEventsWait(ur_queue_handle_t Queue, hipStream_t Stream,
2828
return UR_RESULT_SUCCESS;
2929
}
3030
try {
31-
auto Result = forLatestEvents(
31+
UR_CHECK_ERROR(forLatestEvents(
3232
EventWaitList, NumEventsInWaitList,
3333
[Stream, Queue](ur_event_handle_t Event) -> ur_result_t {
3434
ScopedDevice Active(Queue->getDevice());
@@ -38,17 +38,13 @@ ur_result_t enqueueEventsWait(ur_queue_handle_t Queue, hipStream_t Stream,
3838
UR_CHECK_ERROR(hipStreamWaitEvent(Stream, Event->get(), 0));
3939
return UR_RESULT_SUCCESS;
4040
}
41-
});
42-
43-
if (Result != UR_RESULT_SUCCESS) {
44-
return Result;
45-
}
46-
return UR_RESULT_SUCCESS;
41+
}));
4742
} catch (ur_result_t Err) {
4843
return Err;
4944
} catch (...) {
5045
return UR_RESULT_ERROR_UNKNOWN;
5146
}
47+
return UR_RESULT_SUCCESS;
5248
}
5349

5450
// Determine local work sizes that result in uniform work groups.
@@ -630,12 +626,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopy(
630626

631627
try {
632628
ScopedDevice Active(hQueue->getDevice());
633-
ur_result_t Result = UR_RESULT_SUCCESS;
634629
auto Stream = hQueue->getNextTransferStream();
635630

636631
if (phEventWaitList) {
637-
Result = enqueueEventsWait(hQueue, Stream, numEventsInWaitList,
638-
phEventWaitList);
632+
UR_CHECK_ERROR(enqueueEventsWait(hQueue, Stream, numEventsInWaitList,
633+
phEventWaitList));
639634
}
640635

641636
if (phEvent) {
@@ -657,12 +652,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopy(
657652
*phEvent = RetImplEvent.release();
658653
}
659654

660-
return Result;
661655
} catch (ur_result_t Err) {
662656
return Err;
663657
} catch (...) {
664658
return UR_RESULT_ERROR_UNKNOWN;
665659
}
660+
return UR_RESULT_SUCCESS;
666661
}
667662

668663
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect(
@@ -672,7 +667,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect(
672667
size_t srcSlicePitch, size_t dstRowPitch, size_t dstSlicePitch,
673668
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
674669
ur_event_handle_t *phEvent) {
675-
ur_result_t Result = UR_RESULT_SUCCESS;
676670
void *SrcPtr =
677671
std::get<BufferMem>(hBufferSrc->Mem).getVoid(hQueue->getDevice());
678672
void *DstPtr =
@@ -682,8 +676,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect(
682676
try {
683677
ScopedDevice Active(hQueue->getDevice());
684678
hipStream_t HIPStream = hQueue->getNextTransferStream();
685-
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
686-
phEventWaitList);
679+
UR_CHECK_ERROR(enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
680+
phEventWaitList));
687681

688682
if (phEvent) {
689683
RetImplEvent =
@@ -692,20 +686,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect(
692686
UR_CHECK_ERROR(RetImplEvent->start());
693687
}
694688

695-
Result = commonEnqueueMemBufferCopyRect(
689+
UR_CHECK_ERROR(commonEnqueueMemBufferCopyRect(
696690
HIPStream, region, &SrcPtr, hipMemoryTypeDevice, srcOrigin, srcRowPitch,
697691
srcSlicePitch, &DstPtr, hipMemoryTypeDevice, dstOrigin, dstRowPitch,
698-
dstSlicePitch);
692+
dstSlicePitch));
699693

700694
if (phEvent) {
701695
UR_CHECK_ERROR(RetImplEvent->record());
702696
*phEvent = RetImplEvent.release();
703697
}
704698

705699
} catch (ur_result_t Err) {
706-
Result = Err;
700+
return Err;
707701
}
708-
return Result;
702+
return UR_RESULT_SUCCESS;
709703
}
710704

711705
static inline void memsetRemainPattern(hipStream_t Stream, uint32_t PatternSize,
@@ -1063,14 +1057,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageCopy(
10631057
std::get<SurfaceMem>(hImageDst->Mem).getImageType(),
10641058
UR_RESULT_ERROR_INVALID_MEM_OBJECT);
10651059

1066-
ur_result_t Result = UR_RESULT_SUCCESS;
1067-
10681060
try {
10691061
ScopedDevice Active(hQueue->getDevice());
10701062
hipStream_t HIPStream = hQueue->getNextTransferStream();
10711063
if (phEventWaitList) {
1072-
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
1073-
phEventWaitList);
1064+
UR_CHECK_ERROR(enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
1065+
phEventWaitList));
10741066
}
10751067

10761068
hipArray *SrcArray =
@@ -1110,13 +1102,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageCopy(
11101102
UR_CHECK_ERROR(RetImplEvent->start());
11111103
}
11121104

1113-
Result = commonEnqueueMemImageNDCopy(
1105+
UR_CHECK_ERROR(commonEnqueueMemImageNDCopy(
11141106
HIPStream, ImgType, AdjustedRegion, SrcArray, hipMemoryTypeArray,
1115-
SrcOffset, DstArray, hipMemoryTypeArray, DstOffset);
1116-
1117-
if (Result != UR_RESULT_SUCCESS) {
1118-
return Result;
1119-
}
1107+
SrcOffset, DstArray, hipMemoryTypeArray, DstOffset));
11201108

11211109
if (phEvent) {
11221110
UR_CHECK_ERROR(RetImplEvent->record());
@@ -1237,7 +1225,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
12371225
ur_queue_handle_t hQueue, void *ptr, size_t patternSize,
12381226
const void *pPattern, size_t size, uint32_t numEventsInWaitList,
12391227
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
1240-
ur_result_t Result = UR_RESULT_SUCCESS;
12411228
std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr};
12421229

12431230
try {
@@ -1246,8 +1233,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
12461233
ur_stream_guard Guard;
12471234
hipStream_t HIPStream = hQueue->getNextComputeStream(
12481235
numEventsInWaitList, phEventWaitList, Guard, &StreamToken);
1249-
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
1250-
phEventWaitList);
1236+
UR_CHECK_ERROR(enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
1237+
phEventWaitList));
12511238
if (phEvent) {
12521239
EventPtr =
12531240
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
@@ -1274,8 +1261,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
12741261
break;
12751262

12761263
default:
1277-
Result = commonMemSetLargePattern(HIPStream, patternSize, size, pPattern,
1278-
reinterpret_cast<hipDeviceptr_t>(ptr));
1264+
UR_CHECK_ERROR(
1265+
commonMemSetLargePattern(HIPStream, patternSize, size, pPattern,
1266+
reinterpret_cast<hipDeviceptr_t>(ptr)));
12791267
break;
12801268
}
12811269

@@ -1284,25 +1272,23 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
12841272
*phEvent = EventPtr.release();
12851273
}
12861274
} catch (ur_result_t Err) {
1287-
Result = Err;
1275+
return Err;
12881276
}
12891277

1290-
return Result;
1278+
return UR_RESULT_SUCCESS;
12911279
}
12921280

12931281
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy(
12941282
ur_queue_handle_t hQueue, bool blocking, void *pDst, const void *pSrc,
12951283
size_t size, uint32_t numEventsInWaitList,
12961284
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {
1297-
ur_result_t Result = UR_RESULT_SUCCESS;
1298-
12991285
std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr};
13001286

13011287
try {
13021288
ScopedDevice Active(hQueue->getDevice());
13031289
hipStream_t HIPStream = hQueue->getNextTransferStream();
1304-
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
1305-
phEventWaitList);
1290+
UR_CHECK_ERROR(enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
1291+
phEventWaitList));
13061292
if (phEvent) {
13071293
EventPtr =
13081294
std::unique_ptr<ur_event_handle_t_>(ur_event_handle_t_::makeNative(
@@ -1321,9 +1307,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy(
13211307
*phEvent = EventPtr.release();
13221308
}
13231309
} catch (ur_result_t Err) {
1324-
Result = Err;
1310+
return Err;
13251311
}
1326-
return Result;
1312+
return UR_RESULT_SUCCESS;
13271313
}
13281314

13291315
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
@@ -1345,13 +1331,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
13451331
UR_ASSERT(size <= PointerRangeSize, UR_RESULT_ERROR_INVALID_SIZE);
13461332
#endif
13471333

1348-
ur_result_t Result = UR_RESULT_SUCCESS;
1349-
13501334
try {
13511335
ScopedDevice Active(hQueue->getDevice());
13521336
hipStream_t HIPStream = hQueue->getNextTransferStream();
1353-
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
1354-
phEventWaitList);
1337+
UR_CHECK_ERROR(enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
1338+
phEventWaitList));
13551339

13561340
std::unique_ptr<ur_event_handle_t_> EventPtr{nullptr};
13571341

@@ -1399,10 +1383,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch(
13991383
hipMemPrefetchAsync(pMem, size, hQueue->getDevice()->get(), HIPStream));
14001384
releaseEvent();
14011385
} catch (ur_result_t Err) {
1402-
Result = Err;
1386+
return Err;
14031387
}
14041388

1405-
return Result;
1389+
return UR_RESULT_SUCCESS;
14061390
}
14071391

14081392
/// USM: memadvise API to govern behavior of automatic migration mechanisms
@@ -1521,6 +1505,7 @@ urEnqueueUSMAdvise(ur_queue_handle_t hQueue, const void *pMem, size_t size,
15211505
UR_RESULT_SUCCESS);
15221506
return UR_RESULT_ERROR_ADAPTER_SPECIFIC;
15231507
}
1508+
UR_CHECK_ERROR(Result);
15241509
}
15251510

15261511
releaseEvent();
@@ -1558,13 +1543,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
15581543
const void *pSrc, size_t srcPitch, size_t width, size_t height,
15591544
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
15601545
ur_event_handle_t *phEvent) {
1561-
ur_result_t Result = UR_RESULT_SUCCESS;
1562-
15631546
try {
15641547
ScopedDevice Active(hQueue->getDevice());
15651548
hipStream_t HIPStream = hQueue->getNextTransferStream();
1566-
Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
1567-
phEventWaitList);
1549+
UR_CHECK_ERROR(enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList,
1550+
phEventWaitList));
15681551

15691552
std::unique_ptr<ur_event_handle_t_> RetImplEvent{nullptr};
15701553
if (phEvent) {
@@ -1668,10 +1651,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
16681651
UR_CHECK_ERROR(hipStreamSynchronize(HIPStream));
16691652
}
16701653
} catch (ur_result_t Err) {
1671-
Result = Err;
1654+
return Err;
16721655
}
16731656

1674-
return Result;
1657+
return UR_RESULT_SUCCESS;
16751658
}
16761659

16771660
namespace {

0 commit comments

Comments
 (0)