Skip to content

Commit 4a1054b

Browse files
committed
context : reuse built_attn_mha
ggml-ci
1 parent a5a85a3 commit 4a1054b

File tree

5 files changed

+107
-165
lines changed

5 files changed

+107
-165
lines changed

src/llama-context.cpp

Lines changed: 64 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -1721,50 +1721,67 @@ void llama_context::build_attn_inp(
17211721
ggml_tensor * llama_context::build_attn(
17221722
ggml_context * ctx0,
17231723
ggml_cgraph * gf,
1724-
ggml_tensor * wo,
1725-
ggml_tensor * wo_b,
17261724
ggml_tensor * q_cur,
17271725
ggml_tensor * k_cur,
17281726
ggml_tensor * v_cur,
17291727
ggml_tensor * kq_b,
1730-
int32_t n_tokens,
17311728
float kq_scale,
17321729
int il) {
1733-
const auto & hparams = model.hparams;
1730+
GGML_UNUSED(il);
17341731

1735-
const auto & n_ctx = cparams.n_ctx;
1732+
const auto & kq_mask = inp.kq_mask_cnv;
17361733

1737-
//const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1738-
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1734+
ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
1735+
//cb(q, "q", il);
17391736

1740-
const auto & kq_mask = inp.kq_mask_cnv;
1737+
ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
1738+
//cb(k, "k", il);
17411739

1742-
const int64_t n_head = hparams.n_head(il);
1743-
const int64_t n_head_kv = hparams.n_head_kv(il);
1740+
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
1741+
//cb(k, "v", il);
17441742

1745-
//const auto & n_embd_head_k = hparams.n_embd_head_k;
1746-
const auto & n_embd_head_v = hparams.n_embd_head_v;
1743+
ggml_tensor * cur = build_attn_mha(ctx0, gf, q, k, v, kq_b, kq_mask, false, kq_scale);
17471744

1748-
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
1749-
const auto n_kv = n_tokens;
1745+
return cur;
1746+
}
17501747

1751-
struct ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
1752-
//cb(q, "q", il);
1748+
ggml_tensor * llama_context::build_attn_mha(
1749+
ggml_context * ctx0,
1750+
ggml_cgraph * gf,
1751+
ggml_tensor * q,
1752+
ggml_tensor * k,
1753+
ggml_tensor * v,
1754+
ggml_tensor * kq_b,
1755+
ggml_tensor * kq_mask,
1756+
bool v_trans,
1757+
float kq_scale) {
1758+
const auto & hparams = model.hparams;
17531759

1754-
struct ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, k_cur, 0, 2, 1, 3));
1755-
//cb(k, "k", il);
1760+
//const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
1761+
//const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
1762+
1763+
//const int64_t n_head = hparams.n_head(il);
1764+
//const int64_t n_head_kv = hparams.n_head_kv(il);
1765+
1766+
//const auto & n_embd_head_k = hparams.n_embd_head_k;
1767+
//const auto & n_embd_head_v = hparams.n_embd_head_v;
1768+
1769+
const auto n_embd_head_v = v_trans ? v->ne[1] : v->ne[0];
1770+
1771+
const auto n_tokens = q->ne[1];
1772+
const auto n_head = q->ne[2];
1773+
const auto n_kv = k->ne[1];
17561774

17571775
struct ggml_tensor * cur;
17581776

1759-
//if (cparams.flash_attn) {
1760-
if (false) { // TODO: need to pad the batch size to a multiple of GGML_KQ_MASK_PAD
1777+
if (cparams.flash_attn && (n_kv % 256 == 0) && kq_b == nullptr) {
17611778
GGML_UNUSED(model);
1762-
GGML_UNUSED(n_ctx);
17631779

1764-
GGML_ASSERT(kq_b == nullptr);
1780+
GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet");
17651781

1766-
struct ggml_tensor * v = ggml_cont(ctx0, ggml_permute(ctx0, v_cur, 0, 2, 1, 3));
1767-
v = ggml_reshape_3d(ctx0, v, n_embd_head_v, n_kv, n_head_kv);
1782+
if (v_trans) {
1783+
v = ggml_transpose(ctx0, v);
1784+
}
17681785

17691786
cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
17701787
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
@@ -1774,7 +1791,6 @@ ggml_tensor * llama_context::build_attn(
17741791
cur = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens);
17751792
} else {
17761793
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
1777-
//cb(kq, "kq", il);
17781794

17791795
// note: this op tends to require high floating point range
17801796
// while for some models F16 is enough, for others it is not, so we default to F32 here
@@ -1802,22 +1818,17 @@ ggml_tensor * llama_context::build_attn(
18021818
}
18031819

18041820
kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
1805-
//cb(kq, "kq_soft_max_ext", il);
1806-
1807-
// split cached v into n_head heads
1808-
struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, v_cur, n_embd_v_gqa, n_tokens)));
18091821

1810-
v = ggml_reshape_3d(ctx0, v, n_kv, n_embd_head_v, n_head_kv);
1811-
//cb(v, "v", il);
1822+
if (!v_trans) {
1823+
// note: avoid this branch
1824+
v = ggml_cont(ctx0, ggml_transpose(ctx0, v));
1825+
}
18121826

18131827
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
1814-
//cb(kqv, "kqv", il);
18151828

18161829
struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
1817-
//cb(kqv_merged, "kqv_merged", il);
18181830

18191831
cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
1820-
//cb(cur, "kqv_merged_cont", il);
18211832

18221833
if (!cparams.offload_kqv) {
18231834
// all nodes between the KV store and the attention output are run on the CPU
@@ -1827,18 +1838,6 @@ ggml_tensor * llama_context::build_attn(
18271838

18281839
ggml_build_forward_expand(gf, cur);
18291840

1830-
if (wo) {
1831-
cur = build_lora_mm(ctx0, wo, cur);
1832-
}
1833-
1834-
if (wo_b) {
1835-
//cb(cur, "kqv_wo", il);
1836-
}
1837-
1838-
if (wo_b) {
1839-
cur = ggml_add(ctx0, cur, wo_b);
1840-
}
1841-
18421841
return cur;
18431842
}
18441843

@@ -3274,13 +3273,10 @@ void llama_context_kv_self::build_attn_inp(
32743273
ggml_tensor * llama_context_kv_self::build_attn(
32753274
ggml_context * ctx0,
32763275
ggml_cgraph * gf,
3277-
ggml_tensor * wo,
3278-
ggml_tensor * wo_b,
32793276
ggml_tensor * q_cur,
32803277
ggml_tensor * k_cur,
32813278
ggml_tensor * v_cur,
32823279
ggml_tensor * kq_b,
3283-
int32_t n_tokens,
32843280
float kq_scale,
32853281
int il) {
32863282
const auto & hparams = model.hparams;
@@ -3290,6 +3286,10 @@ ggml_tensor * llama_context_kv_self::build_attn(
32903286
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
32913287
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
32923288

3289+
const auto n_tokens = q_cur->ne[2];
3290+
3291+
const bool v_trans = !cparams.flash_attn;
3292+
32933293
// store to KV cache
32943294
{
32953295
GGML_ASSERT(!kv_self.recurrent);
@@ -3308,7 +3308,7 @@ ggml_tensor * llama_context_kv_self::build_attn(
33083308

33093309
struct ggml_tensor * v_cache_view = nullptr;
33103310

3311-
if (cparams.flash_attn) {
3311+
if (!v_trans) {
33123312
v_cache_view = ggml_view_1d(ctx0, kv_self.v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa)*kv_head);
33133313
} else {
33143314
// note: the V cache is transposed when not using flash attention
@@ -3351,117 +3351,35 @@ ggml_tensor * llama_context_kv_self::build_attn(
33513351

33523352
const auto n_kv = kv_self.n;
33533353

3354-
const int64_t n_head = hparams.n_head(il);
33553354
const int64_t n_head_kv = hparams.n_head_kv(il);
33563355

33573356
const auto & n_embd_head_k = hparams.n_embd_head_k;
33583357
const auto & n_embd_head_v = hparams.n_embd_head_v;
33593358

3360-
struct ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
3359+
ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
33613360
//cb(q, "q", il);
33623361

3363-
struct ggml_tensor * k =
3362+
ggml_tensor * k =
33643363
ggml_view_3d(ctx0, kv_self.k_l[il],
33653364
n_embd_head_k, n_kv, n_head_kv,
33663365
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
33673366
ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
33683367
0);
33693368
//cb(k, "k", il);
33703369

3371-
struct ggml_tensor * cur;
3372-
3373-
if (cparams.flash_attn) {
3374-
GGML_UNUSED(model);
3375-
GGML_UNUSED(n_ctx);
3376-
3377-
GGML_ASSERT(kq_b == nullptr);
3378-
3379-
// split cached v into n_head heads (not transposed)
3380-
struct ggml_tensor * v =
3381-
ggml_view_3d(ctx0, kv_self.v_l[il],
3382-
n_embd_head_v, n_kv, n_head_kv,
3383-
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
3384-
ggml_row_size(kv_self.v_l[il]->type, n_embd_head_v),
3385-
0);
3386-
//cb(v, "v", il);
3387-
3388-
cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
3389-
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
3390-
3391-
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
3392-
3393-
cur = ggml_reshape_2d(ctx0, cur, n_embd_head_v*n_head, n_tokens);
3394-
} else {
3395-
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
3396-
//cb(kq, "kq", il);
3397-
3398-
// note: this op tends to require high floating point range
3399-
// while for some models F16 is enough, for others it is not, so we default to F32 here
3400-
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
3401-
3402-
if (model.arch == LLM_ARCH_GROK) {
3403-
// need to do the following:
3404-
// multiply by attn_output_multiplyer of 0.08838834764831845
3405-
// and then :
3406-
// kq = 30 * tanh(kq / 30)
3407-
// before the softmax below
3408-
3409-
kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, 0.08838834764831845f/30.0f));
3410-
kq = ggml_scale(ctx0, kq, 30);
3411-
}
3412-
3413-
if (hparams.attn_soft_cap) {
3414-
kq = ggml_scale(ctx0, kq, 1.0f / hparams.f_attn_logit_softcapping);
3415-
kq = ggml_tanh (ctx0, kq);
3416-
kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
3417-
}
3418-
3419-
if (kq_b) {
3420-
kq = ggml_add(ctx0, kq, kq_b);
3421-
}
3422-
3423-
kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
3424-
//cb(kq, "kq_soft_max_ext", il);
3425-
3426-
GGML_ASSERT(kv_self.size == n_ctx);
3427-
3428-
// split cached v into n_head heads
3429-
struct ggml_tensor * v =
3430-
ggml_view_3d(ctx0, kv_self.v_l[il],
3431-
n_kv, n_embd_head_v, n_head_kv,
3432-
ggml_element_size(kv_self.v_l[il])*n_ctx,
3433-
ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v,
3434-
0);
3435-
//cb(v, "v", il);
3436-
3437-
struct ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
3438-
//cb(kqv, "kqv", il);
3439-
3440-
struct ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
3441-
//cb(kqv_merged, "kqv_merged", il);
3442-
3443-
cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_head_v*n_head, n_tokens);
3444-
//cb(cur, "kqv_merged_cont", il);
3445-
3446-
if (!cparams.offload_kqv) {
3447-
// all nodes between the KV store and the attention output are run on the CPU
3448-
ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend_cpu);
3449-
}
3450-
}
3451-
3452-
ggml_build_forward_expand(gf, cur);
3453-
3454-
if (wo) {
3455-
cur = build_lora_mm(ctx0, wo, cur);
3456-
}
3457-
3458-
if (wo_b) {
3459-
//cb(cur, "kqv_wo", il);
3460-
}
3370+
ggml_tensor * v = !v_trans ?
3371+
ggml_view_3d(ctx0, kv_self.v_l[il],
3372+
n_embd_head_v, n_kv, n_head_kv,
3373+
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
3374+
ggml_row_size(kv_self.v_l[il]->type, n_embd_head_v),
3375+
0) :
3376+
ggml_view_3d(ctx0, kv_self.v_l[il],
3377+
n_kv, n_embd_head_v, n_head_kv,
3378+
ggml_element_size(kv_self.v_l[il])*n_ctx,
3379+
ggml_element_size(kv_self.v_l[il])*n_ctx*n_embd_head_v,
3380+
0);
34613381

3462-
if (wo_b) {
3463-
cur = ggml_add(ctx0, cur, wo_b);
3464-
}
3382+
struct ggml_tensor * cur = build_attn_mha(ctx0, gf, q, k, v, kq_b, kq_mask, v_trans, kq_scale);
34653383

34663384
return cur;
34673385
}

src/llama-context.h

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -261,17 +261,25 @@ struct llama_context : public llama_graph_i {
261261
ggml_tensor * build_attn(
262262
ggml_context * ctx0,
263263
ggml_cgraph * gf,
264-
ggml_tensor * wo,
265-
ggml_tensor * wo_b,
266264
ggml_tensor * q_cur,
267265
ggml_tensor * k_cur,
268266
ggml_tensor * v_cur,
269267
ggml_tensor * kq_b,
270-
int32_t n_tokens,
271268
float kq_scale,
272269
int il) override;
273270

274271
protected:
272+
virtual ggml_tensor * build_attn_mha(
273+
ggml_context * ctx0,
274+
ggml_cgraph * gf,
275+
ggml_tensor * q,
276+
ggml_tensor * k,
277+
ggml_tensor * v,
278+
ggml_tensor * kq_b,
279+
ggml_tensor * kq_mask,
280+
bool v_trans,
281+
float kq_scale);
282+
275283
virtual ggml_tensor * build_inp_self_k_shift(
276284
ggml_context * ctx0);
277285

@@ -472,13 +480,10 @@ class llama_context_kv_self : public llama_context {
472480
ggml_tensor * build_attn(
473481
ggml_context * ctx0,
474482
ggml_cgraph * gf,
475-
ggml_tensor * wo,
476-
ggml_tensor * wo_b,
477483
ggml_tensor * q_cur,
478484
ggml_tensor * k_cur,
479485
ggml_tensor * v_cur,
480486
ggml_tensor * kq_b,
481-
int32_t n_tokens,
482487
float kq_scale,
483488
int il) override;
484489

src/llama-graph.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,18 @@ llama_graph_i::llama_graph_i(llama_graph_type type) : type(type) {}
77
ggml_tensor * llama_graph_i::build_attn(
88
ggml_context * ctx0,
99
ggml_cgraph * gf,
10-
ggml_tensor * wo,
11-
ggml_tensor * wo_b,
1210
ggml_tensor * q_cur,
1311
ggml_tensor * k_cur,
1412
ggml_tensor * v_cur,
1513
ggml_tensor * kq_b,
16-
int32_t n_tokens,
1714
float kq_scale,
1815
int il) {
1916
GGML_UNUSED(ctx0);
2017
GGML_UNUSED(gf);
21-
GGML_UNUSED(wo);
22-
GGML_UNUSED(wo_b);
2318
GGML_UNUSED(q_cur);
2419
GGML_UNUSED(k_cur);
2520
GGML_UNUSED(v_cur);
2621
GGML_UNUSED(kq_b);
27-
GGML_UNUSED(n_tokens);
2822
GGML_UNUSED(kq_scale);
2923
GGML_UNUSED(il);
3024

src/llama-graph.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,13 +107,10 @@ class llama_graph_i {
107107
virtual ggml_tensor * build_attn(
108108
ggml_context * ctx0,
109109
ggml_cgraph * gf,
110-
ggml_tensor * wo,
111-
ggml_tensor * wo_b,
112110
ggml_tensor * q_cur,
113111
ggml_tensor * k_cur,
114112
ggml_tensor * v_cur,
115113
ggml_tensor * kq_b,
116-
int32_t n_tokens,
117114
float kq_scale,
118115
int il);
119116

0 commit comments

Comments
 (0)