Skip to content

Commit bbf6c35

Browse files
committed
feat: Use common methods for accessing recurrent and unified caches in llama-graph
Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent a11e742 commit bbf6c35

File tree

2 files changed

+37
-8
lines changed

2 files changed

+37
-8
lines changed

src/llama-graph.cpp

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -954,7 +954,7 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
954954
}
955955

956956
ggml_tensor * llm_graph_context::build_inp_s_copy() const {
957-
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
957+
const llama_kv_cache_recurrent * kv_self = get_recurrent_cache();
958958

959959
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
960960

@@ -971,7 +971,7 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
971971
}
972972

973973
ggml_tensor * llm_graph_context::build_inp_s_mask() const {
974-
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
974+
const llama_kv_cache_recurrent * kv_self = get_recurrent_cache();
975975

976976
auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
977977

@@ -1025,7 +1025,7 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
10251025
}
10261026

10271027
ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
1028-
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1028+
const llama_kv_cache_unified * kv_self = get_unified_cache();
10291029

10301030
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);
10311031

@@ -1231,7 +1231,7 @@ ggml_tensor * llm_graph_context::build_attn(
12311231
}
12321232

12331233
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
1234-
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1234+
const llama_kv_cache_unified * kv_self = get_unified_cache();
12351235

12361236
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
12371237

@@ -1268,7 +1268,7 @@ ggml_tensor * llm_graph_context::build_attn(
12681268
ggml_build_forward_expand(gf, k_cur);
12691269
ggml_build_forward_expand(gf, v_cur);
12701270

1271-
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
1271+
const llama_kv_cache_unified * kv_self = get_unified_cache();
12721272

12731273
// store to KV cache
12741274
{
@@ -1439,14 +1439,38 @@ ggml_tensor * llm_graph_context::build_attn(
14391439
return cur;
14401440
}
14411441

1442+
const llama_kv_cache_recurrent * llm_graph_context::get_recurrent_cache() const {
1443+
const llama_kv_cache_recurrent * kv_self = dynamic_cast<const llama_kv_cache_recurrent *>(memory);
1444+
if (!kv_self) {
1445+
const llama_kv_cache_hybrid_recurrent * kv_hybrid = dynamic_cast<const llama_kv_cache_hybrid_recurrent *>(memory);
1446+
if (kv_hybrid) {
1447+
kv_self = kv_hybrid->get_kv_recurrent();
1448+
}
1449+
}
1450+
GGML_ASSERT(kv_self);
1451+
return kv_self;
1452+
}
1453+
1454+
const llama_kv_cache_unified * llm_graph_context::get_unified_cache() const {
1455+
const llama_kv_cache_unified * kv_self = dynamic_cast<const llama_kv_cache_unified *>(memory);
1456+
if (!kv_self) {
1457+
const llama_kv_cache_hybrid_recurrent * kv_hybrid = dynamic_cast<const llama_kv_cache_hybrid_recurrent *>(memory);
1458+
if (kv_hybrid) {
1459+
kv_self = kv_hybrid->get_kv_attn();
1460+
}
1461+
}
1462+
GGML_ASSERT(kv_self);
1463+
return kv_self;
1464+
}
1465+
14421466
ggml_tensor * llm_graph_context::build_copy_mask_state(
14431467
ggml_cgraph * gf,
14441468
ggml_tensor * s,
14451469
ggml_tensor * state_copy,
14461470
ggml_tensor * state_mask,
14471471
int32_t n_state,
14481472
int32_t n_seqs) const {
1449-
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1473+
const llama_kv_cache_recurrent * kv_self = get_recurrent_cache();
14501474

14511475
const auto n_kv = kv_self->n;
14521476
const auto kv_head = kv_self->head;
@@ -1478,7 +1502,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
14781502
ggml_tensor * state_mask,
14791503
const llama_ubatch & ubatch,
14801504
int il) const {
1481-
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1505+
const llama_kv_cache_recurrent * kv_self = get_recurrent_cache();
14821506

14831507
const auto token_shift_count = hparams.token_shift_count;
14841508

@@ -1499,7 +1523,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
14991523
ggml_tensor * token_shift,
15001524
const llama_ubatch & ubatch,
15011525
int il) const {
1502-
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
1526+
const llama_kv_cache_recurrent * kv_self = get_recurrent_cache();
15031527

15041528
const auto token_shift_count = hparams.token_shift_count;
15051529
const auto n_embd = hparams.n_embd;

src/llama-graph.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,11 @@ struct llm_graph_context {
604604
// recurrent
605605
//
606606

607+
// Getters to support hybrid cache
608+
// TODO: Should these be protected to the derived class hierachy?
609+
const llama_kv_cache_recurrent * get_recurrent_cache() const;
610+
const llama_kv_cache_unified * get_unified_cache() const;
611+
607612
ggml_tensor * build_copy_mask_state(
608613
ggml_cgraph * gf,
609614
ggml_tensor * s,

0 commit comments

Comments
 (0)