Skip to content

Commit f1d179e

Browse files
committed
llama : refactor kv cache guard
ggml-ci
1 parent e39e727 commit f1d179e

File tree

3 files changed

+61
-108
lines changed

3 files changed

+61
-108
lines changed

src/llama-context.cpp

Lines changed: 3 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1201,33 +1201,7 @@ int llama_context::decode(llama_batch & inp_batch) {
12011201
const int64_t n_tokens_all = batch.n_tokens;
12021202
const int64_t n_embd = hparams.n_embd;
12031203

1204-
// TODO: remove this stuff
1205-
class batch_guard {
1206-
public:
1207-
batch_guard(llama_kv_cache_unified & kv_self) : kv_slot_restorer(kv_self) {
1208-
}
1209-
1210-
~batch_guard() {
1211-
if (!is_done) {
1212-
kv_slot_restorer.restore();
1213-
}
1214-
}
1215-
1216-
void done() {
1217-
is_done = true;
1218-
}
1219-
1220-
void save(const llama_kv_cache_slot_info & slot_info) {
1221-
kv_slot_restorer.save(slot_info);
1222-
}
1223-
1224-
private:
1225-
bool is_done = false;
1226-
1227-
llama_kv_slot_restorer kv_slot_restorer;
1228-
};
1229-
1230-
batch_guard bg(*kv_self);
1204+
llama_kv_cache_guard kvg(kv_self.get());
12311205

12321206
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
12331207

@@ -1327,14 +1301,11 @@ int llama_context::decode(llama_batch & inp_batch) {
13271301
kv_self->head = 0;
13281302
}
13291303

1330-
const auto slot_info = kv_self->find_slot(ubatch);
1331-
if (!slot_info) {
1304+
if (!kv_self->find_slot(ubatch)) {
13321305
LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__);
13331306
return -3;
13341307
}
13351308

1336-
bg.save(slot_info);
1337-
13381309
if (!kv_self->recurrent) {
13391310
// a heuristic, to avoid attending the full cache if it is not yet utilized
13401311
// after enough generations, the benefit from this heuristic disappears
@@ -1467,7 +1438,7 @@ int llama_context::decode(llama_batch & inp_batch) {
14671438
}
14681439

14691440
// finalize the batch processing
1470-
bg.done();
1441+
kvg.commit();
14711442

14721443
// set output mappings
14731444
{

src/llama-kv-cache.cpp

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
#include <map>
1212
#include <stdexcept>
1313

14-
static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false};
15-
1614
llama_kv_cache_unified::llama_kv_cache_unified(const llama_hparams & hparams, callbacks cbs) : hparams(hparams), cbs(std::move(cbs)) {
1715
}
1816

@@ -446,11 +444,25 @@ void llama_kv_cache_unified::defrag() {
446444
}
447445
}
448446

447+
void llama_kv_cache_unified::restore() {
448+
if (pending.ranges.empty()) {
449+
return;
450+
}
451+
452+
for (auto & range : pending.ranges) {
453+
seq_rm(-1, range.p0, range.p1);
454+
}
455+
}
456+
457+
void llama_kv_cache_unified::commit() {
458+
pending.ranges.clear();
459+
}
460+
449461
bool llama_kv_cache_unified::get_can_shift() const {
450462
return can_shift;
451463
}
452464

453-
llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
465+
bool llama_kv_cache_unified::find_slot(
454466
const llama_ubatch & ubatch) {
455467
const uint32_t n_tokens = ubatch.n_tokens;
456468
const uint32_t n_seqs = ubatch.n_seqs;
@@ -477,7 +489,7 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
477489
// too big seq_id
478490
// TODO: would it be possible to resize the cache instead?
479491
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size);
480-
return llama_kv_cache_slot_info_failed;
492+
return false;
481493
}
482494
if (j > 0) {
483495
llama_kv_cell & seq = cells[seq_id];
@@ -616,14 +628,14 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
616628
[](const llama_kv_cell& cell){ return !cell.is_empty(); });
617629

618630
// sanity check
619-
return llama_kv_cache_slot_info(n >= n_seqs);
631+
return n >= n_seqs;
620632
}
621633

622634
// otherwise, one cell per token.
623635

624636
if (n_tokens > size) {
625637
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %d\n", __func__, n_tokens, size);
626-
return llama_kv_cache_slot_info_failed;
638+
return false;
627639
}
628640

629641
uint32_t n_tested = 0;
@@ -651,7 +663,7 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
651663

652664
if (n_tested >= size) {
653665
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
654-
return llama_kv_cache_slot_info_failed;
666+
return false;
655667
}
656668
}
657669

@@ -668,7 +680,9 @@ llama_kv_cache_slot_info llama_kv_cache_unified::find_slot(
668680

669681
used += n_tokens;
670682

671-
return llama_kv_cache_slot_info(head, head + n_tokens);
683+
pending.ranges.push_back({head, head + n_tokens});
684+
685+
return true;
672686
}
673687

674688
uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) const {
@@ -1033,6 +1047,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
10331047
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
10341048
return false;
10351049
}
1050+
commit();
10361051

10371052
// DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
10381053
// Assume that this is one contiguous block of cells

src/llama-kv-cache.h

Lines changed: 35 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ struct llama_ubatch;
1717
struct llama_kv_cache : public llama_memory_i {
1818
using llama_memory_i::llama_memory_i;
1919

20+
virtual void restore() = 0; // call if batch processing fails to restore the cache state
21+
virtual void commit() = 0; // call after successful batch processing
22+
2023
virtual int32_t get_n_tokens() const = 0;
2124
virtual uint32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
2225

@@ -25,9 +28,24 @@ struct llama_kv_cache : public llama_memory_i {
2528
bool get_can_edit() const override { return get_can_shift(); }
2629
};
2730

31+
struct llama_kv_cache_guard {
32+
llama_kv_cache_guard(llama_kv_cache * kv) : kv(kv) {}
33+
34+
~llama_kv_cache_guard() {
35+
kv->restore();
36+
}
37+
38+
void commit() {
39+
kv->commit();
40+
}
41+
42+
private:
43+
llama_kv_cache * kv;
44+
};
45+
2846
struct llama_kv_cell {
2947
llama_pos pos = -1;
30-
llama_pos delta = 0;
48+
llama_pos delta = 0;
3149
int32_t src = -1; // used by recurrent state models to copy states
3250
int32_t tail = -1;
3351

@@ -46,17 +64,6 @@ struct llama_kv_cell {
4664
}
4765
};
4866

49-
// a structure holds information about the slot found in llama_kv_cache_find_slot
50-
struct llama_kv_cache_slot_info {
51-
std::pair<uint32_t, uint32_t> boundaries; // slot boundaries [begin, end)
52-
bool found = false; // the slot was found
53-
54-
explicit llama_kv_cache_slot_info(bool found_) : found{found_} {}
55-
llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {}
56-
57-
operator bool() const { return found; }
58-
};
59-
6067
// ring-buffer of cached KV data
6168
// TODO: pimpl
6269
// TODO: add notion of max sequences
@@ -93,6 +100,9 @@ class llama_kv_cache_unified : public llama_kv_cache {
93100
void clear() override;
94101
void defrag() override;
95102

103+
virtual void restore() override;
104+
virtual void commit() override;
105+
96106
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
97107
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
98108
void seq_keep(llama_seq_id seq_id) override;
@@ -105,10 +115,9 @@ class llama_kv_cache_unified : public llama_kv_cache {
105115

106116
// find an empty slot of size "n_tokens" in the cache
107117
// updates the cache head
108-
// returns a structure holding information about the slot found
109118
// Note: On success, it's important that cache.head points
110119
// to the first cell of the slot.
111-
llama_kv_cache_slot_info find_slot(const llama_ubatch & batch);
120+
bool find_slot(const llama_ubatch & batch);
112121

113122
// TODO: maybe not needed
114123
uint32_t get_padding(const llama_cparams & cparams) const;
@@ -128,7 +137,18 @@ class llama_kv_cache_unified : public llama_kv_cache {
128137
// return true if cells have been moved
129138
bool defrag_prepare(int32_t n_max_nodes);
130139

131-
// state save/load
140+
// commit/restore cache
141+
142+
struct slot_range {
143+
uint32_t p0 = 0;
144+
uint32_t p1 = 0;
145+
};
146+
147+
struct {
148+
std::vector<slot_range> ranges;
149+
} pending;
150+
151+
// state write/load
132152

133153
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const;
134154
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1);
@@ -183,59 +203,6 @@ class llama_kv_cache_unified : public llama_kv_cache {
183203
// using llama_kv_cache_unified::llama_kv_cache_unified;
184204
//};
185205

186-
//
187-
// kv cache restore
188-
//
189-
190-
// saves the kv_cache state for future recovery.
191-
// used to rollback llama_kv_cache_find_slot changes.
192-
struct llama_kv_slot_restorer {
193-
struct llama_kv_cache_state {
194-
uint32_t head = 0;
195-
uint32_t n = 0;
196-
} old_state;
197-
198-
// for non-recurrent models only
199-
// list of slots to restore
200-
std::vector<std::pair<uint32_t, uint32_t>> slot_boundaries;
201-
202-
bool do_restore = false;
203-
204-
llama_kv_cache_unified & cache;
205-
206-
explicit llama_kv_slot_restorer(llama_kv_cache_unified & cache) : cache(cache) {
207-
old_state.head = cache.head;
208-
old_state.n = cache.n;
209-
}
210-
211-
// saves a slot information for future restoration
212-
void save(const llama_kv_cache_slot_info & slot) {
213-
if (slot) {
214-
do_restore = true;
215-
if (slot.boundaries.first != slot.boundaries.second) {
216-
slot_boundaries.push_back(slot.boundaries);
217-
}
218-
}
219-
}
220-
221-
// must be explicitly called to restore the kv_cache state
222-
// and rollback changes from all llama_kv_cache_find_slot calls
223-
void restore() {
224-
if (do_restore) {
225-
cache.head = old_state.head;
226-
cache.n = old_state.n;
227-
228-
if (cache.recurrent) { // recurrent models like Mamba or RWKV can't have a state partially erased
229-
cache.seq_rm(-1, -1, -1);
230-
} else {
231-
for (auto & slot : slot_boundaries) {
232-
cache.seq_rm(-1, slot.first, slot.second);
233-
}
234-
}
235-
}
236-
}
237-
};
238-
239206
// TODO: maybe become part of the public llama_kv_cache in the future
240207
int32_t llama_kv_cache_n_tokens(const llama_kv_cache * kv);
241208

0 commit comments

Comments
 (0)