@@ -2142,18 +2142,17 @@ static bool llama_kv_cache_find_slot(
2142
2142
}
2143
2143
// Assuming the tokens are in-order
2144
2144
if (batch.pos[i] != cache.cells[seq_id].pos + 1) {
2145
- // What should happen when the pos backtracks?
2145
+ // What should happen when the pos backtracks or skips a value ?
2146
2146
// Clearing the state mid-batch would require special-casing which isn't done.
2147
- LLAMA_LOG_ERROR ("%s: non-consecutive token position %d after %d for sequence %d\n",
2147
+ LLAMA_LOG_WARN ("%s: non-consecutive token position %d after %d for sequence %d\n",
2148
2148
__func__, batch.pos[i], cache.cells[seq_id].pos, seq_id);
2149
- return false;
2150
2149
}
2151
2150
cache.cells[seq_id].pos = batch.pos[i];
2152
- // NOTE: seq_ids are not inserted here, because they are handled when the graph is built.
2151
+ // NOTE: seq_ids are not inserted here; they are handled when the input tensors are set
2153
2152
} else {
2154
2153
// too big seq_id
2155
2154
// TODO: would it be possible to resize the KV cache size instead?
2156
- LLAMA_LOG_ERROR("%s: seq_id=%d >= kv_size=%d\n", __func__, seq_id, cache.size);
2155
+ LLAMA_LOG_ERROR("%s: seq_id=%d >= kv_size=%d Try using a bigger --parallel value \n", __func__, seq_id, cache.size);
2157
2156
return false;
2158
2157
}
2159
2158
}
@@ -2282,24 +2281,26 @@ static void llama_kv_cache_seq_cp(
2282
2281
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
2283
2282
2284
2283
if (cache.unlimited) {
2285
- if ((uint32_t)seq_id_dst < cache.size && (uint32_t)seq_id_src < cache.size) {
2286
- // intent to "copy from" (does not support copy chains)
2284
+ if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) {
2285
+ seq_id_src = cache.cells[seq_id_src].delta;
2286
+ GGML_ASSERT((uint32_t) seq_id_src < cache.size);
2287
+ // intent to "copy from"
2288
+ // supports copy chains thanks to taking the source of the source
2287
2289
cache.cells[seq_id_dst].delta = seq_id_src;
2288
- // NOTE: a sequence can't have multiple sources, but can have multiple destinations.
2289
- // For compatibility with the other KV cache API functions,
2290
- // the seq_id(s) of a cell suggests an intent to "copy to" those id(s),
2291
- // so that when a sequence is copied, it can initially be found from the source cell.
2292
- cache.cells[seq_id_src].seq_id.insert(seq_id_dst);
2293
- // prevent the destination from getting cleared
2294
- cache.cells[seq_id_dst].seq_id.insert(seq_id_dst);
2290
+
2291
+ // prevent the destination from getting cleared if the source is not empty
2292
+ if (cache.cells[seq_id_src].has_seq_id(seq_id_src)) {
2293
+ cache.cells[seq_id_dst].seq_id.insert(seq_id_dst);
2294
+ }
2295
2295
// repurposed as a "need copy" flag
2296
2296
// (shifting can't be done anyway for this kind of KV cache)
2297
- cache.has_shift = seq_id_src != seq_id_dst ;
2298
- // NOTE: this is not correct for sequence swaps (which aren't a thing in the KV cache API yet)
2297
+ cache.has_shift = true ;
2298
+
2299
2299
cache.cells[seq_id_dst].pos = cache.cells[seq_id_src].pos;
2300
2300
}
2301
2301
return;
2302
2302
}
2303
+ // otherwise, this is the KV cache of a Transformer-like model
2303
2304
2304
2305
cache.head = 0;
2305
2306
@@ -2341,7 +2342,14 @@ static void llama_kv_cache_seq_add(
2341
2342
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
2342
2343
2343
2344
if (cache.unlimited) {
2344
- GGML_ASSERT(false); // not supported
2345
+ // for Mamba-like models, only the pos needs to be shifted
2346
+ if (0 <= seq_id && seq_id < (int64_t) cache.size) {
2347
+ llama_kv_cell & cell = cache.cells[seq_id];
2348
+ if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
2349
+ cell.pos += delta;
2350
+ }
2351
+ }
2352
+ return;
2345
2353
}
2346
2354
2347
2355
for (uint32_t i = 0; i < cache.size; ++i) {
@@ -2378,7 +2386,14 @@ static void llama_kv_cache_seq_div(
2378
2386
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
2379
2387
2380
2388
if (cache.unlimited) {
2381
- GGML_ASSERT(false); // not supported
2389
+ // for Mamba-like models, only the pos needs to be changed
2390
+ if (0 <= seq_id && seq_id < (int64_t) cache.size) {
2391
+ llama_kv_cell & cell = cache.cells[seq_id];
2392
+ if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
2393
+ cell.pos /= d;
2394
+ }
2395
+ }
2396
+ return;
2382
2397
}
2383
2398
2384
2399
for (uint32_t i = 0; i < cache.size; ++i) {
@@ -8198,7 +8213,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
8198
8213
}
8199
8214
8200
8215
if (kv_self.unlimited) {
8201
- const int64_t kv_size = kv_self.size;
8202
8216
const int64_t n_kv = kv_self.n;
8203
8217
8204
8218
{
@@ -8214,7 +8228,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
8214
8228
data[i] = (float) has_self_seq;
8215
8229
8216
8230
// ensure current sequences will be kept
8217
- if (!has_self_seq) {
8231
+ if (!has_self_seq && kv_cell.pos >= 0 ) {
8218
8232
kv_cell.seq_id.insert(seq_id);
8219
8233
}
8220
8234
}
@@ -8243,21 +8257,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
8243
8257
}
8244
8258
}
8245
8259
}
8246
- // remove extraneous seq_ids when state copies are made
8247
- if (kv_self.has_shift) {
8248
- for (int i = 0; i < kv_size; ++i) {
8249
- llama_kv_cell & kv_cell = lctx.kv_self.cells[i];
8250
- uint32_t n_seqs = kv_cell.seq_id.size();
8251
- bool has_self_seq = kv_cell.has_seq_id(i);
8252
-
8253
- if (has_self_seq && n_seqs > 1) {
8254
- kv_cell.seq_id.clear();
8255
- kv_cell.seq_id.insert(i);
8256
- } else if (!has_self_seq && n_seqs > 0) {
8257
- kv_cell.seq_id.clear();
8258
- }
8259
- }
8260
- }
8261
8260
}
8262
8261
}
8263
8262
0 commit comments