Skip to content

Commit 0893b4c

Browse files
committed
feat: Support hybrid recurrent in llama-graph
NOTE: I intentionally did not add support for s_mask since it will be going away soon Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 2fa7b32 commit 0893b4c

File tree

2 files changed

+78
-4
lines changed

2 files changed

+78
-4
lines changed

src/llama-graph.cpp

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "llama-kv-cache-unified.h"
88
#include "llama-kv-cache-unified-iswa.h"
99
#include "llama-kv-cache-recurrent.h"
10+
#include "llama-kv-cache-hybrid-recurrent.h"
1011

1112
#include <cassert>
1213
#include <cmath>
@@ -412,6 +413,13 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
412413
}
413414
}
414415

416+
llm_graph_input_attn_kv_hybrid_recurrent::llm_graph_input_attn_kv_hybrid_recurrent(
417+
const llama_hparams & hparams,
418+
const llama_cparams & cparams,
419+
const llama_kv_cache_hybrid_recurrent_state * kv_state) :
420+
llm_graph_input_attn_kv_unified(hparams, cparams, kv_state->get_state_attn()) {
421+
}
422+
415423
//
416424
// llm_graph_context
417425
//
@@ -955,8 +963,10 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
955963
return cur;
956964
}
957965

958-
ggml_tensor * llm_graph_context::build_inp_s_copy() const {
959-
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
966+
ggml_tensor * llm_graph_context::build_inp_s_copy(const llama_kv_cache_recurrent_state * kv_state) const {
967+
if (kv_state == nullptr) {
968+
kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
969+
}
960970

961971
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
962972

@@ -1302,6 +1312,44 @@ ggml_tensor * llm_graph_context::build_attn(
13021312
return cur;
13031313
}
13041314

1315+
llm_graph_input_attn_kv_hybrid_recurrent * llm_graph_context::build_attn_inp_kv_hybrid_recurrent() const {
1316+
const auto * kv_state = static_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate);
1317+
1318+
auto inp = std::make_unique<llm_graph_input_attn_kv_hybrid_recurrent>(hparams, cparams, kv_state);
1319+
1320+
{
1321+
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
1322+
1323+
const auto n_kv = kv_state->get_state_attn()->get_n_kv();
1324+
1325+
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1326+
//cb(inp->self_kq_mask, "KQ_mask", -1);
1327+
ggml_set_input(inp->self_kq_mask);
1328+
1329+
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1330+
}
1331+
1332+
return (llm_graph_input_attn_kv_hybrid_recurrent *) res->add_input(std::move(inp));
1333+
}
1334+
1335+
ggml_tensor * llm_graph_context::build_attn(
1336+
llm_graph_input_attn_kv_hybrid_recurrent * inp,
1337+
ggml_cgraph * gf,
1338+
ggml_tensor * wo,
1339+
ggml_tensor * wo_b,
1340+
ggml_tensor * q_cur,
1341+
ggml_tensor * k_cur,
1342+
ggml_tensor * v_cur,
1343+
ggml_tensor * kq_b,
1344+
ggml_tensor * v_mla,
1345+
float kq_scale,
1346+
int il) const {
1347+
return build_attn(
1348+
static_cast<llm_graph_input_attn_kv_unified *>(inp),
1349+
gf, wo, wo_b, q_cur, k_cur, v_cur, kq_b, v_mla, kq_scale, il
1350+
);
1351+
}
1352+
13051353
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
13061354
const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
13071355

src/llama-graph.h

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class llama_memory_state_i;
2222
class llama_kv_cache_unified_state;
2323
class llama_kv_cache_unified_iswa_state;
2424
class llama_kv_cache_recurrent_state;
25+
class llama_kv_cache_hybrid_recurrent_state;
2526

2627
// certain models (typically multi-modal) can produce different types of graphs
2728
enum llm_graph_type {
@@ -253,7 +254,7 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
253254
cparams(cparams),
254255
kv_state(kv_state) {
255256
}
256-
~llm_graph_input_attn_kv_unified() = default;
257+
virtual ~llm_graph_input_attn_kv_unified() = default;
257258

258259
void set_input(const llama_ubatch * ubatch) override;
259260

@@ -296,6 +297,16 @@ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
296297
const llama_kv_cache_unified_iswa_state * kv_state;
297298
};
298299

300+
class llm_graph_input_attn_kv_hybrid_recurrent : public llm_graph_input_attn_kv_unified {
301+
public:
302+
llm_graph_input_attn_kv_hybrid_recurrent(
303+
const llama_hparams & hparams,
304+
const llama_cparams & cparams,
305+
const llama_kv_cache_hybrid_recurrent_state * kv_state);
306+
307+
virtual ~llm_graph_input_attn_kv_hybrid_recurrent() = default;
308+
};
309+
299310
class llm_graph_input_attn_cross : public llm_graph_input_i {
300311
public:
301312
llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {}
@@ -519,7 +530,7 @@ struct llm_graph_context {
519530
ggml_tensor * build_inp_out_ids() const;
520531
ggml_tensor * build_inp_mean() const;
521532
ggml_tensor * build_inp_cls() const;
522-
ggml_tensor * build_inp_s_copy() const;
533+
ggml_tensor * build_inp_s_copy(const llama_kv_cache_recurrent_state * kv_state = nullptr) const;
523534
ggml_tensor * build_inp_s_mask() const;
524535

525536
ggml_tensor * build_inp_cross_embd() const;
@@ -586,6 +597,21 @@ struct llm_graph_context {
586597
float kq_scale,
587598
int il) const;
588599

600+
llm_graph_input_attn_kv_hybrid_recurrent * build_attn_inp_kv_hybrid_recurrent() const;
601+
602+
ggml_tensor * build_attn(
603+
llm_graph_input_attn_kv_hybrid_recurrent * inp,
604+
ggml_cgraph * gf,
605+
ggml_tensor * wo,
606+
ggml_tensor * wo_b,
607+
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
608+
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
609+
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
610+
ggml_tensor * kq_b,
611+
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
612+
float kq_scale,
613+
int il) const;
614+
589615
llm_graph_input_attn_cross * build_attn_inp_cross() const;
590616

591617
ggml_tensor * build_attn(

0 commit comments

Comments
 (0)