Skip to content

Commit 567b16c

Browse files
ggerganovMinh141120
authored andcommitted
graph : prepare for 4D mask (ggml-org#14515)
ggml-ci
1 parent 732d0ed commit 567b16c

File tree

2 files changed

+18
-20
lines changed

2 files changed

+18
-20
lines changed

src/llama-graph.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,8 +1009,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
10091009
inp->self_k_idxs = mctx_cur->get_attn()->build_input_k_idxs(ctx0, ubatch);
10101010
inp->self_v_idxs = mctx_cur->get_attn()->build_input_v_idxs(ctx0, ubatch);
10111011

1012-
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1013-
//cb(inp->self_kq_mask, "KQ_mask", -1);
1012+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
10141013
ggml_set_input(inp->self_kq_mask);
10151014

10161015
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1147,8 +1146,7 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
11471146
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
11481147

11491148
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
1150-
inp->kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1151-
//cb(inp_kq_mask, "KQ_mask", -1);
1149+
inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
11521150
ggml_set_input(inp->kq_mask);
11531151

11541152
inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
@@ -1213,7 +1211,7 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
12131211
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
12141212
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
12151213

1216-
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1214+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
12171215
ggml_set_input(inp->self_kq_mask);
12181216

12191217
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1347,7 +1345,7 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
13471345

13481346
const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
13491347

1350-
inp->cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1348+
inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
13511349
ggml_set_input(inp->cross_kq_mask);
13521350

13531351
inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
@@ -1461,7 +1459,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14611459
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
14621460
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
14631461

1464-
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1462+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
14651463
ggml_set_input(inp->self_kq_mask);
14661464

14671465
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1475,7 +1473,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
14751473
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
14761474
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
14771475

1478-
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1476+
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
14791477
ggml_set_input(inp->self_kq_mask_swa);
14801478

14811479
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;

src/llama-graph.h

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,8 @@ class llm_graph_input_attn_no_cache : public llm_graph_input_i {
229229

230230
ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
231231

232-
ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch]
233-
ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch]
232+
ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch, 1, 1]
233+
ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch, 1, 1]
234234

235235
const llama_hparams & hparams;
236236
const llama_cparams & cparams;
@@ -258,8 +258,8 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
258258
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
259259
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
260260

261-
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
262-
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
261+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
262+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
263263

264264
const llama_hparams & hparams;
265265
const llama_cparams & cparams;
@@ -294,10 +294,10 @@ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
294294
ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
295295
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch]
296296

297-
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
298-
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
299-
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch]
300-
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch]
297+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
298+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
299+
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch, 1, 1]
300+
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch, 1, 1]
301301

302302
const llama_hparams & hparams;
303303
const llama_cparams & cparams;
@@ -314,8 +314,8 @@ class llm_graph_input_attn_cross : public llm_graph_input_i {
314314

315315
ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
316316

317-
ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch]
318-
ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch]
317+
ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
318+
ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
319319

320320
const llama_cross * cross = nullptr;
321321
};
@@ -344,8 +344,8 @@ class llm_graph_input_mem_hybrid : public llm_graph_input_i {
344344
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
345345
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
346346

347-
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
348-
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
347+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
348+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
349349

350350
const llama_hparams & hparams;
351351
const llama_cparams & cparams;

0 commit comments

Comments
 (0)