Skip to content

Commit 5e4451f

Browse files
committed
fix: Fix logic for initializing inputs and attn layers for hybrid caches
Branch: GraniteFour Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 007c005 commit 5e4451f

File tree

2 files changed

+25
-66
lines changed

2 files changed

+25
-66
lines changed

src/llama-graph.cpp

Lines changed: 17 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -413,13 +413,6 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
413413
}
414414
}
415415

416-
llm_graph_input_attn_kv_hybrid_recurrent::llm_graph_input_attn_kv_hybrid_recurrent(
417-
const llama_hparams & hparams,
418-
const llama_cparams & cparams,
419-
const llama_kv_cache_hybrid_recurrent_state * kv_state) :
420-
llm_graph_input_attn_kv_unified(hparams, cparams, kv_state->get_state_attn()) {
421-
}
422-
423416
//
424417
// llm_graph_context
425418
//
@@ -1280,7 +1273,9 @@ ggml_tensor * llm_graph_context::build_attn(
12801273
ggml_build_forward_expand(gf, k_cur);
12811274
ggml_build_forward_expand(gf, v_cur);
12821275

1283-
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
1276+
// NOTE: For hybrid caches, this may be a child of mstate, so we use the one
1277+
// encapsulated in inp
1278+
const auto * kv_state = inp->kv_state;
12841279

12851280
// store to KV cache
12861281
{
@@ -1312,10 +1307,10 @@ ggml_tensor * llm_graph_context::build_attn(
13121307
return cur;
13131308
}
13141309

1315-
llm_graph_input_attn_kv_hybrid_recurrent * llm_graph_context::build_attn_inp_kv_hybrid_recurrent() const {
1310+
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_hybrid_recurrent() const {
13161311
const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate);
13171312

1318-
auto inp = std::make_unique<llm_graph_input_attn_kv_hybrid_recurrent>(hparams, cparams, kv_state);
1313+
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_state->get_state_attn());
13191314

13201315
{
13211316
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
@@ -1329,25 +1324,7 @@ llm_graph_input_attn_kv_hybrid_recurrent * llm_graph_context::build_attn_inp_kv_
13291324
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
13301325
}
13311326

1332-
return (llm_graph_input_attn_kv_hybrid_recurrent *) res->add_input(std::move(inp));
1333-
}
1334-
1335-
ggml_tensor * llm_graph_context::build_attn(
1336-
llm_graph_input_attn_kv_hybrid_recurrent * inp,
1337-
ggml_cgraph * gf,
1338-
ggml_tensor * wo,
1339-
ggml_tensor * wo_b,
1340-
ggml_tensor * q_cur,
1341-
ggml_tensor * k_cur,
1342-
ggml_tensor * v_cur,
1343-
ggml_tensor * kq_b,
1344-
ggml_tensor * v_mla,
1345-
float kq_scale,
1346-
int il) const {
1347-
return build_attn(
1348-
static_cast<llm_graph_input_attn_kv_unified *>(inp),
1349-
gf, wo, wo_b, q_cur, k_cur, v_cur, kq_b, v_mla, kq_scale, il
1350-
);
1327+
return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
13511328
}
13521329

13531330
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
@@ -1490,13 +1467,17 @@ ggml_tensor * llm_graph_context::build_attn(
14901467
}
14911468

14921469
ggml_tensor * llm_graph_context::build_copy_mask_state(
1493-
ggml_cgraph * gf,
1494-
ggml_tensor * s,
1495-
ggml_tensor * state_copy,
1496-
ggml_tensor * state_mask,
1497-
int32_t n_state,
1498-
int32_t n_seqs) const {
1499-
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
1470+
ggml_cgraph * gf,
1471+
ggml_tensor * s,
1472+
ggml_tensor * state_copy,
1473+
ggml_tensor * state_mask,
1474+
int32_t n_state,
1475+
int32_t n_seqs,
1476+
const llama_kv_cache_recurrent_state * kv_state) const {
1477+
1478+
if (kv_state == nullptr) {
1479+
kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
1480+
}
15001481

15011482
const auto n_kv = kv_state->get_n_kv();
15021483
const auto kv_head = kv_state->get_head();

src/llama-graph.h

Lines changed: 8 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -297,16 +297,6 @@ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
297297
const llama_kv_cache_unified_iswa_state * kv_state;
298298
};
299299

300-
class llm_graph_input_attn_kv_hybrid_recurrent : public llm_graph_input_attn_kv_unified {
301-
public:
302-
llm_graph_input_attn_kv_hybrid_recurrent(
303-
const llama_hparams & hparams,
304-
const llama_cparams & cparams,
305-
const llama_kv_cache_hybrid_recurrent_state * kv_state);
306-
307-
virtual ~llm_graph_input_attn_kv_hybrid_recurrent() = default;
308-
};
309-
310300
class llm_graph_input_attn_cross : public llm_graph_input_i {
311301
public:
312302
llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {}
@@ -597,20 +587,7 @@ struct llm_graph_context {
597587
float kq_scale,
598588
int il) const;
599589

600-
llm_graph_input_attn_kv_hybrid_recurrent * build_attn_inp_kv_hybrid_recurrent() const;
601-
602-
ggml_tensor * build_attn(
603-
llm_graph_input_attn_kv_hybrid_recurrent * inp,
604-
ggml_cgraph * gf,
605-
ggml_tensor * wo,
606-
ggml_tensor * wo_b,
607-
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
608-
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
609-
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
610-
ggml_tensor * kq_b,
611-
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
612-
float kq_scale,
613-
int il) const;
590+
llm_graph_input_attn_kv_unified * build_attn_inp_kv_hybrid_recurrent() const;
614591

615592
llm_graph_input_attn_cross * build_attn_inp_cross() const;
616593

@@ -632,12 +609,13 @@ struct llm_graph_context {
632609
//
633610

634611
ggml_tensor * build_copy_mask_state(
635-
ggml_cgraph * gf,
636-
ggml_tensor * s,
637-
ggml_tensor * state_copy,
638-
ggml_tensor * state_mask,
639-
int32_t n_state,
640-
int32_t n_seqs) const;
612+
ggml_cgraph * gf,
613+
ggml_tensor * s,
614+
ggml_tensor * state_copy,
615+
ggml_tensor * state_mask,
616+
int32_t n_state,
617+
int32_t n_seqs,
618+
const llama_kv_cache_recurrent_state * kv_state = nullptr) const;
641619

642620
ggml_tensor * build_rwkv_token_shift_load(
643621
ggml_cgraph * gf,

0 commit comments

Comments
 (0)