Skip to content

HIP: Enable Matrix cores for MMQ Kernels, Enable stream-K for CDNA 3 #14624

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .devops/rocm.Dockerfile
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
ARG UBUNTU_VERSION=24.04

# This needs to generally match the container host's environment.
ARG ROCM_VERSION=6.3
ARG AMDGPU_VERSION=6.3
ARG ROCM_VERSION=6.4
ARG AMDGPU_VERSION=6.4

# Target the CUDA build image
ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-complete
Expand Down
16 changes: 13 additions & 3 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16
#define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue
#define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 0x906) // MI50/Radeon VII, minimum for dp4a
#define GGML_CUDA_CC_CDNA (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers
#define GGML_CUDA_CC_CDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers
#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x910) // MI210, minimum acc register renameing
#define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300

Expand All @@ -72,8 +72,9 @@
#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA4)
#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4)
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA1)
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_RDNA1)
#define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1)

// Moore Threads
#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
Expand Down Expand Up @@ -226,6 +227,10 @@ typedef float2 dfloat2;
#define FP16_MMA_AVAILABLE
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))

#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && defined(CDNA3)
#define AMD_MMA_AVAILABLE
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && defined(CDNA3)

#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
#define NEW_MMA_AVAILABLE
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
Expand Down Expand Up @@ -288,6 +293,11 @@ static bool fp32_mma_hardware_available(const int cc) {
return GGML_CUDA_CC_IS_CDNA(cc);
}

// AMD CDNA3 matrix cores.. Will add support for other CDNA generations later.
static bool amd_mma_available(const int cc) {
return cc >= GGML_CUDA_CC_OFFSET_AMD && GGML_CUDA_CC_IS_CDNA3(cc);
}

// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
static bool new_mma_available(const int cc) {
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
Expand Down
99 changes: 97 additions & 2 deletions ggml/src/ggml-cuda/mma.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,40 @@ namespace ggml_cuda_mma {
struct tile {
static constexpr int I = I_;
static constexpr int J = J_;
static constexpr int ne = I * J / WARP_SIZE;

#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
static constexpr int ne = I * J / 64;
T x[ne] = {0};

static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 16 && J == 8) {
return threadIdx.x % 16;
} else if constexpr (I == 32 && J == 4) {
return threadIdx.x % 32;
} else if constexpr (I == 16 && J == 16) {
return 4 * (threadIdx.x / 16) + l;
} else if constexpr (I == 32 && J == 32) {
return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
}

static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 8) {
return 2 * (threadIdx.x / 16) + l;
} else if constexpr (I == 32 && J == 4) {
return 2 * (threadIdx.x / 32) + l;
} else if constexpr (I == 16 && J == 16) {
return threadIdx.x % 16;
} else if constexpr (I == 32 && J == 32) {
return threadIdx.x % 32;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
}
#else
static constexpr int ne = I * J / 32;
T x[ne] = {0};

static __device__ __forceinline__ int get_i(const int l) {
Expand Down Expand Up @@ -94,6 +127,7 @@ namespace ggml_cuda_mma {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
}
#endif
};

template <int I_, int J_>
Expand Down Expand Up @@ -186,7 +220,11 @@ namespace ggml_cuda_mma {
template <typename T>
static __device__ __forceinline__ void load_ldmatrix(
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
#ifdef NEW_MMA_AVAILABLE
#if defined(AMD_MMA_AVAILABLE)
int64_t* xi = (int64_t*) t.x;
const int64_t* xs = (int64_t*) ((const int*) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
xi[0] = xs[0];
#elif defined(NEW_MMA_AVAILABLE)
int * xi = (int * ) t.x;
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
Expand All @@ -197,6 +235,23 @@ namespace ggml_cuda_mma {
#endif // NEW_MMA_AVAILABLE
}

template <typename T>
static __device__ __forceinline__ void load_ldmatrix(
tile<32, 4, T> & t, const T * __restrict__ xs0, const int stride) {
#if defined(AMD_MMA_AVAILABLE)
int64_t* xi = (int64_t*) t.x;
const int64_t* xs = (int64_t*) ((const int*) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
xi[0] = xs[0];
#elif defined(NEW_MMA_AVAILABLE)
GGML_UNUSED(t);
GGML_UNUSED(xs0);
GGML_UNUSED(stride);
NO_DEVICE_CODE;
#else
load_generic(t, xs0, stride);
#endif // AMD_MMA_AVAILABLE
}

template <typename T>
static __device__ __forceinline__ void load_ldmatrix_trans(
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
Expand Down Expand Up @@ -386,6 +441,46 @@ namespace ggml_cuda_mma {
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#else
GGML_UNUSED(D);
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}

static __device__ __forceinline__ void mma(
tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) {
#if defined(AMD_MMA_AVAILABLE)
#if defined(CDNA3)
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
int32x4_t* acc = (int32x4_t*) D.x;
acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t*) A.x)[0],
((int64_t*) B.x)[0],
acc[0],
0, 0, 0);
#elif defined(CDNA2) || defined(CDNA)
#endif
#else
GGML_UNUSED(D);
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}

static __device__ __forceinline__ void mma(
tile<32, 32, int> & D, const tile<32, 4, int> & A, const tile<32, 4, int> & B) {
#if defined(AMD_MMA_AVAILABLE)
#if defined(CDNA3)
using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int;
int32x16_t* acc = (int32x16_t*) D.x;
acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t*) A.x)[0],
((int64_t*) B.x)[0],
acc[0],
0, 0, 0);
#elif defined(CDNA2) || defined(CDNA)
#endif
#else
GGML_UNUSED(D);
GGML_UNUSED(A);
Expand Down
10 changes: 6 additions & 4 deletions ggml/src/ggml-cuda/mmq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ void ggml_cuda_mul_mat_q(
const int64_t s03 = src0->nb[3] / ts_src0;
const int64_t s3 = dst->nb[3] / ts_dst;

const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA;
const bool use_stream_k = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
|| (GGML_CUDA_CC_IS_AMD(cc) && GGML_CUDA_CC_IS_CDNA3(cc)));

if (!ids) {
const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 +
Expand Down Expand Up @@ -250,8 +251,9 @@ void ggml_cuda_op_mul_mat_q(
// The stream-k decomposition is only faster for recent NVIDIA GPUs.
// Also its fixup needs to allocate a temporary buffer in the memory pool.
// There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) &&
ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && src1_ncols == ne11;
const bool use_stream_k = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
|| (GGML_CUDA_CC_IS_AMD(cc) && GGML_CUDA_CC_IS_CDNA3(cc)))
&& src1_ncols == ne11;
const mmq_args args = {
src0_dd_i, src0->type, (const int *) src1_ddq_i, nullptr, nullptr, dst_dd_i,
ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst,
Expand Down Expand Up @@ -304,7 +306,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
return false;
}

if (new_mma_available(cc)) {
if (new_mma_available(cc) || amd_mma_available(cc)) {
return true;
}

Expand Down
Loading
Loading