Skip to content

Commit f1ceed6

Browse files
committed
fix: Split up seq_rm interface into immutable can_seq_rm and mutating seq_rm
This allows the hybrid cache to check first before mutating any of the children. Branch: HybridCache Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
1 parent 9eca84e commit f1ceed6

File tree

2 files changed

+74
-21
lines changed

2 files changed

+74
-21
lines changed

src/llama-kv-cache.cpp

Lines changed: 63 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,11 @@ void llama_kv_cache_unified::set_full() {
450450
head = 0;
451451
}
452452

453+
bool llama_kv_cache_unified::can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
454+
// Unified attention cache can always do a sequence removal
455+
return true;
456+
}
457+
453458
llama_sbatch llama_kv_cache_unified::sbatch_init(
454459
const llama_batch & batch,
455460
bool logits_all) {
@@ -1488,39 +1493,33 @@ void llama_kv_cache_recurrent::clear() {
14881493
}
14891494

14901495
bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
1491-
uint32_t new_head = size;
1496+
if (!can_seq_rm(seq_id, p0, p1)) {
1497+
// could be fatal
1498+
return false;
1499+
}
14921500

1501+
uint32_t new_head = size;
14931502
if (p0 < 0) {
14941503
p0 = 0;
14951504
}
1496-
14971505
if (p1 < 0) {
14981506
p1 = std::numeric_limits<llama_pos>::max();
14991507
}
15001508

1501-
// models like Mamba or RWKV can't have a state partially erased
1502-
if (seq_id >= (int64_t) size) {
1503-
// could be fatal
1504-
return false;
1505-
}
15061509
if (0 <= seq_id) {
15071510
int32_t & tail_id = cells[seq_id].tail;
15081511
if (tail_id >= 0) {
15091512
const kv_cell & cell = cells[tail_id];
1510-
// partial intersection is invalid
1511-
if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
1512-
return false;
1513-
}
1513+
// already validated in can_seq_rm
1514+
GGML_ASSERT(!((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)));
15141515
// invalidate tails which will be cleared
15151516
if (p0 <= cell.pos && cell.pos < p1) {
15161517
tail_id = -1;
15171518
}
15181519
}
15191520
} else {
1520-
// seq_id is negative, then the range should include everything or nothing
1521-
if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
1522-
return false;
1523-
}
1521+
// already validated in can_seq_rm
1522+
GGML_ASSERT(!(p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())));
15241523
}
15251524

15261525
for (uint32_t i = 0; i < size; ++i) {
@@ -1722,6 +1721,35 @@ void llama_kv_cache_recurrent::set_full() {
17221721
head = 0;
17231722
}
17241723

1724+
bool llama_kv_cache_recurrent::can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
1725+
if (p0 < 0) {
1726+
p0 = 0;
1727+
}
1728+
1729+
if (p1 < 0) {
1730+
p1 = std::numeric_limits<llama_pos>::max();
1731+
}
1732+
// models like Mamba or RWKV can't have a state partially erased
1733+
if (seq_id >= (int64_t) size) {
1734+
// could be fatal
1735+
return false;
1736+
}
1737+
if (0 <= seq_id) {
1738+
const int32_t & tail_id = cells[seq_id].tail;
1739+
if (tail_id >= 0) {
1740+
const kv_cell & cell = cells[tail_id];
1741+
// partial intersection is invalid
1742+
if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
1743+
return false;
1744+
}
1745+
}
1746+
// seq_id is negative, then the range should include everything or nothing
1747+
} else if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
1748+
return false;
1749+
}
1750+
return true;
1751+
}
1752+
17251753
llama_sbatch llama_kv_cache_recurrent::sbatch_init(
17261754
const llama_batch & batch,
17271755
bool logits_all) {
@@ -2464,13 +2492,18 @@ void llama_kv_cache_hybrid::clear() {
24642492
}
24652493

24662494
bool llama_kv_cache_hybrid::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
2467-
// TODO: Will it cause problems if some caches are able to remove the seq
2468-
// but others aren't?
2469-
bool removed = true;
2495+
// First check if we can do this removal. This checks all children so that
2496+
// no mutation happens before we know if it's possible
2497+
if (!can_seq_rm(seq_id, p0, p1)) {
2498+
return false;
2499+
}
2500+
2501+
// Do the removal from each child which should never fail
24702502
for (const auto & cache : m_children) {
2471-
removed = cache->seq_rm(seq_id, p0, p1) && removed;
2503+
const bool failed = cache->seq_rm(seq_id, p0, p1);
2504+
GGML_ASSERT(!failed);
24722505
}
2473-
return removed;
2506+
return true;
24742507
}
24752508

24762509
void llama_kv_cache_hybrid::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
@@ -2537,6 +2570,15 @@ void llama_kv_cache_hybrid::set_full() {
25372570
}
25382571
}
25392572

2573+
bool llama_kv_cache_hybrid::can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
2574+
for (const auto & cache : m_children) {
2575+
if (!cache->can_seq_rm(seq_id, p0, p1)) {
2576+
return false;
2577+
}
2578+
}
2579+
return true;
2580+
}
2581+
25402582
llama_sbatch llama_kv_cache_hybrid::sbatch_init(const llama_batch & batch, bool logits_all) {
25412583
// If any of the caches are recurrent, require equal split
25422584
return llama_sbatch(batch, m_hparams.n_embd, !m_has_recurrent, logits_all);
@@ -2574,7 +2616,7 @@ int32_t llama_kv_cache_hybrid::get_n_tokens() const {
25742616

25752617
int32_t llama_kv_cache_hybrid::get_used_cells() const {
25762618
// TODO: Is this correct?
2577-
// Return the largetst number of used cells
2619+
// Return the largest number of used cells
25782620
int32_t used_cells = -1;
25792621
for (const auto & cache : m_children) {
25802622
used_cells = std::max(used_cells, cache->get_used_cells());

src/llama-kv-cache.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ struct llama_kv_cache : public llama_memory_i {
3737
// simulate full cache, used for allocating worst-case compute buffers
3838
virtual void set_full() = 0;
3939

40+
// sometimes it is useful to check whether a cache can remove a sequence
41+
// before attempting to mutate the cache (eg a hybrid cache with multiple
42+
// children to keep in sync)
43+
virtual bool can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const = 0;
44+
4045
//
4146
// batch processing
4247
//
@@ -150,6 +155,8 @@ class llama_kv_cache_unified : public llama_kv_cache {
150155

151156
void set_full() override;
152157

158+
bool can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const override;
159+
153160
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
154161

155162
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
@@ -318,6 +325,8 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
318325

319326
void set_full() override;
320327

328+
bool can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const override;
329+
321330
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
322331

323332
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
@@ -433,6 +442,8 @@ class llama_kv_cache_hybrid : public llama_kv_cache {
433442

434443
void set_full() override;
435444

445+
bool can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const override;
446+
436447
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
437448

438449
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;

0 commit comments

Comments
 (0)