Skip to content

Commit 45d76b7

Browse files
author
Hugh Delaney
committed
Refactor read write funcs
1 parent de02e99 commit 45d76b7

File tree

1 file changed

+34
-39
lines changed

1 file changed

+34
-39
lines changed

source/adapters/hip/enqueue.cpp

Lines changed: 34 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1544,19 +1544,22 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
15441544
return Result;
15451545
}
15461546

1547-
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableWrite(
1547+
namespace {
1548+
1549+
enum class GlobalVariableCopy { Read, Write };
1550+
1551+
ur_result_t deviceGlobalCopyHelper(
15481552
ur_queue_handle_t hQueue, ur_program_handle_t hProgram, const char *name,
1549-
bool blockingWrite, size_t count, size_t offset, const void *pSrc,
1553+
bool blocking, size_t count, size_t offset, void *ptr,
15501554
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
1551-
ur_event_handle_t *phEvent) {
1552-
// Since HIP requires the global variable to be referenced by name, we use
1555+
ur_event_handle_t *phEvent, GlobalVariableCopy CopyType) {
1556+
// Since HIP requires a the global variable to be referenced by name, we use
15531557
// metadata to find the correct name to access it by.
15541558
auto DeviceGlobalNameIt = hProgram->GlobalIDMD.find(name);
15551559
if (DeviceGlobalNameIt == hProgram->GlobalIDMD.end())
15561560
return UR_RESULT_ERROR_INVALID_VALUE;
15571561
std::string DeviceGlobalName = DeviceGlobalNameIt->second;
15581562

1559-
ur_result_t Result = UR_RESULT_SUCCESS;
15601563
try {
15611564
hipDeviceptr_t DeviceGlobal = 0;
15621565
size_t DeviceGlobalSize = 0;
@@ -1567,49 +1570,41 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableWrite(
15671570
if (offset + count > DeviceGlobalSize)
15681571
return UR_RESULT_ERROR_INVALID_VALUE;
15691572

1570-
return urEnqueueUSMMemcpy(
1571-
hQueue, blockingWrite,
1572-
reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(DeviceGlobal) +
1573-
offset),
1574-
pSrc, count, numEventsInWaitList, phEventWaitList, phEvent);
1573+
void *pSrc, *pDst;
1574+
if (CopyType == GlobalVariableCopy::Write) {
1575+
pSrc = ptr;
1576+
pDst = reinterpret_cast<uint8_t *>(DeviceGlobal) + offset;
1577+
} else {
1578+
pSrc = reinterpret_cast<uint8_t *>(DeviceGlobal) + offset;
1579+
pDst = ptr;
1580+
}
1581+
return urEnqueueUSMMemcpy(hQueue, blocking, pDst, pSrc, count,
1582+
numEventsInWaitList, phEventWaitList, phEvent);
15751583
} catch (ur_result_t Err) {
1576-
Result = Err;
1584+
return Err;
15771585
}
1578-
return Result;
1586+
}
1587+
} // namespace
1588+
1589+
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableWrite(
1590+
ur_queue_handle_t hQueue, ur_program_handle_t hProgram, const char *name,
1591+
bool blockingWrite, size_t count, size_t offset, const void *pSrc,
1592+
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
1593+
ur_event_handle_t *phEvent) {
1594+
return deviceGlobalCopyHelper(hQueue, hProgram, name, blockingWrite, count,
1595+
offset, const_cast<void *>(pSrc),
1596+
numEventsInWaitList, phEventWaitList, phEvent,
1597+
GlobalVariableCopy::Write);
15791598
}
15801599

15811600
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableRead(
15821601
ur_queue_handle_t hQueue, ur_program_handle_t hProgram, const char *name,
15831602
bool blockingRead, size_t count, size_t offset, void *pDst,
15841603
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
15851604
ur_event_handle_t *phEvent) {
1586-
// Since HIP requires the global variable to be referenced by name, we use
1587-
// metadata to find the correct name to access it by.
1588-
auto DeviceGlobalNameIt = hProgram->GlobalIDMD.find(name);
1589-
if (DeviceGlobalNameIt == hProgram->GlobalIDMD.end())
1590-
return UR_RESULT_ERROR_INVALID_VALUE;
1591-
std::string DeviceGlobalName = DeviceGlobalNameIt->second;
1592-
1593-
ur_result_t Result = UR_RESULT_SUCCESS;
1594-
try {
1595-
hipDeviceptr_t DeviceGlobal = 0;
1596-
size_t DeviceGlobalSize = 0;
1597-
UR_CHECK_ERROR(hipModuleGetGlobal(&DeviceGlobal, &DeviceGlobalSize,
1598-
hProgram->get(),
1599-
DeviceGlobalName.c_str()));
1600-
1601-
if (offset + count > DeviceGlobalSize)
1602-
return UR_RESULT_ERROR_INVALID_VALUE;
1603-
1604-
return urEnqueueUSMMemcpy(
1605-
hQueue, blockingRead, pDst,
1606-
reinterpret_cast<const void *>(
1607-
reinterpret_cast<uint8_t *>(DeviceGlobal) + offset),
1608-
count, numEventsInWaitList, phEventWaitList, phEvent);
1609-
} catch (ur_result_t Err) {
1610-
Result = Err;
1611-
}
1612-
return Result;
1605+
return deviceGlobalCopyHelper(
1606+
hQueue, hProgram, name, blockingRead, count, offset, pDst,
1607+
numEventsInWaitList, phEventWaitList, phEvent, GlobalVariableCopy::Read);
16131608
}
16141609

16151610
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueReadHostPipe(

0 commit comments

Comments
 (0)