Skip to content

Commit 20f8e43

Browse files
committed
graph : add back hybrid memory graph input
But this time it contains the sub-cache graph inputs. This *should* make it easier to handle updating the inputs when caching the graph (eventually).
1 parent 4682e21 commit 20f8e43

File tree

3 files changed

+80
-22
lines changed

3 files changed

+80
-22
lines changed

src/llama-graph.cpp

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,11 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
335335
}
336336
}
337337

338+
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
339+
inp_attn->set_input(ubatch);
340+
inp_rs->set_input(ubatch);
341+
}
342+
338343
void llm_graph_input_one::set_input(const llama_ubatch * ubatch) {
339344
GGML_UNUSED(ubatch);
340345
GGML_ASSERT(one && ggml_nelements(one) == 1);
@@ -1147,17 +1152,20 @@ ggml_tensor * llm_graph_context::build_attn(
11471152
return cur;
11481153
}
11491154

1150-
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified(const llama_kv_cache_unified_context * mctx_cur) const {
1151-
if (!mctx_cur) {
1152-
mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
1153-
}
1155+
static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unified_impl(
1156+
ggml_context * ctx0,
1157+
const llama_ubatch & ubatch,
1158+
const llama_hparams & hparams,
1159+
const llama_cparams & cparams,
1160+
const llama_kv_cache_unified_context * mctx_cur) {
11541161

11551162
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
11561163

11571164
{
11581165
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
11591166

11601167
const auto n_kv = mctx_cur->get_n_kv();
1168+
const auto n_tokens = ubatch.n_tokens;
11611169

11621170
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
11631171
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
@@ -1168,6 +1176,14 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified(c
11681176
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
11691177
}
11701178

1179+
return inp;
1180+
}
1181+
1182+
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
1183+
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
1184+
1185+
auto inp = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
1186+
11711187
return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
11721188
}
11731189

@@ -1346,10 +1362,11 @@ ggml_tensor * llm_graph_context::build_attn(
13461362
return cur;
13471363
}
13481364

1349-
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa(const llama_kv_cache_unified_iswa_context * mctx_cur) const {
1350-
if (!mctx_cur) {
1351-
mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
1352-
}
1365+
// TODO: maybe separate the inner implementation into a separate function
1366+
// like with the non-sliding window equivalent
1367+
// once sliding-window hybrid caches are a thing.
1368+
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
1369+
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
13531370

13541371
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
13551372

@@ -1417,10 +1434,9 @@ ggml_tensor * llm_graph_context::build_rs(
14171434
return output_states;
14181435
}
14191436

1420-
llm_graph_input_rs * llm_graph_context::build_rs_inp(const llama_memory_recurrent_context * mctx_cur) const {
1421-
if (!mctx_cur) {
1422-
mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1423-
}
1437+
static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
1438+
ggml_context * ctx0,
1439+
const llama_memory_recurrent_context * mctx_cur) {
14241440

14251441
auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
14261442

@@ -1429,6 +1445,14 @@ llm_graph_input_rs * llm_graph_context::build_rs_inp(const llama_memory_recurren
14291445
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
14301446
ggml_set_input(inp->s_copy);
14311447

1448+
return inp;
1449+
}
1450+
1451+
llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
1452+
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1453+
1454+
auto inp = build_rs_inp_impl(ctx0, mctx_cur);
1455+
14321456
return (llm_graph_input_rs *) res->add_input(std::move(inp));
14331457
}
14341458

@@ -1486,6 +1510,17 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
14861510
);
14871511
}
14881512

1513+
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
1514+
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
1515+
1516+
auto inp_rs = build_rs_inp_impl(ctx0, mctx_cur->get_recr());
1517+
auto inp_attn = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
1518+
1519+
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);
1520+
1521+
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
1522+
}
1523+
14891524
void llm_graph_context::build_pooling(
14901525
ggml_cgraph * gf,
14911526
ggml_tensor * cls,

src/llama-graph.h

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,28 @@ class llm_graph_input_attn_cross : public llm_graph_input_i {
319319
const llama_cross * cross = nullptr;
320320
};
321321

322+
class llm_graph_input_mem_hybrid : public llm_graph_input_i {
323+
public:
324+
llm_graph_input_mem_hybrid(
325+
std::unique_ptr<llm_graph_input_attn_kv_unified> inp_attn,
326+
std::unique_ptr<llm_graph_input_rs> inp_rs,
327+
const llama_memory_hybrid_context * mctx) :
328+
inp_attn(std::move(inp_attn)),
329+
inp_rs(std::move(inp_rs)),
330+
mctx(mctx) { }
331+
virtual ~llm_graph_input_mem_hybrid() = default;
332+
333+
void set_input(const llama_ubatch * ubatch) override;
334+
335+
std::unique_ptr<llm_graph_input_attn_kv_unified> inp_attn;
336+
std::unique_ptr<llm_graph_input_rs> inp_rs;
337+
338+
llm_graph_input_attn_kv_unified * get_attn() const { return inp_attn.get(); }
339+
llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
340+
341+
const llama_memory_hybrid_context * mctx;
342+
};
343+
322344
// TODO: remove this when ggml_scale_add is implemented
323345
class llm_graph_input_one : public llm_graph_input_i {
324346
public:
@@ -575,7 +597,7 @@ struct llm_graph_context {
575597
float kq_scale,
576598
int il) const;
577599

578-
llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified(const llama_kv_cache_unified_context * mctx_cur = nullptr) const;
600+
llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified() const;
579601

580602
ggml_tensor * build_attn(
581603
llm_graph_input_attn_kv_unified * inp,
@@ -590,7 +612,7 @@ struct llm_graph_context {
590612
float kq_scale,
591613
int il) const;
592614

593-
llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa(const llama_kv_cache_unified_iswa_context * mctx_cur = nullptr) const;
615+
llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
594616

595617
// note: if k_cur or v_cur are not provided, they will not be stored in the memory
596618
ggml_tensor * build_attn(
@@ -643,7 +665,7 @@ struct llm_graph_context {
643665
int32_t rs_zero,
644666
const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
645667

646-
llm_graph_input_rs * build_rs_inp(const llama_memory_recurrent_context * mctx_cur = nullptr) const;
668+
llm_graph_input_rs * build_rs_inp() const;
647669

648670
ggml_tensor * build_rs(
649671
llm_graph_input_rs * inp,
@@ -663,6 +685,11 @@ struct llm_graph_context {
663685
ggml_tensor * token_shift,
664686
const llama_ubatch & ubatch,
665687
int il) const;
688+
//
689+
// hybrid
690+
//
691+
692+
llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
666693

667694
//
668695
// pooling

src/llama-model.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10220,11 +10220,7 @@ struct llm_build_jamba : public llm_graph_context_mamba {
1022010220
// {n_embd, n_tokens}
1022110221
inpL = build_inp_embd(model.tok_embd);
1022210222

10223-
const auto * mctx_hyb = static_cast<const llama_memory_hybrid_context *>(mctx);
10224-
10225-
auto * inp_rs = build_rs_inp(mctx_hyb->get_recr());
10226-
10227-
auto * inp_attn = build_attn_inp_kv_unified(mctx_hyb->get_attn());
10223+
auto * inp_hybrid = build_inp_mem_hybrid();
1022810224

1022910225
ggml_tensor * inp_out_ids = build_inp_out_ids();
1023010226

@@ -10235,7 +10231,7 @@ struct llm_build_jamba : public llm_graph_context_mamba {
1023510231
cb(cur, "attn_norm", il);
1023610232

1023710233
if (n_head_kv == 0) {
10238-
cur = build_mamba_layer(inp_rs, gf, cur, model, ubatch, il);
10234+
cur = build_mamba_layer(inp_hybrid->get_recr(), gf, cur, model, ubatch, il);
1023910235
} else {
1024010236
// Attention
1024110237

@@ -10256,7 +10252,7 @@ struct llm_build_jamba : public llm_graph_context_mamba {
1025610252
cb(Vcur, "Vcur", il);
1025710253

1025810254
// No RoPE :)
10259-
cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f/sqrtf(float(n_embd_head)), il);
10255+
cur = build_attn(inp_hybrid->get_attn(), gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f/sqrtf(float(n_embd_head)), il);
1026010256
}
1026110257

1026210258
if (il == n_layer - 1 && inp_out_ids) {

0 commit comments

Comments
 (0)