Skip to content

Commit 5fbbb28

Browse files
committed
feat: Add can_seq_rm API to llama_kv_cache API
This will be key for the hybrid cache which needs to be able to validate that all children can perform seq_rm cleanly before attempting to remove the seq from any single child to avoid ending up in a corrupted state. Branch: HybridCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 5268278 commit 5fbbb28

File tree

2 files changed

+65
-15
lines changed

2 files changed

+65
-15
lines changed

src/llama-kv-cache.cpp

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,14 @@ void llama_kv_cache_unified::set_full() {
463463
head = 0;
464464
}
465465

466+
bool llama_kv_cache_unified::can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
467+
GGML_UNUSED(seq_id);
468+
GGML_UNUSED(p0);
469+
GGML_UNUSED(p1);
470+
// Unified attention cache can always do a sequence removal
471+
return true;
472+
}
473+
466474
llama_sbatch llama_kv_cache_unified::sbatch_init(const llama_batch & batch, bool logits_all) {
467475
return llama_sbatch(batch, hparams.n_embd, true, logits_all);
468476
}
@@ -1773,6 +1781,15 @@ void llama_kv_cache_unified_iswa::set_full() {
17731781
kv_swa ->set_full();
17741782
}
17751783

1784+
bool llama_kv_cache_unified_iswa::can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
1785+
GGML_UNUSED(seq_id);
1786+
GGML_UNUSED(p0);
1787+
GGML_UNUSED(p1);
1788+
// Unified attention caches can always do a sequence removal, so since both
1789+
// children can, the parent can as well.
1790+
return true;
1791+
}
1792+
17761793
llama_sbatch llama_kv_cache_unified_iswa::sbatch_init(const llama_batch & batch, bool logits_all) {
17771794
pending.clear();
17781795

@@ -1968,39 +1985,33 @@ void llama_kv_cache_recurrent::clear() {
19681985
}
19691986

19701987
bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
1971-
uint32_t new_head = size;
1988+
if (!can_seq_rm(seq_id, p0, p1)) {
1989+
// could be fatal
1990+
return false;
1991+
}
19721992

1993+
uint32_t new_head = size;
19731994
if (p0 < 0) {
19741995
p0 = 0;
19751996
}
1976-
19771997
if (p1 < 0) {
19781998
p1 = std::numeric_limits<llama_pos>::max();
19791999
}
19802000

1981-
// models like Mamba or RWKV can't have a state partially erased
1982-
if (seq_id >= (int64_t) size) {
1983-
// could be fatal
1984-
return false;
1985-
}
19862001
if (0 <= seq_id) {
19872002
int32_t & tail_id = cells[seq_id].tail;
19882003
if (tail_id >= 0) {
19892004
const kv_cell & cell = cells[tail_id];
1990-
// partial intersection is invalid
1991-
if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
1992-
return false;
1993-
}
2005+
// already validated in can_seq_rm
2006+
GGML_ASSERT(!((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)));
19942007
// invalidate tails which will be cleared
19952008
if (p0 <= cell.pos && cell.pos < p1) {
19962009
tail_id = -1;
19972010
}
19982011
}
19992012
} else {
2000-
// seq_id is negative, then the range should include everything or nothing
2001-
if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
2002-
return false;
2003-
}
2013+
// already validated in can_seq_rm
2014+
GGML_ASSERT(!(p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())));
20042015
}
20052016

20062017
for (uint32_t i = 0; i < size; ++i) {
@@ -2217,6 +2228,34 @@ void llama_kv_cache_recurrent::set_full() {
22172228
n = size;
22182229
head = 0;
22192230
}
2231+
bool llama_kv_cache_recurrent::can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
2232+
if (p0 < 0) {
2233+
p0 = 0;
2234+
}
2235+
2236+
if (p1 < 0) {
2237+
p1 = std::numeric_limits<llama_pos>::max();
2238+
}
2239+
// models like Mamba or RWKV can't have a state partially erased
2240+
if (seq_id >= (int64_t) size) {
2241+
// could be fatal
2242+
return false;
2243+
}
2244+
if (0 <= seq_id) {
2245+
const int32_t & tail_id = cells[seq_id].tail;
2246+
if (tail_id >= 0) {
2247+
const kv_cell & cell = cells[tail_id];
2248+
// partial intersection is invalid
2249+
if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
2250+
return false;
2251+
}
2252+
}
2253+
// seq_id is negative, then the range should include everything or nothing
2254+
} else if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
2255+
return false;
2256+
}
2257+
return true;
2258+
}
22202259

22212260
llama_sbatch llama_kv_cache_recurrent::sbatch_init(
22222261
const llama_batch & batch,

src/llama-kv-cache.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ struct llama_kv_cache : public llama_memory_i {
3737
// simulate full cache, used for allocating worst-case compute buffers
3838
virtual void set_full() = 0;
3939

40+
// sometimes it is useful to check whether a cache can remove a sequence
41+
// before attempting to mutate the cache (eg a hybrid cache with multiple
42+
// children to keep in sync)
43+
virtual bool can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const = 0;
44+
4045
//
4146
// batch processing
4247
//
@@ -142,6 +147,8 @@ class llama_kv_cache_unified : public llama_kv_cache {
142147

143148
void set_full() override;
144149

150+
bool can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const override;
151+
145152
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
146153
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
147154

@@ -353,6 +360,8 @@ class llama_kv_cache_unified_iswa : public llama_kv_cache {
353360

354361
void set_full() override;
355362

363+
bool can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const override;
364+
356365
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
357366
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
358367

@@ -464,6 +473,8 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
464473

465474
void set_full() override;
466475

476+
bool can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const override;
477+
467478
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
468479
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
469480

0 commit comments

Comments
 (0)