Skip to content

Commit 5c84488

Browse files
committed
context : simplify kv cache updates
ggml-ci
1 parent 623954b commit 5c84488

File tree

2 files changed

+9
-18
lines changed

2 files changed

+9
-18
lines changed

src/llama-context.cpp

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1254,6 +1254,9 @@ int llama_context::decode(llama_batch & inp_batch) {
12541254
return -2;
12551255
};
12561256

1257+
// handle any pending defrags/shifts
1258+
kv_self_update();
1259+
12571260
int64_t n_outputs_prev = 0;
12581261

12591262
while (sbatch.n_tokens > 0) {
@@ -1293,14 +1296,6 @@ int llama_context::decode(llama_batch & inp_batch) {
12931296

12941297
// find KV slot
12951298
{
1296-
kv_self_update();
1297-
1298-
// if we have enough unused cells before the current head ->
1299-
// better to start searching from the beginning of the cache, hoping to fill it
1300-
if (kv_self->head > kv_self->used + 2*ubatch.n_tokens) {
1301-
kv_self->head = 0;
1302-
}
1303-
13041299
if (!kv_self->find_slot(ubatch)) {
13051300
LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__);
13061301
return -3;
@@ -1342,16 +1337,6 @@ int llama_context::decode(llama_batch & inp_batch) {
13421337
}
13431338
}
13441339

1345-
// update the kv ring buffer
1346-
{
1347-
kv_self->head += ubatch.n_tokens;
1348-
1349-
// Ensure kv cache head points to a valid index.
1350-
if (kv_self->head >= kv_self->size) {
1351-
kv_self->head = 0;
1352-
}
1353-
}
1354-
13551340
// plot the computation graph in dot format (for debugging purposes)
13561341
//if (n_past%100 == 0) {
13571342
// ggml_graph_dump_dot(gf, NULL, "llama.dot");

src/llama-kv-cache.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,12 @@ bool llama_kv_cache_unified::find_slot(
492492
const uint32_t n_seqs = ubatch.n_seqs;
493493
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
494494

495+
// if we have enough unused cells before the current head ->
496+
// better to start searching from the beginning of the cache, hoping to fill it
497+
if (head > used + 2*ubatch.n_tokens) {
498+
head = 0;
499+
}
500+
495501
if (recurrent) {
496502
// For recurrent state architectures (like Mamba or RWKV),
497503
// each cache cell can store the state for a whole sequence.

0 commit comments

Comments
 (0)