Skip to content

Commit 79c28d0

Browse files
authored
Merge pull request #1186 from hdelan/device-global-hip
[HIP] Add support for global variable read write
2 parents 25e0b60 + 45d76b7 commit 79c28d0

File tree

3 files changed

+77
-6
lines changed

3 files changed

+77
-6
lines changed

source/adapters/hip/enqueue.cpp

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1697,16 +1697,67 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy2D(
16971697
return Result;
16981698
}
16991699

1700+
namespace {
1701+
1702+
enum class GlobalVariableCopy { Read, Write };
1703+
1704+
ur_result_t deviceGlobalCopyHelper(
1705+
ur_queue_handle_t hQueue, ur_program_handle_t hProgram, const char *name,
1706+
bool blocking, size_t count, size_t offset, void *ptr,
1707+
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
1708+
ur_event_handle_t *phEvent, GlobalVariableCopy CopyType) {
1709+
// Since HIP requires a the global variable to be referenced by name, we use
1710+
// metadata to find the correct name to access it by.
1711+
auto DeviceGlobalNameIt = hProgram->GlobalIDMD.find(name);
1712+
if (DeviceGlobalNameIt == hProgram->GlobalIDMD.end())
1713+
return UR_RESULT_ERROR_INVALID_VALUE;
1714+
std::string DeviceGlobalName = DeviceGlobalNameIt->second;
1715+
1716+
try {
1717+
hipDeviceptr_t DeviceGlobal = 0;
1718+
size_t DeviceGlobalSize = 0;
1719+
UR_CHECK_ERROR(hipModuleGetGlobal(&DeviceGlobal, &DeviceGlobalSize,
1720+
hProgram->get(),
1721+
DeviceGlobalName.c_str()));
1722+
1723+
if (offset + count > DeviceGlobalSize)
1724+
return UR_RESULT_ERROR_INVALID_VALUE;
1725+
1726+
void *pSrc, *pDst;
1727+
if (CopyType == GlobalVariableCopy::Write) {
1728+
pSrc = ptr;
1729+
pDst = reinterpret_cast<uint8_t *>(DeviceGlobal) + offset;
1730+
} else {
1731+
pSrc = reinterpret_cast<uint8_t *>(DeviceGlobal) + offset;
1732+
pDst = ptr;
1733+
}
1734+
return urEnqueueUSMMemcpy(hQueue, blocking, pDst, pSrc, count,
1735+
numEventsInWaitList, phEventWaitList, phEvent);
1736+
} catch (ur_result_t Err) {
1737+
return Err;
1738+
}
1739+
}
1740+
} // namespace
1741+
17001742
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableWrite(
1701-
ur_queue_handle_t, ur_program_handle_t, const char *, bool, size_t, size_t,
1702-
const void *, uint32_t, const ur_event_handle_t *, ur_event_handle_t *) {
1703-
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
1743+
ur_queue_handle_t hQueue, ur_program_handle_t hProgram, const char *name,
1744+
bool blockingWrite, size_t count, size_t offset, const void *pSrc,
1745+
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
1746+
ur_event_handle_t *phEvent) {
1747+
return deviceGlobalCopyHelper(hQueue, hProgram, name, blockingWrite, count,
1748+
offset, const_cast<void *>(pSrc),
1749+
numEventsInWaitList, phEventWaitList, phEvent,
1750+
GlobalVariableCopy::Write);
17041751
}
17051752

17061753
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableRead(
1707-
ur_queue_handle_t, ur_program_handle_t, const char *, bool, size_t, size_t,
1708-
void *, uint32_t, const ur_event_handle_t *, ur_event_handle_t *) {
1709-
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
1754+
ur_queue_handle_t hQueue, ur_program_handle_t hProgram, const char *name,
1755+
bool blockingRead, size_t count, size_t offset, void *pDst,
1756+
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
1757+
ur_event_handle_t *phEvent) {
1758+
return deviceGlobalCopyHelper(
1759+
hQueue, hProgram, name, blockingRead, count, offset, pDst,
1760+
numEventsInWaitList, phEventWaitList, phEvent, GlobalVariableCopy::Read);
17101761
}
17111762

17121763
UR_APIEXPORT ur_result_t UR_APICALL urEnqueueReadHostPipe(

source/adapters/hip/program.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,17 +78,35 @@ void getCoMgrBuildLog(const amd_comgr_data_set_t BuildDataSet, char *BuildLog,
7878
} // namespace
7979
#endif
8080

81+
std::pair<std::string, std::string>
82+
splitMetadataName(const std::string &metadataName) {
83+
size_t splitPos = metadataName.rfind('@');
84+
if (splitPos == std::string::npos)
85+
return std::make_pair(metadataName, std::string{});
86+
return std::make_pair(metadataName.substr(0, splitPos),
87+
metadataName.substr(splitPos, metadataName.length()));
88+
}
89+
8190
ur_result_t
8291
ur_program_handle_t_::setMetadata(const ur_program_metadata_t *Metadata,
8392
size_t Length) {
8493
for (size_t i = 0; i < Length; ++i) {
8594
const ur_program_metadata_t MetadataElement = Metadata[i];
8695
std::string MetadataElementName{MetadataElement.pName};
8796

97+
auto [Prefix, Tag] = splitMetadataName(MetadataElementName);
98+
8899
if (MetadataElementName ==
89100
__SYCL_UR_PROGRAM_METADATA_TAG_NEED_FINALIZATION) {
90101
assert(MetadataElement.type == UR_PROGRAM_METADATA_TYPE_UINT32);
91102
IsRelocatable = MetadataElement.value.data32;
103+
} else if (Tag == __SYCL_UR_PROGRAM_METADATA_GLOBAL_ID_MAPPING) {
104+
const char *MetadataValPtr =
105+
reinterpret_cast<const char *>(MetadataElement.value.pData) +
106+
sizeof(std::uint64_t);
107+
const char *MetadataValPtrEnd =
108+
MetadataValPtr + MetadataElement.size - sizeof(std::uint64_t);
109+
GlobalIDMD[Prefix] = std::string{MetadataValPtr, MetadataValPtrEnd};
92110
}
93111
}
94112
return UR_RESULT_SUCCESS;

source/adapters/hip/program.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ struct ur_program_handle_t_ {
2929
// Metadata
3030
bool IsRelocatable = false;
3131

32+
std::unordered_map<std::string, std::string> GlobalIDMD;
33+
3234
constexpr static size_t MAX_LOG_SIZE = 8192u;
3335

3436
char ErrorLog[MAX_LOG_SIZE], InfoLog[MAX_LOG_SIZE];

0 commit comments

Comments
 (0)