@@ -1432,21 +1432,24 @@ ggml_tensor * llm_graph_context::build_attn(
1432
1432
1433
1433
v_cur = ggml_reshape_2d (ctx0, v_cur, n_embd_v_gqa, n_tokens);
1434
1434
1435
- ggml_tensor * v_cache_view = nullptr ;
1435
+ // note: MLA flash attention now uses the last 512 elements of K in place of V
1436
+ if (v_trans || !v_mla) {
1437
+ ggml_tensor * v_cache_view = nullptr ;
1436
1438
1437
- if (!v_trans) {
1438
- v_cache_view = ggml_view_1d (ctx0, kv_self->v_l [il], n_tokens*n_embd_v_gqa, ggml_row_size (kv_self->v_l [il]->type , n_embd_v_gqa)*kv_head);
1439
- } else {
1440
- // note: the V cache is transposed when not using flash attention
1441
- v_cache_view = ggml_view_2d (ctx0, kv_self->v_l [il], n_tokens, n_embd_v_gqa,
1442
- ( n_ctx)*ggml_element_size (kv_self->v_l [il]),
1443
- (kv_head)*ggml_element_size (kv_self->v_l [il]));
1439
+ if (!v_trans) {
1440
+ v_cache_view = ggml_view_1d (ctx0, kv_self->v_l [il], n_tokens*n_embd_v_gqa, ggml_row_size (kv_self->v_l [il]->type , n_embd_v_gqa)*kv_head);
1441
+ } else {
1442
+ // note: the V cache is transposed when not using flash attention
1443
+ v_cache_view = ggml_view_2d (ctx0, kv_self->v_l [il], n_tokens, n_embd_v_gqa,
1444
+ ( n_ctx)*ggml_element_size (kv_self->v_l [il]),
1445
+ (kv_head)*ggml_element_size (kv_self->v_l [il]));
1444
1446
1445
- v_cur = ggml_transpose (ctx0, v_cur);
1446
- }
1447
- // cb(v_cache_view, "v_cache_view", il);
1447
+ v_cur = ggml_transpose (ctx0, v_cur);
1448
+ }
1449
+ // cb(v_cache_view, "v_cache_view", il);
1448
1450
1449
- ggml_build_forward_expand (gf, ggml_cpy (ctx0, v_cur, v_cache_view));
1451
+ ggml_build_forward_expand (gf, ggml_cpy (ctx0, v_cur, v_cache_view));
1452
+ }
1450
1453
}
1451
1454
1452
1455
const bool is_swa = hparams.is_swa (il);
@@ -1471,17 +1474,27 @@ ggml_tensor * llm_graph_context::build_attn(
1471
1474
0 );
1472
1475
// cb(k, "k", il);
1473
1476
1474
- ggml_tensor * v = !v_trans ?
1475
- ggml_view_3d (ctx0, kv_self->v_l [il],
1476
- n_embd_head_v, n_kv, n_head_kv,
1477
- ggml_row_size (kv_self->v_l [il]->type , n_embd_v_gqa),
1478
- ggml_row_size (kv_self->v_l [il]->type , n_embd_head_v),
1479
- 0 ) :
1480
- ggml_view_3d (ctx0, kv_self->v_l [il],
1477
+ ggml_tensor * v = nullptr ;
1478
+ if (v_trans) {
1479
+ v = ggml_view_3d (ctx0, kv_self->v_l [il],
1481
1480
n_kv, n_embd_head_v, n_head_kv,
1482
1481
ggml_element_size (kv_self->v_l [il])*n_ctx,
1483
1482
ggml_element_size (kv_self->v_l [il])*n_ctx*n_embd_head_v,
1484
1483
0 );
1484
+ } else if (!v_mla) {
1485
+ v = ggml_view_3d (ctx0, kv_self->v_l [il],
1486
+ n_embd_head_v, n_kv, n_head_kv,
1487
+ ggml_row_size (kv_self->v_l [il]->type , n_embd_v_gqa),
1488
+ ggml_row_size (kv_self->v_l [il]->type , n_embd_head_v),
1489
+ 0 );
1490
+ } else {
1491
+ // note: MLA flash attention now uses the last 512 elements of K in place of V
1492
+ v = ggml_view_3d (ctx0, kv_self->k_l [il],
1493
+ n_embd_head_v, n_kv, n_head_kv,
1494
+ ggml_row_size (kv_self->v_l [il]->type , n_embd_v_gqa),
1495
+ ggml_row_size (kv_self->v_l [il]->type , n_embd_head_v),
1496
+ n_embd_head_k-n_embd_head_v); // offset by n_rot elements
1497
+ }
1485
1498
1486
1499
ggml_tensor * cur = build_attn_mha (gf, q, k, v, kq_b, kq_mask, v_mla, v_trans, kq_scale);
1487
1500
cb (cur, " kqv_out" , il);
0 commit comments