Skip to content

Commit 3be533d

Browse files
ngxsonSilver267
authored andcommitted
server : fix cache_tokens bug with no cache_prompt (ggml-org#13533)
1 parent 1fa821d commit 3be533d

File tree

3 files changed

+25
-11
lines changed

3 files changed

+25
-11
lines changed

tools/server/server.cpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2955,7 +2955,8 @@ struct server_context {
29552955
llama_kv_self_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
29562956
llama_kv_self_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard);
29572957

2958-
if (slot.params.cache_prompt) {
2958+
// add generated tokens to cache
2959+
{
29592960
llama_tokens new_tokens = slot.cache_tokens.get_text_tokens(); // copy
29602961
for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) {
29612962
new_tokens[i - n_discard] = new_tokens[i];
@@ -3000,10 +3001,7 @@ struct server_context {
30003001
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
30013002

30023003
slot.n_past += 1;
3003-
3004-
if (slot.params.cache_prompt) {
3005-
slot.cache_tokens.push_back(slot.sampled);
3006-
}
3004+
slot.cache_tokens.push_back(slot.sampled);
30073005

30083006
SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n",
30093007
slot.n_ctx, slot.n_past, (int) slot.cache_tokens.size(), slot.truncated);
@@ -3175,6 +3173,11 @@ struct server_context {
31753173

31763174
SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past);
31773175
}
3176+
} else {
3177+
// if we don't cache the prompt, we have to remove the entire KV cache
3178+
llama_kv_self_seq_rm(ctx, slot.id, 0, -1);
3179+
slot.n_past = 0;
3180+
slot.cache_tokens.clear();
31783181
}
31793182
}
31803183

@@ -3208,7 +3211,7 @@ struct server_context {
32083211
SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
32093212

32103213
// remove the non-common part from the cache
3211-
slot.cache_tokens.resize(slot.n_past);
3214+
slot.cache_tokens.keep_first(slot.n_past);
32123215

32133216
// check if we should process the image
32143217
if (slot.n_past < slot.n_prompt_tokens
@@ -3225,7 +3228,8 @@ struct server_context {
32253228
continue;
32263229
}
32273230

3228-
if (slot.params.cache_prompt) {
3231+
// add the image chunk to cache
3232+
{
32293233
const auto & chunk = slot.prompt_tokens.find_chunk(slot.n_past);
32303234
slot.cache_tokens.push_back(chunk.get()); // copy
32313235
}
@@ -3246,9 +3250,7 @@ struct server_context {
32463250
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
32473251

32483252
common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd);
3249-
if (slot.params.cache_prompt) {
3250-
slot.cache_tokens.push_back(cur_tok);
3251-
}
3253+
slot.cache_tokens.push_back(cur_tok);
32523254

32533255
slot.n_prompt_tokens_processed++;
32543256
slot.n_past++;

tools/server/tests/unit/test_completion.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,18 @@ def test_cache_vs_nocache_prompt():
196196
assert res_cache.body["content"] == res_no_cache.body["content"]
197197

198198

199+
def test_nocache_long_input_prompt():
200+
global server
201+
server.start()
202+
res = server.make_request("POST", "/completion", data={
203+
"prompt": "I believe the meaning of life is"*32,
204+
"seed": 42,
205+
"temperature": 1.0,
206+
"cache_prompt": False,
207+
})
208+
assert res.status_code == 200
209+
210+
199211
def test_completion_with_tokens_input():
200212
global server
201213
server.temperature = 0.0

tools/server/utils.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1153,7 +1153,7 @@ struct server_tokens {
11531153
tokens.clear();
11541154
}
11551155

1156-
void resize(size_t n) {
1156+
void keep_first(size_t n) {
11571157
GGML_ASSERT(n <= tokens.size());
11581158
if (has_mtmd) {
11591159
// we throw an error if we try to remove a token in the middle of an image

0 commit comments

Comments
 (0)