@@ -1001,10 +1001,10 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
1001
1001
{
1002
1002
GGML_ASSERT (hparams.swa_type == LLAMA_SWA_TYPE_NONE && " Hybrid recurrent is not supported with SWA attention layers" );
1003
1003
1004
- const auto n_kv = inp->mctx ->get_attn ()->get_n_kv ();
1004
+ const auto n_kv = inp->mctx ->get_attn ()->get_n_kv ();
1005
+ const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1 ;
1005
1006
1006
- inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1007
- // cb(inp->self_kq_mask, "KQ_mask", -1);
1007
+ inp->self_kq_mask = ggml_new_tensor_3d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs);
1008
1008
ggml_set_input (inp->self_kq_mask );
1009
1009
1010
1010
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1206,14 +1206,13 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
1206
1206
{
1207
1207
GGML_ASSERT (hparams.swa_type == LLAMA_SWA_TYPE_NONE && " Use llama_kv_cache_unified_iswa for SWA" );
1208
1208
1209
- const auto n_kv = mctx_cur->get_n_kv ();
1209
+ const auto n_kv = mctx_cur->get_n_kv ();
1210
1210
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1 ;
1211
1211
1212
1212
inp->self_kv_idxs = ggml_new_tensor_1d (ctx0, GGML_TYPE_I64, n_tokens);
1213
1213
ggml_set_input (inp->self_kv_idxs );
1214
1214
1215
1215
inp->self_kq_mask = ggml_new_tensor_3d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs);
1216
- // cb(inp->self_kq_mask, "KQ_mask", -1);
1217
1216
ggml_set_input (inp->self_kq_mask );
1218
1217
1219
1218
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1440,14 +1439,15 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1440
1439
1441
1440
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
1442
1441
1442
+ const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1 ;
1443
+
1443
1444
{
1444
1445
const auto n_kv = mctx_cur->get_base ()->get_n_kv ();
1445
1446
1446
1447
inp->self_kv_idxs = ggml_new_tensor_1d (ctx0, GGML_TYPE_I64, n_tokens);
1447
1448
ggml_set_input (inp->self_kv_idxs );
1448
1449
1449
- inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1450
- // cb(inp->self_kq_mask, "KQ_mask", -1);
1450
+ inp->self_kq_mask = ggml_new_tensor_3d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs);
1451
1451
ggml_set_input (inp->self_kq_mask );
1452
1452
1453
1453
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1461,8 +1461,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1461
1461
inp->self_kv_idxs_swa = ggml_new_tensor_1d (ctx0, GGML_TYPE_I64, n_tokens);
1462
1462
ggml_set_input (inp->self_kv_idxs_swa );
1463
1463
1464
- inp->self_kq_mask_swa = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1465
- // cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
1464
+ inp->self_kq_mask_swa = ggml_new_tensor_3d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs);
1466
1465
ggml_set_input (inp->self_kq_mask_swa );
1467
1466
1468
1467
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask_swa , GGML_TYPE_F16) : inp->self_kq_mask_swa ;
0 commit comments