Skip to content

Commit 5f0fc46

Browse files
committed
Reapply parts of "CUDA: faster Deepseek FA, add Turing support (ggml-org#13435)"
1 parent 99204eb commit 5f0fc46

File tree

3 files changed

+16
-6
lines changed

3 files changed

+16
-6
lines changed

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -835,6 +835,10 @@ void launch_fattn(
835835
GGML_ASSERT(Q->type == GGML_TYPE_F32);
836836
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
837837

838+
GGML_ASSERT( Q->nb[0] == ggml_element_size(Q));
839+
GGML_ASSERT( K->nb[0] == ggml_element_size(K));
840+
GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V));
841+
838842
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
839843
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
840844
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
@@ -859,10 +863,10 @@ void launch_fattn(
859863
size_t nb12 = K->nb[2];
860864
size_t nb13 = K->nb[3];
861865

862-
const char * V_data = (const char *) V->data;
863-
size_t nb21 = V->nb[1];
864-
size_t nb22 = V->nb[2];
865-
size_t nb23 = V->nb[3];
866+
const char * V_data = V ? (const char *) V->data : nullptr;
867+
size_t nb21 = V ? V->nb[1] : nb11;
868+
size_t nb22 = V ? V->nb[2] : nb12;
869+
size_t nb23 = V ? V->nb[3] : nb13;
866870

867871
if (need_f16_K && K->type != GGML_TYPE_F16) {
868872
K_f16.alloc(ggml_nelements(K));
@@ -878,7 +882,8 @@ void launch_fattn(
878882
nb13 = nb13*bs*sizeof(half)/ts;
879883
}
880884

881-
if (need_f16_V && V->type != GGML_TYPE_F16) {
885+
if (V && need_f16_V && V->type != GGML_TYPE_F16) {
886+
// GGML_ASSERT(ggml_is_contiguous(V));
882887
V_f16.alloc(ggml_nelements(V));
883888
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
884889
to_fp16(V_data, V_f16.ptr, 1, ggml_nelements(V), main_stream);

ggml/src/ggml-cuda/fattn.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
template <int D, int ncols2>
1212
static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
13+
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
1314
const ggml_tensor * Q = dst->src[0];
1415

1516
if (Q->ne[1] <= 8/ncols2) {
@@ -26,7 +27,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con
2627
// ggml_cuda_flash_attn_ext_mma_f16_case<D, 32/ncols2, ncols2>(ctx, dst);
2728

2829
if (ggml_cuda_highest_compiled_arch(cc) <= GGML_CUDA_CC_TURING || Q->ne[1] <= 32/ncols2) {
29-
ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 32/ncols2, ncols2>(ctx, dst);
30+
ggml_cuda_flash_attn_ext_mma_f16_case<D, 32/ncols2, ncols2>(ctx, dst);
3031
return;
3132
}
3233

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3537,6 +3537,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
35373537
#ifndef FLASH_ATTN_AVAILABLE
35383538
return false;
35393539
#endif // FLASH_ATTN_AVAILABLE
3540+
const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
3541+
if (!new_mma_available(cc)) {
3542+
return false;
3543+
}
35403544
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
35413545
// different head sizes of K and V are not supported yet
35423546
return false;

0 commit comments

Comments
 (0)