Skip to content

Commit 8edfe63

Browse files
committed
feat: Add layer filter to recurrent cache
Branch: HybridCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 6986249 commit 8edfe63

File tree

3 files changed

+26
-14
lines changed

3 files changed

+26
-14
lines changed

src/llama-kv-cache-recurrent.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@
1616
//
1717

1818
llama_kv_cache_recurrent::llama_kv_cache_recurrent(
19-
const llama_model & model,
20-
ggml_type type_k,
21-
ggml_type type_v,
22-
bool offload,
23-
uint32_t kv_size,
24-
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
19+
const llama_model & model,
20+
layer_filter_cb && filter,
21+
ggml_type type_k,
22+
ggml_type type_v,
23+
bool offload,
24+
uint32_t kv_size,
25+
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
2526
const int32_t n_layer = hparams.n_layer;
2627

2728
LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n",
@@ -63,6 +64,11 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
6364
v_l.reserve(n_layer);
6465

6566
for (int i = 0; i < n_layer; i++) {
67+
if (filter && !filter(i)) {
68+
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, i);
69+
continue;
70+
}
71+
6672
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
6773
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
6874

@@ -88,8 +94,8 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
8894
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
8995
ggml_format_name(k, "cache_k_l%d", i);
9096
ggml_format_name(v, "cache_v_l%d", i);
91-
k_l.push_back(k);
92-
v_l.push_back(v);
97+
k_l[i] = k;
98+
v_l[i] = v;
9399
}
94100

95101
// allocate tensors and initialize the buffers to avoid NaNs in the padding

src/llama-kv-cache-recurrent.h

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,18 @@
1515
// see the implementation of llama_kv_cache_unified_state_i for an example how to do it
1616
class llama_kv_cache_recurrent : public llama_memory_i {
1717
public:
18+
19+
// this callback is used to filter out layers that should not be included in the cache
20+
using layer_filter_cb = std::function<bool(int32_t il)>;
21+
1822
llama_kv_cache_recurrent(
19-
const llama_model & model,
20-
ggml_type type_k,
21-
ggml_type type_v,
22-
bool offload,
23-
uint32_t kv_size,
24-
uint32_t n_seq_max);
23+
const llama_model & model,
24+
layer_filter_cb && filter,
25+
ggml_type type_k,
26+
ggml_type type_v,
27+
bool offload,
28+
uint32_t kv_size,
29+
uint32_t n_seq_max);
2530

2631
~llama_kv_cache_recurrent() = default;
2732

src/llama-model.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13213,6 +13213,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1321313213
{
1321413214
res = new llama_kv_cache_recurrent(
1321513215
*this,
13216+
nullptr,
1321613217
GGML_TYPE_F32,
1321713218
GGML_TYPE_F32,
1321813219
cparams.offload_kqv,

0 commit comments

Comments
 (0)