Skip to content

Commit a81b94f

Browse files
committed
mamba : support llama_kv_cache_seq_cp copy chains
* mamba : support shifting and dividing the kv cache pos
1 parent 21800e9 commit a81b94f

File tree

1 file changed

+34
-35
lines changed

1 file changed

+34
-35
lines changed

llama.cpp

Lines changed: 34 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2142,18 +2142,17 @@ static bool llama_kv_cache_find_slot(
21422142
}
21432143
// Assuming the tokens are in-order
21442144
if (batch.pos[i] != cache.cells[seq_id].pos + 1) {
2145-
// What should happen when the pos backtracks?
2145+
// What should happen when the pos backtracks or skips a value?
21462146
// Clearing the state mid-batch would require special-casing which isn't done.
2147-
LLAMA_LOG_ERROR("%s: non-consecutive token position %d after %d for sequence %d\n",
2147+
LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n",
21482148
__func__, batch.pos[i], cache.cells[seq_id].pos, seq_id);
2149-
return false;
21502149
}
21512150
cache.cells[seq_id].pos = batch.pos[i];
2152-
// NOTE: seq_ids are not inserted here, because they are handled when the graph is built.
2151+
// NOTE: seq_ids are not inserted here; they are handled when the input tensors are set
21532152
} else {
21542153
// too big seq_id
21552154
// TODO: would it be possible to resize the KV cache size instead?
2156-
LLAMA_LOG_ERROR("%s: seq_id=%d >= kv_size=%d\n", __func__, seq_id, cache.size);
2155+
LLAMA_LOG_ERROR("%s: seq_id=%d >= kv_size=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size);
21572156
return false;
21582157
}
21592158
}
@@ -2282,24 +2281,26 @@ static void llama_kv_cache_seq_cp(
22822281
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
22832282

22842283
if (cache.unlimited) {
2285-
if ((uint32_t)seq_id_dst < cache.size && (uint32_t)seq_id_src < cache.size) {
2286-
// intent to "copy from" (does not support copy chains)
2284+
if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) {
2285+
seq_id_src = cache.cells[seq_id_src].delta;
2286+
GGML_ASSERT((uint32_t) seq_id_src < cache.size);
2287+
// intent to "copy from"
2288+
// supports copy chains thanks to taking the source of the source
22872289
cache.cells[seq_id_dst].delta = seq_id_src;
2288-
// NOTE: a sequence can't have multiple sources, but can have multiple destinations.
2289-
// For compatibility with the other KV cache API functions,
2290-
// the seq_id(s) of a cell suggests an intent to "copy to" those id(s),
2291-
// so that when a sequence is copied, it can initially be found from the source cell.
2292-
cache.cells[seq_id_src].seq_id.insert(seq_id_dst);
2293-
// prevent the destination from getting cleared
2294-
cache.cells[seq_id_dst].seq_id.insert(seq_id_dst);
2290+
2291+
// prevent the destination from getting cleared if the source is not empty
2292+
if (cache.cells[seq_id_src].has_seq_id(seq_id_src)) {
2293+
cache.cells[seq_id_dst].seq_id.insert(seq_id_dst);
2294+
}
22952295
// repurposed as a "need copy" flag
22962296
// (shifting can't be done anyway for this kind of KV cache)
2297-
cache.has_shift = seq_id_src != seq_id_dst;
2298-
// NOTE: this is not correct for sequence swaps (which aren't a thing in the KV cache API yet)
2297+
cache.has_shift = true;
2298+
22992299
cache.cells[seq_id_dst].pos = cache.cells[seq_id_src].pos;
23002300
}
23012301
return;
23022302
}
2303+
// otherwise, this is the KV cache of a Transformer-like model
23032304

23042305
cache.head = 0;
23052306

@@ -2341,7 +2342,14 @@ static void llama_kv_cache_seq_add(
23412342
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
23422343

23432344
if (cache.unlimited) {
2344-
GGML_ASSERT(false); // not supported
2345+
// for Mamba-like models, only the pos needs to be shifted
2346+
if (0 <= seq_id && seq_id < (int64_t) cache.size) {
2347+
llama_kv_cell & cell = cache.cells[seq_id];
2348+
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
2349+
cell.pos += delta;
2350+
}
2351+
}
2352+
return;
23452353
}
23462354

23472355
for (uint32_t i = 0; i < cache.size; ++i) {
@@ -2378,7 +2386,14 @@ static void llama_kv_cache_seq_div(
23782386
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
23792387

23802388
if (cache.unlimited) {
2381-
GGML_ASSERT(false); // not supported
2389+
// for Mamba-like models, only the pos needs to be changed
2390+
if (0 <= seq_id && seq_id < (int64_t) cache.size) {
2391+
llama_kv_cell & cell = cache.cells[seq_id];
2392+
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
2393+
cell.pos /= d;
2394+
}
2395+
}
2396+
return;
23822397
}
23832398

23842399
for (uint32_t i = 0; i < cache.size; ++i) {
@@ -8198,7 +8213,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
81988213
}
81998214

82008215
if (kv_self.unlimited) {
8201-
const int64_t kv_size = kv_self.size;
82028216
const int64_t n_kv = kv_self.n;
82038217

82048218
{
@@ -8214,7 +8228,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
82148228
data[i] = (float) has_self_seq;
82158229

82168230
// ensure current sequences will be kept
8217-
if (!has_self_seq) {
8231+
if (!has_self_seq && kv_cell.pos >= 0) {
82188232
kv_cell.seq_id.insert(seq_id);
82198233
}
82208234
}
@@ -8243,21 +8257,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
82438257
}
82448258
}
82458259
}
8246-
// remove extraneous seq_ids when state copies are made
8247-
if (kv_self.has_shift) {
8248-
for (int i = 0; i < kv_size; ++i) {
8249-
llama_kv_cell & kv_cell = lctx.kv_self.cells[i];
8250-
uint32_t n_seqs = kv_cell.seq_id.size();
8251-
bool has_self_seq = kv_cell.has_seq_id(i);
8252-
8253-
if (has_self_seq && n_seqs > 1) {
8254-
kv_cell.seq_id.clear();
8255-
kv_cell.seq_id.insert(i);
8256-
} else if (!has_self_seq && n_seqs > 0) {
8257-
kv_cell.seq_id.clear();
8258-
}
8259-
}
8260-
}
82618260
}
82628261
}
82638262

0 commit comments

Comments
 (0)