Skip to content

Commit 756aa10

Browse files
authored
HIP : Add HIP 7.0+ compatibility for hipBLAS compute types (#14634)
1 parent aaa088d commit 756aa10

File tree

1 file changed

+14
-5
lines changed
  • ggml/src/ggml-cuda/vendors

1 file changed

+14
-5
lines changed

ggml/src/ggml-cuda/vendors/hip.h

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,6 @@
1010
#include "rocblas/rocblas.h"
1111
#endif // __HIP_PLATFORM_AMD__
1212

13-
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
14-
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
15-
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
1613
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
1714
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
1815
#define CUBLAS_OP_N HIPBLAS_OP_N
@@ -30,7 +27,6 @@
3027
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
3128
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
3229
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
33-
#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
3430
#define cublasCreate hipblasCreate
3531
#define cublasDestroy hipblasDestroy
3632
#define cublasGemmEx hipblasGemmEx
@@ -42,7 +38,6 @@
4238
#define cublasSgemm hipblasSgemm
4339
#define cublasStatus_t hipblasStatus_t
4440
#define cublasOperation_t hipblasOperation_t
45-
#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
4641
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
4742
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
4843
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
@@ -144,6 +139,20 @@
144139
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
145140
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
146141

142+
#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION >= 70000000
143+
#define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F
144+
#define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F
145+
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_COMPUTE_32F_FAST_16F
146+
#define cublasComputeType_t hipblasComputeType_t
147+
#define cudaDataType_t hipDataType
148+
#else
149+
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
150+
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
151+
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
152+
#define cublasComputeType_t hipblasDatatype_t
153+
#define cudaDataType_t hipblasDatatype_t
154+
#endif
155+
147156
#define __CUDA_ARCH__ 1300
148157

149158
#if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)

0 commit comments

Comments
 (0)