Skip to content

Commit 6036177

Browse files
committed
ggml : fix FA mask dim 2 and 3
ggml-ci
1 parent 55a1c5a commit 6036177

File tree

7 files changed

+18
-14
lines changed

7 files changed

+18
-14
lines changed

ggml/include/ggml.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1980,15 +1980,16 @@ extern "C" {
19801980

19811981
#define GGML_KQ_MASK_PAD 64
19821982

1983-
// q: [n_embd_k, n_batch, n_head, ne3]
1984-
// k: [n_embd_k, n_kv, n_head_kv, ne3]
1985-
// v: [n_embd_v, n_kv, n_head_kv, ne3] !! not transposed !!
1986-
// mask: [n_kv, n_batch_pad, ne32, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
1987-
// res: [n_embd_v, n_head, n_batch, ne3] !! permuted !!
1983+
// q: [n_embd_k, n_batch, n_head, ne3 ]
1984+
// k: [n_embd_k, n_kv, n_head_kv, ne3 ]
1985+
// v: [n_embd_v, n_kv, n_head_kv, ne3 ] !! not transposed !!
1986+
// mask: [n_kv, n_batch_pad, ne32, ne33] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
1987+
// res: [n_embd_v, n_head, n_batch, ne3 ] !! permuted !!
19881988
//
19891989
// broadcast:
19901990
// n_head % n_head_kv == 0
1991-
// ne3 % ne32 == 0
1991+
// n_head % ne32 == 0
1992+
// ne3 % ne33 == 0
19921993
//
19931994
GGML_API struct ggml_tensor * ggml_flash_attn_ext(
19941995
struct ggml_context * ctx,

ggml/src/ggml-cpu/ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7799,7 +7799,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
77997799
memset(VKQ32, 0, DV*sizeof(float));
78007800
}
78017801

7802-
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq3%mask->ne[2])*mask->nb[2]) : NULL;
7802+
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]) : NULL;
78037803

78047804
// k indices
78057805
const int ik3 = iq3 / rk3;

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,10 @@ typedef struct {
230230
uint64_t nb22;
231231
uint64_t nb23;
232232
int32_t ne32;
233+
int32_t ne33;
233234
uint64_t nb31;
234235
uint64_t nb32;
236+
uint64_t nb33;
235237
int32_t ne1;
236238
int32_t ne2;
237239
float scale;

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4989,8 +4989,10 @@ static bool ggml_metal_encode_node(
49894989
/*.nb22 =*/ nb22,
49904990
/*.nb23 =*/ nb23,
49914991
/*.ne32 =*/ ne32,
4992+
/*.ne33 =*/ ne33,
49924993
/*.nb31 =*/ nb31,
49934994
/*.nb32 =*/ nb32,
4995+
/*.nb33 =*/ nb33,
49944996
/*.ne1 =*/ ne1,
49954997
/*.ne2 =*/ ne2,
49964998
/*.scale =*/ scale,

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3784,7 +3784,7 @@ kernel void kernel_flash_attn_ext(
37843784
// load the mask in shared memory
37853785
#pragma unroll(Q)
37863786
for (short j = 0; j < Q; ++j) {
3787-
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq3%args.ne32)*args.nb32);
3787+
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
37883788

37893789
const float m = pm[ic + tiisg];
37903790

@@ -4270,7 +4270,7 @@ kernel void kernel_flash_attn_ext_vec(
42704270
const bool has_mask = mask != q;
42714271

42724272
// pointer to the mask
4273-
device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq3%args.ne32)*args.nb32);
4273+
device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
42744274

42754275
float slope = 1.0f;
42764276

ggml/src/ggml.c

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3666,7 +3666,6 @@ static struct ggml_tensor * ggml_soft_max_impl(
36663666
if (mask) {
36673667
GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
36683668
GGML_ASSERT(ggml_is_contiguous(mask));
3669-
GGML_ASSERT(ggml_is_3d(mask));
36703669
GGML_ASSERT(mask->ne[0] == a->ne[0]);
36713670
GGML_ASSERT(mask->ne[1] >= a->ne[1]);
36723671
GGML_ASSERT(a->ne[2]%mask->ne[2] == 0);
@@ -4696,12 +4695,12 @@ struct ggml_tensor * ggml_flash_attn_ext(
46964695

46974696
if (mask) {
46984697
GGML_ASSERT(ggml_is_contiguous(mask));
4699-
GGML_ASSERT(mask->ne[2] == q->ne[3]);
47004698
GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
47014699
"the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
47024700
//GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
47034701

4704-
GGML_ASSERT(q->ne[3] % mask->ne[2] == 0);
4702+
GGML_ASSERT(q->ne[2] % mask->ne[2] == 0);
4703+
GGML_ASSERT(q->ne[3] % mask->ne[3] == 0);
47054704
}
47064705

47074706
if (max_bias > 0.0f) {

tests/test-backend-ops.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3607,7 +3607,7 @@ struct test_flash_attn_ext : public test_case {
36073607

36083608
ggml_tensor * m = nullptr;
36093609
if (mask) {
3610-
m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), nr23[1], 1);
3610+
m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), nr23[0], nr23[1]);
36113611
ggml_set_name(m, "m");
36123612
}
36133613

@@ -4720,7 +4720,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
47204720
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, {1, 1}, scale, max_bias));
47214721

47224722
if (ne0 <= 32 && ne1 <= 32) {
4723-
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, m_prec, {3, 1}, scale, max_bias));
4723+
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 3}, mask, m_prec, {3, 1}, scale, max_bias));
47244724
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, {2, 3}, scale, max_bias));
47254725
}
47264726
}

0 commit comments

Comments
 (0)