Skip to content

Commit 0942e74

Browse files
committed
Fix getArrayDesc on ROCm 6 and make it return error codes
1 parent 8d1486a commit 0942e74

File tree

1 file changed

+25
-10
lines changed

1 file changed

+25
-10
lines changed

source/adapters/hip/common.hpp

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,39 @@
1515
#include <hip/hip_runtime.h>
1616
#include <ur/ur.hpp>
1717

18-
// Hipify doesn't support cuArrayGetDescriptor, on AMD the hipArray can just be
19-
// indexed, but on NVidia it is an opaque type and needs to go through
20-
// cuArrayGetDescriptor so implement a utility function to get the array
21-
// properties
22-
inline void getArrayDesc(hipArray *Array, hipArray_Format &Format,
23-
size_t &Channels) {
18+
// Before ROCm 6, hipify doesn't support cuArrayGetDescriptor, on AMD the
19+
// hipArray can just be indexed, but on NVidia it is an opaque type and needs to
20+
// go through cuArrayGetDescriptor so implement a utility function to get the
21+
// array properties
22+
inline static hipError_t getArrayDesc(hipArray *Array, hipArray_Format &Format,
23+
size_t &Channels) {
24+
#if HIP_VERSION_MAJOR >= 6
25+
HIP_ARRAY_DESCRIPTOR ArrayDesc;
26+
hipError_t err = hipArrayGetDescriptor(&ArrayDesc, Array);
27+
if (err == hipSuccess) {
28+
Format = ArrayDesc.Format;
29+
Channels = ArrayDesc.NumChannels;
30+
}
31+
return err;
32+
#else
2433
#if defined(__HIP_PLATFORM_AMD__)
2534
Format = Array->Format;
2635
Channels = Array->NumChannels;
36+
return hipSuccess;
2737
#elif defined(__HIP_PLATFORM_NVIDIA__)
2838
CUDA_ARRAY_DESCRIPTOR ArrayDesc;
29-
cuArrayGetDescriptor(&ArrayDesc, (CUarray)Array);
30-
31-
Format = ArrayDesc.Format;
32-
Channels = ArrayDesc.NumChannels;
39+
CUresult err = cuArrayGetDescriptor(&ArrayDesc, (CUarray)Array);
40+
if (err == CUDA_SUCCESS) {
41+
Format = ArrayDesc.Format;
42+
Channels = ArrayDesc.NumChannels;
43+
return hipSuccess;
44+
} else {
45+
return hipErrorUnknown; // No easy way to map CUerror to hipError
46+
}
3347
#else
3448
#error("Must define exactly one of __HIP_PLATFORM_AMD__ or __HIP_PLATFORM_NVIDIA__");
3549
#endif
50+
#endif
3651
}
3752

3853
// HIP on NVIDIA headers guard hipArray3DCreate behind __CUDACC__, this does not

0 commit comments

Comments
 (0)