From f1d179ec12a25ae0e7ffd75a1bc712d44394cf52 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 1 Apr 2025 20:09:35 +0300 Subject: [PATCH 1/8] llama : refactor kv cache guard ggml-ci --- src/llama-context.cpp | 35 ++------------ src/llama-kv-cache.cpp | 31 +++++++++---- src/llama-kv-cache.h | 103 ++++++++++++++--------------------------- 3 files changed, 61 insertions(+), 108 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 3479a8cca3d64..88d0d2ea8a465 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1201,33 +1201,7 @@ int llama_context::decode(llama_batch & inp_batch) { const int64_t n_tokens_all = batch.n_tokens; const int64_t n_embd = hparams.n_embd; - // TODO: remove this stuff - class batch_guard { - public: - batch_guard(llama_kv_cache_unified & kv_self) : kv_slot_restorer(kv_self) { - } - - ~batch_guard() { - if (!is_done) { - kv_slot_restorer.restore(); - } - } - - void done() { - is_done = true; - } - - void save(const llama_kv_cache_slot_info & slot_info) { - kv_slot_restorer.save(slot_info); - } - - private: - bool is_done = false; - - llama_kv_slot_restorer kv_slot_restorer; - }; - - batch_guard bg(*kv_self); + llama_kv_cache_guard kvg(kv_self.get()); GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT @@ -1327,14 +1301,11 @@ int llama_context::decode(llama_batch & inp_batch) { kv_self->head = 0; } - const auto slot_info = kv_self->find_slot(ubatch); - if (!slot_info) { + if (!kv_self->find_slot(ubatch)) { LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__); return -3; } - bg.save(slot_info); - if (!kv_self->recurrent) { // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears @@ -1467,7 +1438,7 @@ int llama_context::decode(llama_batch & inp_batch) { } // finalize the batch processing - bg.done(); + kvg.commit(); // set output mappings { diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 14c8933b4d6c4..f0851f5c3aa45 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -11,8 +11,6 @@ #include #include -static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false}; - llama_kv_cache_unified::llama_kv_cache_unified(const llama_hparams & hparams, callbacks cbs) : hparams(hparams), cbs(std::move(cbs)) { } @@ -446,11 +444,25 @@ void llama_kv_cache_unified::defrag() { } } +void llama_kv_cache_unified::restore() { + if (pending.ranges.empty()) { + return; + } + + for (auto & range : pending.ranges) { + seq_rm(-1, range.p0, range.p1); + } +} + +void llama_kv_cache_unified::commit() { + pending.ranges.clear(); +} + bool llama_kv_cache_unified::get_can_shift() const { return can_shift; } -llama_kv_cache_slot_info llama_kv_cache_unified::find_slot( +bool llama_kv_cache_unified::find_slot( const llama_ubatch & ubatch) { const uint32_t n_tokens = ubatch.n_tokens; const uint32_t n_seqs = ubatch.n_seqs; @@ -477,7 +489,7 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot( // too big seq_id // TODO: would it be possible to resize the cache instead? LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size); - return llama_kv_cache_slot_info_failed; + return false; } if (j > 0) { llama_kv_cell & seq = cells[seq_id]; @@ -616,14 +628,14 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot( [](const llama_kv_cell& cell){ return !cell.is_empty(); }); // sanity check - return llama_kv_cache_slot_info(n >= n_seqs); + return n >= n_seqs; } // otherwise, one cell per token. if (n_tokens > size) { LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %d\n", __func__, n_tokens, size); - return llama_kv_cache_slot_info_failed; + return false; } uint32_t n_tested = 0; @@ -651,7 +663,7 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot( if (n_tested >= size) { //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); - return llama_kv_cache_slot_info_failed; + return false; } } @@ -668,7 +680,9 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot( used += n_tokens; - return llama_kv_cache_slot_info(head, head + n_tokens); + pending.ranges.push_back({head, head + n_tokens}); + + return true; } uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) const { @@ -1033,6 +1047,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); return false; } + commit(); // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values) // Assume that this is one contiguous block of cells diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 0a7ff8a4ea3e6..5728d32a90414 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -17,6 +17,9 @@ struct llama_ubatch; struct llama_kv_cache : public llama_memory_i { using llama_memory_i::llama_memory_i; + virtual void restore() = 0; // call if batch processing fails to restore the cache state + virtual void commit() = 0; // call after successful batch processing + virtual int32_t get_n_tokens() const = 0; virtual uint32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache @@ -25,9 +28,24 @@ struct llama_kv_cache : public llama_memory_i { bool get_can_edit() const override { return get_can_shift(); } }; +struct llama_kv_cache_guard { + llama_kv_cache_guard(llama_kv_cache * kv) : kv(kv) {} + + ~llama_kv_cache_guard() { + kv->restore(); + } + + void commit() { + kv->commit(); + } + +private: + llama_kv_cache * kv; +}; + struct llama_kv_cell { llama_pos pos = -1; - llama_pos delta = 0; + llama_pos delta = 0; int32_t src = -1; // used by recurrent state models to copy states int32_t tail = -1; @@ -46,17 +64,6 @@ struct llama_kv_cell { } }; -// a structure holds information about the slot found in llama_kv_cache_find_slot -struct llama_kv_cache_slot_info { - std::pair boundaries; // slot boundaries [begin, end) - bool found = false; // the slot was found - - explicit llama_kv_cache_slot_info(bool found_) : found{found_} {} - llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {} - - operator bool() const { return found; } -}; - // ring-buffer of cached KV data // TODO: pimpl // TODO: add notion of max sequences @@ -93,6 +100,9 @@ class llama_kv_cache_unified : public llama_kv_cache { void clear() override; void defrag() override; + virtual void restore() override; + virtual void commit() override; + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; void seq_keep(llama_seq_id seq_id) override; @@ -105,10 +115,9 @@ class llama_kv_cache_unified : public llama_kv_cache { // find an empty slot of size "n_tokens" in the cache // updates the cache head - // returns a structure holding information about the slot found // Note: On success, it's important that cache.head points // to the first cell of the slot. - llama_kv_cache_slot_info find_slot(const llama_ubatch & batch); + bool find_slot(const llama_ubatch & batch); // TODO: maybe not needed uint32_t get_padding(const llama_cparams & cparams) const; @@ -128,7 +137,18 @@ class llama_kv_cache_unified : public llama_kv_cache { // return true if cells have been moved bool defrag_prepare(int32_t n_max_nodes); - // state save/load + // commit/restore cache + + struct slot_range { + uint32_t p0 = 0; + uint32_t p1 = 0; + }; + + struct { + std::vector ranges; + } pending; + + // state write/load void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const; void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1); @@ -183,59 +203,6 @@ class llama_kv_cache_unified : public llama_kv_cache { // using llama_kv_cache_unified::llama_kv_cache_unified; //}; -// -// kv cache restore -// - -// saves the kv_cache state for future recovery. -// used to rollback llama_kv_cache_find_slot changes. -struct llama_kv_slot_restorer { - struct llama_kv_cache_state { - uint32_t head = 0; - uint32_t n = 0; - } old_state; - - // for non-recurrent models only - // list of slots to restore - std::vector> slot_boundaries; - - bool do_restore = false; - - llama_kv_cache_unified & cache; - - explicit llama_kv_slot_restorer(llama_kv_cache_unified & cache) : cache(cache) { - old_state.head = cache.head; - old_state.n = cache.n; - } - - // saves a slot information for future restoration - void save(const llama_kv_cache_slot_info & slot) { - if (slot) { - do_restore = true; - if (slot.boundaries.first != slot.boundaries.second) { - slot_boundaries.push_back(slot.boundaries); - } - } - } - - // must be explicitly called to restore the kv_cache state - // and rollback changes from all llama_kv_cache_find_slot calls - void restore() { - if (do_restore) { - cache.head = old_state.head; - cache.n = old_state.n; - - if (cache.recurrent) { // recurrent models like Mamba or RWKV can't have a state partially erased - cache.seq_rm(-1, -1, -1); - } else { - for (auto & slot : slot_boundaries) { - cache.seq_rm(-1, slot.first, slot.second); - } - } - } - } -}; - // TODO: maybe become part of the public llama_kv_cache in the future int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv); From 4fdd6e514e3afd227982e5f2ea4247072ce4a6f8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 1 Apr 2025 20:12:05 +0300 Subject: [PATCH 2/8] cont : fix comment [no ci] --- src/llama-kv-cache.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 5728d32a90414..d670b32238d4a 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -17,8 +17,8 @@ struct llama_ubatch; struct llama_kv_cache : public llama_memory_i { using llama_memory_i::llama_memory_i; - virtual void restore() = 0; // call if batch processing fails to restore the cache state - virtual void commit() = 0; // call after successful batch processing + virtual void restore() = 0; // call if batch processing fails - restores the cache state + virtual void commit() = 0; // call after successful batch processing - clears any pending state virtual int32_t get_n_tokens() const = 0; virtual uint32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache From 623954b580fc342a108844d8d82ab54466312b65 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 2 Apr 2025 11:55:56 +0300 Subject: [PATCH 3/8] llama : fix kv_cache restore logic ggml-ci --- src/llama-kv-cache.cpp | 26 +++++++++++++++++++++++++- src/llama-kv-cache.h | 4 ++-- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index f0851f5c3aa45..4cc9d79fa8a15 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -449,8 +449,32 @@ void llama_kv_cache_unified::restore() { return; } + // TODO: tmp - move to llama_kv_cache_recurrent + if (recurrent) { + seq_rm(-1, -1, -1); + return; + } + + uint32_t new_head = size; + for (auto & range : pending.ranges) { - seq_rm(-1, range.p0, range.p1); + for (uint32_t i = range.c0; i < range.c1; ++i) { + cells[i].seq_id.clear(); + + // keep count of the number of used cells + if (cells[i].pos >= 0) { + used--; + } + + cells[i].pos = -1; + cells[i].src = -1; + } + + new_head = std::min(new_head, range.c0); + } + + if (new_head != size && new_head < head) { + head = new_head; } } diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index d670b32238d4a..d69885de8a69b 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -140,8 +140,8 @@ class llama_kv_cache_unified : public llama_kv_cache { // commit/restore cache struct slot_range { - uint32_t p0 = 0; - uint32_t p1 = 0; + uint32_t c0 = 0; // note: these are cell indices, not sequence positions + uint32_t c1 = 0; }; struct { From 5c8448874e6dac83099cfb1e5c295e958f1a7cc1 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 2 Apr 2025 12:55:59 +0300 Subject: [PATCH 4/8] context : simplify kv cache updates ggml-ci --- src/llama-context.cpp | 21 +++------------------ src/llama-kv-cache.cpp | 6 ++++++ 2 files changed, 9 insertions(+), 18 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 88d0d2ea8a465..aeab74077ffb5 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1254,6 +1254,9 @@ int llama_context::decode(llama_batch & inp_batch) { return -2; }; + // handle any pending defrags/shifts + kv_self_update(); + int64_t n_outputs_prev = 0; while (sbatch.n_tokens > 0) { @@ -1293,14 +1296,6 @@ int llama_context::decode(llama_batch & inp_batch) { // find KV slot { - kv_self_update(); - - // if we have enough unused cells before the current head -> - // better to start searching from the beginning of the cache, hoping to fill it - if (kv_self->head > kv_self->used + 2*ubatch.n_tokens) { - kv_self->head = 0; - } - if (!kv_self->find_slot(ubatch)) { LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__); return -3; @@ -1342,16 +1337,6 @@ int llama_context::decode(llama_batch & inp_batch) { } } - // update the kv ring buffer - { - kv_self->head += ubatch.n_tokens; - - // Ensure kv cache head points to a valid index. - if (kv_self->head >= kv_self->size) { - kv_self->head = 0; - } - } - // plot the computation graph in dot format (for debugging purposes) //if (n_past%100 == 0) { // ggml_graph_dump_dot(gf, NULL, "llama.dot"); diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 4cc9d79fa8a15..1b9db4bae3c67 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -492,6 +492,12 @@ bool llama_kv_cache_unified::find_slot( const uint32_t n_seqs = ubatch.n_seqs; const uint32_t n_seq_tokens = ubatch.n_seq_tokens; + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (head > used + 2*ubatch.n_tokens) { + head = 0; + } + if (recurrent) { // For recurrent state architectures (like Mamba or RWKV), // each cache cell can store the state for a whole sequence. From eb5518f0153a6df63bccc8136e94891f99e4185c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 2 Apr 2025 12:57:04 +0300 Subject: [PATCH 5/8] cont : better name [no ci] --- src/llama-context.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index aeab74077ffb5..c9af88609ca6d 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1201,7 +1201,7 @@ int llama_context::decode(llama_batch & inp_batch) { const int64_t n_tokens_all = batch.n_tokens; const int64_t n_embd = hparams.n_embd; - llama_kv_cache_guard kvg(kv_self.get()); + llama_kv_cache_guard kv_guard(kv_self.get()); GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT @@ -1423,7 +1423,7 @@ int llama_context::decode(llama_batch & inp_batch) { } // finalize the batch processing - kvg.commit(); + kv_guard.commit(); // set output mappings { From 2c41dffcc72a20651e2dbfb2d2244c8c89fb1d3e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 2 Apr 2025 13:36:18 +0300 Subject: [PATCH 6/8] llama : fix llama_decode return code when could not find KV slot ggml-ci --- examples/parallel/parallel.cpp | 2 ++ src/llama-context.cpp | 2 +- src/llama-kv-cache.cpp | 2 ++ 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 588632f0432b2..cec9c45bb359e 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -106,6 +106,8 @@ int main(int argc, char ** argv) { common_params params; + params.n_predict = 128; + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PARALLEL)) { return 1; } diff --git a/src/llama-context.cpp b/src/llama-context.cpp index c9af88609ca6d..30d3dba95dad8 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1298,7 +1298,7 @@ int llama_context::decode(llama_batch & inp_batch) { { if (!kv_self->find_slot(ubatch)) { LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__); - return -3; + return 1; } if (!kv_self->recurrent) { diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 1b9db4bae3c67..35132a88f9001 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -204,6 +204,8 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos return false; } } + + return true; } for (uint32_t i = 0; i < size; ++i) { From 8ab37b183133028449212c58b7f5f84a302954ec Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 2 Apr 2025 14:10:12 +0300 Subject: [PATCH 7/8] context : change log err -> warn [no ci] --- src/llama-context.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 30d3dba95dad8..7d067afbe7399 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1297,7 +1297,8 @@ int llama_context::decode(llama_batch & inp_batch) { // find KV slot { if (!kv_self->find_slot(ubatch)) { - LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__); + LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens); + return 1; } From 626f822c98f5d36c7539ea6be1ab68295d15deb8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 2 Apr 2025 14:28:37 +0300 Subject: [PATCH 8/8] kv-cache : add comment + warning [no ci] --- src/llama-kv-cache.cpp | 6 ++++++ src/llama-kv-cache.h | 1 + 2 files changed, 7 insertions(+) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 35132a88f9001..7ba546c10ff74 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -481,6 +481,12 @@ void llama_kv_cache_unified::restore() { } void llama_kv_cache_unified::commit() { + if (pending.ranges.empty()) { + LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n", + __func__, "https://github.com/ggml-org/llama.cpp/pull/12695"); + return; + } + pending.ranges.clear(); } diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index d69885de8a69b..ff0ba3540d6e2 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -144,6 +144,7 @@ class llama_kv_cache_unified : public llama_kv_cache { uint32_t c1 = 0; }; + // pending cell updates that are not yet committed struct { std::vector ranges; } pending;