Skip to content

Commit 125ef32

Browse files
committed
Initial commit
1 parent ab3971f commit 125ef32

File tree

2 files changed

+34
-21
lines changed

2 files changed

+34
-21
lines changed

src/llama-graph.cpp

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1432,21 +1432,24 @@ ggml_tensor * llm_graph_context::build_attn(
14321432

14331433
v_cur = ggml_reshape_2d(ctx0, v_cur, n_embd_v_gqa, n_tokens);
14341434

1435-
ggml_tensor * v_cache_view = nullptr;
1435+
// note: MLA flash attention now uses the last 512 elements of K in place of V
1436+
if (v_trans || !v_mla) {
1437+
ggml_tensor * v_cache_view = nullptr;
14361438

1437-
if (!v_trans) {
1438-
v_cache_view = ggml_view_1d(ctx0, kv_self->v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa)*kv_head);
1439-
} else {
1440-
// note: the V cache is transposed when not using flash attention
1441-
v_cache_view = ggml_view_2d(ctx0, kv_self->v_l[il], n_tokens, n_embd_v_gqa,
1442-
( n_ctx)*ggml_element_size(kv_self->v_l[il]),
1443-
(kv_head)*ggml_element_size(kv_self->v_l[il]));
1439+
if (!v_trans) {
1440+
v_cache_view = ggml_view_1d(ctx0, kv_self->v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa)*kv_head);
1441+
} else {
1442+
// note: the V cache is transposed when not using flash attention
1443+
v_cache_view = ggml_view_2d(ctx0, kv_self->v_l[il], n_tokens, n_embd_v_gqa,
1444+
( n_ctx)*ggml_element_size(kv_self->v_l[il]),
1445+
(kv_head)*ggml_element_size(kv_self->v_l[il]));
14441446

1445-
v_cur = ggml_transpose(ctx0, v_cur);
1446-
}
1447-
//cb(v_cache_view, "v_cache_view", il);
1447+
v_cur = ggml_transpose(ctx0, v_cur);
1448+
}
1449+
//cb(v_cache_view, "v_cache_view", il);
14481450

1449-
ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
1451+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
1452+
}
14501453
}
14511454

14521455
const bool is_swa = hparams.is_swa(il);
@@ -1471,17 +1474,27 @@ ggml_tensor * llm_graph_context::build_attn(
14711474
0);
14721475
//cb(k, "k", il);
14731476

1474-
ggml_tensor * v = !v_trans ?
1475-
ggml_view_3d(ctx0, kv_self->v_l[il],
1476-
n_embd_head_v, n_kv, n_head_kv,
1477-
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
1478-
ggml_row_size(kv_self->v_l[il]->type, n_embd_head_v),
1479-
0) :
1480-
ggml_view_3d(ctx0, kv_self->v_l[il],
1477+
ggml_tensor * v = nullptr;
1478+
if (v_trans) {
1479+
v = ggml_view_3d(ctx0, kv_self->v_l[il],
14811480
n_kv, n_embd_head_v, n_head_kv,
14821481
ggml_element_size(kv_self->v_l[il])*n_ctx,
14831482
ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v,
14841483
0);
1484+
} else if (!v_mla) {
1485+
v = ggml_view_3d(ctx0, kv_self->v_l[il],
1486+
n_embd_head_v, n_kv, n_head_kv,
1487+
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
1488+
ggml_row_size(kv_self->v_l[il]->type, n_embd_head_v),
1489+
0);
1490+
} else {
1491+
// note: MLA flash attention now uses the last 512 elements of K in place of V
1492+
v = ggml_view_3d(ctx0, kv_self->k_l[il],
1493+
n_embd_head_v, n_kv, n_head_kv,
1494+
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
1495+
ggml_row_size(kv_self->v_l[il]->type, n_embd_head_v),
1496+
n_embd_head_k-n_embd_head_v); // offset by n_rot elements
1497+
}
14851498

14861499
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, v_trans, kq_scale);
14871500
cb(cur, "kqv_out", il);

src/llama-kv-cache.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
3333
const int32_t n_layer = hparams.n_layer;
3434

3535
has_shift = false;
36-
can_shift = true;
36+
can_shift = !(model.arch == LLM_ARCH_DEEPSEEK2 && cparams.flash_attn);
3737

3838
LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d, padding = %d\n",
3939
__func__, kv_size, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift, padding);
@@ -101,7 +101,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
101101
}
102102

103103
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
104-
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
104+
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, model.arch == LLM_ARCH_DEEPSEEK2 && cparams.flash_att ? 0 : n_embd_v_gqa*kv_size);
105105
ggml_format_name(k, "cache_k_l%d", i);
106106
ggml_format_name(v, "cache_v_l%d", i);
107107
k_l.push_back(k);

0 commit comments

Comments
 (0)