Skip to content

Commit 1c9ae3a

Browse files
committed
mamba : make the server and parallel examples work with whole sequences
A seq_id is dedicated to the system prompt in both cases. * llama : make llama_kv_cache_seq_rm return whether it succeeded or not
1 parent 60f70d8 commit 1c9ae3a

File tree

4 files changed

+70
-33
lines changed

4 files changed

+70
-33
lines changed

examples/parallel/parallel.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ int main(int argc, char ** argv) {
107107
// number of simultaneous "clients" to simulate
108108
const int32_t n_clients = params.n_parallel;
109109

110+
// dedicate one sequence to the system prompt
111+
params.n_parallel += 1;
112+
110113
// requests to simulate
111114
const int32_t n_seq = params.n_sequences;
112115

@@ -196,8 +199,8 @@ int main(int argc, char ** argv) {
196199
}
197200

198201
// assign the system KV cache to all parallel sequences
199-
for (int32_t i = 1; i < n_clients; ++i) {
200-
llama_kv_cache_seq_cp(ctx, 0, i, 0, n_tokens_system);
202+
for (int32_t i = 1; i <= n_clients; ++i) {
203+
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
201204
}
202205

203206
LOG_TEE("\n");
@@ -221,15 +224,17 @@ int main(int argc, char ** argv) {
221224

222225
client.i_batch = batch.n_tokens;
223226

224-
llama_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id }, true);
227+
llama_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id + 1 }, true);
225228

226229
client.n_decoded += 1;
227230
}
228231

229232
if (batch.n_tokens == 0) {
230233
// all sequences have ended - clear the entire KV cache
231-
for (int i = 0; i < n_clients; ++i) {
232-
llama_kv_cache_seq_rm(ctx, i, n_tokens_system, -1);
234+
for (int i = 1; i <= n_clients; ++i) {
235+
llama_kv_cache_seq_rm(ctx, i, -1, -1);
236+
// but keep the system prompt
237+
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
233238
}
234239

235240
LOG_TEE("%s: clearing the KV cache\n", __func__);
@@ -255,7 +260,7 @@ int main(int argc, char ** argv) {
255260
tokens_prompt = ::llama_tokenize(ctx, client.prompt, false);
256261

257262
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
258-
llama_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id }, false);
263+
llama_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id + 1 }, false);
259264
}
260265

261266
// extract the logits only for the last token
@@ -366,7 +371,8 @@ int main(int argc, char ** argv) {
366371
}
367372

368373
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
369-
llama_kv_cache_seq_rm(ctx, client.id, n_tokens_system, -1);
374+
llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1);
375+
llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1);
370376

371377
const auto t_main_end = ggml_time_us();
372378

examples/server/server.cpp

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -377,12 +377,16 @@ struct llama_server_context
377377
return false;
378378
}
379379

380-
if (params.n_ctx < 2048) { // request larger context for the image embedding
380+
if (params.n_ctx != 0 && params.n_ctx < 2048) { // request larger context for the image embedding
381381
params.n_ctx = 2048;
382382
}
383383
}
384384

385+
// dedicate one sequence to the system prompt
386+
params.n_parallel += 1;
387+
385388
std::tie(model, ctx) = llama_init_from_gpt_params(params);
389+
params.n_parallel -= 1; // but be sneaky about it
386390
if (model == nullptr)
387391
{
388392
LOG_ERROR("unable to load model", {{"model", params.model}});
@@ -862,9 +866,9 @@ struct llama_server_context
862866
}
863867

864868
// assign the system KV cache to all parallel sequences
865-
for (int32_t i = 1; i < params.n_parallel; ++i)
869+
for (int32_t i = 1; i <= params.n_parallel; ++i)
866870
{
867-
llama_kv_cache_seq_cp(ctx, 0, i, 0, system_tokens.size());
871+
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
868872
}
869873
}
870874

@@ -1351,7 +1355,7 @@ struct llama_server_context
13511355
std::vector<llama_token> append_tokens = tokenize(json_prompt, false); // has next image
13521356
for (int i = 0; i < (int) append_tokens.size(); ++i)
13531357
{
1354-
llama_batch_add(batch, append_tokens[i], system_tokens.size() + slot.n_past, { slot.id }, true);
1358+
llama_batch_add(batch, append_tokens[i], system_tokens.size() + slot.n_past, { slot.id + 1 }, true);
13551359
slot.n_past += 1;
13561360
}
13571361
}
@@ -1587,8 +1591,8 @@ struct llama_server_context
15871591
{"n_system_tokens", system_tokens.size()},
15881592
{"n_cache_tokens", slot.cache_tokens.size()}
15891593
});
1590-
llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
1591-
llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
1594+
llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard);
1595+
llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
15921596

15931597
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++)
15941598
{
@@ -1640,7 +1644,7 @@ struct llama_server_context
16401644

16411645
// TODO: we always have to take into account the "system_tokens"
16421646
// this is not great and needs to be improved somehow
1643-
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true);
1647+
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id + 1 }, true);
16441648
slot.n_past += 1;
16451649
}
16461650

@@ -1808,13 +1812,28 @@ struct llama_server_context
18081812
}
18091813
}
18101814

1815+
// keep only the common part
18111816
int p0 = (int) system_tokens.size() + slot.n_past;
18121817
LOG_INFO("kv cache rm [p0, end)", {
18131818
{ "slot_id", slot.id },
18141819
{ "task_id", slot.task_id },
18151820
{ "p0", p0 }
18161821
});
1817-
llama_kv_cache_seq_rm(ctx, slot.id, p0, -1);
1822+
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
1823+
// could not partially delete (likely using a non-Transformer model)
1824+
// TODO: logging
1825+
llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1);
1826+
llama_kv_cache_seq_cp(ctx, 0, slot.id + 1, -1, -1);
1827+
1828+
// there is no common part left (except for the system prompt)
1829+
// TODO: maybe find a way to refactor this to reuse the !cache_prompt case above
1830+
slot.n_past = 0;
1831+
slot.n_past_se = 0;
1832+
slot.ga_i = 0;
1833+
slot.n_prompt_tokens_processed = slot.n_prompt_tokens;
1834+
// TODO: is the system prompt ever in the sampling context?
1835+
llama_sampling_reset(slot.ctx_sampling);
1836+
}
18181837

18191838
LOG_VERBOSE("prompt ingested", {
18201839
{"n_past", slot.n_past},
@@ -1843,7 +1862,7 @@ struct llama_server_context
18431862
ga_i += ga_w/ga_n;
18441863
}
18451864
}
1846-
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false);
1865+
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id + 1 }, false);
18471866
slot_npast++;
18481867
}
18491868

@@ -1897,9 +1916,9 @@ struct llama_server_context
18971916
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
18981917
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
18991918

1900-
llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd);
1901-
llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n);
1902-
llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_se + ib * bd, dd);
1919+
llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd);
1920+
llama_kv_cache_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n);
1921+
llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_se + ib * bd, dd);
19031922

19041923
slot.n_past_se -= bd;
19051924

llama.cpp

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2270,7 +2270,7 @@ static void llama_kv_cache_clear(struct llama_kv_cache & cache) {
22702270
cache.used = 0;
22712271
}
22722272

2273-
static void llama_kv_cache_seq_rm(
2273+
static bool llama_kv_cache_seq_rm(
22742274
struct llama_kv_cache & cache,
22752275
llama_seq_id seq_id,
22762276
llama_pos p0,
@@ -2280,11 +2280,23 @@ static void llama_kv_cache_seq_rm(
22802280
if (p0 < 0) p0 = 0;
22812281
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
22822282

2283+
// models like Mamba can't have a state partially erased
22832284
if (cache.unlimited) {
2284-
// can only remove whole sequences for models like Mamba
2285-
GGML_ASSERT(p0 == 0);
2286-
GGML_ASSERT((uint32_t)seq_id < cache.size);
2287-
GGML_ASSERT(cache.cells[seq_id].pos < p1);
2285+
if (seq_id >= (int64_t) cache.size) {
2286+
// could be fatal
2287+
return false;
2288+
}
2289+
if (0 <= seq_id) {
2290+
// partial intersection is invalid
2291+
if ((0 < p0 && p0 <= cache.cells[seq_id].pos) || (0 < p1 && p1 <= cache.cells[seq_id].pos)) {
2292+
return false;
2293+
}
2294+
} else {
2295+
// seq_id is negative, then the range should include everything or nothing
2296+
if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
2297+
return false;
2298+
}
2299+
}
22882300
}
22892301

22902302
for (uint32_t i = 0; i < cache.size; ++i) {
@@ -2308,6 +2320,8 @@ static void llama_kv_cache_seq_rm(
23082320

23092321
// If we freed up a slot, set head to it so searching can start there.
23102322
if (new_head != cache.size && new_head < cache.head) cache.head = new_head;
2323+
2324+
return true;
23112325
}
23122326

23132327
static void llama_kv_cache_seq_cp(
@@ -12491,13 +12505,11 @@ struct llama_context * llama_new_context_with_model(
1249112505

1249212506
// Mamba only needs a constant number of KV cache cells per sequence
1249312507
if (model->arch == LLM_ARCH_MAMBA) {
12494-
// Mamba needs as many KV cells as there are sequences kept at any time
12495-
// The extra cell allows dedicating a sequence id to the system prompt
12496-
// TODO: find a better way to get the max number of parallel sequences
12497-
kv_size = params.n_parallel + 1;
12508+
// Mamba needs at least as many KV cells as there are sequences kept at any time
12509+
kv_size = std::max((uint32_t) 1, params.n_parallel);
1249812510
// it's probably best to keep as much precision as possible for the states
12499-
type_k = GGML_TYPE_F32; // required by ggml_set for Mamba's conv_state
12500-
type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_state
12511+
type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
12512+
type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
1250112513
}
1250212514

1250312515
GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
@@ -13016,8 +13028,8 @@ void llama_kv_cache_clear(struct llama_context * ctx) {
1301613028
llama_kv_cache_clear(ctx->kv_self);
1301713029
}
1301813030

13019-
void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
13020-
llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1);
13031+
bool llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
13032+
return llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1);
1302113033
}
1302213034

1302313035
void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {

llama.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,7 @@ extern "C" {
494494
// seq_id < 0 : match any sequence
495495
// p0 < 0 : [0, p1]
496496
// p1 < 0 : [p0, inf)
497-
LLAMA_API void llama_kv_cache_seq_rm(
497+
LLAMA_API bool llama_kv_cache_seq_rm(
498498
struct llama_context * ctx,
499499
llama_seq_id seq_id,
500500
llama_pos p0,

0 commit comments

Comments
 (0)