@@ -1000,13 +1000,13 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
1000
1000
{
1001
1001
GGML_ASSERT (hparams.swa_type == LLAMA_SWA_TYPE_NONE && " Hybrid recurrent is not supported with SWA attention layers" );
1002
1002
1003
- const auto n_kv = inp->mctx ->get_attn ()->get_n_kv ();
1004
- const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1 ;
1003
+ const auto n_kv = inp->mctx ->get_attn ()->get_n_kv ();
1004
+ const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq ;
1005
1005
1006
1006
inp->self_k_idxs = mctx_cur->get_attn ()->build_input_k_idxs (ctx0, ubatch);
1007
1007
inp->self_v_idxs = mctx_cur->get_attn ()->build_input_v_idxs (ctx0, ubatch);
1008
1008
1009
- inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_seqs , GGML_KQ_MASK_PAD), 1 , n_seqs );
1009
+ inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_stream , GGML_KQ_MASK_PAD), 1 , n_stream );
1010
1010
ggml_set_input (inp->self_kq_mask );
1011
1011
1012
1012
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1033,9 +1033,10 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1033
1033
float kq_scale) const {
1034
1034
const bool v_trans = v->nb [1 ] > v->nb [2 ];
1035
1035
1036
- const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1 ;
1036
+ // split the batch into streams if needed
1037
+ const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq ;
1037
1038
1038
- q = ggml_reshape_4d (ctx0, q, q->ne [0 ], q->ne [1 ], q->ne [2 ]/n_seqs, n_seqs );
1039
+ q = ggml_reshape_4d (ctx0, q, q->ne [0 ], q->ne [1 ], q->ne [2 ]/n_stream, n_stream );
1039
1040
1040
1041
q = ggml_permute (ctx0, q, 0 , 2 , 1 , 3 );
1041
1042
k = ggml_permute (ctx0, k, 0 , 2 , 1 , 3 );
@@ -1085,7 +1086,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1085
1086
#endif
1086
1087
}
1087
1088
1088
- cur = ggml_reshape_2d (ctx0, cur, cur->ne [0 ]*n_head, n_tokens*n_seqs);
1089
+ // recombine streams
1090
+ cur = ggml_reshape_2d (ctx0, cur, cur->ne [0 ]*n_head, n_tokens*n_stream);
1089
1091
} else {
1090
1092
ggml_tensor * kq = ggml_mul_mat (ctx0, k, q);
1091
1093
@@ -1130,7 +1132,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1130
1132
1131
1133
cur = ggml_permute (ctx0, kqv, 0 , 2 , 1 , 3 );
1132
1134
1133
- cur = ggml_cont_2d (ctx0, cur, cur->ne [0 ]*n_head, n_tokens*n_seqs);
1135
+ // recombine streams
1136
+ cur = ggml_cont_2d (ctx0, cur, cur->ne [0 ]*n_head, n_tokens*n_stream);
1134
1137
1135
1138
if (!cparams.offload_kqv ) {
1136
1139
// all nodes between the KV store and the attention output are run on the CPU
@@ -1207,13 +1210,13 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
1207
1210
{
1208
1211
GGML_ASSERT (hparams.swa_type == LLAMA_SWA_TYPE_NONE && " Use llama_kv_cache_unified_iswa for SWA" );
1209
1212
1210
- const auto n_kv = mctx_cur->get_n_kv ();
1211
- const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1 ;
1213
+ const auto n_kv = mctx_cur->get_n_kv ();
1214
+ const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq ;
1212
1215
1213
1216
inp->self_k_idxs = mctx_cur->build_input_k_idxs (ctx0, ubatch);
1214
1217
inp->self_v_idxs = mctx_cur->build_input_v_idxs (ctx0, ubatch);
1215
1218
1216
- inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_seqs , GGML_KQ_MASK_PAD), 1 , n_seqs );
1219
+ inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_stream , GGML_KQ_MASK_PAD), 1 , n_stream );
1217
1220
ggml_set_input (inp->self_kq_mask );
1218
1221
1219
1222
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1455,15 +1458,15 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1455
1458
1456
1459
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
1457
1460
1458
- const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1 ;
1461
+ const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq ;
1459
1462
1460
1463
{
1461
1464
const auto n_kv = mctx_cur->get_base ()->get_n_kv ();
1462
1465
1463
1466
inp->self_k_idxs = mctx_cur->get_base ()->build_input_k_idxs (ctx0, ubatch);
1464
1467
inp->self_v_idxs = mctx_cur->get_base ()->build_input_v_idxs (ctx0, ubatch);
1465
1468
1466
- inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_seqs , GGML_KQ_MASK_PAD), 1 , n_seqs );
1469
+ inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_stream , GGML_KQ_MASK_PAD), 1 , n_stream );
1467
1470
ggml_set_input (inp->self_kq_mask );
1468
1471
1469
1472
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1477,7 +1480,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1477
1480
inp->self_k_idxs_swa = mctx_cur->get_swa ()->build_input_k_idxs (ctx0, ubatch);
1478
1481
inp->self_v_idxs_swa = mctx_cur->get_swa ()->build_input_v_idxs (ctx0, ubatch);
1479
1482
1480
- inp->self_kq_mask_swa = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_seqs , GGML_KQ_MASK_PAD), 1 , n_seqs );
1483
+ inp->self_kq_mask_swa = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_stream , GGML_KQ_MASK_PAD), 1 , n_stream );
1481
1484
ggml_set_input (inp->self_kq_mask_swa );
1482
1485
1483
1486
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask_swa , GGML_TYPE_F16) : inp->self_kq_mask_swa ;
0 commit comments