Skip to content

Commit 60f70d8

Browse files
committed
mamba : support llama_kv_cache_seq_cp copy chains
* mamba : support shifting and dividing the kv cache pos
1 parent 05e2212 commit 60f70d8

File tree

1 file changed

+34
-35
lines changed

1 file changed

+34
-35
lines changed

llama.cpp

Lines changed: 34 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2181,18 +2181,17 @@ static bool llama_kv_cache_find_slot(
21812181
}
21822182
// Assuming the tokens are in-order
21832183
if (batch.pos[i] != cache.cells[seq_id].pos + 1) {
2184-
// What should happen when the pos backtracks?
2184+
// What should happen when the pos backtracks or skips a value?
21852185
// Clearing the state mid-batch would require special-casing which isn't done.
2186-
LLAMA_LOG_ERROR("%s: non-consecutive token position %d after %d for sequence %d\n",
2186+
LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n",
21872187
__func__, batch.pos[i], cache.cells[seq_id].pos, seq_id);
2188-
return false;
21892188
}
21902189
cache.cells[seq_id].pos = batch.pos[i];
2191-
// NOTE: seq_ids are not inserted here, because they are handled when the graph is built.
2190+
// NOTE: seq_ids are not inserted here; they are handled when the input tensors are set
21922191
} else {
21932192
// too big seq_id
21942193
// TODO: would it be possible to resize the KV cache size instead?
2195-
LLAMA_LOG_ERROR("%s: seq_id=%d >= kv_size=%d\n", __func__, seq_id, cache.size);
2194+
LLAMA_LOG_ERROR("%s: seq_id=%d >= kv_size=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size);
21962195
return false;
21972196
}
21982197
}
@@ -2321,24 +2320,26 @@ static void llama_kv_cache_seq_cp(
23212320
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
23222321

23232322
if (cache.unlimited) {
2324-
if ((uint32_t)seq_id_dst < cache.size && (uint32_t)seq_id_src < cache.size) {
2325-
// intent to "copy from" (does not support copy chains)
2323+
if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) {
2324+
seq_id_src = cache.cells[seq_id_src].delta;
2325+
GGML_ASSERT((uint32_t) seq_id_src < cache.size);
2326+
// intent to "copy from"
2327+
// supports copy chains thanks to taking the source of the source
23262328
cache.cells[seq_id_dst].delta = seq_id_src;
2327-
// NOTE: a sequence can't have multiple sources, but can have multiple destinations.
2328-
// For compatibility with the other KV cache API functions,
2329-
// the seq_id(s) of a cell suggests an intent to "copy to" those id(s),
2330-
// so that when a sequence is copied, it can initially be found from the source cell.
2331-
cache.cells[seq_id_src].seq_id.insert(seq_id_dst);
2332-
// prevent the destination from getting cleared
2333-
cache.cells[seq_id_dst].seq_id.insert(seq_id_dst);
2329+
2330+
// prevent the destination from getting cleared if the source is not empty
2331+
if (cache.cells[seq_id_src].has_seq_id(seq_id_src)) {
2332+
cache.cells[seq_id_dst].seq_id.insert(seq_id_dst);
2333+
}
23342334
// repurposed as a "need copy" flag
23352335
// (shifting can't be done anyway for this kind of KV cache)
2336-
cache.has_shift = seq_id_src != seq_id_dst;
2337-
// NOTE: this is not correct for sequence swaps (which aren't a thing in the KV cache API yet)
2336+
cache.has_shift = true;
2337+
23382338
cache.cells[seq_id_dst].pos = cache.cells[seq_id_src].pos;
23392339
}
23402340
return;
23412341
}
2342+
// otherwise, this is the KV cache of a Transformer-like model
23422343

23432344
cache.head = 0;
23442345

@@ -2380,7 +2381,14 @@ static void llama_kv_cache_seq_add(
23802381
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
23812382

23822383
if (cache.unlimited) {
2383-
GGML_ASSERT(false); // not supported
2384+
// for Mamba-like models, only the pos needs to be shifted
2385+
if (0 <= seq_id && seq_id < (int64_t) cache.size) {
2386+
llama_kv_cell & cell = cache.cells[seq_id];
2387+
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
2388+
cell.pos += delta;
2389+
}
2390+
}
2391+
return;
23842392
}
23852393

23862394
for (uint32_t i = 0; i < cache.size; ++i) {
@@ -2417,7 +2425,14 @@ static void llama_kv_cache_seq_div(
24172425
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
24182426

24192427
if (cache.unlimited) {
2420-
GGML_ASSERT(false); // not supported
2428+
// for Mamba-like models, only the pos needs to be changed
2429+
if (0 <= seq_id && seq_id < (int64_t) cache.size) {
2430+
llama_kv_cell & cell = cache.cells[seq_id];
2431+
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
2432+
cell.pos /= d;
2433+
}
2434+
}
2435+
return;
24212436
}
24222437

24232438
for (uint32_t i = 0; i < cache.size; ++i) {
@@ -8426,7 +8441,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
84268441
}
84278442

84288443
if (kv_self.unlimited) {
8429-
const int64_t kv_size = kv_self.size;
84308444
const int64_t n_kv = kv_self.n;
84318445

84328446
{
@@ -8442,7 +8456,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
84428456
data[i] = (float) has_self_seq;
84438457

84448458
// ensure current sequences will be kept
8445-
if (!has_self_seq) {
8459+
if (!has_self_seq && kv_cell.pos >= 0) {
84468460
kv_cell.seq_id.insert(seq_id);
84478461
}
84488462
}
@@ -8471,21 +8485,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
84718485
}
84728486
}
84738487
}
8474-
// remove extraneous seq_ids when state copies are made
8475-
{
8476-
for (int i = 0; i < kv_size; ++i) {
8477-
llama_kv_cell & kv_cell = lctx.kv_self.cells[i];
8478-
uint32_t n_seqs = kv_cell.seq_id.size();
8479-
bool has_self_seq = kv_cell.has_seq_id(i);
8480-
8481-
if (has_self_seq && n_seqs > 1) {
8482-
kv_cell.seq_id.clear();
8483-
kv_cell.seq_id.insert(i);
8484-
} else if (!has_self_seq && n_seqs > 0) {
8485-
kv_cell.seq_id.clear();
8486-
}
8487-
}
8488-
}
84898488
}
84908489
}
84918490

0 commit comments

Comments
 (0)