Skip to content

Commit 6b50ba7

Browse files
committed
kv-cache : simplify interface (wip)
ggml-ci
1 parent 5ef7559 commit 6b50ba7

File tree

4 files changed

+108
-93
lines changed

4 files changed

+108
-93
lines changed

src/llama-context.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1108,7 +1108,7 @@ int llama_context::decode(llama_batch & inp_batch) {
11081108

11091109
// decide if we need to defrag the kv cache
11101110
if (cparams.defrag_thold > 0.0f) {
1111-
kv_self->defrag(cparams.defrag_thold);
1111+
kv_self->defrag_sched(cparams.defrag_thold);
11121112
}
11131113

11141114
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
@@ -2150,7 +2150,7 @@ void llama_kv_cache_seq_cp(
21502150
llama_seq_id seq_id_dst,
21512151
llama_pos p0,
21522152
llama_pos p1) {
2153-
return llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
2153+
llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
21542154
}
21552155

21562156
void llama_kv_self_seq_cp(
@@ -2164,14 +2164,14 @@ void llama_kv_self_seq_cp(
21642164
return;
21652165
}
21662166

2167-
return kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
2167+
kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
21682168
}
21692169

21702170
// deprecated
21712171
void llama_kv_cache_seq_keep(
21722172
llama_context * ctx,
21732173
llama_seq_id seq_id) {
2174-
return llama_kv_self_seq_keep(ctx, seq_id);
2174+
llama_kv_self_seq_keep(ctx, seq_id);
21752175
}
21762176

21772177
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
@@ -2180,7 +2180,7 @@ void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
21802180
return;
21812181
}
21822182

2183-
return kv->seq_keep(seq_id);
2183+
kv->seq_keep(seq_id);
21842184
}
21852185

21862186
// deprecated
@@ -2190,7 +2190,7 @@ void llama_kv_cache_seq_add(
21902190
llama_pos p0,
21912191
llama_pos p1,
21922192
llama_pos delta) {
2193-
return llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
2193+
llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
21942194
}
21952195

21962196
void llama_kv_self_seq_add(
@@ -2204,7 +2204,7 @@ void llama_kv_self_seq_add(
22042204
return;
22052205
}
22062206

2207-
return kv->seq_add(seq_id, p0, p1, delta);
2207+
kv->seq_add(seq_id, p0, p1, delta);
22082208
}
22092209

22102210
// deprecated
@@ -2214,7 +2214,7 @@ void llama_kv_cache_seq_div(
22142214
llama_pos p0,
22152215
llama_pos p1,
22162216
int d) {
2217-
return llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
2217+
llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
22182218
}
22192219

22202220
void llama_kv_self_seq_div(
@@ -2228,7 +2228,7 @@ void llama_kv_self_seq_div(
22282228
return;
22292229
}
22302230

2231-
return kv->seq_div(seq_id, p0, p1, d);
2231+
kv->seq_div(seq_id, p0, p1, d);
22322232
}
22332233

22342234
// deprecated
@@ -2247,7 +2247,7 @@ llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
22472247

22482248
// deprecated
22492249
void llama_kv_cache_defrag(llama_context * ctx) {
2250-
return llama_kv_self_defrag(ctx);
2250+
llama_kv_self_defrag(ctx);
22512251
}
22522252

22532253
void llama_kv_self_defrag(llama_context * ctx) {
@@ -2257,7 +2257,7 @@ void llama_kv_self_defrag(llama_context * ctx) {
22572257
}
22582258

22592259
// force defrag
2260-
return kv->defrag(-1.0f);
2260+
kv->defrag_sched(-1.0f);
22612261
}
22622262

22632263
// deprecated

src/llama-graph.cpp

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -284,24 +284,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
284284

285285
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
286286
for (uint32_t i = 0; i < n_kv; ++i) {
287-
const uint32_t cell_id = i + kv_self->head;
288-
289-
//////////////////////////////////////////////
290-
// TODO: this should not mutate the KV cache !
291-
llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_recurrent *>(kv_self)->cells[i];
292-
293-
// prevent out-of-bound sources
294-
if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self->size) {
295-
kv_cell.src = cell_id;
296-
}
297-
298-
data[i] = kv_cell.src;
299-
300-
// TODO: do not mutate the KV cache
301-
// ensure copy only happens once
302-
if (kv_cell.src != (int32_t) cell_id) {
303-
kv_cell.src = cell_id;
304-
}
287+
data[i] = kv_self->s_copy(i);
305288
}
306289
}
307290
}
@@ -317,18 +300,7 @@ void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
317300

318301
// clear unused states
319302
for (int i = 0; i < n_kv; ++i) {
320-
const uint32_t cell_id = i + kv_self->head;
321-
322-
//////////////////////////////////////////////
323-
// TODO: this should not mutate the KV cache !
324-
llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_recurrent *>(kv_self)->cells[i];
325-
326-
data[i] = (float) (kv_cell.src >= 0);
327-
328-
// only clear once
329-
if (kv_cell.src < 0) {
330-
kv_cell.src = cell_id;
331-
}
303+
data[i] = kv_self->s_mask(i);
332304
}
333305
}
334306
}

src/llama-kv-cache.cpp

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -130,14 +130,6 @@ int32_t llama_kv_cache_unified::get_used_cells() const {
130130
return used;
131131
}
132132

133-
bool llama_kv_cache_unified::get_has_shift() const {
134-
return has_shift;
135-
}
136-
137-
bool llama_kv_cache_unified::get_do_defrag() const {
138-
return do_defrag;
139-
}
140-
141133
size_t llama_kv_cache_unified::total_size() const {
142134
size_t size = 0;
143135
for (const auto & buf : bufs) {
@@ -358,10 +350,10 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
358350
return result;
359351
}
360352

361-
void llama_kv_cache_unified::defrag(float thold) {
353+
void llama_kv_cache_unified::defrag_sched(float thold) {
362354
// - do not defrag small contexts (i.e. < 2048 tokens)
363355
// - count the padding towards the number of used tokens
364-
const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - float(used + padding)/float(n)) : 0.0f;
356+
const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(used + padding)/n)) : 0.0f;
365357

366358
// queue defragmentation for next llama_kv_cache_update
367359
if (fragmentation > thold) {
@@ -699,7 +691,7 @@ bool llama_kv_cache_unified::update(const graph_params & params) {
699691

700692
const auto & sched = params.sched;
701693

702-
if (get_has_shift()) {
694+
if (has_shift) {
703695
if (!get_can_shift()) {
704696
GGML_ABORT("The current KV cache / model configuration does not support K-shift");
705697
}
@@ -732,7 +724,7 @@ bool llama_kv_cache_unified::update(const graph_params & params) {
732724
}
733725
}
734726

735-
if (get_do_defrag()) {
727+
if (do_defrag) {
736728
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
737729

738730
if (defrag_prepare(params.n_max_nodes)) {
@@ -1496,14 +1488,6 @@ int32_t llama_kv_cache_recurrent::get_used_cells() const {
14961488
return used;
14971489
}
14981490

1499-
bool llama_kv_cache_recurrent::get_has_shift() const {
1500-
return false;
1501-
}
1502-
1503-
bool llama_kv_cache_recurrent::get_do_defrag() const {
1504-
return false;
1505-
}
1506-
15071491
size_t llama_kv_cache_recurrent::total_size() const {
15081492
size_t size = 0;
15091493
for (const auto & buf : bufs) {
@@ -1716,7 +1700,7 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
17161700
return result;
17171701
}
17181702

1719-
void llama_kv_cache_recurrent::defrag(float thold) {
1703+
void llama_kv_cache_recurrent::defrag_sched(float thold) {
17201704
GGML_UNUSED(thold);
17211705
// noop
17221706
}
@@ -1742,6 +1726,46 @@ bool llama_kv_cache_recurrent::get_can_shift() const {
17421726
return false;
17431727
}
17441728

1729+
int32_t llama_kv_cache_recurrent::s_copy(int i) const {
1730+
const uint32_t cell_id = i + head;
1731+
1732+
//////////////////////////////////////////////
1733+
// TODO: this should not mutate the KV cache !
1734+
llama_kv_cell & kv_cell = const_cast<llama_kv_cell &>(cells[i]);
1735+
1736+
// prevent out-of-bound sources
1737+
if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= size) {
1738+
kv_cell.src = cell_id;
1739+
}
1740+
1741+
int32_t res = kv_cell.src;
1742+
1743+
// TODO: do not mutate the KV cache
1744+
// ensure copy only happens once
1745+
if (kv_cell.src != (int32_t) cell_id) {
1746+
kv_cell.src = cell_id;
1747+
}
1748+
1749+
return res;
1750+
}
1751+
1752+
float llama_kv_cache_recurrent::s_mask(int i) const {
1753+
const uint32_t cell_id = i + head;
1754+
1755+
//////////////////////////////////////////////
1756+
// TODO: this should not mutate the KV cache !
1757+
llama_kv_cell & kv_cell = const_cast<llama_kv_cell &>(cells[i]);
1758+
1759+
float res = (float) (kv_cell.src >= 0);
1760+
1761+
// only clear once
1762+
if (kv_cell.src < 0) {
1763+
kv_cell.src = cell_id;
1764+
}
1765+
1766+
return res;
1767+
}
1768+
17451769
bool llama_kv_cache_recurrent::find_slot(
17461770
const llama_ubatch & ubatch) {
17471771
const uint32_t n_tokens = ubatch.n_tokens;

0 commit comments

Comments
 (0)