Skip to content

Commit 12de5c7

Browse files
committed
mamba : make the server and parallel examples work with whole sequences
A seq_id is dedicated to the system prompt in both cases. * llama : make llama_kv_cache_seq_rm return whether it succeeded or not
1 parent a81b94f commit 12de5c7

File tree

4 files changed

+70
-33
lines changed

4 files changed

+70
-33
lines changed

examples/parallel/parallel.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ int main(int argc, char ** argv) {
107107
// number of simultaneous "clients" to simulate
108108
const int32_t n_clients = params.n_parallel;
109109

110+
// dedicate one sequence to the system prompt
111+
params.n_parallel += 1;
112+
110113
// requests to simulate
111114
const int32_t n_seq = params.n_sequences;
112115

@@ -196,8 +199,8 @@ int main(int argc, char ** argv) {
196199
}
197200

198201
// assign the system KV cache to all parallel sequences
199-
for (int32_t i = 1; i < n_clients; ++i) {
200-
llama_kv_cache_seq_cp(ctx, 0, i, 0, n_tokens_system);
202+
for (int32_t i = 1; i <= n_clients; ++i) {
203+
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
201204
}
202205

203206
LOG_TEE("\n");
@@ -221,15 +224,17 @@ int main(int argc, char ** argv) {
221224

222225
client.i_batch = batch.n_tokens;
223226

224-
llama_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id }, true);
227+
llama_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id + 1 }, true);
225228

226229
client.n_decoded += 1;
227230
}
228231

229232
if (batch.n_tokens == 0) {
230233
// all sequences have ended - clear the entire KV cache
231-
for (int i = 0; i < n_clients; ++i) {
232-
llama_kv_cache_seq_rm(ctx, i, n_tokens_system, -1);
234+
for (int i = 1; i <= n_clients; ++i) {
235+
llama_kv_cache_seq_rm(ctx, i, -1, -1);
236+
// but keep the system prompt
237+
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
233238
}
234239

235240
LOG_TEE("%s: clearing the KV cache\n", __func__);
@@ -255,7 +260,7 @@ int main(int argc, char ** argv) {
255260
tokens_prompt = ::llama_tokenize(ctx, client.prompt, false);
256261

257262
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
258-
llama_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id }, false);
263+
llama_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id + 1 }, false);
259264
}
260265

261266
// extract the logits only for the last token
@@ -366,7 +371,8 @@ int main(int argc, char ** argv) {
366371
}
367372

368373
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
369-
llama_kv_cache_seq_rm(ctx, client.id, n_tokens_system, -1);
374+
llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1);
375+
llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1);
370376

371377
const auto t_main_end = ggml_time_us();
372378

examples/server/server.cpp

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -438,12 +438,16 @@ struct llama_server_context
438438
return false;
439439
}
440440

441-
if (params.n_ctx < 2048) { // request larger context for the image embedding
441+
if (params.n_ctx != 0 && params.n_ctx < 2048) { // request larger context for the image embedding
442442
params.n_ctx = 2048;
443443
}
444444
}
445445

446+
// dedicate one sequence to the system prompt
447+
params.n_parallel += 1;
448+
446449
std::tie(model, ctx) = llama_init_from_gpt_params(params);
450+
params.n_parallel -= 1; // but be sneaky about it
447451
if (model == nullptr)
448452
{
449453
LOG_ERROR("unable to load model", {{"model", params.model}});
@@ -923,9 +927,9 @@ struct llama_server_context
923927
}
924928

925929
// assign the system KV cache to all parallel sequences
926-
for (int32_t i = 1; i < params.n_parallel; ++i)
930+
for (int32_t i = 1; i <= params.n_parallel; ++i)
927931
{
928-
llama_kv_cache_seq_cp(ctx, 0, i, 0, system_tokens.size());
932+
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
929933
}
930934
}
931935

@@ -1400,7 +1404,7 @@ struct llama_server_context
14001404
std::vector<llama_token> append_tokens = tokenize(json_prompt, false); // has next image
14011405
for (int i = 0; i < (int) append_tokens.size(); ++i)
14021406
{
1403-
llama_batch_add(batch, append_tokens[i], system_tokens.size() + slot.n_past, { slot.id }, true);
1407+
llama_batch_add(batch, append_tokens[i], system_tokens.size() + slot.n_past, { slot.id + 1 }, true);
14041408
slot.n_past += 1;
14051409
}
14061410
}
@@ -1636,8 +1640,8 @@ struct llama_server_context
16361640
{"n_system_tokens", system_tokens.size()},
16371641
{"n_cache_tokens", slot.cache_tokens.size()}
16381642
});
1639-
llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
1640-
llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
1643+
llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard);
1644+
llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
16411645

16421646
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++)
16431647
{
@@ -1689,7 +1693,7 @@ struct llama_server_context
16891693

16901694
// TODO: we always have to take into account the "system_tokens"
16911695
// this is not great and needs to be improved somehow
1692-
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true);
1696+
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id + 1 }, true);
16931697
slot.n_past += 1;
16941698
}
16951699

@@ -1852,13 +1856,28 @@ struct llama_server_context
18521856
}
18531857
}
18541858

1859+
// keep only the common part
18551860
int p0 = (int) system_tokens.size() + slot.n_past;
18561861
LOG_INFO("kv cache rm [p0, end)", {
18571862
{ "slot_id", slot.id },
18581863
{ "task_id", slot.task_id },
18591864
{ "p0", p0 }
18601865
});
1861-
llama_kv_cache_seq_rm(ctx, slot.id, p0, -1);
1866+
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
1867+
// could not partially delete (likely using a non-Transformer model)
1868+
// TODO: logging
1869+
llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1);
1870+
llama_kv_cache_seq_cp(ctx, 0, slot.id + 1, -1, -1);
1871+
1872+
// there is no common part left (except for the system prompt)
1873+
// TODO: maybe find a way to refactor this to reuse the !cache_prompt case above
1874+
slot.n_past = 0;
1875+
slot.n_past_se = 0;
1876+
slot.ga_i = 0;
1877+
slot.num_prompt_tokens_processed = slot.num_prompt_tokens;
1878+
// TODO: is the system prompt ever in the sampling context?
1879+
llama_sampling_reset(slot.ctx_sampling);
1880+
}
18621881

18631882
LOG_VERBOSE("prompt ingested", {
18641883
{"n_past", slot.n_past},
@@ -1887,7 +1906,7 @@ struct llama_server_context
18871906
ga_i += ga_w/ga_n;
18881907
}
18891908
}
1890-
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false);
1909+
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id + 1 }, false);
18911910
slot_npast++;
18921911
}
18931912

@@ -1941,9 +1960,9 @@ struct llama_server_context
19411960
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
19421961
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
19431962

1944-
llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd);
1945-
llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n);
1946-
llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_se + ib * bd, dd);
1963+
llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd);
1964+
llama_kv_cache_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n);
1965+
llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_se + ib * bd, dd);
19471966

19481967
slot.n_past_se -= bd;
19491968

llama.cpp

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2231,7 +2231,7 @@ static void llama_kv_cache_clear(struct llama_kv_cache & cache) {
22312231
cache.used = 0;
22322232
}
22332233

2234-
static void llama_kv_cache_seq_rm(
2234+
static bool llama_kv_cache_seq_rm(
22352235
struct llama_kv_cache & cache,
22362236
llama_seq_id seq_id,
22372237
llama_pos p0,
@@ -2241,11 +2241,23 @@ static void llama_kv_cache_seq_rm(
22412241
if (p0 < 0) p0 = 0;
22422242
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
22432243

2244+
// models like Mamba can't have a state partially erased
22442245
if (cache.unlimited) {
2245-
// can only remove whole sequences for models like Mamba
2246-
GGML_ASSERT(p0 == 0);
2247-
GGML_ASSERT((uint32_t)seq_id < cache.size);
2248-
GGML_ASSERT(cache.cells[seq_id].pos < p1);
2246+
if (seq_id >= (int64_t) cache.size) {
2247+
// could be fatal
2248+
return false;
2249+
}
2250+
if (0 <= seq_id) {
2251+
// partial intersection is invalid
2252+
if ((0 < p0 && p0 <= cache.cells[seq_id].pos) || (0 < p1 && p1 <= cache.cells[seq_id].pos)) {
2253+
return false;
2254+
}
2255+
} else {
2256+
// seq_id is negative, then the range should include everything or nothing
2257+
if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
2258+
return false;
2259+
}
2260+
}
22492261
}
22502262

22512263
for (uint32_t i = 0; i < cache.size; ++i) {
@@ -2269,6 +2281,8 @@ static void llama_kv_cache_seq_rm(
22692281

22702282
// If we freed up a slot, set head to it so searching can start there.
22712283
if (new_head != cache.size && new_head < cache.head) cache.head = new_head;
2284+
2285+
return true;
22722286
}
22732287

22742288
static void llama_kv_cache_seq_cp(
@@ -12283,13 +12297,11 @@ struct llama_context * llama_new_context_with_model(
1228312297

1228412298
// Mamba only needs a constant number of KV cache cells per sequence
1228512299
if (model->arch == LLM_ARCH_MAMBA) {
12286-
// Mamba needs as many KV cells as there are sequences kept at any time
12287-
// The extra cell allows dedicating a sequence id to the system prompt
12288-
// TODO: find a better way to get the max number of parallel sequences
12289-
kv_size = params.n_parallel + 1;
12300+
// Mamba needs at least as many KV cells as there are sequences kept at any time
12301+
kv_size = std::max((uint32_t) 1, params.n_parallel);
1229012302
// it's probably best to keep as much precision as possible for the states
12291-
type_k = GGML_TYPE_F32; // required by ggml_set for Mamba's conv_state
12292-
type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_state
12303+
type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
12304+
type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
1229312305
}
1229412306

1229512307
GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
@@ -12799,8 +12811,8 @@ void llama_kv_cache_clear(struct llama_context * ctx) {
1279912811
llama_kv_cache_clear(ctx->kv_self);
1280012812
}
1280112813

12802-
void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
12803-
llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1);
12814+
bool llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
12815+
return llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1);
1280412816
}
1280512817

1280612818
void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {

llama.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ extern "C" {
502502
// seq_id < 0 : match any sequence
503503
// p0 < 0 : [0, p1]
504504
// p1 < 0 : [p0, inf)
505-
LLAMA_API void llama_kv_cache_seq_rm(
505+
LLAMA_API bool llama_kv_cache_seq_rm(
506506
struct llama_context * ctx,
507507
llama_seq_id seq_id,
508508
llama_pos p0,

0 commit comments

Comments
 (0)