From 2f2fd158ba2d86abb490731e3a7b5349434e916b Mon Sep 17 00:00:00 2001 From: juk Date: Thu, 12 Jun 2025 12:54:02 +0100 Subject: [PATCH] Revived PR --- src/llama-graph.cpp | 12 ++++++++++-- src/llama-kv-cache-unified.cpp | 35 ++++++++++++++++++++++++++-------- 2 files changed, 37 insertions(+), 10 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index e74c9ff53b05a..d78a8aad3189c 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1257,7 +1257,11 @@ ggml_tensor * llm_graph_context::build_attn( // store to KV cache { ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il)); - ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il)); + + // note: MLA with flash attention now uses the last 512 elements of K-cache in place of a V-cache + if (!v_mla || !cparams.flash_attn) { + ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il)); + } } const auto & kq_mask = inp->get_kq_mask(); @@ -1341,7 +1345,11 @@ ggml_tensor * llm_graph_context::build_attn( // store to KV cache { ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il)); - ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il)); + + // note: MLA with flash attention now uses the last 512 elements of K-cache in place of a V-cache + if (!v_mla || !cparams.flash_attn) { + ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il)); + } } const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask(); diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index 89606c598fc4f..46338f3f8eb25 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -62,6 +62,8 @@ llama_kv_cache_unified::llama_kv_cache_unified( cells.resize(kv_size); + const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); + for (uint32_t il = 0; il < hparams.n_layer; il++) { if (filter && !filter(il)) { LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il); @@ -93,7 +95,9 @@ llama_kv_cache_unified::llama_kv_cache_unified( ggml_tensor * v; k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size); - v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size); + + // note: MLA with flash attention now uses the last 512 elements of K-cache in place of a V-cache + v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, !is_mla || v_trans ? kv_size : 0); ggml_format_name(k, "cache_k_l%d", il); ggml_format_name(v, "cache_v_l%d", il); @@ -700,7 +704,9 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch } bool llama_kv_cache_unified::get_can_shift() const { - return true; + const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); + + return !is_mla || v_trans; // TODO: allow context shifting for MLA with flash attention; } uint32_t llama_kv_cache_unified::get_size() const { @@ -733,12 +739,25 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint auto * v = layers[ikv].v; if (!v_trans) { - // note: v->nb[1] <= v->nb[2] - return ggml_view_3d(ctx, v, - hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, - ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1] - ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2] - 0); + const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); + + if (!is_mla) { + // note: v->nb[1] <= v->nb[2] + return ggml_view_3d(ctx, v, + hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, + ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1] + ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2] + 0); + } else { + auto * k = layers[ikv].k; + + // note: v->nb[1] == v->nb[2] for MLA as transforms into MQA + return ggml_view_3d(ctx, k, + hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, + ggml_row_size(k->type, hparams.n_embd_head_k), // v->nb[1] + ggml_row_size(k->type, hparams.n_embd_k_gqa(il)), // v->nb[2] + hparams.n_embd_head_k - hparams.n_embd_head_v); // offset by n_rot elements + } } // note: v->nb[1] > v->nb[2]