@@ -281,8 +281,12 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
281
281
}
282
282
283
283
void llm_graph_input_attn_kv_unified::set_input (const llama_ubatch * ubatch) {
284
- if (self_kv_idxs) {
285
- mctx->set_input_kv_idxs (self_kv_idxs, ubatch);
284
+ if (self_k_idxs) {
285
+ mctx->set_input_k_idxs (self_k_idxs, ubatch);
286
+ }
287
+
288
+ if (self_v_idxs) {
289
+ mctx->set_input_v_idxs (self_v_idxs, ubatch);
286
290
}
287
291
288
292
if (self_kq_mask) {
@@ -291,12 +295,20 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
291
295
}
292
296
293
297
void llm_graph_input_attn_kv_unified_iswa::set_input (const llama_ubatch * ubatch) {
294
- if (self_kv_idxs) {
295
- mctx->get_base ()->set_input_kv_idxs (self_kv_idxs, ubatch);
298
+ if (self_k_idxs) {
299
+ mctx->get_base ()->set_input_k_idxs (self_k_idxs, ubatch);
300
+ }
301
+
302
+ if (self_v_idxs) {
303
+ mctx->get_base ()->set_input_v_idxs (self_v_idxs, ubatch);
304
+ }
305
+
306
+ if (self_k_idxs_swa) {
307
+ mctx->get_swa ()->set_input_k_idxs (self_k_idxs_swa, ubatch);
296
308
}
297
309
298
- if (self_kv_idxs_swa ) {
299
- mctx->get_swa ()->set_input_kv_idxs (self_kv_idxs_swa , ubatch);
310
+ if (self_v_idxs_swa ) {
311
+ mctx->get_swa ()->set_input_v_idxs (self_v_idxs_swa , ubatch);
300
312
}
301
313
302
314
if (self_kq_mask) {
@@ -1209,8 +1221,8 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
1209
1221
const auto n_kv = mctx_cur->get_n_kv ();
1210
1222
const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1 ;
1211
1223
1212
- inp->self_kv_idxs = ggml_new_tensor_1d (ctx0, GGML_TYPE_I64, n_tokens );
1213
- ggml_set_input ( inp->self_kv_idxs );
1224
+ inp->self_k_idxs = mctx_cur-> build_input_k_idxs (ctx0, ubatch );
1225
+ inp->self_v_idxs = mctx_cur-> build_input_v_idxs (ctx0, ubatch );
1214
1226
1215
1227
inp->self_kq_mask = ggml_new_tensor_3d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs);
1216
1228
ggml_set_input (inp->self_kq_mask );
@@ -1243,10 +1255,11 @@ ggml_tensor * llm_graph_context::build_attn(
1243
1255
1244
1256
// store to KV cache
1245
1257
{
1246
- const auto & kv_idxs = inp->get_kv_idxs ();
1258
+ const auto & k_idxs = inp->get_k_idxs ();
1259
+ const auto & v_idxs = inp->get_v_idxs ();
1247
1260
1248
- ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, kv_idxs , il));
1249
- ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, kv_idxs , il));
1261
+ ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, k_idxs , il));
1262
+ ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, v_idxs , il));
1250
1263
}
1251
1264
1252
1265
const auto & kq_mask = inp->get_kq_mask ();
@@ -1299,10 +1312,11 @@ ggml_tensor * llm_graph_context::build_attn(
1299
1312
1300
1313
// store to KV cache
1301
1314
{
1302
- const auto & kv_idxs = is_swa ? inp->get_kv_idxs_swa () : inp->get_kv_idxs ();
1315
+ const auto & k_idxs = is_swa ? inp->get_k_idxs_swa () : inp->get_k_idxs ();
1316
+ const auto & v_idxs = is_swa ? inp->get_v_idxs_swa () : inp->get_v_idxs ();
1303
1317
1304
- ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, kv_idxs , il));
1305
- ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, kv_idxs , il));
1318
+ ggml_build_forward_expand (gf, mctx_cur->cpy_k (ctx0, k_cur, k_idxs , il));
1319
+ ggml_build_forward_expand (gf, mctx_cur->cpy_v (ctx0, v_cur, v_idxs , il));
1306
1320
}
1307
1321
1308
1322
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa () : inp->get_kq_mask ();
@@ -1444,8 +1458,8 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1444
1458
{
1445
1459
const auto n_kv = mctx_cur->get_base ()->get_n_kv ();
1446
1460
1447
- inp->self_kv_idxs = ggml_new_tensor_1d ( ctx0, GGML_TYPE_I64, n_tokens );
1448
- ggml_set_input ( inp->self_kv_idxs );
1461
+ inp->self_k_idxs = mctx_cur-> get_base ()-> build_input_k_idxs ( ctx0, ubatch );
1462
+ inp->self_v_idxs = mctx_cur-> get_base ()-> build_input_v_idxs (ctx0, ubatch );
1449
1463
1450
1464
inp->self_kq_mask = ggml_new_tensor_3d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs);
1451
1465
ggml_set_input (inp->self_kq_mask );
@@ -1458,8 +1472,8 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1458
1472
1459
1473
const auto n_kv = mctx_cur->get_swa ()->get_n_kv ();
1460
1474
1461
- inp->self_kv_idxs_swa = ggml_new_tensor_1d ( ctx0, GGML_TYPE_I64, n_tokens );
1462
- ggml_set_input ( inp->self_kv_idxs_swa );
1475
+ inp->self_k_idxs_swa = mctx_cur-> get_swa ()-> build_input_k_idxs ( ctx0, ubatch );
1476
+ inp->self_v_idxs_swa = mctx_cur-> get_swa ()-> build_input_v_idxs (ctx0, ubatch );
1463
1477
1464
1478
inp->self_kq_mask_swa = ggml_new_tensor_3d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_seqs, GGML_KQ_MASK_PAD), n_seqs);
1465
1479
ggml_set_input (inp->self_kq_mask_swa );
0 commit comments