@@ -463,6 +463,14 @@ void llama_kv_cache_unified::set_full() {
463
463
head = 0 ;
464
464
}
465
465
466
+ bool llama_kv_cache_unified::can_seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
467
+ GGML_UNUSED (seq_id);
468
+ GGML_UNUSED (p0);
469
+ GGML_UNUSED (p1);
470
+ // Unified attention cache can always do a sequence removal
471
+ return true ;
472
+ }
473
+
466
474
llama_sbatch llama_kv_cache_unified::sbatch_init (const llama_batch & batch, bool logits_all) {
467
475
return llama_sbatch (batch, hparams.n_embd , true , logits_all);
468
476
}
@@ -1773,6 +1781,15 @@ void llama_kv_cache_unified_iswa::set_full() {
1773
1781
kv_swa ->set_full ();
1774
1782
}
1775
1783
1784
+ bool llama_kv_cache_unified_iswa::can_seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
1785
+ GGML_UNUSED (seq_id);
1786
+ GGML_UNUSED (p0);
1787
+ GGML_UNUSED (p1);
1788
+ // Unified attention caches can always do a sequence removal, so since both
1789
+ // children can, the parent can as well.
1790
+ return true ;
1791
+ }
1792
+
1776
1793
llama_sbatch llama_kv_cache_unified_iswa::sbatch_init (const llama_batch & batch, bool logits_all) {
1777
1794
pending.clear ();
1778
1795
@@ -1968,39 +1985,33 @@ void llama_kv_cache_recurrent::clear() {
1968
1985
}
1969
1986
1970
1987
bool llama_kv_cache_recurrent::seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
1971
- uint32_t new_head = size;
1988
+ if (!can_seq_rm (seq_id, p0, p1)) {
1989
+ // could be fatal
1990
+ return false ;
1991
+ }
1972
1992
1993
+ uint32_t new_head = size;
1973
1994
if (p0 < 0 ) {
1974
1995
p0 = 0 ;
1975
1996
}
1976
-
1977
1997
if (p1 < 0 ) {
1978
1998
p1 = std::numeric_limits<llama_pos>::max ();
1979
1999
}
1980
2000
1981
- // models like Mamba or RWKV can't have a state partially erased
1982
- if (seq_id >= (int64_t ) size) {
1983
- // could be fatal
1984
- return false ;
1985
- }
1986
2001
if (0 <= seq_id) {
1987
2002
int32_t & tail_id = cells[seq_id].tail ;
1988
2003
if (tail_id >= 0 ) {
1989
2004
const kv_cell & cell = cells[tail_id];
1990
- // partial intersection is invalid
1991
- if ((0 < p0 && p0 <= cell.pos ) || (0 < p1 && p1 <= cell.pos )) {
1992
- return false ;
1993
- }
2005
+ // already validated in can_seq_rm
2006
+ GGML_ASSERT (!((0 < p0 && p0 <= cell.pos ) || (0 < p1 && p1 <= cell.pos )));
1994
2007
// invalidate tails which will be cleared
1995
2008
if (p0 <= cell.pos && cell.pos < p1) {
1996
2009
tail_id = -1 ;
1997
2010
}
1998
2011
}
1999
2012
} else {
2000
- // seq_id is negative, then the range should include everything or nothing
2001
- if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max ())) {
2002
- return false ;
2003
- }
2013
+ // already validated in can_seq_rm
2014
+ GGML_ASSERT (!(p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max ())));
2004
2015
}
2005
2016
2006
2017
for (uint32_t i = 0 ; i < size; ++i) {
@@ -2217,6 +2228,34 @@ void llama_kv_cache_recurrent::set_full() {
2217
2228
n = size;
2218
2229
head = 0 ;
2219
2230
}
2231
+ bool llama_kv_cache_recurrent::can_seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
2232
+ if (p0 < 0 ) {
2233
+ p0 = 0 ;
2234
+ }
2235
+
2236
+ if (p1 < 0 ) {
2237
+ p1 = std::numeric_limits<llama_pos>::max ();
2238
+ }
2239
+ // models like Mamba or RWKV can't have a state partially erased
2240
+ if (seq_id >= (int64_t ) size) {
2241
+ // could be fatal
2242
+ return false ;
2243
+ }
2244
+ if (0 <= seq_id) {
2245
+ const int32_t & tail_id = cells[seq_id].tail ;
2246
+ if (tail_id >= 0 ) {
2247
+ const kv_cell & cell = cells[tail_id];
2248
+ // partial intersection is invalid
2249
+ if ((0 < p0 && p0 <= cell.pos ) || (0 < p1 && p1 <= cell.pos )) {
2250
+ return false ;
2251
+ }
2252
+ }
2253
+ // seq_id is negative, then the range should include everything or nothing
2254
+ } else if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max ())) {
2255
+ return false ;
2256
+ }
2257
+ return true ;
2258
+ }
2220
2259
2221
2260
llama_sbatch llama_kv_cache_recurrent::sbatch_init (
2222
2261
const llama_batch & batch,
0 commit comments