Skip to content

Commit 5313f82

Browse files
IMbackKqnixsynapse
authored andcommitted
HIP: enable vec fattn on RDNA4 (ggml-org#14323)
1 parent 8644697 commit 5313f82

File tree

2 files changed

+14
-26
lines changed

2 files changed

+14
-26
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,11 @@
7676
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
7777

7878
// Moore Threads
79-
#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
80-
#define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
81-
#define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD
79+
#define GGML_CUDA_MUSA_ARCH_IS_QY1 (__MUSA_ARCH__ <= 210)
80+
81+
#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
82+
#define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
83+
#define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD
8284

8385
#define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD)
8486
#define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2)
@@ -201,9 +203,9 @@ typedef float2 dfloat2;
201203
#define FAST_FP16_AVAILABLE
202204
#endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
203205

204-
#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
206+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
205207
#define FP16_MMA_AVAILABLE
206-
#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
208+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
207209

208210
#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
209211
#define FP16_MMA_AVAILABLE
@@ -217,9 +219,9 @@ typedef float2 dfloat2;
217219
#define CP_ASYNC_AVAILABLE
218220
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
219221

220-
#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)
222+
#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1)
221223
#define FLASH_ATTN_AVAILABLE
222-
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)
224+
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1)
223225

224226
static bool fp16_available(const int cc) {
225227
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;
@@ -231,8 +233,7 @@ static bool fast_fp16_available(const int cc) {
231233

232234
// To be used for feature selection of external libraries, e.g. cuBLAS.
233235
static bool fast_fp16_hardware_available(const int cc) {
234-
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) || GGML_CUDA_CC_IS_AMD(cc) ||
235-
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
236+
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
236237
}
237238

238239
// Any FP16 tensor core instructions are available for ggml code.
@@ -241,8 +242,7 @@ static bool fp16_mma_available(const int cc) {
241242
return false;
242243
#else
243244
if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
244-
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) ||
245-
GGML_CUDA_CC_IS_MTHREADS(cc)) {
245+
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) {
246246
return true;
247247
} else if (GGML_CUDA_CC_IS_RDNA4(cc)) {
248248
#if defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
@@ -259,16 +259,7 @@ static bool fp16_mma_available(const int cc) {
259259
// To be used for feature selection of external libraries, e.g. cuBLAS.
260260
static bool fp16_mma_hardware_available(const int cc) {
261261
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) ||
262-
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc) ||
263-
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
264-
}
265-
266-
static bool bf16_mma_hardware_available(const int cc) {
267-
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE) || GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3;
268-
}
269-
270-
static bool fp32_mma_hardware_available(const int cc) {
271-
return GGML_CUDA_CC_IS_CDNA(cc);
262+
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
272263
}
273264

274265
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,7 @@ int ggml_cuda_get_device() {
100100
static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
101101
ggml_cuda_set_device(device);
102102
cudaError_t err;
103-
if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr)
104-
{
103+
if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr) {
105104
err = cudaMallocManaged(ptr, size);
106105
#if defined(GGML_USE_HIP)
107106
if (err == hipSuccess) {
@@ -119,9 +118,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
119118
err = cudaMalloc(ptr, size);
120119
}
121120
#endif // defined(GGML_USE_HIP)
122-
}
123-
else
124-
{
121+
} else {
125122
err = cudaMalloc(ptr, size);
126123
}
127124
return err;

0 commit comments

Comments
 (0)