Skip to content

Commit 0dccd87

Browse files
committed
Fixed wrong stride. Use v_trans to detect use of FA.
1 parent 125ef32 commit 0dccd87

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

src/llama-graph.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1432,7 +1432,7 @@ ggml_tensor * llm_graph_context::build_attn(
14321432

14331433
v_cur = ggml_reshape_2d(ctx0, v_cur, n_embd_v_gqa, n_tokens);
14341434

1435-
// note: MLA flash attention now uses the last 512 elements of K in place of V
1435+
// note: MLA with flash attention now uses the last 512 elements of K in place of V
14361436
if (v_trans || !v_mla) {
14371437
ggml_tensor * v_cache_view = nullptr;
14381438

@@ -1488,11 +1488,11 @@ ggml_tensor * llm_graph_context::build_attn(
14881488
ggml_row_size(kv_self->v_l[il]->type, n_embd_head_v),
14891489
0);
14901490
} else {
1491-
// note: MLA flash attention now uses the last 512 elements of K in place of V
1491+
// note: MLA with flash attention now uses the last 512 elements of K in place of V
14921492
v = ggml_view_3d(ctx0, kv_self->k_l[il],
14931493
n_embd_head_v, n_kv, n_head_kv,
1494-
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
1495-
ggml_row_size(kv_self->v_l[il]->type, n_embd_head_v),
1494+
ggml_row_size(kv_self->v_l[il]->type, n_embd_k_gqa),
1495+
ggml_row_size(kv_self->v_l[il]->type, n_embd_head_k),
14961496
n_embd_head_k-n_embd_head_v); // offset by n_rot elements
14971497
}
14981498

src/llama-kv-cache.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
3333
const int32_t n_layer = hparams.n_layer;
3434

3535
has_shift = false;
36-
can_shift = !(model.arch == LLM_ARCH_DEEPSEEK2 && cparams.flash_attn);
36+
can_shift = model.arch != LLM_ARCH_DEEPSEEK2 || v_trans; // TODO: allow context shifting for MLA with flash attention
3737

3838
LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d, padding = %d\n",
3939
__func__, kv_size, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift, padding);
@@ -100,8 +100,9 @@ llama_kv_cache_unified::llama_kv_cache_unified(
100100
throw std::runtime_error("failed to create ggml context for kv cache");
101101
}
102102

103+
// note: MLA with flash attention now uses the last 512 elements of K in place of V
103104
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
104-
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, model.arch == LLM_ARCH_DEEPSEEK2 && cparams.flash_att ? 0 : n_embd_v_gqa*kv_size);
105+
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, model.arch != LLM_ARCH_DEEPSEEK2 || v_trans ? n_embd_v_gqa*kv_size : 0);
105106
ggml_format_name(k, "cache_k_l%d", i);
106107
ggml_format_name(v, "cache_v_l%d", i);
107108
k_l.push_back(k);

0 commit comments

Comments
 (0)