@@ -281,19 +281,22 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
281
281
}
282
282
283
283
void llm_graph_input_attn_kv_unified::set_input (const llama_ubatch * ubatch) {
284
- if (self_kq_mask) {
285
- mctx->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
286
- }
284
+ mctx->set_input_k_idxs (self_k_idxs, ubatch);
285
+ mctx->set_input_v_idxs (self_v_idxs, ubatch);
286
+
287
+ mctx->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
287
288
}
288
289
289
290
void llm_graph_input_attn_kv_unified_iswa::set_input (const llama_ubatch * ubatch) {
290
- if (self_kq_mask) {
291
- mctx->get_base ()->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
292
- }
291
+ mctx->get_base ()->set_input_k_idxs (self_k_idxs, ubatch);
292
+ mctx->get_base ()->set_input_v_idxs (self_v_idxs, ubatch);
293
293
294
- if (self_kq_mask_swa) {
295
- mctx->get_swa ()->set_input_kq_mask (self_kq_mask_swa, ubatch, cparams.causal_attn );
296
- }
294
+ mctx->get_base ()->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
295
+
296
+ mctx->get_swa ()->set_input_k_idxs (self_k_idxs_swa, ubatch);
297
+ mctx->get_swa ()->set_input_v_idxs (self_v_idxs_swa, ubatch);
298
+
299
+ mctx->get_swa ()->set_input_kq_mask (self_kq_mask_swa, ubatch, cparams.causal_attn );
297
300
}
298
301
299
302
void llm_graph_input_attn_cross::set_input (const llama_ubatch * ubatch) {
@@ -333,9 +336,10 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
333
336
}
334
337
335
338
void llm_graph_input_mem_hybrid::set_input (const llama_ubatch * ubatch) {
336
- if (self_kq_mask) {
337
- mctx->get_attn ()->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
338
- }
339
+ mctx->get_attn ()->set_input_k_idxs (self_k_idxs, ubatch);
340
+ mctx->get_attn ()->set_input_v_idxs (self_v_idxs, ubatch);
341
+
342
+ mctx->get_attn ()->set_input_kq_mask (self_kq_mask, ubatch, cparams.causal_attn );
339
343
340
344
const int64_t n_rs = mctx->get_recr ()->get_n_rs ();
341
345
@@ -350,7 +354,8 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
350
354
}
351
355
}
352
356
353
- void llm_graph_input_one::set_input (const llama_ubatch *) {
357
+ void llm_graph_input_one::set_input (const llama_ubatch * ubatch) {
358
+ GGML_UNUSED (ubatch);
354
359
GGML_ASSERT (one && ggml_nelements (one) == 1 );
355
360
float f_one = 1 .0f ;
356
361
ggml_backend_tensor_set (one, &f_one, 0 , sizeof (float ));
@@ -997,6 +1002,9 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
997
1002
998
1003
const auto n_kv = inp->mctx ->get_attn ()->get_n_kv ();
999
1004
1005
+ inp->self_k_idxs = mctx_cur->get_attn ()->build_input_k_idxs (ctx0, ubatch);
1006
+ inp->self_v_idxs = mctx_cur->get_attn ()->build_input_v_idxs (ctx0, ubatch);
1007
+
1000
1008
inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1001
1009
// cb(inp->self_kq_mask, "KQ_mask", -1);
1002
1010
ggml_set_input (inp->self_kq_mask );
@@ -1198,8 +1206,10 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
1198
1206
1199
1207
const auto n_kv = mctx_cur->get_n_kv ();
1200
1208
1209
+ inp->self_k_idxs = mctx_cur->build_input_k_idxs (ctx0, ubatch);
1210
+ inp->self_v_idxs = mctx_cur->build_input_v_idxs (ctx0, ubatch);
1211
+
1201
1212
inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1202
- // cb(inp->self_kq_mask, "KQ_mask", -1);
1203
1213
ggml_set_input (inp->self_kq_mask );
1204
1214
1205
1215
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1230,8 +1240,11 @@ ggml_tensor * llm_graph_context::build_attn(
1230
1240
1231
1241
// store to KV cache
1232
1242
{
1233
- ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, il));
1234
- ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, il));
1243
+ const auto & k_idxs = inp->get_k_idxs ();
1244
+ const auto & v_idxs = inp->get_v_idxs ();
1245
+
1246
+ ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, k_idxs, il));
1247
+ ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, v_idxs, il));
1235
1248
}
1236
1249
1237
1250
const auto & kq_mask = inp->get_kq_mask ();
@@ -1290,11 +1303,15 @@ ggml_tensor * llm_graph_context::build_attn(
1290
1303
1291
1304
// optionally store to KV cache
1292
1305
if (k_cur) {
1293
- ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, il));
1306
+ const auto & k_idxs = is_swa ? inp->get_k_idxs_swa () : inp->get_k_idxs ();
1307
+
1308
+ ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, k_idxs, il));
1294
1309
}
1295
1310
1296
1311
if (v_cur) {
1297
- ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, il));
1312
+ const auto & v_idxs = is_swa ? inp->get_v_idxs_swa () : inp->get_v_idxs ();
1313
+
1314
+ ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, v_idxs, il));
1298
1315
}
1299
1316
1300
1317
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa () : inp->get_kq_mask ();
@@ -1398,8 +1415,11 @@ ggml_tensor * llm_graph_context::build_attn(
1398
1415
1399
1416
// store to KV cache
1400
1417
{
1401
- ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, il));
1402
- ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, il));
1418
+ const auto & k_idxs = inp->get_k_idxs ();
1419
+ const auto & v_idxs = inp->get_v_idxs ();
1420
+
1421
+ ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, k_idxs, il));
1422
+ ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, v_idxs, il));
1403
1423
}
1404
1424
1405
1425
const auto & kq_mask = inp->get_kq_mask ();
@@ -1434,8 +1454,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1434
1454
{
1435
1455
const auto n_kv = mctx_cur->get_base ()->get_n_kv ();
1436
1456
1457
+ inp->self_k_idxs = mctx_cur->get_base ()->build_input_k_idxs (ctx0, ubatch);
1458
+ inp->self_v_idxs = mctx_cur->get_base ()->build_input_v_idxs (ctx0, ubatch);
1459
+
1437
1460
inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1438
- // cb(inp->self_kq_mask, "KQ_mask", -1);
1439
1461
ggml_set_input (inp->self_kq_mask );
1440
1462
1441
1463
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1446,8 +1468,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1446
1468
1447
1469
const auto n_kv = mctx_cur->get_swa ()->get_n_kv ();
1448
1470
1471
+ inp->self_k_idxs_swa = mctx_cur->get_swa ()->build_input_k_idxs (ctx0, ubatch);
1472
+ inp->self_v_idxs_swa = mctx_cur->get_swa ()->build_input_v_idxs (ctx0, ubatch);
1473
+
1449
1474
inp->self_kq_mask_swa = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1450
- // cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
1451
1475
ggml_set_input (inp->self_kq_mask_swa );
1452
1476
1453
1477
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask_swa , GGML_TYPE_F16) : inp->self_kq_mask_swa ;
0 commit comments