Skip to content

Commit 676b2db

Browse files
committed
Tidied up to use is_mla
1 parent c3cc463 commit 676b2db

File tree

2 files changed

+11
-13
lines changed

2 files changed

+11
-13
lines changed

src/llama-graph.cpp

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1432,8 +1432,8 @@ ggml_tensor * llm_graph_context::build_attn(
14321432

14331433
v_cur = ggml_reshape_2d(ctx0, v_cur, n_embd_v_gqa, n_tokens);
14341434

1435-
// note: MLA with flash attention now uses the last 512 elements of K in place of V
1436-
if (v_trans || !v_mla) {
1435+
// note: MLA with flash attention now uses the last 512 elements of K-cache in place of a V-cache
1436+
if (!v_mla || v_trans) {
14371437
ggml_tensor * v_cache_view = nullptr;
14381438

14391439
if (!v_trans) {
@@ -1474,7 +1474,9 @@ ggml_tensor * llm_graph_context::build_attn(
14741474
0);
14751475
//cb(k, "k", il);
14761476

1477+
// note: MLA with flash attention now uses the last 512 elements of K-cache in place of a V-cache
14771478
ggml_tensor * v = nullptr;
1479+
14781480
if (v_trans) {
14791481
v = ggml_view_3d(ctx0, kv_self->v_l[il],
14801482
n_kv, n_embd_head_v, n_head_kv,
@@ -1487,14 +1489,6 @@ ggml_tensor * llm_graph_context::build_attn(
14871489
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
14881490
ggml_row_size(kv_self->v_l[il]->type, n_embd_head_v),
14891491
0);
1490-
} else {
1491-
// note: MLA with flash attention now uses the last 512 elements of K in place of V
1492-
v = ggml_view_3d(ctx0, kv_self->k_l[il],
1493-
n_embd_head_v, n_kv, n_head_kv,
1494-
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
1495-
ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
1496-
n_embd_head_k-n_embd_head_v); // offset by n_rot elements
1497-
v = ggml_cont(ctx0, v);
14981492
}
14991493

15001494
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, v_trans, kq_scale);

src/llama-kv-cache.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,12 @@ llama_kv_cache_unified::llama_kv_cache_unified(
3232
uint32_t padding) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding) {
3333
const int32_t n_layer = hparams.n_layer;
3434

35+
const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0);
36+
37+
is_mla_with_fa = model.arch != LLM_ARCH_DEEPSEEK2 || v_trans
38+
3539
has_shift = false;
36-
can_shift = model.arch != LLM_ARCH_DEEPSEEK2 || v_trans; // TODO: allow context shifting for MLA with flash attention
40+
can_shift = !is_mla || v_trans; // TODO: allow context shifting for MLA with flash attention
3741

3842
LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d, padding = %d\n",
3943
__func__, kv_size, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift, padding);
@@ -100,9 +104,9 @@ llama_kv_cache_unified::llama_kv_cache_unified(
100104
throw std::runtime_error("failed to create ggml context for kv cache");
101105
}
102106

103-
// note: MLA with flash attention now uses the last 512 elements of K in place of V
107+
// note: MLA with flash attention now uses the last 512 elements of K-cache in place of a V-cache
104108
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
105-
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, model.arch != LLM_ARCH_DEEPSEEK2 || v_trans ? n_embd_v_gqa*kv_size : 0);
109+
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, !is_mla || v_trans ? n_embd_v_gqa*kv_size : 0);
106110
ggml_format_name(k, "cache_k_l%d", i);
107111
ggml_format_name(v, "cache_v_l%d", i);
108112
k_l.push_back(k);

0 commit comments

Comments
 (0)