Skip to content

Commit db34e6b

Browse files
committed
Added a shift kv cache function to llama as well, this allows for manually purging of old tokens from kv cache while keeping ones needed.
1 parent d428926 commit db34e6b

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

llama.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2418,6 +2418,35 @@ int llama_get_kv_cache_token_count(struct llama_context * ctx) {
24182418
return ctx->model.kv_self.n;
24192419
}
24202420

2421+
// Assumes contiguous data
2422+
void llama_shift_kv_cache(struct llama_context * ctx, int n) {
2423+
auto & model = ctx->model;
2424+
auto & kv_self = model.kv_self;
2425+
auto & hparams = model.hparams;
2426+
auto n_layer = hparams.n_layer;
2427+
auto n_embd = hparams.n_embd;
2428+
auto n_ctx = hparams.n_ctx;
2429+
for(int il = 0; il < n_layer; il++) {
2430+
// K: Embeddings are in regular order so moving them is easy as copying the memory
2431+
{
2432+
int elem_byte_size = ggml_element_size(kv_self.k);
2433+
uint8_t * dst_ptr = ((uint8_t *)kv_self.k->data) + (elem_byte_size * n_embd * (il * n_ctx));
2434+
uint8_t * src_ptr = ((uint8_t *)kv_self.k->data) + (elem_byte_size * n_embd * (il * n_ctx + n));
2435+
memcpy(dst_ptr, src_ptr, elem_byte_size * n_embd * (n_ctx - n));
2436+
}
2437+
2438+
// V: Embeddings are transposed so each embedding element must be copied separately
2439+
{
2440+
int elem_byte_size = ggml_element_size(kv_self.v);
2441+
for(int i = 0; i < n_embd; i++) {
2442+
uint8_t * dst_ptr = ((uint8_t *)kv_self.v->data) + (elem_byte_size * (il * n_ctx * i));
2443+
uint8_t * src_ptr = ((uint8_t *)kv_self.v->data) + (elem_byte_size * (il * n_ctx * i + n));
2444+
memcpy(dst_ptr, src_ptr, elem_byte_size * (n_ctx - n));
2445+
}
2446+
}
2447+
}
2448+
}
2449+
24212450
#define LLAMA_MAX_RNG_STATE 64*1024
24222451

24232452
void llama_set_rng_seed(struct llama_context * ctx, int seed) {

llama.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,9 @@ extern "C" {
126126
// Returns the number of tokens in the KV cache
127127
LLAMA_API int llama_get_kv_cache_token_count(struct llama_context * ctx);
128128

129+
// Shifts the KV cache effectively removing the first n tokens
130+
LLAMA_API void llama_shift_kv_cache(struct llama_context * ctx, int n);
131+
129132
// Sets the current rng seed.
130133
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, int seed);
131134

0 commit comments

Comments
 (0)