Skip to content

Commit f336e6e

Browse files
ggerganovMinh141120
authored andcommitted
ggml : fix FA mask dim 2 and 3 (ggml-org#14505)
* ggml : fix FA mask dim 2 and 3 ggml-ci * backends : unsupport batched FA in CUDA and Vulkan ggml-ci * vulkan : disable FA for mask->ne[2] != 1
1 parent 03d390a commit f336e6e

File tree

9 files changed

+26
-15
lines changed

9 files changed

+26
-15
lines changed

ggml/include/ggml.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1983,15 +1983,16 @@ extern "C" {
19831983

19841984
#define GGML_KQ_MASK_PAD 64
19851985

1986-
// q: [n_embd_k, n_batch, n_head, ne3]
1987-
// k: [n_embd_k, n_kv, n_head_kv, ne3]
1988-
// v: [n_embd_v, n_kv, n_head_kv, ne3] !! not transposed !!
1989-
// mask: [n_kv, n_batch_pad, ne32, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
1990-
// res: [n_embd_v, n_head, n_batch, ne3] !! permuted !!
1986+
// q: [n_embd_k, n_batch, n_head, ne3 ]
1987+
// k: [n_embd_k, n_kv, n_head_kv, ne3 ]
1988+
// v: [n_embd_v, n_kv, n_head_kv, ne3 ] !! not transposed !!
1989+
// mask: [n_kv, n_batch_pad, ne32, ne33] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
1990+
// res: [n_embd_v, n_head, n_batch, ne3 ] !! permuted !!
19911991
//
19921992
// broadcast:
19931993
// n_head % n_head_kv == 0
1994-
// ne3 % ne32 == 0
1994+
// n_head % ne32 == 0
1995+
// ne3 % ne33 == 0
19951996
//
19961997
GGML_API struct ggml_tensor * ggml_flash_attn_ext(
19971998
struct ggml_context * ctx,

ggml/src/ggml-cpu/ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7799,7 +7799,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
77997799
memset(VKQ32, 0, DV*sizeof(float));
78007800
}
78017801

7802-
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq3%mask->ne[2])*mask->nb[2]) : NULL;
7802+
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]) : NULL;
78037803

78047804
// k indices
78057805
const int ik3 = iq3 / rk3;

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3390,7 +3390,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
33903390
return false;
33913391
}
33923392
// TODO: support broadcast
3393-
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
3393+
// note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14500, but
3394+
// the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
33943395
if (op->src[0]->ne[3] != 1) {
33953396
return false;
33963397
}

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,10 @@ typedef struct {
230230
uint64_t nb22;
231231
uint64_t nb23;
232232
int32_t ne32;
233+
int32_t ne33;
233234
uint64_t nb31;
234235
uint64_t nb32;
236+
uint64_t nb33;
235237
int32_t ne1;
236238
int32_t ne2;
237239
float scale;

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5019,8 +5019,10 @@ static bool ggml_metal_encode_node(
50195019
/*.nb22 =*/ nb22,
50205020
/*.nb23 =*/ nb23,
50215021
/*.ne32 =*/ ne32,
5022+
/*.ne33 =*/ ne33,
50225023
/*.nb31 =*/ nb31,
50235024
/*.nb32 =*/ nb32,
5025+
/*.nb33 =*/ nb33,
50245026
/*.ne1 =*/ ne1,
50255027
/*.ne2 =*/ ne2,
50265028
/*.scale =*/ scale,

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3850,7 +3850,7 @@ kernel void kernel_flash_attn_ext(
38503850
// load the mask in shared memory
38513851
#pragma unroll(Q)
38523852
for (short j = 0; j < Q; ++j) {
3853-
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq3%args.ne32)*args.nb32);
3853+
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
38543854

38553855
const float m = pm[ic + tiisg];
38563856

@@ -4336,7 +4336,7 @@ kernel void kernel_flash_attn_ext_vec(
43364336
const bool has_mask = mask != q;
43374337

43384338
// pointer to the mask
4339-
device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq3%args.ne32)*args.nb32);
4339+
device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
43404340

43414341
float slope = 1.0f;
43424342

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10268,6 +10268,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1026810268
if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
1026910269
return false;
1027010270
}
10271+
// TODO: support broadcast
10272+
// note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14449, but
10273+
// the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
10274+
if (op->src[0]->ne[3] != 1 || (op->src[3] && op->src[3]->ne[2] != 1)) {
10275+
return false;
10276+
}
1027110277
// It's straightforward to support different K/V dequant, but would
1027210278
// significantly increase the number of pipelines
1027310279
if (op->src[1]->type != op->src[2]->type) {

ggml/src/ggml.c

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3675,7 +3675,6 @@ static struct ggml_tensor * ggml_soft_max_impl(
36753675
if (mask) {
36763676
GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
36773677
GGML_ASSERT(ggml_is_contiguous(mask));
3678-
GGML_ASSERT(ggml_is_3d(mask));
36793678
GGML_ASSERT(mask->ne[0] == a->ne[0]);
36803679
GGML_ASSERT(mask->ne[1] >= a->ne[1]);
36813680
GGML_ASSERT(a->ne[2]%mask->ne[2] == 0);
@@ -4706,12 +4705,12 @@ struct ggml_tensor * ggml_flash_attn_ext(
47064705

47074706
if (mask) {
47084707
GGML_ASSERT(ggml_is_contiguous(mask));
4709-
GGML_ASSERT(mask->ne[2] == q->ne[3]);
47104708
GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
47114709
"the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
47124710
//GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
47134711

4714-
GGML_ASSERT(q->ne[3] % mask->ne[2] == 0);
4712+
GGML_ASSERT(q->ne[2] % mask->ne[2] == 0);
4713+
GGML_ASSERT(q->ne[3] % mask->ne[3] == 0);
47154714
}
47164715

47174716
if (max_bias > 0.0f) {

tests/test-backend-ops.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3666,7 +3666,7 @@ struct test_flash_attn_ext : public test_case {
36663666

36673667
ggml_tensor * m = nullptr;
36683668
if (mask) {
3669-
m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), nr23[1], 1);
3669+
m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), nr23[0], nr23[1]);
36703670
ggml_set_name(m, "m");
36713671
}
36723672

@@ -4780,7 +4780,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
47804780
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, {1, 1}, scale, max_bias));
47814781

47824782
if (ne0 <= 32 && ne1 <= 32) {
4783-
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, m_prec, {3, 1}, scale, max_bias));
4783+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 3}, mask, m_prec, {3, 1}, scale, max_bias));
47844784
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, {2, 3}, scale, max_bias));
47854785
}
47864786
}

0 commit comments

Comments
 (0)