@@ -1432,8 +1432,8 @@ ggml_tensor * llm_graph_context::build_attn(
1432
1432
1433
1433
v_cur = ggml_reshape_2d (ctx0, v_cur, n_embd_v_gqa, n_tokens);
1434
1434
1435
- // note: MLA with flash attention now uses the last 512 elements of K in place of V
1436
- if (v_trans || !v_mla ) {
1435
+ // note: MLA with flash attention now uses the last 512 elements of K-cache in place of a V-cache
1436
+ if (!v_mla || v_trans ) {
1437
1437
ggml_tensor * v_cache_view = nullptr ;
1438
1438
1439
1439
if (!v_trans) {
@@ -1474,7 +1474,9 @@ ggml_tensor * llm_graph_context::build_attn(
1474
1474
0 );
1475
1475
// cb(k, "k", il);
1476
1476
1477
+ // note: MLA with flash attention now uses the last 512 elements of K-cache in place of a V-cache
1477
1478
ggml_tensor * v = nullptr ;
1479
+
1478
1480
if (v_trans) {
1479
1481
v = ggml_view_3d (ctx0, kv_self->v_l [il],
1480
1482
n_kv, n_embd_head_v, n_head_kv,
@@ -1487,14 +1489,6 @@ ggml_tensor * llm_graph_context::build_attn(
1487
1489
ggml_row_size (kv_self->v_l [il]->type , n_embd_v_gqa),
1488
1490
ggml_row_size (kv_self->v_l [il]->type , n_embd_head_v),
1489
1491
0 );
1490
- } else {
1491
- // note: MLA with flash attention now uses the last 512 elements of K in place of V
1492
- v = ggml_view_3d (ctx0, kv_self->k_l [il],
1493
- n_embd_head_v, n_kv, n_head_kv,
1494
- ggml_row_size (kv_self->k_l [il]->type , n_embd_k_gqa),
1495
- ggml_row_size (kv_self->k_l [il]->type , n_embd_head_k),
1496
- n_embd_head_k-n_embd_head_v); // offset by n_rot elements
1497
- v = ggml_cont (ctx0, v);
1498
1492
}
1499
1493
1500
1494
ggml_tensor * cur = build_attn_mha (gf, q, k, v, kq_b, kq_mask, v_mla, v_trans, kq_scale);
0 commit comments