Skip to content

Commit a886cc1

Browse files
committed
feat: Add layer filter to recurrent cache
Branch: HybridCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 162639c commit a886cc1

File tree

3 files changed

+20
-12
lines changed

3 files changed

+20
-12
lines changed

src/llama-kv-cache.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1961,12 +1961,13 @@ const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state::get_swa(
19611961
//
19621962

19631963
llama_kv_cache_recurrent::llama_kv_cache_recurrent(
1964-
const llama_model & model,
1965-
ggml_type type_k,
1966-
ggml_type type_v,
1967-
bool offload,
1968-
uint32_t kv_size,
1969-
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
1964+
const llama_model & model,
1965+
layer_filter_cb && filter,
1966+
ggml_type type_k,
1967+
ggml_type type_v,
1968+
bool offload,
1969+
uint32_t kv_size,
1970+
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
19701971
const int32_t n_layer = hparams.n_layer;
19711972

19721973
LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n",
@@ -2008,6 +2009,11 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
20082009
v_l.reserve(n_layer);
20092010

20102011
for (int i = 0; i < n_layer; i++) {
2012+
if (filter && !filter(i)) {
2013+
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, i);
2014+
continue;
2015+
}
2016+
20112017
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
20122018
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
20132019

src/llama-kv-cache.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -464,12 +464,13 @@ class llama_kv_cache_unified_iswa_state : public llama_memory_state_i {
464464
class llama_kv_cache_recurrent : public llama_kv_cache {
465465
public:
466466
llama_kv_cache_recurrent(
467-
const llama_model & model,
468-
ggml_type type_k,
469-
ggml_type type_v,
470-
bool offload,
471-
uint32_t kv_size,
472-
uint32_t n_seq_max);
467+
const llama_model & model,
468+
layer_filter_cb && filter,
469+
ggml_type type_k,
470+
ggml_type type_v,
471+
bool offload,
472+
uint32_t kv_size,
473+
uint32_t n_seq_max);
473474

474475
~llama_kv_cache_recurrent() = default;
475476

src/llama-model.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13206,6 +13206,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1320613206
{
1320713207
res = new llama_kv_cache_recurrent(
1320813208
*this,
13209+
nullptr,
1320913210
GGML_TYPE_F32,
1321013211
GGML_TYPE_F32,
1321113212
cparams.offload_kqv,

0 commit comments

Comments
 (0)