Skip to content

Commit 3465e6a

Browse files
IMbackKqnixsynapse
authored andcommitted
HIP: enable vec fattn on RDNA4 (ggml-org#14323)
1 parent 9a70c5d commit 3465e6a

File tree

2 files changed

+14
-26
lines changed

2 files changed

+14
-26
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,11 @@
7676
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
7777

7878
// Moore Threads
79-
#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
80-
#define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
81-
#define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD
79+
#define GGML_CUDA_MUSA_ARCH_IS_QY1 (__MUSA_ARCH__ <= 210)
80+
81+
#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
82+
#define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
83+
#define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD
8284

8385
#define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD)
8486
#define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2)
@@ -215,9 +217,9 @@ typedef float2 dfloat2;
215217
#define FAST_FP16_AVAILABLE
216218
#endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
217219

218-
#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
220+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
219221
#define FP16_MMA_AVAILABLE
220-
#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)
222+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
221223

222224
#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
223225
#define FP16_MMA_AVAILABLE
@@ -231,9 +233,9 @@ typedef float2 dfloat2;
231233
#define CP_ASYNC_AVAILABLE
232234
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
233235

234-
#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)
236+
#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1)
235237
#define FLASH_ATTN_AVAILABLE
236-
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)
238+
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1)
237239

238240
static bool fp16_available(const int cc) {
239241
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;
@@ -245,8 +247,7 @@ static bool fast_fp16_available(const int cc) {
245247

246248
// To be used for feature selection of external libraries, e.g. cuBLAS.
247249
static bool fast_fp16_hardware_available(const int cc) {
248-
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) || GGML_CUDA_CC_IS_AMD(cc) ||
249-
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
250+
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
250251
}
251252

252253
// Any FP16 tensor core instructions are available for ggml code.
@@ -255,8 +256,7 @@ static bool fp16_mma_available(const int cc) {
255256
return false;
256257
#else
257258
if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
258-
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) ||
259-
GGML_CUDA_CC_IS_MTHREADS(cc)) {
259+
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) {
260260
return true;
261261
} else if (GGML_CUDA_CC_IS_RDNA4(cc)) {
262262
#if defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12)
@@ -273,16 +273,7 @@ static bool fp16_mma_available(const int cc) {
273273
// To be used for feature selection of external libraries, e.g. cuBLAS.
274274
static bool fp16_mma_hardware_available(const int cc) {
275275
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) ||
276-
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc) ||
277-
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
278-
}
279-
280-
static bool bf16_mma_hardware_available(const int cc) {
281-
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE) || GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3;
282-
}
283-
284-
static bool fp32_mma_hardware_available(const int cc) {
285-
return GGML_CUDA_CC_IS_CDNA(cc);
276+
GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
286277
}
287278

288279
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,7 @@ int ggml_cuda_get_device() {
100100
static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
101101
ggml_cuda_set_device(device);
102102
cudaError_t err;
103-
if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr)
104-
{
103+
if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr) {
105104
err = cudaMallocManaged(ptr, size);
106105
#if defined(GGML_USE_HIP)
107106
if (err == hipSuccess) {
@@ -119,9 +118,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
119118
err = cudaMalloc(ptr, size);
120119
}
121120
#endif // defined(GGML_USE_HIP)
122-
}
123-
else
124-
{
121+
} else {
125122
err = cudaMalloc(ptr, size);
126123
}
127124
return err;

0 commit comments

Comments
 (0)