Skip to content

Commit 6bd7a3d

Browse files
gabe-l-hartggerganov
authored andcommitted
memory : Hybrid recurrent cache (ggml-org#13979)
* feat: Add llama_model_is_hybrid API call Also, split llama_model_is_recurrent into llm_arch_is_recurrent in llama-arch with llama_model_is_recurrent delegating to llm_arch_is_recurrent. The same split is done for hybird. This is needed because there are places where the llama_model has not yet been initialized but we need to check if the model is recurrent (specifically for the per-layer recurrent check array in hparams). Branch: GraniteFour Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Add c++ side constants for attention layer indices hparam Branch: GraniteFour * feat: Add support for distinguishing recurrent vs non-recurrent layers in hparams Branch: GraniteFour Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Auto-fill hparams.recurrent_layer_arr based on whether the model is recurrent Branch: GraniteFour Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * refactor: rename *_is_hybrid -> *_is_hybrid_recurrent The implementation of the hybrid cache intentionally does not specify the types of the child caches, so there was a naming mismatch with these predicate functions that used "hybrid" to imply "hybrid recurrent." Branch: HybridCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Add layer filter to recurrent cache Branch: HybridCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Use per-layer sizing everywhere in kv caches Branch: GraniteFour Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: First pass at llama_kv_cache_hybrid_recurrent This follows the pattern in iswa where the two child caches are held explicitly to support the case where a model requires a single attention cache and a single recurrent cache where each layer uses exactly one of the caches. This is a rewrite of the more generic approach in the original hybrid cache PR: ggml-org#13276 Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Construct hybrid recurrent cache for hybrid recurrent models This includes a refactor of the create_memory logic to avoid needing to use the arch enum explicitly unless a model needs explicit cache instantiation logic beyond the standard logic for recurrent, hybrid, unified, and iswa. Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Fix wrong bool condition for split equal in hybrid cache Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Fix shift logic to defer to unified cache Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Support hybrid recurrent in llama-graph NOTE: I intentionally did not add support for s_mask since it will be going away soon Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Fix logic for initializing inputs and attn layers for hybrid caches Branch: GraniteFour Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Update recurrent cache for changes to remove intermediate kv_cache interface Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Fix status for init_update sig for recurrent cache state Branch: GraniteFour Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Add missing padding to n_ctx for hybrid cache construction Branch: GraniteFour Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Update clear signature for data argument after rebase Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Remove errant virtual destructor leftover from previous impl attempt Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Use per-layer n_embd_k/v_s calls for mamba (1) layers Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * refactor: Remove n_embd_k/v_s from unified cache No longer needed now that unified isn't also supporting recurrent ggml-org#13979 (comment) Branch: HybridRecurrentCache * refactor: Remove layer index from n_embd_k/v_s Now that it's not used at all in the unified cache, we don't need to use the layer index to zero it out for attention layers. Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * refactor: Remove n_embd_k/v_gqa from recurrent cache This is no longer needed now that there are separate implementations ggml-org#13979 (comment) Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Allow custom layer filters for hybrid recurrent This should help support architectures like Falcon H1 where there is overlap between layers that need attention and recurrent caches. ggml-org#13979 (comment) Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Remove logits_all after rebase Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Remove llama_model_is_hybrid_Recurrent public API ggml-org#13979 (comment) Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * refactor: Use llama_memory_state_ptr for child states in hybrid memory state Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Overhaul build_recurrent_state / build_inp_s_copy to match attention pattern https://github.com/ggml-org/llama.cpp/pull/13979/files#r2141701738 This is a big overhaul to bring consistency between how inputs and per- layer components are created for attention layers and recurrent layers. The main changes are: - Rename class llm_graph_input_s_copy -> llm_graph_input_rs - Add a corresponding llm_graph_input_rs_hybrid_recurrent - Rename build_inp_s_copy -> build_rs_inp_recurrent - Add a corresponding build_rs_inp_hybrid_recurrent - Rename build_recurrent_state -> build_rs to match build_attn w/ llm_graph_input_rs android-build AUTHORS bamba-9b-2.2T.gguf bamba-9b-2.2T.q4_k_m.gguf broken.log build build-rel build-xcframework.sh build.android build.android.bak ci cmake CMakeLists.txt CMakePresets.json CODEOWNERS common common.o CONTRIBUTING.md convert_hf_to_gguf_update.py convert_hf_to_gguf.py convert_llama_ggml_to_gguf.py convert_lora_to_gguf.py debug.log docs examples flake.lock flake.nix ggml ggml-alloc.o ggml-backend.o ggml-metal.o ggml-model-BF16.gguf ggml-model-Q4_K_M.gguf ggml-quants.o ggml.o gguf-py grammar-parser.o grammars include LICENSE licenses llama.log llama.o llamacpp_trace.log main.log Makefile media models mypy.ini pocs poetry.lock prompts pyproject.toml pyrightconfig.json q4_k_m_boot.log q8_0_boot.log quant.log quant2.log README.md requirements requirements.txt sampling.o scripts SECURITY.md src test-grammar-output.tmp test-json-schema-input.tmp tests tools vendor working.log as the first input - Add a corresponding overload of build_rs w/ llm_graph_input_rs_hybrid_recurrent android-build AUTHORS bamba-9b-2.2T.gguf bamba-9b-2.2T.q4_k_m.gguf broken.log build build-rel build-xcframework.sh build.android build.android.bak ci cmake CMakeLists.txt CMakePresets.json CODEOWNERS common common.o CONTRIBUTING.md convert_hf_to_gguf_update.py convert_hf_to_gguf.py convert_llama_ggml_to_gguf.py convert_lora_to_gguf.py debug.log docs examples flake.lock flake.nix ggml ggml-alloc.o ggml-backend.o ggml-metal.o ggml-model-BF16.gguf ggml-model-Q4_K_M.gguf ggml-quants.o ggml.o gguf-py grammar-parser.o grammars include LICENSE licenses llama.log llama.o llamacpp_trace.log main.log Makefile media models mypy.ini pocs poetry.lock prompts pyproject.toml pyrightconfig.json q4_k_m_boot.log q8_0_boot.log quant.log quant2.log README.md requirements requirements.txt sampling.o scripts SECURITY.md src test-grammar-output.tmp test-json-schema-input.tmp tests tools vendor working.log as the first input - Add a llm_graph_input_attn_kv_hybrid_recurrent analogous to llm_graph_input_attn_kv_unified - Add a build_attn override that takes llm_graph_input_attn_kv_hybrid_recurrent android-build AUTHORS bamba-9b-2.2T.gguf bamba-9b-2.2T.q4_k_m.gguf broken.log build build-rel build-xcframework.sh build.android build.android.bak ci cmake CMakeLists.txt CMakePresets.json CODEOWNERS common common.o CONTRIBUTING.md convert_hf_to_gguf_update.py convert_hf_to_gguf.py convert_llama_ggml_to_gguf.py convert_lora_to_gguf.py debug.log docs examples flake.lock flake.nix ggml ggml-alloc.o ggml-backend.o ggml-metal.o ggml-model-BF16.gguf ggml-model-Q4_K_M.gguf ggml-quants.o ggml.o gguf-py grammar-parser.o grammars include LICENSE licenses llama.log llama.o llamacpp_trace.log main.log Makefile media models mypy.ini pocs poetry.lock prompts pyproject.toml pyrightconfig.json q4_k_m_boot.log q8_0_boot.log quant.log quant2.log README.md requirements requirements.txt sampling.o scripts SECURITY.md src test-grammar-output.tmp test-json-schema-input.tmp tests tools vendor working.log as the first input This makes the two paradigms fully consistent. The main drawback is the code duplication in the build_attn and build_rs implementations where the only difference between implementations is how they cast the memory state. Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Fix resize vs reserve and skip null tensors in size computation https://github.com/ggml-org/llama.cpp/pull/13979/files#r2149469788 Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> Co-Authored-By: @younesbelkada * fix: Fix initialization of child states Since initially writing this PR, the logic in the child state types changed such that using the "init full" signature and keeping the ubatches on the parent struct no longer worked. Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * refactor: Use a common build_recurrent_state method that is cache-agnostic This reduces the code duplication between the different build_rs impls and also retains a similar signature to the previous build_recurrent_state method while standardizing on the input-dispatched build_rs implementation. Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * recurrent : rework graph inputs + add TODOs ggml-ci * refactor: Make status and child states const in hybrid and iswa Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * refactor: Rename llama_kv_cache_[recurrent|hybrid_recurrent] to remove kv cache This removes the notion of "kv" from the interface names for these memory types. There are still many references to kv in the implementation of the recurrent memory which will need further adjustment. Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * refactor!: Rename all k/v related values for recurrent/hybrid to r/s Anywhere that "kv_<state|cell|size|etc>" is used, I've used the more generic "mem_" prefix. The specifics of "k" (key) translate to "r" (recurrent state) and "v" (value) translate to "s" (state-space embedding states). Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * refacor: _recurrent -> _recr for brevity It just _happens_ to have the same number of letters as _attn! Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * style: Fix spacing for ref Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * refactor: recurrent_layer() -> is_recurrent() Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * style: Fix spacing for size_s_bytes declaration Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> --------- Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
1 parent bc06756 commit 6bd7a3d

11 files changed

+313
-292
lines changed

src/llama-graph.cpp

Lines changed: 62 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -234,15 +234,15 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
234234
void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
235235
GGML_UNUSED(ubatch);
236236

237-
const int64_t n_rs = mctx->get_n_rs();
237+
const int64_t n_rs = mem_state->get_n_rs();
238238

239239
if (s_copy) {
240240
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
241241
int32_t * data = (int32_t *) s_copy->data;
242242

243243
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
244-
for (uint32_t i = 0; i < n_kv; ++i) {
245-
data[i] = kv_state->s_copy(i);
244+
for (uint32_t i = 0; i < n_rs; ++i) {
245+
data[i] = mem_state->s_copy(i);
246246
}
247247
}
248248
}
@@ -376,38 +376,33 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
376376
}
377377
}
378378

379-
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
380-
for (int j = 0; j < n_enc; ++j) {
381-
data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
379+
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
380+
for (int j = 0; j < n_enc; ++j) {
381+
data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
382+
}
382383
}
383384
}
384385
}
385386
}
386387

387388
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
388389
if (self_kq_mask) {
389-
mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
390+
mem_state->get_state_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
390391
}
391392

392-
const int64_t n_rs = mctx->get_recr()->get_n_rs();
393+
const int64_t n_rs = mem_state->get_state_recr()->get_n_rs();
393394

394395
if (s_copy) {
395396
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
396397
int32_t * data = (int32_t *) s_copy->data;
397398

398399
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
399400
for (uint32_t i = 0; i < n_rs; ++i) {
400-
data[i] = mctx->get_recr()->s_copy(i);
401+
data[i] = mem_state->get_state_recr()->s_copy(i);
401402
}
402403
}
403404
}
404405

405-
void llm_graph_input_one::set_input(const llama_ubatch *) {
406-
GGML_ASSERT(one && ggml_nelements(one) == 1);
407-
float f_one = 1.0f;
408-
ggml_backend_tensor_set(one, &f_one, 0, sizeof(float));
409-
}
410-
411406
//
412407
// llm_graph_context
413408
//
@@ -983,23 +978,6 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
983978
return cur;
984979
}
985980

986-
ggml_tensor * llm_graph_context::build_inp_s_copy() const {
987-
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
988-
989-
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
990-
991-
const auto n_kv = kv_state->get_n_kv();
992-
993-
auto & cur = inp->s_copy;
994-
995-
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
996-
ggml_set_input(cur);
997-
998-
res->add_input(std::move(inp));
999-
1000-
return cur;
1001-
}
1002-
1003981
ggml_tensor * llm_graph_context::build_inp_cross_embd() const {
1004982
auto inp = std::make_unique<llm_graph_input_cross_embd>(cross);
1005983

@@ -1070,14 +1048,14 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
10701048
}
10711049

10721050
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
1073-
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
1051+
const auto * mem_state = static_cast<const llama_memory_hybrid_state *>(mstate);
10741052

1075-
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mctx_cur);
1053+
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mem_state);
10761054

10771055
{
10781056
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
10791057

1080-
const auto n_kv = inp->mctx->get_attn()->get_n_kv();
1058+
const auto n_kv = inp->mem_state->get_state_attn()->get_n_kv();
10811059

10821060
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
10831061
//cb(inp->self_kq_mask, "KQ_mask", -1);
@@ -1087,7 +1065,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
10871065
}
10881066

10891067
{
1090-
const auto n_rs = mctx_cur->get_recr()->get_n_rs();
1068+
const auto n_rs = mem_state->get_state_recr()->get_n_rs();
10911069

10921070
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
10931071
ggml_set_input(inp->s_copy);
@@ -1468,7 +1446,7 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
14681446
const auto kv_head = kv_state->get_head();
14691447
const auto rs_zero = kv_state->get_rs_z();
14701448

1471-
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_state->get_size());
1449+
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
14721450

14731451
// Clear a single state which will then be copied to the other cleared states.
14741452
// Note that this is a no-op when the view is zero-sized.
@@ -1496,22 +1474,59 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
14961474
return output_states;
14971475
}
14981476

1477+
llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
1478+
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
1479+
1480+
auto inp = std::make_unique<llm_graph_input_rs>(kv_state);
1481+
1482+
const auto n_rs = kv_state->get_n_rs();
1483+
1484+
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
1485+
ggml_set_input(inp->s_copy);
1486+
1487+
return (llm_graph_input_rs *) res->add_input(std::move(inp));
1488+
}
1489+
1490+
ggml_tensor * llm_graph_context::build_rs(
1491+
llm_graph_input_rs * inp,
1492+
ggml_cgraph * gf,
1493+
ggml_tensor * s,
1494+
int32_t state_size,
1495+
int32_t n_seqs,
1496+
bool avoid_copies) const {
1497+
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
1498+
1499+
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
1500+
}
1501+
1502+
ggml_tensor * llm_graph_context::build_rs(
1503+
llm_graph_input_mem_hybrid * inp,
1504+
ggml_cgraph * gf,
1505+
ggml_tensor * s,
1506+
int32_t state_size,
1507+
int32_t n_seqs,
1508+
bool avoid_copies) const {
1509+
const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_recr();
1510+
1511+
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
1512+
}
1513+
14991514
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
1500-
ggml_cgraph * gf,
1501-
ggml_tensor * state_copy,
1502-
const llama_ubatch & ubatch,
1515+
llm_graph_input_rs * inp,
1516+
ggml_cgraph * gf,
1517+
const llama_ubatch & ubatch,
15031518
int il) const {
1504-
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1519+
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
15051520

15061521
const auto token_shift_count = hparams.token_shift_count;
15071522

15081523
const int64_t n_seqs = ubatch.n_seqs;
15091524

1510-
ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
1525+
ggml_tensor * token_shift_all = kv_state->get_r_l(il);
15111526

1512-
ggml_tensor * token_shift = build_recurrent_state(
1513-
gf, token_shift_all, state_copy,
1514-
hparams.n_embd_k_s(), n_seqs);
1527+
ggml_tensor * token_shift = build_rs(
1528+
inp, gf, token_shift_all,
1529+
hparams.n_embd_r(), n_seqs);
15151530

15161531
token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
15171532

@@ -1522,7 +1537,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
15221537
ggml_tensor * token_shift,
15231538
const llama_ubatch & ubatch,
15241539
int il) const {
1525-
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1540+
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
15261541

15271542
const auto token_shift_count = hparams.token_shift_count;
15281543
const auto n_embd = hparams.n_embd;
@@ -1534,7 +1549,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
15341549
return ggml_cpy(
15351550
ctx0,
15361551
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
1537-
ggml_view_1d(ctx0, mctx_cur->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(mctx_cur->get_r_l(il)))
1552+
ggml_view_1d(ctx0, kv_state->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(kv_state->get_r_l(il)))
15381553
);
15391554
}
15401555

src/llama-graph.h

Lines changed: 48 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ struct llama_cparams;
1919

2020
struct llama_memory_context_i;
2121

22-
class llama_kv_cache_unified_context;
23-
class llama_kv_cache_unified_iswa_context;
24-
class llama_memory_recurrent_context;
25-
class llama_memory_hybrid_context;
22+
class llama_kv_cache_unified_state;
23+
class llama_kv_cache_unified_iswa_state;
24+
class llama_memory_recurrent_state;
25+
class llama_memory_hybrid_state;
2626

2727
// certain models (typically multi-modal) can produce different types of graphs
2828
enum llm_graph_type {
@@ -190,16 +190,17 @@ class llm_graph_input_cls : public llm_graph_input_i {
190190
const llama_cparams & cparams;
191191
};
192192

193+
class llm_graph_input_rs : public llm_graph_input_i {
193194
class llm_graph_input_rs : public llm_graph_input_i {
194195
public:
195-
llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {}
196+
llm_graph_input_rs(const llama_memory_recurrent_state * mem_state) : mem_state(mem_state) {}
196197
virtual ~llm_graph_input_rs() = default;
197198

198199
void set_input(const llama_ubatch * ubatch) override;
199200

200201
ggml_tensor * s_copy; // I32 [kv_size]
201202

202-
const llama_kv_cache_recurrent_state * kv_state;
203+
const llama_memory_recurrent_state * mem_state;
203204
};
204205

205206
class llm_graph_input_cross_embd : public llm_graph_input_i {
@@ -307,10 +308,10 @@ class llm_graph_input_mem_hybrid : public llm_graph_input_i {
307308
llm_graph_input_mem_hybrid(
308309
const llama_hparams & hparams,
309310
const llama_cparams & cparams,
310-
const llama_memory_hybrid_context * mctx) :
311+
const llama_memory_hybrid_state * mem_state) :
311312
hparams(hparams),
312313
cparams(cparams),
313-
mctx(mctx) {
314+
mem_state(mem_state) {
314315
}
315316
virtual ~llm_graph_input_mem_hybrid() = default;
316317

@@ -326,18 +327,7 @@ class llm_graph_input_mem_hybrid : public llm_graph_input_i {
326327
const llama_hparams & hparams;
327328
const llama_cparams & cparams;
328329

329-
const llama_memory_hybrid_context * mctx;
330-
};
331-
332-
// TODO: remove this when ggml_scale_add is implemented
333-
class llm_graph_input_one : public llm_graph_input_i {
334-
public:
335-
llm_graph_input_one() {}
336-
virtual ~llm_graph_input_one() = default;
337-
338-
void set_input(const llama_ubatch *) override;
339-
340-
ggml_tensor * one = nullptr; // F32
330+
const llama_memory_hybrid_state * mem_state;
341331
};
342332

343333
//
@@ -547,7 +537,6 @@ struct llm_graph_context {
547537
ggml_tensor * build_inp_out_ids() const;
548538
ggml_tensor * build_inp_mean() const;
549539
ggml_tensor * build_inp_cls() const;
550-
ggml_tensor * build_inp_s_copy() const;
551540

552541
ggml_tensor * build_inp_cross_embd() const;
553542
ggml_tensor * build_inp_pos_bucket_enc() const;
@@ -647,18 +636,46 @@ struct llm_graph_context {
647636
// recurrent
648637
//
649638

650-
ggml_tensor * build_recurrent_state(
651-
ggml_cgraph * gf,
652-
ggml_tensor * s,
653-
ggml_tensor * state_copy,
654-
int32_t state_size,
655-
int32_t n_seqs,
656-
bool avoid_copies = false) const;
639+
// TODO: avoid notion of "kv"
640+
// TODO: move this implementation to llama_memory_recurrent.
641+
// this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
642+
// when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
643+
// implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
644+
// `llama_memory_recurrent`
645+
ggml_tensor * build_rs(
646+
ggml_cgraph * gf,
647+
ggml_tensor * s,
648+
ggml_tensor * state_copy,
649+
int32_t state_size,
650+
int32_t n_seqs,
651+
uint32_t n_kv,
652+
uint32_t kv_head,
653+
uint32_t kv_size,
654+
int32_t rs_zero,
655+
bool avoid_copies = false) const;
656+
657+
llm_graph_input_rs * build_rs_inp() const;
658+
659+
ggml_tensor * build_rs(
660+
llm_graph_input_rs * inp,
661+
ggml_cgraph * gf,
662+
ggml_tensor * s,
663+
int32_t state_size,
664+
int32_t n_seqs,
665+
bool avoid_copies = false) const;
666+
667+
ggml_tensor * build_rs(
668+
llm_graph_input_mem_hybrid * inp,
669+
ggml_cgraph * gf,
670+
ggml_tensor * s,
671+
int32_t state_size,
672+
int32_t n_seqs,
673+
bool avoid_copies = false) const;
657674

658675
ggml_tensor * build_rwkv_token_shift_load(
659-
ggml_cgraph * gf,
660-
ggml_tensor * state_copy,
661-
const llama_ubatch & ubatch,
676+
llm_graph_input_rs * inp,
677+
ggml_cgraph * gf,
678+
const llama_ubatch & ubatch,
662679
int il) const;
663680

664681
ggml_tensor * build_rwkv_token_shift_store(

src/llama-hparams.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
6565
return n_embd_head_v * n_head_kv;
6666
}
6767

68+
uint32_t llama_hparams::n_embd_r() const {
6869
uint32_t llama_hparams::n_embd_r() const {
6970
if (wkv_head_size != 0) {
7071
// for RWKV models
@@ -90,10 +91,6 @@ bool llama_hparams::is_recurrent(uint32_t il) const {
9091
return recurrent_layer_arr[il];
9192
}
9293

93-
uint32_t llama_hparams::n_pos_per_embd() const {
94-
return rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1;
95-
}
96-
9794
bool llama_hparams::is_swa(uint32_t il) const {
9895
if (il < n_layer) {
9996
return swa_layers[il];

src/llama-hparams.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,9 @@ struct llama_hparams {
118118
// for hybrid state space models
119119
std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;
120120

121+
// for hybrid state space models
122+
std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;
123+
121124
bool ssm_dt_b_c_rms = false;
122125

123126
float f_clamp_kqv = 0.0f;
@@ -198,8 +201,6 @@ struct llama_hparams {
198201
// whether or not the given layer is recurrent (for hybrid models)
199202
bool is_recurrent(uint32_t il) const;
200203

201-
uint32_t n_pos_per_embd() const;
202-
203204
bool is_swa(uint32_t il) const;
204205
};
205206

0 commit comments

Comments
 (0)