|
10 | 10 | #include "rocblas/rocblas.h"
|
11 | 11 | #endif // __HIP_PLATFORM_AMD__
|
12 | 12 |
|
13 |
| -#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F |
14 |
| -#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F |
15 |
| -#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F |
16 | 13 | #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
|
17 | 14 | #define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
|
18 | 15 | #define CUBLAS_OP_N HIPBLAS_OP_N
|
|
30 | 27 | #define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
|
31 | 28 | #define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
|
32 | 29 | #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
|
33 |
| -#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6 |
34 | 30 | #define cublasCreate hipblasCreate
|
35 | 31 | #define cublasDestroy hipblasDestroy
|
36 | 32 | #define cublasGemmEx hipblasGemmEx
|
|
42 | 38 | #define cublasSgemm hipblasSgemm
|
43 | 39 | #define cublasStatus_t hipblasStatus_t
|
44 | 40 | #define cublasOperation_t hipblasOperation_t
|
45 |
| -#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6 |
46 | 41 | #define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
|
47 | 42 | #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
|
48 | 43 | #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
|
|
144 | 139 | #define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
|
145 | 140 | #define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
|
146 | 141 |
|
| 142 | +#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION >= 70000000 |
| 143 | +#define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F |
| 144 | +#define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F |
| 145 | +#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_COMPUTE_32F_FAST_16F |
| 146 | +#define cublasComputeType_t hipblasComputeType_t |
| 147 | +#define cudaDataType_t hipDataType |
| 148 | +#else |
| 149 | +#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F |
| 150 | +#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F |
| 151 | +#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F |
| 152 | +#define cublasComputeType_t hipblasDatatype_t |
| 153 | +#define cudaDataType_t hipblasDatatype_t |
| 154 | +#endif |
| 155 | + |
147 | 156 | #define __CUDA_ARCH__ 1300
|
148 | 157 |
|
149 | 158 | #if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)
|
|
0 commit comments