Skip to content

Commit ebd34d0

Browse files
committed
feat: Add layer filter to recurrent cache
Branch: HybridCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent cf47b52 commit ebd34d0

File tree

3 files changed

+20
-12
lines changed

3 files changed

+20
-12
lines changed

src/llama-kv-cache.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2153,12 +2153,13 @@ class llama_kv_cache_recurrent_state_t : public llama_kv_cache_recurrent_state_i
21532153
};
21542154

21552155
llama_kv_cache_recurrent::llama_kv_cache_recurrent(
2156-
const llama_model & model,
2157-
ggml_type type_k,
2158-
ggml_type type_v,
2159-
bool offload,
2160-
uint32_t kv_size,
2161-
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
2156+
const llama_model & model,
2157+
layer_filter_cb && filter,
2158+
ggml_type type_k,
2159+
ggml_type type_v,
2160+
bool offload,
2161+
uint32_t kv_size,
2162+
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
21622163
const int32_t n_layer = hparams.n_layer;
21632164

21642165
LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n",
@@ -2200,6 +2201,11 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
22002201
v_l.reserve(n_layer);
22012202

22022203
for (int i = 0; i < n_layer; i++) {
2204+
if (filter && !filter(i)) {
2205+
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, i);
2206+
continue;
2207+
}
2208+
22032209
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
22042210
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
22052211

src/llama-kv-cache.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -370,12 +370,13 @@ class llama_kv_cache_unified_iswa_state_i : public llama_memory_state_i {
370370
class llama_kv_cache_recurrent : public llama_kv_cache {
371371
public:
372372
llama_kv_cache_recurrent(
373-
const llama_model & model,
374-
ggml_type type_k,
375-
ggml_type type_v,
376-
bool offload,
377-
uint32_t kv_size,
378-
uint32_t n_seq_max);
373+
const llama_model & model,
374+
layer_filter_cb && filter,
375+
ggml_type type_k,
376+
ggml_type type_v,
377+
bool offload,
378+
uint32_t kv_size,
379+
uint32_t n_seq_max);
379380

380381
~llama_kv_cache_recurrent() = default;
381382

src/llama-model.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13212,6 +13212,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1321213212
{
1321313213
res = new llama_kv_cache_recurrent(
1321413214
*this,
13215+
nullptr,
1321513216
GGML_TYPE_F32,
1321613217
GGML_TYPE_F32,
1321713218
cparams.offload_kqv,

0 commit comments

Comments
 (0)