@@ -413,13 +413,6 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
413
413
}
414
414
}
415
415
416
- llm_graph_input_attn_kv_hybrid_recurrent::llm_graph_input_attn_kv_hybrid_recurrent (
417
- const llama_hparams & hparams,
418
- const llama_cparams & cparams,
419
- const llama_kv_cache_hybrid_recurrent_state * kv_state) :
420
- llm_graph_input_attn_kv_unified(hparams, cparams, kv_state->get_state_attn ()) {
421
- }
422
-
423
416
//
424
417
// llm_graph_context
425
418
//
@@ -1295,7 +1288,9 @@ ggml_tensor * llm_graph_context::build_attn(
1295
1288
ggml_build_forward_expand (gf, k_cur);
1296
1289
ggml_build_forward_expand (gf, v_cur);
1297
1290
1298
- const auto * kv_state = static_cast <const llama_kv_cache_unified_state *>(mstate);
1291
+ // NOTE: For hybrid caches, this may be a child of mstate, so we use the one
1292
+ // encapsulated in inp
1293
+ const auto * kv_state = inp->kv_state ;
1299
1294
1300
1295
// store to KV cache
1301
1296
{
@@ -1327,10 +1322,10 @@ ggml_tensor * llm_graph_context::build_attn(
1327
1322
return cur;
1328
1323
}
1329
1324
1330
- llm_graph_input_attn_kv_hybrid_recurrent * llm_graph_context::build_attn_inp_kv_hybrid_recurrent () const {
1325
+ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_hybrid_recurrent () const {
1331
1326
const auto * kv_state = static_cast <const llama_kv_cache_hybrid_recurrent_state *>(mstate);
1332
1327
1333
- auto inp = std::make_unique<llm_graph_input_attn_kv_hybrid_recurrent >(hparams, cparams, kv_state);
1328
+ auto inp = std::make_unique<llm_graph_input_attn_kv_unified >(hparams, cparams, kv_state-> get_state_attn () );
1334
1329
1335
1330
{
1336
1331
GGML_ASSERT (hparams.swa_type == LLAMA_SWA_TYPE_NONE && " Hybrid recurrent is not supported with SWA attention layers" );
@@ -1344,25 +1339,7 @@ llm_graph_input_attn_kv_hybrid_recurrent * llm_graph_context::build_attn_inp_kv_
1344
1339
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
1345
1340
}
1346
1341
1347
- return (llm_graph_input_attn_kv_hybrid_recurrent *) res->add_input (std::move (inp));
1348
- }
1349
-
1350
- ggml_tensor * llm_graph_context::build_attn (
1351
- llm_graph_input_attn_kv_hybrid_recurrent * inp,
1352
- ggml_cgraph * gf,
1353
- ggml_tensor * wo,
1354
- ggml_tensor * wo_b,
1355
- ggml_tensor * q_cur,
1356
- ggml_tensor * k_cur,
1357
- ggml_tensor * v_cur,
1358
- ggml_tensor * kq_b,
1359
- ggml_tensor * v_mla,
1360
- float kq_scale,
1361
- int il) const {
1362
- return build_attn (
1363
- static_cast <llm_graph_input_attn_kv_unified *>(inp),
1364
- gf, wo, wo_b, q_cur, k_cur, v_cur, kq_b, v_mla, kq_scale, il
1365
- );
1342
+ return (llm_graph_input_attn_kv_unified *) res->add_input (std::move (inp));
1366
1343
}
1367
1344
1368
1345
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa () const {
@@ -1505,13 +1482,17 @@ ggml_tensor * llm_graph_context::build_attn(
1505
1482
}
1506
1483
1507
1484
ggml_tensor * llm_graph_context::build_copy_mask_state (
1508
- ggml_cgraph * gf,
1509
- ggml_tensor * s,
1510
- ggml_tensor * state_copy,
1511
- ggml_tensor * state_mask,
1512
- int32_t n_state,
1513
- int32_t n_seqs) const {
1514
- const auto * kv_state = static_cast <const llama_kv_cache_recurrent_state *>(mstate);
1485
+ ggml_cgraph * gf,
1486
+ ggml_tensor * s,
1487
+ ggml_tensor * state_copy,
1488
+ ggml_tensor * state_mask,
1489
+ int32_t n_state,
1490
+ int32_t n_seqs,
1491
+ const llama_kv_cache_recurrent_state * kv_state) const {
1492
+
1493
+ if (kv_state == nullptr ) {
1494
+ kv_state = static_cast <const llama_kv_cache_recurrent_state *>(mstate);
1495
+ }
1515
1496
1516
1497
const auto n_kv = kv_state->get_n_kv ();
1517
1498
const auto kv_head = kv_state->get_head ();
0 commit comments