Skip to content

Commit 3d69396

Browse files
committed
Reapply "CUDA: FA support for Deepseek (Ampere or newer) (ggml-org#13306)
Update fattn-mma-f16.cuh
1 parent b9559a0 commit 3d69396

32 files changed

+855
-522
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,25 @@ static __device__ void no_device_code(
326326
#define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.")
327327
#endif // __CUDA_ARCH__
328328

329+
// The compiler is always able to unroll loops if they contain continue expressions.
330+
// In such cases loop unrolling can still be achieved via recursion:
331+
template <int n>
332+
struct ggml_cuda_unroll {
333+
template <typename Func, typename... Args>
334+
__device__ void operator()(const Func & f, Args... args) const {
335+
f(n - 1, args...);
336+
ggml_cuda_unroll<n - 1>{}(f, args...);
337+
}
338+
};
339+
340+
template <>
341+
struct ggml_cuda_unroll<1> {
342+
template <typename Func, typename... Args>
343+
__device__ void operator()(const Func & f, Args... args) const {
344+
f(0, args...);
345+
}
346+
};
347+
329348
template<int width = WARP_SIZE>
330349
static __device__ __forceinline__ int warp_reduce_sum(int x) {
331350
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE

ggml/src/ggml-cuda/cp-async.cuh

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,17 @@
22

33
#include "common.cuh"
44

5+
6+
static __device__ __forceinline__ unsigned int ggml_cuda_cvta_generic_to_shared(void * generic_ptr) {
7+
#ifdef CP_ASYNC_AVAILABLE
8+
return __cvta_generic_to_shared(generic_ptr);
9+
#else
10+
GGML_UNUSED(generic_ptr);
11+
NO_DEVICE_CODE;
12+
return 0;
13+
#endif // CP_ASYNC_AVAILABLE
14+
}
15+
516
// Copies data from global to shared memory, cg == cache global.
617
// Both the src and dst pointers must be aligned to 16 bit.
718
// Shared memory uses 32 bit addressing, the pointer is passed as unsigned int.

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -655,7 +655,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
655655
nullptr;
656656
}
657657

658-
template<int D, int ncols1, int ncols2, int KQ_stride> // D == head size
658+
template<int D, int ncols1, int ncols2> // D == head size
659659
__launch_bounds__(D, 1)
660660
static __global__ void flash_attn_stream_k_fixup(
661661
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
@@ -813,13 +813,13 @@ static void on_no_fattn_vec_case(const int D) {
813813
fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, iq4_nl, q5_0, q5_1, q6_0, q8_0, and f16.\n");
814814
GGML_ABORT("fatal error");
815815
} else {
816-
fprintf(stderr, "Unsupported KV type combination for head_size 256.\n");
816+
fprintf(stderr, "Unsupported KV type combination for head_size %d.\n", D);
817817
fprintf(stderr, "Only f16 is supported.\n");
818818
GGML_ABORT("fatal error");
819819
}
820820
}
821821

822-
template <int D, int ncols1, int ncols2, int KQ_stride>
822+
template <int DV, int ncols1, int ncols2>
823823
void launch_fattn(
824824
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
825825
const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
@@ -905,10 +905,13 @@ void launch_fattn(
905905
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
906906

907907
const dim3 block_dim(warp_size, nwarps, 1);
908+
int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
909+
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
910+
908911
dim3 blocks_num;
909912
if (stream_k) {
910913
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
911-
const int max_blocks = 2*nsm;
914+
const int max_blocks = max_blocks_per_sm*nsm;
912915
const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
913916
const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
914917

@@ -920,14 +923,11 @@ void launch_fattn(
920923
blocks_num.y = 1;
921924
blocks_num.z = 1;
922925

923-
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float));
926+
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float));
924927
} else {
925928
GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
926929
const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
927930

928-
int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
929-
CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
930-
931931
// parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
932932
parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1);
933933

@@ -1005,19 +1005,19 @@ void launch_fattn(
10051005

10061006
if (stream_k) {
10071007
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
1008-
const dim3 block_dim_combine(D, 1, 1);
1008+
const dim3 block_dim_combine(DV, 1, 1);
10091009
const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
10101010

1011-
flash_attn_stream_k_fixup<D, ncols1, ncols2, KQ_stride>
1011+
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
10121012
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
10131013
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
10141014
}
10151015
} else if (parallel_blocks > 1) {
1016-
const dim3 block_dim_combine(D, 1, 1);
1016+
const dim3 block_dim_combine(DV, 1, 1);
10171017
const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z);
10181018
const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
10191019

1020-
flash_attn_combine_results<D>
1020+
flash_attn_combine_results<DV>
10211021
<<<blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream>>>
10221022
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks);
10231023
}

0 commit comments

Comments
 (0)