@@ -2181,18 +2181,17 @@ static bool llama_kv_cache_find_slot(
2181
2181
}
2182
2182
// Assuming the tokens are in-order
2183
2183
if (batch.pos[i] != cache.cells[seq_id].pos + 1) {
2184
- // What should happen when the pos backtracks?
2184
+ // What should happen when the pos backtracks or skips a value ?
2185
2185
// Clearing the state mid-batch would require special-casing which isn't done.
2186
- LLAMA_LOG_ERROR ("%s: non-consecutive token position %d after %d for sequence %d\n",
2186
+ LLAMA_LOG_WARN ("%s: non-consecutive token position %d after %d for sequence %d\n",
2187
2187
__func__, batch.pos[i], cache.cells[seq_id].pos, seq_id);
2188
- return false;
2189
2188
}
2190
2189
cache.cells[seq_id].pos = batch.pos[i];
2191
- // NOTE: seq_ids are not inserted here, because they are handled when the graph is built.
2190
+ // NOTE: seq_ids are not inserted here; they are handled when the input tensors are set
2192
2191
} else {
2193
2192
// too big seq_id
2194
2193
// TODO: would it be possible to resize the KV cache size instead?
2195
- LLAMA_LOG_ERROR("%s: seq_id=%d >= kv_size=%d\n", __func__, seq_id, cache.size);
2194
+ LLAMA_LOG_ERROR("%s: seq_id=%d >= kv_size=%d Try using a bigger --parallel value \n", __func__, seq_id, cache.size);
2196
2195
return false;
2197
2196
}
2198
2197
}
@@ -2321,24 +2320,26 @@ static void llama_kv_cache_seq_cp(
2321
2320
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
2322
2321
2323
2322
if (cache.unlimited) {
2324
- if ((uint32_t)seq_id_dst < cache.size && (uint32_t)seq_id_src < cache.size) {
2325
- // intent to "copy from" (does not support copy chains)
2323
+ if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) {
2324
+ seq_id_src = cache.cells[seq_id_src].delta;
2325
+ GGML_ASSERT((uint32_t) seq_id_src < cache.size);
2326
+ // intent to "copy from"
2327
+ // supports copy chains thanks to taking the source of the source
2326
2328
cache.cells[seq_id_dst].delta = seq_id_src;
2327
- // NOTE: a sequence can't have multiple sources, but can have multiple destinations.
2328
- // For compatibility with the other KV cache API functions,
2329
- // the seq_id(s) of a cell suggests an intent to "copy to" those id(s),
2330
- // so that when a sequence is copied, it can initially be found from the source cell.
2331
- cache.cells[seq_id_src].seq_id.insert(seq_id_dst);
2332
- // prevent the destination from getting cleared
2333
- cache.cells[seq_id_dst].seq_id.insert(seq_id_dst);
2329
+
2330
+ // prevent the destination from getting cleared if the source is not empty
2331
+ if (cache.cells[seq_id_src].has_seq_id(seq_id_src)) {
2332
+ cache.cells[seq_id_dst].seq_id.insert(seq_id_dst);
2333
+ }
2334
2334
// repurposed as a "need copy" flag
2335
2335
// (shifting can't be done anyway for this kind of KV cache)
2336
- cache.has_shift = seq_id_src != seq_id_dst ;
2337
- // NOTE: this is not correct for sequence swaps (which aren't a thing in the KV cache API yet)
2336
+ cache.has_shift = true ;
2337
+
2338
2338
cache.cells[seq_id_dst].pos = cache.cells[seq_id_src].pos;
2339
2339
}
2340
2340
return;
2341
2341
}
2342
+ // otherwise, this is the KV cache of a Transformer-like model
2342
2343
2343
2344
cache.head = 0;
2344
2345
@@ -2380,7 +2381,14 @@ static void llama_kv_cache_seq_add(
2380
2381
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
2381
2382
2382
2383
if (cache.unlimited) {
2383
- GGML_ASSERT(false); // not supported
2384
+ // for Mamba-like models, only the pos needs to be shifted
2385
+ if (0 <= seq_id && seq_id < (int64_t) cache.size) {
2386
+ llama_kv_cell & cell = cache.cells[seq_id];
2387
+ if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
2388
+ cell.pos += delta;
2389
+ }
2390
+ }
2391
+ return;
2384
2392
}
2385
2393
2386
2394
for (uint32_t i = 0; i < cache.size; ++i) {
@@ -2417,7 +2425,14 @@ static void llama_kv_cache_seq_div(
2417
2425
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
2418
2426
2419
2427
if (cache.unlimited) {
2420
- GGML_ASSERT(false); // not supported
2428
+ // for Mamba-like models, only the pos needs to be changed
2429
+ if (0 <= seq_id && seq_id < (int64_t) cache.size) {
2430
+ llama_kv_cell & cell = cache.cells[seq_id];
2431
+ if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
2432
+ cell.pos /= d;
2433
+ }
2434
+ }
2435
+ return;
2421
2436
}
2422
2437
2423
2438
for (uint32_t i = 0; i < cache.size; ++i) {
@@ -8426,7 +8441,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
8426
8441
}
8427
8442
8428
8443
if (kv_self.unlimited) {
8429
- const int64_t kv_size = kv_self.size;
8430
8444
const int64_t n_kv = kv_self.n;
8431
8445
8432
8446
{
@@ -8442,7 +8456,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
8442
8456
data[i] = (float) has_self_seq;
8443
8457
8444
8458
// ensure current sequences will be kept
8445
- if (!has_self_seq) {
8459
+ if (!has_self_seq && kv_cell.pos >= 0 ) {
8446
8460
kv_cell.seq_id.insert(seq_id);
8447
8461
}
8448
8462
}
@@ -8471,21 +8485,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
8471
8485
}
8472
8486
}
8473
8487
}
8474
- // remove extraneous seq_ids when state copies are made
8475
- {
8476
- for (int i = 0; i < kv_size; ++i) {
8477
- llama_kv_cell & kv_cell = lctx.kv_self.cells[i];
8478
- uint32_t n_seqs = kv_cell.seq_id.size();
8479
- bool has_self_seq = kv_cell.has_seq_id(i);
8480
-
8481
- if (has_self_seq && n_seqs > 1) {
8482
- kv_cell.seq_id.clear();
8483
- kv_cell.seq_id.insert(i);
8484
- } else if (!has_self_seq && n_seqs > 0) {
8485
- kv_cell.seq_id.clear();
8486
- }
8487
- }
8488
- }
8489
8488
}
8490
8489
}
8491
8490
0 commit comments