@@ -413,13 +413,6 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
413
413
}
414
414
}
415
415
416
- llm_graph_input_attn_kv_hybrid_recurrent::llm_graph_input_attn_kv_hybrid_recurrent (
417
- const llama_hparams & hparams,
418
- const llama_cparams & cparams,
419
- const llama_kv_cache_hybrid_recurrent_state * kv_state) :
420
- llm_graph_input_attn_kv_unified(hparams, cparams, kv_state->get_state_attn ()) {
421
- }
422
-
423
416
//
424
417
// llm_graph_context
425
418
//
@@ -1294,7 +1287,9 @@ ggml_tensor * llm_graph_context::build_attn(
1294
1287
ggml_build_forward_expand (gf, k_cur);
1295
1288
ggml_build_forward_expand (gf, v_cur);
1296
1289
1297
- const auto * kv_state = static_cast <const llama_kv_cache_unified_state *>(mstate);
1290
+ // NOTE: For hybrid caches, this may be a child of mstate, so we use the one
1291
+ // encapsulated in inp
1292
+ const auto * kv_state = inp->kv_state ;
1298
1293
1299
1294
// store to KV cache
1300
1295
{
@@ -1326,10 +1321,10 @@ ggml_tensor * llm_graph_context::build_attn(
1326
1321
return cur;
1327
1322
}
1328
1323
1329
- llm_graph_input_attn_kv_hybrid_recurrent * llm_graph_context::build_attn_inp_kv_hybrid_recurrent () const {
1324
+ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_hybrid_recurrent () const {
1330
1325
const auto * kv_state = static_cast <const llama_kv_cache_hybrid_recurrent_state *>(mstate);
1331
1326
1332
- auto inp = std::make_unique<llm_graph_input_attn_kv_hybrid_recurrent >(hparams, cparams, kv_state);
1327
+ auto inp = std::make_unique<llm_graph_input_attn_kv_unified >(hparams, cparams, kv_state-> get_state_attn () );
1333
1328
1334
1329
{
1335
1330
GGML_ASSERT (hparams.swa_type == LLAMA_SWA_TYPE_NONE && " Hybrid recurrent is not supported with SWA attention layers" );
@@ -1343,25 +1338,7 @@ llm_graph_input_attn_kv_hybrid_recurrent * llm_graph_context::build_attn_inp_kv_
1343
1338
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
1344
1339
}
1345
1340
1346
- return (llm_graph_input_attn_kv_hybrid_recurrent *) res->add_input (std::move (inp));
1347
- }
1348
-
1349
- ggml_tensor * llm_graph_context::build_attn (
1350
- llm_graph_input_attn_kv_hybrid_recurrent * inp,
1351
- ggml_cgraph * gf,
1352
- ggml_tensor * wo,
1353
- ggml_tensor * wo_b,
1354
- ggml_tensor * q_cur,
1355
- ggml_tensor * k_cur,
1356
- ggml_tensor * v_cur,
1357
- ggml_tensor * kq_b,
1358
- ggml_tensor * v_mla,
1359
- float kq_scale,
1360
- int il) const {
1361
- return build_attn (
1362
- static_cast <llm_graph_input_attn_kv_unified *>(inp),
1363
- gf, wo, wo_b, q_cur, k_cur, v_cur, kq_b, v_mla, kq_scale, il
1364
- );
1341
+ return (llm_graph_input_attn_kv_unified *) res->add_input (std::move (inp));
1365
1342
}
1366
1343
1367
1344
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa () const {
@@ -1504,13 +1481,17 @@ ggml_tensor * llm_graph_context::build_attn(
1504
1481
}
1505
1482
1506
1483
ggml_tensor * llm_graph_context::build_copy_mask_state (
1507
- ggml_cgraph * gf,
1508
- ggml_tensor * s,
1509
- ggml_tensor * state_copy,
1510
- ggml_tensor * state_mask,
1511
- int32_t n_state,
1512
- int32_t n_seqs) const {
1513
- const auto * kv_state = static_cast <const llama_kv_cache_recurrent_state *>(mstate);
1484
+ ggml_cgraph * gf,
1485
+ ggml_tensor * s,
1486
+ ggml_tensor * state_copy,
1487
+ ggml_tensor * state_mask,
1488
+ int32_t n_state,
1489
+ int32_t n_seqs,
1490
+ const llama_kv_cache_recurrent_state * kv_state) const {
1491
+
1492
+ if (kv_state == nullptr ) {
1493
+ kv_state = static_cast <const llama_kv_cache_recurrent_state *>(mstate);
1494
+ }
1514
1495
1515
1496
const auto n_kv = kv_state->get_n_kv ();
1516
1497
const auto kv_head = kv_state->get_head ();
0 commit comments