@@ -1009,8 +1009,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
1009
1009
inp->self_k_idxs = mctx_cur->get_attn ()->build_input_k_idxs (ctx0, ubatch);
1010
1010
inp->self_v_idxs = mctx_cur->get_attn ()->build_input_v_idxs (ctx0, ubatch);
1011
1011
1012
- inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1013
- // cb(inp->self_kq_mask, "KQ_mask", -1);
1012
+ inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
1014
1013
ggml_set_input (inp->self_kq_mask );
1015
1014
1016
1015
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1147,8 +1146,7 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
1147
1146
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
1148
1147
1149
1148
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
1150
- inp->kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1151
- // cb(inp_kq_mask, "KQ_mask", -1);
1149
+ inp->kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
1152
1150
ggml_set_input (inp->kq_mask );
1153
1151
1154
1152
inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->kq_mask , GGML_TYPE_F16) : inp->kq_mask ;
@@ -1213,7 +1211,7 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
1213
1211
inp->self_k_idxs = mctx_cur->build_input_k_idxs (ctx0, ubatch);
1214
1212
inp->self_v_idxs = mctx_cur->build_input_v_idxs (ctx0, ubatch);
1215
1213
1216
- inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1214
+ inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
1217
1215
ggml_set_input (inp->self_kq_mask );
1218
1216
1219
1217
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1347,7 +1345,7 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
1347
1345
1348
1346
const int32_t n_enc = !cross->v_embd .empty () ? cross->n_enc : hparams.n_ctx_train ;
1349
1347
1350
- inp->cross_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_enc, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1348
+ inp->cross_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_enc, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
1351
1349
ggml_set_input (inp->cross_kq_mask );
1352
1350
1353
1351
inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->cross_kq_mask , GGML_TYPE_F16) : inp->cross_kq_mask ;
@@ -1461,7 +1459,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1461
1459
inp->self_k_idxs = mctx_cur->get_base ()->build_input_k_idxs (ctx0, ubatch);
1462
1460
inp->self_v_idxs = mctx_cur->get_base ()->build_input_v_idxs (ctx0, ubatch);
1463
1461
1464
- inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1462
+ inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
1465
1463
ggml_set_input (inp->self_kq_mask );
1466
1464
1467
1465
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1475,7 +1473,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1475
1473
inp->self_k_idxs_swa = mctx_cur->get_swa ()->build_input_k_idxs (ctx0, ubatch);
1476
1474
inp->self_v_idxs_swa = mctx_cur->get_swa ()->build_input_v_idxs (ctx0, ubatch);
1477
1475
1478
- inp->self_kq_mask_swa = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1476
+ inp->self_kq_mask_swa = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
1479
1477
ggml_set_input (inp->self_kq_mask_swa );
1480
1478
1481
1479
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask_swa , GGML_TYPE_F16) : inp->self_kq_mask_swa ;
0 commit comments