@@ -397,13 +397,6 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
397
397
}
398
398
}
399
399
400
- llm_graph_input_attn_kv_hybrid_recurrent::llm_graph_input_attn_kv_hybrid_recurrent (
401
- const llama_hparams & hparams,
402
- const llama_cparams & cparams,
403
- const llama_kv_cache_hybrid_recurrent_state * kv_state) :
404
- llm_graph_input_attn_kv_unified(hparams, cparams, kv_state->get_state_attn ()) {
405
- }
406
-
407
400
//
408
401
// llm_graph_context
409
402
//
@@ -1262,7 +1255,9 @@ ggml_tensor * llm_graph_context::build_attn(
1262
1255
ggml_build_forward_expand (gf, k_cur);
1263
1256
ggml_build_forward_expand (gf, v_cur);
1264
1257
1265
- const auto * kv_state = static_cast <const llama_kv_cache_unified_state *>(mstate);
1258
+ // NOTE: For hybrid caches, this may be a child of mstate, so we use the one
1259
+ // encapsulated in inp
1260
+ const auto * kv_state = inp->kv_state ;
1266
1261
1267
1262
// store to KV cache
1268
1263
{
@@ -1294,10 +1289,10 @@ ggml_tensor * llm_graph_context::build_attn(
1294
1289
return cur;
1295
1290
}
1296
1291
1297
- llm_graph_input_attn_kv_hybrid_recurrent * llm_graph_context::build_attn_inp_kv_hybrid_recurrent () const {
1292
+ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_hybrid_recurrent () const {
1298
1293
const auto * kv_state = static_cast <const llama_kv_cache_hybrid_recurrent_state *>(mstate);
1299
1294
1300
- auto inp = std::make_unique<llm_graph_input_attn_kv_hybrid_recurrent >(hparams, cparams, kv_state);
1295
+ auto inp = std::make_unique<llm_graph_input_attn_kv_unified >(hparams, cparams, kv_state-> get_state_attn () );
1301
1296
1302
1297
{
1303
1298
GGML_ASSERT (hparams.swa_type == LLAMA_SWA_TYPE_NONE && " Hybrid recurrent is not supported with SWA attention layers" );
@@ -1311,25 +1306,7 @@ llm_graph_input_attn_kv_hybrid_recurrent * llm_graph_context::build_attn_inp_kv_
1311
1306
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
1312
1307
}
1313
1308
1314
- return (llm_graph_input_attn_kv_hybrid_recurrent *) res->add_input (std::move (inp));
1315
- }
1316
-
1317
- ggml_tensor * llm_graph_context::build_attn (
1318
- llm_graph_input_attn_kv_hybrid_recurrent * inp,
1319
- ggml_cgraph * gf,
1320
- ggml_tensor * wo,
1321
- ggml_tensor * wo_b,
1322
- ggml_tensor * q_cur,
1323
- ggml_tensor * k_cur,
1324
- ggml_tensor * v_cur,
1325
- ggml_tensor * kq_b,
1326
- ggml_tensor * v_mla,
1327
- float kq_scale,
1328
- int il) const {
1329
- return build_attn (
1330
- static_cast <llm_graph_input_attn_kv_unified *>(inp),
1331
- gf, wo, wo_b, q_cur, k_cur, v_cur, kq_b, v_mla, kq_scale, il
1332
- );
1309
+ return (llm_graph_input_attn_kv_unified *) res->add_input (std::move (inp));
1333
1310
}
1334
1311
1335
1312
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa () const {
@@ -1472,13 +1449,17 @@ ggml_tensor * llm_graph_context::build_attn(
1472
1449
}
1473
1450
1474
1451
ggml_tensor * llm_graph_context::build_recurrent_state (
1475
- ggml_cgraph * gf,
1476
- ggml_tensor * s,
1477
- ggml_tensor * state_copy,
1478
- int32_t state_size,
1479
- int32_t n_seqs,
1480
- bool avoid_copies) const {
1481
- const auto * kv_state = static_cast <const llama_kv_cache_recurrent_state *>(mstate);
1452
+ ggml_cgraph * gf,
1453
+ ggml_tensor * s,
1454
+ ggml_tensor * state_copy,
1455
+ int32_t state_size,
1456
+ int32_t n_seqs,
1457
+ bool avoid_copies,
1458
+ const llama_kv_cache_recurrent_state * kv_state) const {
1459
+
1460
+ if (kv_state == nullptr ) {
1461
+ kv_state = static_cast <const llama_kv_cache_recurrent_state *>(mstate);
1462
+ }
1482
1463
1483
1464
const auto n_kv = kv_state->get_n_kv ();
1484
1465
const auto kv_head = kv_state->get_head ();
0 commit comments