@@ -17,6 +17,9 @@ struct llama_ubatch;
17
17
struct llama_kv_cache : public llama_memory_i {
18
18
using llama_memory_i::llama_memory_i;
19
19
20
+ virtual void restore () = 0; // call if batch processing fails to restore the cache state
21
+ virtual void commit () = 0; // call after successful batch processing
22
+
20
23
virtual int32_t get_n_tokens () const = 0;
21
24
virtual uint32_t get_used_cells () const = 0; // TODO: remove, this is too-specific to the unified cache
22
25
@@ -25,9 +28,24 @@ struct llama_kv_cache : public llama_memory_i {
25
28
bool get_can_edit () const override { return get_can_shift (); }
26
29
};
27
30
31
+ struct llama_kv_cache_guard {
32
+ llama_kv_cache_guard (llama_kv_cache * kv) : kv(kv) {}
33
+
34
+ ~llama_kv_cache_guard () {
35
+ kv->restore ();
36
+ }
37
+
38
+ void commit () {
39
+ kv->commit ();
40
+ }
41
+
42
+ private:
43
+ llama_kv_cache * kv;
44
+ };
45
+
28
46
struct llama_kv_cell {
29
47
llama_pos pos = -1 ;
30
- llama_pos delta = 0 ;
48
+ llama_pos delta = 0 ;
31
49
int32_t src = -1 ; // used by recurrent state models to copy states
32
50
int32_t tail = -1 ;
33
51
@@ -46,17 +64,6 @@ struct llama_kv_cell {
46
64
}
47
65
};
48
66
49
- // a structure holds information about the slot found in llama_kv_cache_find_slot
50
- struct llama_kv_cache_slot_info {
51
- std::pair<uint32_t , uint32_t > boundaries; // slot boundaries [begin, end)
52
- bool found = false ; // the slot was found
53
-
54
- explicit llama_kv_cache_slot_info (bool found_) : found{found_} {}
55
- llama_kv_cache_slot_info (uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true } {}
56
-
57
- operator bool () const { return found; }
58
- };
59
-
60
67
// ring-buffer of cached KV data
61
68
// TODO: pimpl
62
69
// TODO: add notion of max sequences
@@ -93,6 +100,9 @@ class llama_kv_cache_unified : public llama_kv_cache {
93
100
void clear () override ;
94
101
void defrag () override ;
95
102
103
+ virtual void restore () override ;
104
+ virtual void commit () override ;
105
+
96
106
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override ;
97
107
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override ;
98
108
void seq_keep (llama_seq_id seq_id) override ;
@@ -105,10 +115,9 @@ class llama_kv_cache_unified : public llama_kv_cache {
105
115
106
116
// find an empty slot of size "n_tokens" in the cache
107
117
// updates the cache head
108
- // returns a structure holding information about the slot found
109
118
// Note: On success, it's important that cache.head points
110
119
// to the first cell of the slot.
111
- llama_kv_cache_slot_info find_slot (const llama_ubatch & batch);
120
+ bool find_slot (const llama_ubatch & batch);
112
121
113
122
// TODO: maybe not needed
114
123
uint32_t get_padding (const llama_cparams & cparams) const ;
@@ -128,7 +137,18 @@ class llama_kv_cache_unified : public llama_kv_cache {
128
137
// return true if cells have been moved
129
138
bool defrag_prepare (int32_t n_max_nodes);
130
139
131
- // state save/load
140
+ // commit/restore cache
141
+
142
+ struct slot_range {
143
+ uint32_t p0 = 0 ;
144
+ uint32_t p1 = 0 ;
145
+ };
146
+
147
+ struct {
148
+ std::vector<slot_range> ranges;
149
+ } pending;
150
+
151
+ // state write/load
132
152
133
153
void state_write (llama_io_write_i & io, llama_seq_id seq_id = -1 ) const ;
134
154
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1 );
@@ -183,59 +203,6 @@ class llama_kv_cache_unified : public llama_kv_cache {
183
203
// using llama_kv_cache_unified::llama_kv_cache_unified;
184
204
// };
185
205
186
- //
187
- // kv cache restore
188
- //
189
-
190
- // saves the kv_cache state for future recovery.
191
- // used to rollback llama_kv_cache_find_slot changes.
192
- struct llama_kv_slot_restorer {
193
- struct llama_kv_cache_state {
194
- uint32_t head = 0 ;
195
- uint32_t n = 0 ;
196
- } old_state;
197
-
198
- // for non-recurrent models only
199
- // list of slots to restore
200
- std::vector<std::pair<uint32_t , uint32_t >> slot_boundaries;
201
-
202
- bool do_restore = false ;
203
-
204
- llama_kv_cache_unified & cache;
205
-
206
- explicit llama_kv_slot_restorer (llama_kv_cache_unified & cache) : cache(cache) {
207
- old_state.head = cache.head ;
208
- old_state.n = cache.n ;
209
- }
210
-
211
- // saves a slot information for future restoration
212
- void save (const llama_kv_cache_slot_info & slot) {
213
- if (slot) {
214
- do_restore = true ;
215
- if (slot.boundaries .first != slot.boundaries .second ) {
216
- slot_boundaries.push_back (slot.boundaries );
217
- }
218
- }
219
- }
220
-
221
- // must be explicitly called to restore the kv_cache state
222
- // and rollback changes from all llama_kv_cache_find_slot calls
223
- void restore () {
224
- if (do_restore) {
225
- cache.head = old_state.head ;
226
- cache.n = old_state.n ;
227
-
228
- if (cache.recurrent ) { // recurrent models like Mamba or RWKV can't have a state partially erased
229
- cache.seq_rm (-1 , -1 , -1 );
230
- } else {
231
- for (auto & slot : slot_boundaries) {
232
- cache.seq_rm (-1 , slot.first , slot.second );
233
- }
234
- }
235
- }
236
- }
237
- };
238
-
239
206
// TODO: maybe become part of the public llama_kv_cache in the future
240
207
int32_t llama_kv_cache_n_tokens (const llama_kv_cache * kv);
241
208
0 commit comments