@@ -1544,19 +1544,22 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
1544
1544
return Result;
1545
1545
}
1546
1546
1547
- UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableWrite (
1547
+ namespace {
1548
+
1549
+ enum class GlobalVariableCopy { Read, Write };
1550
+
1551
+ ur_result_t deviceGlobalCopyHelper (
1548
1552
ur_queue_handle_t hQueue, ur_program_handle_t hProgram, const char *name,
1549
- bool blockingWrite , size_t count, size_t offset, const void *pSrc ,
1553
+ bool blocking , size_t count, size_t offset, void *ptr ,
1550
1554
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
1551
- ur_event_handle_t *phEvent) {
1552
- // Since HIP requires the global variable to be referenced by name, we use
1555
+ ur_event_handle_t *phEvent, GlobalVariableCopy CopyType ) {
1556
+ // Since HIP requires a the global variable to be referenced by name, we use
1553
1557
// metadata to find the correct name to access it by.
1554
1558
auto DeviceGlobalNameIt = hProgram->GlobalIDMD .find (name);
1555
1559
if (DeviceGlobalNameIt == hProgram->GlobalIDMD .end ())
1556
1560
return UR_RESULT_ERROR_INVALID_VALUE;
1557
1561
std::string DeviceGlobalName = DeviceGlobalNameIt->second ;
1558
1562
1559
- ur_result_t Result = UR_RESULT_SUCCESS;
1560
1563
try {
1561
1564
hipDeviceptr_t DeviceGlobal = 0 ;
1562
1565
size_t DeviceGlobalSize = 0 ;
@@ -1567,49 +1570,41 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableWrite(
1567
1570
if (offset + count > DeviceGlobalSize)
1568
1571
return UR_RESULT_ERROR_INVALID_VALUE;
1569
1572
1570
- return urEnqueueUSMMemcpy (
1571
- hQueue, blockingWrite,
1572
- reinterpret_cast <void *>(reinterpret_cast <uint8_t *>(DeviceGlobal) +
1573
- offset),
1574
- pSrc, count, numEventsInWaitList, phEventWaitList, phEvent);
1573
+ void *pSrc, *pDst;
1574
+ if (CopyType == GlobalVariableCopy::Write) {
1575
+ pSrc = ptr;
1576
+ pDst = reinterpret_cast <uint8_t *>(DeviceGlobal) + offset;
1577
+ } else {
1578
+ pSrc = reinterpret_cast <uint8_t *>(DeviceGlobal) + offset;
1579
+ pDst = ptr;
1580
+ }
1581
+ return urEnqueueUSMMemcpy (hQueue, blocking, pDst, pSrc, count,
1582
+ numEventsInWaitList, phEventWaitList, phEvent);
1575
1583
} catch (ur_result_t Err) {
1576
- Result = Err;
1584
+ return Err;
1577
1585
}
1578
- return Result;
1586
+ }
1587
+ } // namespace
1588
+
1589
+ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableWrite (
1590
+ ur_queue_handle_t hQueue, ur_program_handle_t hProgram, const char *name,
1591
+ bool blockingWrite, size_t count, size_t offset, const void *pSrc,
1592
+ uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
1593
+ ur_event_handle_t *phEvent) {
1594
+ return deviceGlobalCopyHelper (hQueue, hProgram, name, blockingWrite, count,
1595
+ offset, const_cast <void *>(pSrc),
1596
+ numEventsInWaitList, phEventWaitList, phEvent,
1597
+ GlobalVariableCopy::Write);
1579
1598
}
1580
1599
1581
1600
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableRead (
1582
1601
ur_queue_handle_t hQueue, ur_program_handle_t hProgram, const char *name,
1583
1602
bool blockingRead, size_t count, size_t offset, void *pDst,
1584
1603
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
1585
1604
ur_event_handle_t *phEvent) {
1586
- // Since HIP requires the global variable to be referenced by name, we use
1587
- // metadata to find the correct name to access it by.
1588
- auto DeviceGlobalNameIt = hProgram->GlobalIDMD .find (name);
1589
- if (DeviceGlobalNameIt == hProgram->GlobalIDMD .end ())
1590
- return UR_RESULT_ERROR_INVALID_VALUE;
1591
- std::string DeviceGlobalName = DeviceGlobalNameIt->second ;
1592
-
1593
- ur_result_t Result = UR_RESULT_SUCCESS;
1594
- try {
1595
- hipDeviceptr_t DeviceGlobal = 0 ;
1596
- size_t DeviceGlobalSize = 0 ;
1597
- UR_CHECK_ERROR (hipModuleGetGlobal (&DeviceGlobal, &DeviceGlobalSize,
1598
- hProgram->get (),
1599
- DeviceGlobalName.c_str ()));
1600
-
1601
- if (offset + count > DeviceGlobalSize)
1602
- return UR_RESULT_ERROR_INVALID_VALUE;
1603
-
1604
- return urEnqueueUSMMemcpy (
1605
- hQueue, blockingRead, pDst,
1606
- reinterpret_cast <const void *>(
1607
- reinterpret_cast <uint8_t *>(DeviceGlobal) + offset),
1608
- count, numEventsInWaitList, phEventWaitList, phEvent);
1609
- } catch (ur_result_t Err) {
1610
- Result = Err;
1611
- }
1612
- return Result;
1605
+ return deviceGlobalCopyHelper (
1606
+ hQueue, hProgram, name, blockingRead, count, offset, pDst,
1607
+ numEventsInWaitList, phEventWaitList, phEvent, GlobalVariableCopy::Read);
1613
1608
}
1614
1609
1615
1610
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueReadHostPipe (
0 commit comments