Skip to content

Commit be2ba0f

Browse files
committed
fix: Fix logic for initializing inputs and attn layers for hybrid caches
Branch: GraniteFour Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 8bac844 commit be2ba0f

File tree

2 files changed

+25
-66
lines changed

2 files changed

+25
-66
lines changed

src/llama-graph.cpp

Lines changed: 17 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -397,13 +397,6 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
397397
}
398398
}
399399

400-
llm_graph_input_attn_kv_hybrid_recurrent::llm_graph_input_attn_kv_hybrid_recurrent(
401-
const llama_hparams & hparams,
402-
const llama_cparams & cparams,
403-
const llama_kv_cache_hybrid_recurrent_state * kv_state) :
404-
llm_graph_input_attn_kv_unified(hparams, cparams, kv_state->get_state_attn()) {
405-
}
406-
407400
//
408401
// llm_graph_context
409402
//
@@ -1262,7 +1255,9 @@ ggml_tensor * llm_graph_context::build_attn(
12621255
ggml_build_forward_expand(gf, k_cur);
12631256
ggml_build_forward_expand(gf, v_cur);
12641257

1265-
const auto * kv_state = static_cast<const llama_kv_cache_unified_state *>(mstate);
1258+
// NOTE: For hybrid caches, this may be a child of mstate, so we use the one
1259+
// encapsulated in inp
1260+
const auto * kv_state = inp->kv_state;
12661261

12671262
// store to KV cache
12681263
{
@@ -1294,10 +1289,10 @@ ggml_tensor * llm_graph_context::build_attn(
12941289
return cur;
12951290
}
12961291

1297-
llm_graph_input_attn_kv_hybrid_recurrent * llm_graph_context::build_attn_inp_kv_hybrid_recurrent() const {
1292+
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_hybrid_recurrent() const {
12981293
const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate);
12991294

1300-
auto inp = std::make_unique<llm_graph_input_attn_kv_hybrid_recurrent>(hparams, cparams, kv_state);
1295+
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_state->get_state_attn());
13011296

13021297
{
13031298
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
@@ -1311,25 +1306,7 @@ llm_graph_input_attn_kv_hybrid_recurrent * llm_graph_context::build_attn_inp_kv_
13111306
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
13121307
}
13131308

1314-
return (llm_graph_input_attn_kv_hybrid_recurrent *) res->add_input(std::move(inp));
1315-
}
1316-
1317-
ggml_tensor * llm_graph_context::build_attn(
1318-
llm_graph_input_attn_kv_hybrid_recurrent * inp,
1319-
ggml_cgraph * gf,
1320-
ggml_tensor * wo,
1321-
ggml_tensor * wo_b,
1322-
ggml_tensor * q_cur,
1323-
ggml_tensor * k_cur,
1324-
ggml_tensor * v_cur,
1325-
ggml_tensor * kq_b,
1326-
ggml_tensor * v_mla,
1327-
float kq_scale,
1328-
int il) const {
1329-
return build_attn(
1330-
static_cast<llm_graph_input_attn_kv_unified *>(inp),
1331-
gf, wo, wo_b, q_cur, k_cur, v_cur, kq_b, v_mla, kq_scale, il
1332-
);
1309+
return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
13331310
}
13341311

13351312
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
@@ -1472,13 +1449,17 @@ ggml_tensor * llm_graph_context::build_attn(
14721449
}
14731450

14741451
ggml_tensor * llm_graph_context::build_recurrent_state(
1475-
ggml_cgraph * gf,
1476-
ggml_tensor * s,
1477-
ggml_tensor * state_copy,
1478-
int32_t state_size,
1479-
int32_t n_seqs,
1480-
bool avoid_copies) const {
1481-
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
1452+
ggml_cgraph * gf,
1453+
ggml_tensor * s,
1454+
ggml_tensor * state_copy,
1455+
int32_t state_size,
1456+
int32_t n_seqs,
1457+
bool avoid_copies,
1458+
const llama_kv_cache_recurrent_state * kv_state) const {
1459+
1460+
if (kv_state == nullptr) {
1461+
kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
1462+
}
14821463

14831464
const auto n_kv = kv_state->get_n_kv();
14841465
const auto kv_head = kv_state->get_head();

src/llama-graph.h

Lines changed: 8 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -286,16 +286,6 @@ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
286286
const llama_kv_cache_unified_iswa_state * kv_state;
287287
};
288288

289-
class llm_graph_input_attn_kv_hybrid_recurrent : public llm_graph_input_attn_kv_unified {
290-
public:
291-
llm_graph_input_attn_kv_hybrid_recurrent(
292-
const llama_hparams & hparams,
293-
const llama_cparams & cparams,
294-
const llama_kv_cache_hybrid_recurrent_state * kv_state);
295-
296-
virtual ~llm_graph_input_attn_kv_hybrid_recurrent() = default;
297-
};
298-
299289
class llm_graph_input_attn_cross : public llm_graph_input_i {
300290
public:
301291
llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {}
@@ -585,20 +575,7 @@ struct llm_graph_context {
585575
float kq_scale,
586576
int il) const;
587577

588-
llm_graph_input_attn_kv_hybrid_recurrent * build_attn_inp_kv_hybrid_recurrent() const;
589-
590-
ggml_tensor * build_attn(
591-
llm_graph_input_attn_kv_hybrid_recurrent * inp,
592-
ggml_cgraph * gf,
593-
ggml_tensor * wo,
594-
ggml_tensor * wo_b,
595-
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
596-
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
597-
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
598-
ggml_tensor * kq_b,
599-
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
600-
float kq_scale,
601-
int il) const;
578+
llm_graph_input_attn_kv_unified * build_attn_inp_kv_hybrid_recurrent() const;
602579

603580
llm_graph_input_attn_cross * build_attn_inp_cross() const;
604581

@@ -620,12 +597,13 @@ struct llm_graph_context {
620597
//
621598

622599
ggml_tensor * build_recurrent_state(
623-
ggml_cgraph * gf,
624-
ggml_tensor * s,
625-
ggml_tensor * state_copy,
626-
int32_t state_size,
627-
int32_t n_seqs,
628-
bool avoid_copies = false) const;
600+
ggml_cgraph * gf,
601+
ggml_tensor * s,
602+
ggml_tensor * state_copy,
603+
int32_t state_size,
604+
int32_t n_seqs,
605+
bool avoid_copies = false,
606+
const llama_kv_cache_recurrent_state * kv_state = nullptr) const;
629607

630608
ggml_tensor * build_rwkv_token_shift_load(
631609
ggml_cgraph * gf,

0 commit comments

Comments
 (0)