Skip to content

Commit d78ca19

Browse files
committed
feat: Add layer filter to recurrent cache
Branch: HybridCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 59a14d4 commit d78ca19

File tree

4 files changed

+19
-10
lines changed

4 files changed

+19
-10
lines changed

src/llama-kv-cache.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1866,11 +1866,12 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_kv_swa() const {
18661866
//
18671867

18681868
llama_kv_cache_recurrent::llama_kv_cache_recurrent(
1869-
const llama_model & model,
1870-
ggml_type type_k,
1871-
ggml_type type_v,
1872-
bool offload,
1873-
uint32_t kv_size) : hparams(model.hparams) {
1869+
const llama_model & model,
1870+
layer_filter_cb && filter,
1871+
ggml_type type_k,
1872+
ggml_type type_v,
1873+
bool offload,
1874+
uint32_t kv_size) : hparams(model.hparams) {
18741875
const int32_t n_layer = hparams.n_layer;
18751876

18761877
LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d\n",
@@ -1915,6 +1916,11 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
19151916
v_l.reserve(n_layer);
19161917

19171918
for (int i = 0; i < n_layer; i++) {
1919+
if (filter && !filter(i)) {
1920+
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, i);
1921+
continue;
1922+
}
1923+
19181924
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
19191925
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
19201926

src/llama-kv-cache.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -440,11 +440,12 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
440440
};
441441

442442
llama_kv_cache_recurrent(
443-
const llama_model & model,
444-
ggml_type type_k,
445-
ggml_type type_v,
446-
bool offload,
447-
uint32_t kv_size);
443+
const llama_model & model,
444+
layer_filter_cb && filter,
445+
ggml_type type_k,
446+
ggml_type type_v,
447+
bool offload,
448+
uint32_t kv_size);
448449

449450
~llama_kv_cache_recurrent() = default;
450451

src/llama-model.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13206,6 +13206,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1320613206
{
1320713207
res = new llama_kv_cache_recurrent(
1320813208
*this,
13209+
nullptr,
1320913210
GGML_TYPE_F32,
1321013211
GGML_TYPE_F32,
1321113212
cparams.offload_kqv,

tests/test-memory.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ static void test_llama_kv_cache_recurrent_constructor() {
156156
auto model = _make_model(LLM_ARCH_MAMBA);
157157
llama_kv_cache_recurrent cache(
158158
/* model */ *model,
159+
/* filter */ nullptr,
159160
/* type_k */ GGML_TYPE_F32,
160161
/* type_v */ GGML_TYPE_F16,
161162
/* offload */ false,

0 commit comments

Comments
 (0)