@@ -450,6 +450,11 @@ void llama_kv_cache_unified::set_full() {
450
450
head = 0 ;
451
451
}
452
452
453
+ bool llama_kv_cache_unified::can_seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
454
+ // Unified attention cache can always do a sequence removal
455
+ return true ;
456
+ }
457
+
453
458
llama_sbatch llama_kv_cache_unified::sbatch_init (
454
459
const llama_batch & batch,
455
460
bool logits_all) {
@@ -1488,39 +1493,33 @@ void llama_kv_cache_recurrent::clear() {
1488
1493
}
1489
1494
1490
1495
bool llama_kv_cache_recurrent::seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
1491
- uint32_t new_head = size;
1496
+ if (!can_seq_rm (seq_id, p0, p1)) {
1497
+ // could be fatal
1498
+ return false ;
1499
+ }
1492
1500
1501
+ uint32_t new_head = size;
1493
1502
if (p0 < 0 ) {
1494
1503
p0 = 0 ;
1495
1504
}
1496
-
1497
1505
if (p1 < 0 ) {
1498
1506
p1 = std::numeric_limits<llama_pos>::max ();
1499
1507
}
1500
1508
1501
- // models like Mamba or RWKV can't have a state partially erased
1502
- if (seq_id >= (int64_t ) size) {
1503
- // could be fatal
1504
- return false ;
1505
- }
1506
1509
if (0 <= seq_id) {
1507
1510
int32_t & tail_id = cells[seq_id].tail ;
1508
1511
if (tail_id >= 0 ) {
1509
1512
const kv_cell & cell = cells[tail_id];
1510
- // partial intersection is invalid
1511
- if ((0 < p0 && p0 <= cell.pos ) || (0 < p1 && p1 <= cell.pos )) {
1512
- return false ;
1513
- }
1513
+ // already validated in can_seq_rm
1514
+ GGML_ASSERT (!((0 < p0 && p0 <= cell.pos ) || (0 < p1 && p1 <= cell.pos )));
1514
1515
// invalidate tails which will be cleared
1515
1516
if (p0 <= cell.pos && cell.pos < p1) {
1516
1517
tail_id = -1 ;
1517
1518
}
1518
1519
}
1519
1520
} else {
1520
- // seq_id is negative, then the range should include everything or nothing
1521
- if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max ())) {
1522
- return false ;
1523
- }
1521
+ // already validated in can_seq_rm
1522
+ GGML_ASSERT (!(p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max ())));
1524
1523
}
1525
1524
1526
1525
for (uint32_t i = 0 ; i < size; ++i) {
@@ -1722,6 +1721,35 @@ void llama_kv_cache_recurrent::set_full() {
1722
1721
head = 0 ;
1723
1722
}
1724
1723
1724
+ bool llama_kv_cache_recurrent::can_seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
1725
+ if (p0 < 0 ) {
1726
+ p0 = 0 ;
1727
+ }
1728
+
1729
+ if (p1 < 0 ) {
1730
+ p1 = std::numeric_limits<llama_pos>::max ();
1731
+ }
1732
+ // models like Mamba or RWKV can't have a state partially erased
1733
+ if (seq_id >= (int64_t ) size) {
1734
+ // could be fatal
1735
+ return false ;
1736
+ }
1737
+ if (0 <= seq_id) {
1738
+ const int32_t & tail_id = cells[seq_id].tail ;
1739
+ if (tail_id >= 0 ) {
1740
+ const kv_cell & cell = cells[tail_id];
1741
+ // partial intersection is invalid
1742
+ if ((0 < p0 && p0 <= cell.pos ) || (0 < p1 && p1 <= cell.pos )) {
1743
+ return false ;
1744
+ }
1745
+ }
1746
+ // seq_id is negative, then the range should include everything or nothing
1747
+ } else if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max ())) {
1748
+ return false ;
1749
+ }
1750
+ return true ;
1751
+ }
1752
+
1725
1753
llama_sbatch llama_kv_cache_recurrent::sbatch_init (
1726
1754
const llama_batch & batch,
1727
1755
bool logits_all) {
@@ -2464,13 +2492,18 @@ void llama_kv_cache_hybrid::clear() {
2464
2492
}
2465
2493
2466
2494
bool llama_kv_cache_hybrid::seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
2467
- // TODO: Will it cause problems if some caches are able to remove the seq
2468
- // but others aren't?
2469
- bool removed = true ;
2495
+ // First check if we can do this removal. This checks all children so that
2496
+ // no mutation happens before we know if it's possible
2497
+ if (!can_seq_rm (seq_id, p0, p1)) {
2498
+ return false ;
2499
+ }
2500
+
2501
+ // Do the removal from each child which should never fail
2470
2502
for (const auto & cache : m_children) {
2471
- removed = cache->seq_rm (seq_id, p0, p1) && removed;
2503
+ const bool failed = cache->seq_rm (seq_id, p0, p1);
2504
+ GGML_ASSERT (!failed);
2472
2505
}
2473
- return removed ;
2506
+ return true ;
2474
2507
}
2475
2508
2476
2509
void llama_kv_cache_hybrid::seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
@@ -2537,6 +2570,15 @@ void llama_kv_cache_hybrid::set_full() {
2537
2570
}
2538
2571
}
2539
2572
2573
+ bool llama_kv_cache_hybrid::can_seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
2574
+ for (const auto & cache : m_children) {
2575
+ if (!cache->can_seq_rm (seq_id, p0, p1)) {
2576
+ return false ;
2577
+ }
2578
+ }
2579
+ return true ;
2580
+ }
2581
+
2540
2582
llama_sbatch llama_kv_cache_hybrid::sbatch_init (const llama_batch & batch, bool logits_all) {
2541
2583
// If any of the caches are recurrent, require equal split
2542
2584
return llama_sbatch (batch, m_hparams.n_embd , !m_has_recurrent, logits_all);
@@ -2574,7 +2616,7 @@ int32_t llama_kv_cache_hybrid::get_n_tokens() const {
2574
2616
2575
2617
int32_t llama_kv_cache_hybrid::get_used_cells () const {
2576
2618
// TODO: Is this correct?
2577
- // Return the largetst number of used cells
2619
+ // Return the largest number of used cells
2578
2620
int32_t used_cells = -1 ;
2579
2621
for (const auto & cache : m_children) {
2580
2622
used_cells = std::max (used_cells, cache->get_used_cells ());
0 commit comments