@@ -954,7 +954,7 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
954
954
}
955
955
956
956
ggml_tensor * llm_graph_context::build_inp_s_copy () const {
957
- const llama_kv_cache_recurrent * kv_self = static_cast < const llama_kv_cache_recurrent *>(memory );
957
+ const llama_kv_cache_recurrent * kv_self = get_recurrent_cache ( );
958
958
959
959
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
960
960
@@ -971,7 +971,7 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
971
971
}
972
972
973
973
ggml_tensor * llm_graph_context::build_inp_s_mask () const {
974
- const llama_kv_cache_recurrent * kv_self = static_cast < const llama_kv_cache_recurrent *>(memory );
974
+ const llama_kv_cache_recurrent * kv_self = get_recurrent_cache ( );
975
975
976
976
auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
977
977
@@ -1025,7 +1025,7 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
1025
1025
}
1026
1026
1027
1027
ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec () const {
1028
- const llama_kv_cache_unified * kv_self = static_cast < const llama_kv_cache_unified *>(memory );
1028
+ const llama_kv_cache_unified * kv_self = get_unified_cache ( );
1029
1029
1030
1030
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);
1031
1031
@@ -1231,7 +1231,7 @@ ggml_tensor * llm_graph_context::build_attn(
1231
1231
}
1232
1232
1233
1233
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified () const {
1234
- const llama_kv_cache_unified * kv_self = static_cast < const llama_kv_cache_unified *>(memory );
1234
+ const llama_kv_cache_unified * kv_self = get_unified_cache ( );
1235
1235
1236
1236
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
1237
1237
@@ -1268,7 +1268,7 @@ ggml_tensor * llm_graph_context::build_attn(
1268
1268
ggml_build_forward_expand (gf, k_cur);
1269
1269
ggml_build_forward_expand (gf, v_cur);
1270
1270
1271
- const llama_kv_cache_unified * kv_self = static_cast < const llama_kv_cache_unified *>(memory );
1271
+ const llama_kv_cache_unified * kv_self = get_unified_cache ( );
1272
1272
1273
1273
// store to KV cache
1274
1274
{
@@ -1439,14 +1439,38 @@ ggml_tensor * llm_graph_context::build_attn(
1439
1439
return cur;
1440
1440
}
1441
1441
1442
+ const llama_kv_cache_recurrent * llm_graph_context::get_recurrent_cache () const {
1443
+ const llama_kv_cache_recurrent * kv_self = dynamic_cast <const llama_kv_cache_recurrent *>(memory);
1444
+ if (!kv_self) {
1445
+ const llama_kv_cache_hybrid_recurrent * kv_hybrid = dynamic_cast <const llama_kv_cache_hybrid_recurrent *>(memory);
1446
+ if (kv_hybrid) {
1447
+ kv_self = kv_hybrid->get_kv_recurrent ();
1448
+ }
1449
+ }
1450
+ GGML_ASSERT (kv_self);
1451
+ return kv_self;
1452
+ }
1453
+
1454
+ const llama_kv_cache_unified * llm_graph_context::get_unified_cache () const {
1455
+ const llama_kv_cache_unified * kv_self = dynamic_cast <const llama_kv_cache_unified *>(memory);
1456
+ if (!kv_self) {
1457
+ const llama_kv_cache_hybrid_recurrent * kv_hybrid = dynamic_cast <const llama_kv_cache_hybrid_recurrent *>(memory);
1458
+ if (kv_hybrid) {
1459
+ kv_self = kv_hybrid->get_kv_attn ();
1460
+ }
1461
+ }
1462
+ GGML_ASSERT (kv_self);
1463
+ return kv_self;
1464
+ }
1465
+
1442
1466
ggml_tensor * llm_graph_context::build_copy_mask_state (
1443
1467
ggml_cgraph * gf,
1444
1468
ggml_tensor * s,
1445
1469
ggml_tensor * state_copy,
1446
1470
ggml_tensor * state_mask,
1447
1471
int32_t n_state,
1448
1472
int32_t n_seqs) const {
1449
- const llama_kv_cache_recurrent * kv_self = static_cast < const llama_kv_cache_recurrent *>(memory );
1473
+ const llama_kv_cache_recurrent * kv_self = get_recurrent_cache ( );
1450
1474
1451
1475
const auto n_kv = kv_self->n ;
1452
1476
const auto kv_head = kv_self->head ;
@@ -1478,7 +1502,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1478
1502
ggml_tensor * state_mask,
1479
1503
const llama_ubatch & ubatch,
1480
1504
int il) const {
1481
- const llama_kv_cache_recurrent * kv_self = static_cast < const llama_kv_cache_recurrent *>(memory );
1505
+ const llama_kv_cache_recurrent * kv_self = get_recurrent_cache ( );
1482
1506
1483
1507
const auto token_shift_count = hparams.token_shift_count ;
1484
1508
@@ -1499,7 +1523,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
1499
1523
ggml_tensor * token_shift,
1500
1524
const llama_ubatch & ubatch,
1501
1525
int il) const {
1502
- const llama_kv_cache_recurrent * kv_self = static_cast < const llama_kv_cache_recurrent *>(memory );
1526
+ const llama_kv_cache_recurrent * kv_self = get_recurrent_cache ( );
1503
1527
1504
1528
const auto token_shift_count = hparams.token_shift_count ;
1505
1529
const auto n_embd = hparams.n_embd ;
0 commit comments