Skip to content

Commit 165d822

Browse files
committed
graph : support iSWA virtual sequences
ggml-ci
1 parent 1b74b9d commit 165d822

File tree

1 file changed

+8
-9
lines changed

1 file changed

+8
-9
lines changed

src/llama-graph.cpp

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1001,10 +1001,10 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
10011001
{
10021002
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
10031003

1004-
const auto n_kv = inp->mctx->get_attn()->get_n_kv();
1004+
const auto n_kv = inp->mctx->get_attn()->get_n_kv();
1005+
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1;
10051006

1006-
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1007-
//cb(inp->self_kq_mask, "KQ_mask", -1);
1007+
inp->self_kq_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs);
10081008
ggml_set_input(inp->self_kq_mask);
10091009

10101010
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1206,14 +1206,13 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
12061206
{
12071207
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
12081208

1209-
const auto n_kv = mctx_cur->get_n_kv();
1209+
const auto n_kv = mctx_cur->get_n_kv();
12101210
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1;
12111211

12121212
inp->self_kv_idxs = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens);
12131213
ggml_set_input(inp->self_kv_idxs);
12141214

12151215
inp->self_kq_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs);
1216-
//cb(inp->self_kq_mask, "KQ_mask", -1);
12171216
ggml_set_input(inp->self_kq_mask);
12181217

12191218
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1440,14 +1439,15 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14401439

14411440
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
14421441

1442+
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1;
1443+
14431444
{
14441445
const auto n_kv = mctx_cur->get_base()->get_n_kv();
14451446

14461447
inp->self_kv_idxs = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens);
14471448
ggml_set_input(inp->self_kv_idxs);
14481449

1449-
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1450-
//cb(inp->self_kq_mask, "KQ_mask", -1);
1450+
inp->self_kq_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs);
14511451
ggml_set_input(inp->self_kq_mask);
14521452

14531453
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1461,8 +1461,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14611461
inp->self_kv_idxs_swa = ggml_new_tensor_1d(ctx0, GGML_TYPE_I64, n_tokens);
14621462
ggml_set_input(inp->self_kv_idxs_swa);
14631463

1464-
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1465-
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
1464+
inp->self_kq_mask_swa = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs);
14661465
ggml_set_input(inp->self_kq_mask_swa);
14671466

14681467
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;

0 commit comments

Comments
 (0)