From 536bea4c8206db24cfb2d27d0ab046227ff371b9 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 15 May 2025 23:55:44 +0200 Subject: [PATCH 1/6] server : separate the notion of position and KV tokens, remove prompt truncation --- tools/server/server.cpp | 153 ++++++++++------------ tools/server/tests/unit/test_ctx_shift.py | 2 +- tools/server/utils.hpp | 44 ++++++- 3 files changed, 109 insertions(+), 90 deletions(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index f32f3c86aad2c..2471766447973 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1274,9 +1274,10 @@ struct server_slot { int32_t i_batch = -1; int32_t n_predict = -1; // TODO: disambiguate from params.n_predict - // n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated - int32_t n_prompt_tokens = 0; - int32_t n_prompt_tokens_processed = 0; + // this reflects the number of tokens in KV cache + // not to be confused with n_past which reflects the positions + // models using m-rope may have multiple tokens in KV cache sharing the same position + int32_t n_kv_tokens = 0; // input prompt tokens server_tokens prompt_tokens; @@ -1324,7 +1325,7 @@ struct server_slot { void reset() { SLT_DBG(*this, "%s", "\n"); - n_prompt_tokens = 0; + n_kv_tokens = 0; last_nl_pos = 0; generated_text = ""; has_new_line = false; @@ -1384,6 +1385,10 @@ struct server_slot { generated_token_probs.push_back(token); } + int32_t n_prompt_tokens() const { + return prompt_tokens.n_kv_tokens(); + } + void release() { if (is_processing()) { SLT_INF(*this, "stop processing: n_past = %d, truncated = %d\n", n_past, truncated); @@ -1397,10 +1402,10 @@ struct server_slot { result_timings get_timings() const { result_timings timings; - timings.prompt_n = n_prompt_tokens_processed; + timings.prompt_n = prompt_tokens.n_kv_tokens(); timings.prompt_ms = t_prompt_processing; - timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed; - timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; + timings.prompt_per_token_ms = t_prompt_processing / prompt_tokens.n_kv_tokens(); + timings.prompt_per_second = 1e3 / t_prompt_processing * prompt_tokens.n_kv_tokens(); timings.predicted_n = n_decoded; timings.predicted_ms = t_token_generation; @@ -1446,6 +1451,7 @@ struct server_slot { } void print_timings() const { + const int32_t n_prompt_tokens_processed = prompt_tokens.n_kv_tokens(); const double t_prompt = t_prompt_processing / n_prompt_tokens_processed; const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; @@ -1516,8 +1522,8 @@ struct server_metrics { } void on_prompt_eval(const server_slot & slot) { - n_prompt_tokens_processed_total += slot.n_prompt_tokens_processed; - n_prompt_tokens_processed += slot.n_prompt_tokens_processed; + n_prompt_tokens_processed_total += slot.prompt_tokens.n_kv_tokens(); + n_prompt_tokens_processed += slot.prompt_tokens.n_kv_tokens(); t_prompt_processing += slot.t_prompt_processing; t_prompt_processing_total += slot.t_prompt_processing; } @@ -2096,7 +2102,7 @@ struct server_context { int cur_lcs_len = slot.cache_tokens.get_common_prefix(task.prompt_tokens); // fraction of the common subsequence length compared to the current slot's prompt length - float cur_similarity = static_cast(cur_lcs_len) / static_cast(slot.cache_tokens.size()); + float cur_similarity = static_cast(cur_lcs_len) / static_cast(slot.cache_tokens.n_kv_tokens()); // select the current slot if the criteria match if (cur_lcs_len > lcs_len && cur_similarity > slot_prompt_similarity) { @@ -2142,6 +2148,8 @@ struct server_context { slot.task_type = task.type; slot.params = std::move(task.params); slot.prompt_tokens = std::move(task.prompt_tokens); + slot.n_past = 0; + slot.n_kv_tokens = 0; if (!are_lora_equal(slot.params.lora, slot.lora)) { // if lora is changed, we cannot reuse cached tokens @@ -2309,13 +2317,13 @@ struct server_context { } // if context shift is disabled, we stop when it reaches the context limit - if (slot.n_past >= slot.n_ctx) { + if (!params_base.ctx_shift && slot.n_kv_tokens >= slot.n_ctx) { slot.truncated = true; slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; - SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n", - slot.n_decoded, slot.n_prompt_tokens, slot.n_past, slot.n_ctx); + SLT_DBG(slot, "stopped due to running out of context capacity, n_kv_tokens = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n", + slot.n_decoded, slot.n_prompt_tokens(), slot.n_past, slot.n_ctx); } if (llama_vocab_is_eog(vocab, result.tok)) { @@ -2327,7 +2335,7 @@ struct server_context { const auto n_ctx_train = llama_model_n_ctx_train(model); - if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) { + if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens() + slot.n_decoded >= n_ctx_train) { slot.truncated = true; slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; // stop prediction @@ -2429,7 +2437,7 @@ struct server_context { res->tokens = { tkn.tok }; res->n_decoded = slot.n_decoded; - res->n_prompt_tokens = slot.n_prompt_tokens; + res->n_prompt_tokens = slot.n_prompt_tokens(); res->post_sampling_probs = slot.params.post_sampling_probs; res->verbose = slot.params.verbose; @@ -2464,7 +2472,7 @@ struct server_context { res->truncated = slot.truncated; res->n_decoded = slot.n_decoded; - res->n_prompt_tokens = slot.n_prompt_tokens; + res->n_prompt_tokens = slot.n_prompt_tokens(); res->n_tokens_cached = slot.n_past; res->has_new_line = slot.has_new_line; res->stopping_word = slot.stopping_word; @@ -2502,7 +2510,7 @@ struct server_context { auto res = std::make_unique(); res->id = slot.id_task; res->index = slot.index; - res->n_tokens = slot.n_prompt_tokens; + res->n_tokens = slot.n_prompt_tokens(); res->oaicompat = slot.params.oaicompat; const int n_embd = llama_model_n_embd(model); @@ -2545,7 +2553,7 @@ struct server_context { auto res = std::make_unique(); res->id = slot.id_task; res->index = slot.index; - res->n_tokens = slot.n_prompt_tokens; + res->n_tokens = slot.n_prompt_tokens(); for (int i = 0; i < batch.n_tokens; ++i) { if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { @@ -2786,7 +2794,7 @@ struct server_context { break; } - const size_t token_count = slot->cache_tokens.size(); + const size_t token_count = slot->cache_tokens.n_kv_tokens(); const int64_t t_start = ggml_time_us(); std::string filename = task.slot_action.filename; @@ -2872,7 +2880,7 @@ struct server_context { } // Erase token cache - const size_t n_erased = slot->cache_tokens.size(); + const size_t n_erased = slot->cache_tokens.n_kv_tokens(); llama_kv_self_seq_rm(ctx, slot->id, -1, -1); slot->cache_tokens.clear(); @@ -2958,7 +2966,7 @@ struct server_context { new_tokens[i - n_discard] = new_tokens[i]; } - new_tokens.resize(slot.cache_tokens.size() - n_discard); + new_tokens.resize(slot.cache_tokens.n_kv_tokens() - n_discard); slot.cache_tokens.clear(); slot.cache_tokens.insert(new_tokens); } @@ -2996,11 +3004,12 @@ struct server_context { common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true); - slot.n_past += 1; + slot.n_past += 1; + slot.n_kv_tokens += 1; slot.cache_tokens.push_back(slot.sampled); - SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n", - slot.n_ctx, slot.n_past, (int) slot.cache_tokens.size(), slot.truncated); + SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_pos = %d, n_cache_tokens = %d, truncated = %d\n", + slot.n_ctx, slot.n_past, slot.cache_tokens.n_pos(), (int) slot.cache_tokens.n_kv_tokens(), slot.truncated); } // process in chunks of params.n_batch @@ -3028,11 +3037,9 @@ struct server_context { slot.t_start_process_prompt = ggml_time_us(); slot.t_start_generation = 0; - slot.n_past = 0; - slot.n_prompt_tokens = prompt_tokens.size(); slot.state = SLOT_STATE_PROCESSING_PROMPT; - SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens); + SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens()); // print prompt tokens (for debugging) /*if (1) { @@ -3058,13 +3065,13 @@ struct server_context { } if (slot.is_non_causal()) { - if (slot.n_prompt_tokens > n_ubatch) { + if (slot.n_prompt_tokens() > n_ubatch) { slot.release(); send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER); continue; } - if (slot.n_prompt_tokens > slot.n_ctx) { + if (slot.n_prompt_tokens() > slot.n_ctx) { slot.release(); send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_SERVER); continue; @@ -3074,52 +3081,21 @@ struct server_context { // if context shift is disabled, we make sure prompt size is smaller than KV size // TODO: there should be a separate parameter that control prompt truncation // context shift should be applied only during the generation phase - if (slot.n_prompt_tokens >= slot.n_ctx) { + if (slot.n_prompt_tokens() >= slot.n_ctx) { slot.release(); send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST); continue; } } if (slot.params.n_keep < 0) { - slot.params.n_keep = slot.n_prompt_tokens; + slot.params.n_keep = slot.n_prompt_tokens(); } slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep); - // if input prompt is too big, truncate it - if (slot.n_prompt_tokens >= slot.n_ctx) { - if (mctx) { - // we should never reach this - GGML_ABORT("not supported by multimodal"); - } - const int n_left = slot.n_ctx - slot.params.n_keep; - - const int n_block_size = n_left / 2; - const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size; - - const llama_tokens & curr_tokens = slot.prompt_tokens.get_text_tokens(); - llama_tokens new_tokens( - curr_tokens.begin(), - curr_tokens.begin() + slot.params.n_keep); - - new_tokens.insert( - new_tokens.end(), - curr_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size, - curr_tokens.end()); - - prompt_tokens.clear(); - prompt_tokens.insert(new_tokens); - - slot.truncated = true; - slot.n_prompt_tokens = prompt_tokens.size(); - - SLT_WRN(slot, "input truncated, n_ctx = %d, n_keep = %d, n_left = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, n_left, slot.n_prompt_tokens); - - GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); - } - if (slot.params.cache_prompt) { // reuse any previously computed tokens that are common with the new prompt - slot.n_past = slot.cache_tokens.get_common_prefix(prompt_tokens); + slot.n_past = slot.cache_tokens.get_common_prefix(prompt_tokens); + slot.n_kv_tokens = slot.cache_tokens.n_kv_tokens(slot.n_past); // reuse chunks from the cached prompt by shifting their KV cache in the new position if (params_base.n_cache_reuse > 0) { @@ -3133,12 +3109,12 @@ struct server_context { SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past); - while (head_c < slot.cache_tokens.size() && - head_p < prompt_tokens.size()) { + while (head_c < slot.cache_tokens.n_kv_tokens() && + head_p < prompt_tokens.n_kv_tokens()) { size_t n_match = 0; - while (head_c + n_match < slot.cache_tokens.size() && - head_p + n_match < prompt_tokens.size() && + while (head_c + n_match < slot.cache_tokens.n_kv_tokens() && + head_p + n_match < prompt_tokens.n_kv_tokens() && slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) { n_match++; @@ -3168,29 +3144,31 @@ struct server_context { } SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past); + // because we're using this logic on text-only, n_past always == n_kv_tokens + slot.n_kv_tokens = slot.n_past; } } else { // if we don't cache the prompt, we have to remove the entire KV cache llama_kv_self_seq_rm(ctx, slot.id, 0, -1); slot.n_past = 0; + slot.n_kv_tokens = 0; slot.cache_tokens.clear(); } } - if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) { + if (slot.n_past == slot.prompt_tokens.n_pos() && slot.n_past > 0) { // we have to evaluate at least 1 token to generate logits. - SLT_WRN(slot, "need to evaluate at least 1 token to generate logits, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens); + SLT_WRN(slot, "need to evaluate at least 1 token to generate logits, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens()); slot.n_past--; + slot.n_kv_tokens--; } - - slot.n_prompt_tokens_processed = 0; } // non-causal tasks require to fit the entire prompt in the physical batch if (slot.is_non_causal()) { // cannot fit the prompt in the current batch - will try next iter - if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { + if (batch.n_tokens + slot.n_prompt_tokens() > n_batch) { continue; } } @@ -3202,6 +3180,7 @@ struct server_context { // there is no common part left slot.n_past = 0; + slot.n_kv_tokens = 0; } SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past); @@ -3210,11 +3189,12 @@ struct server_context { slot.cache_tokens.keep_first(slot.n_past); // check if we should process the image - if (slot.n_past < slot.n_prompt_tokens + if (slot.n_kv_tokens < slot.n_prompt_tokens() && slot.prompt_tokens[slot.n_past] == LLAMA_TOKEN_NULL) { // process the image int32_t new_n_past; - int32_t res = slot.prompt_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past); + size_t n_tok = 0; + int32_t res = slot.prompt_tokens.process_chunk(ctx, mctx, slot.n_past, slot.id, new_n_past, n_tok); int32_t n_pos = new_n_past - slot.n_past; if (res != 0) { @@ -3230,12 +3210,12 @@ struct server_context { slot.cache_tokens.push_back(chunk.get()); // copy } - slot.n_past += n_pos; - slot.n_prompt_tokens_processed += n_pos; + slot.n_past += n_pos; + slot.n_kv_tokens += n_tok; } // add prompt tokens for processing in the current batch - while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { + while (slot.n_kv_tokens < slot.n_prompt_tokens() && batch.n_tokens < n_batch) { // get next token to process llama_token cur_tok = slot.prompt_tokens[slot.n_past]; if (cur_tok == LLAMA_TOKEN_NULL) { @@ -3243,30 +3223,31 @@ struct server_context { } // without pooling, we want to output the embeddings for all the tokens in the batch - const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE; + const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING + && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE; common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd); slot.cache_tokens.push_back(cur_tok); - slot.n_prompt_tokens_processed++; + slot.n_kv_tokens++; slot.n_past++; } // SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str()); - SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); + SLT_INF(slot, "prompt processing progress, n_past = %d, batch.n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_kv_tokens / slot.n_prompt_tokens()); // entire prompt has been processed - if (slot.n_past == slot.n_prompt_tokens) { + if (slot.n_kv_tokens == slot.n_prompt_tokens()) { slot.state = SLOT_STATE_DONE_PROMPT; GGML_ASSERT(batch.n_tokens > 0); - GGML_ASSERT((size_t) slot.n_prompt_tokens == slot.prompt_tokens.size()); + GGML_ASSERT((size_t) slot.n_prompt_tokens() == slot.prompt_tokens.n_kv_tokens()); common_sampler_reset(slot.smpl); // Process all prompt tokens through sampler system - for (int i = 0; i < slot.n_prompt_tokens; ++i) { + for (int i = 0; i < slot.n_prompt_tokens(); ++i) { llama_token id = slot.prompt_tokens[i]; if (id != LLAMA_TOKEN_NULL) { common_sampler_accept(slot.smpl, id, false); @@ -3279,7 +3260,7 @@ struct server_context { slot.n_decoded = 0; slot.i_batch = batch.n_tokens - 1; - SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens); + SLT_INF(slot, "prompt done, n_past = %d, batch.n_tokens = %d\n", slot.n_past, batch.n_tokens); } } @@ -3294,7 +3275,7 @@ struct server_context { return; } - SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens); + SRV_DBG("decoding batch, batch.n_tokens = %d\n", batch.n_tokens); if (slot_batched) { // make sure we're in the right embedding mode diff --git a/tools/server/tests/unit/test_ctx_shift.py b/tools/server/tests/unit/test_ctx_shift.py index be93a6d31f410..8303fc3fdf9fa 100644 --- a/tools/server/tests/unit/test_ctx_shift.py +++ b/tools/server/tests/unit/test_ctx_shift.py @@ -31,7 +31,7 @@ def test_ctx_shift_enabled(): "prompt": LONG_TEXT, }) assert res.status_code == 200 - assert res.body["timings"]["prompt_n"] == 109 + assert res.body["timings"]["prompt_n"] == 301 assert res.body["timings"]["predicted_n"] == 64 assert res.body["truncated"] is True diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index 232eef195437f..9c4e190e4ab45 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -1045,6 +1045,7 @@ struct server_tokens { // it can include LLAMA_TOKEN_NULL, which is used to indicate a token that is not a text token // a mtmd_input_chunk can occupy multiple tokens, one llama_token per **position** // important: for models using mrope, an image can contain multiple tokens but will use only one **position** + // in otherwords, tokens.size() == n_past llama_tokens tokens; // for ex. with input of 5 text tokens and 2 images: @@ -1052,6 +1053,11 @@ struct server_tokens { // pos 0 1 2 3 4 5 6 7 8 9 // map_pos_to_image will contain: {5, img0}, {8, img1} + // number of tokens in KV cache + // it is named this way to avoid confusion between the notion of "tokens" and "positions" + // for example, models using m-rope can have multiple tokens in the KV cache but they all share one position + size_t n_kv = 0; + public: server_tokens() = default; ~server_tokens() = default; @@ -1074,7 +1080,7 @@ struct server_tokens { } } - server_tokens(llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) {} + server_tokens(llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens), n_kv(tokens.size()) {} // for debugging std::string str() const { @@ -1108,6 +1114,7 @@ struct server_tokens { if (tok == LLAMA_TOKEN_NULL) { throw std::runtime_error("Invalid token"); } + n_kv++; tokens.emplace_back(tok); } @@ -1122,6 +1129,7 @@ struct server_tokens { for (int i = 0; i < n_pos; ++i) { tokens.emplace_back(LLAMA_TOKEN_NULL); } + n_kv += mtmd_image_tokens_get_n_tokens(img_tokens); mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk)); map_pos_to_image[start_pos] = std::move(new_chunk); } else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) { @@ -1138,6 +1146,7 @@ struct server_tokens { // for compatibility with context shift and prompt truncation void insert(const llama_tokens & inp_tokens) { GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled + n_kv += inp_tokens.size(); tokens.insert(tokens.end(), inp_tokens.begin(), inp_tokens.end()); } @@ -1153,7 +1162,29 @@ struct server_tokens { tokens[pos] = id; } - size_t size() const { + // if end_pos == -1, we count all positions + size_t n_kv_tokens(llama_pos end_pos = -1) const { + if (end_pos == -1) { + return n_kv; + } else { + size_t res = 0; + for (llama_pos i = 0; i < end_pos;) { + auto & t = tokens[i]; + if (t == LLAMA_TOKEN_NULL) { + auto & chunk = find_chunk(i); + auto img_tokens = mtmd_input_chunk_get_tokens_image(chunk.get()); + res += mtmd_image_tokens_get_n_tokens(img_tokens); + i += mtmd_image_tokens_get_n_pos(img_tokens); + } else { + res++; + i++; + } + } + return res; + } + } + + llama_pos n_pos() const { return tokens.size(); } @@ -1162,6 +1193,7 @@ struct server_tokens { } void clear() { + n_kv = 0; tokens.clear(); } @@ -1185,6 +1217,8 @@ struct server_tokens { for (auto it = map_pos_to_image.begin(); it != map_pos_to_image.end(); ) { llama_pos pos = it->first; if (pos >= (llama_pos)n) { + auto img_tokens = mtmd_input_chunk_get_tokens_image(it->second.get()); + n_kv -= mtmd_image_tokens_get_n_tokens(img_tokens); it = map_pos_to_image.erase(it); } else { ++it; @@ -1205,6 +1239,7 @@ struct server_tokens { return common_detokenize(ctx, text_tokens, special); } + // returns the position of the first token that is different size_t get_common_prefix(const server_tokens & b) const { size_t max_idx = std::min(tokens.size(), b.tokens.size()); for (size_t i = 0; i < max_idx; ++i) { @@ -1268,7 +1303,8 @@ struct server_tokens { mtmd_context * mctx, llama_pos n_past, int32_t seq_id, - llama_pos & n_pos_out) { + llama_pos & n_pos_out, + size_t & n_kv_tokens_out) { auto it = map_pos_to_image.find(n_past); if (it == map_pos_to_image.end()) { throw std::runtime_error("Chunk not found"); @@ -1290,6 +1326,8 @@ struct server_tokens { n_pos_out = n_past; return result; } + auto img_tokens = mtmd_input_chunk_get_tokens_image(it->second.get()); + n_kv_tokens_out = mtmd_image_tokens_get_n_tokens(img_tokens); n_pos_out = new_n_past; return 0; } From 678d7b1569a4b0fcb15bf8d65ad9779a9bb77e9f Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 16 May 2025 13:28:36 +0200 Subject: [PATCH 2/6] no more scan loop in n_kv_tokens() --- tools/server/server.cpp | 10 ++++++---- tools/server/utils.hpp | 35 ++++++++++------------------------- 2 files changed, 16 insertions(+), 29 deletions(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 2471766447973..e93e15c1f096e 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -2099,10 +2099,11 @@ struct server_context { } // length of the Longest Common Subsequence between the current slot's prompt and the input prompt - int cur_lcs_len = slot.cache_tokens.get_common_prefix(task.prompt_tokens); + auto common_pos = slot.cache_tokens.get_common_prefix(task.prompt_tokens); + int cur_lcs_len = common_pos.first; // position, not tokens // fraction of the common subsequence length compared to the current slot's prompt length - float cur_similarity = static_cast(cur_lcs_len) / static_cast(slot.cache_tokens.n_kv_tokens()); + float cur_similarity = static_cast(cur_lcs_len) / static_cast(slot.cache_tokens.n_pos()); // select the current slot if the criteria match if (cur_lcs_len > lcs_len && cur_similarity > slot_prompt_similarity) { @@ -3094,8 +3095,9 @@ struct server_context { if (slot.params.cache_prompt) { // reuse any previously computed tokens that are common with the new prompt - slot.n_past = slot.cache_tokens.get_common_prefix(prompt_tokens); - slot.n_kv_tokens = slot.cache_tokens.n_kv_tokens(slot.n_past); + auto common_pos = slot.cache_tokens.get_common_prefix(prompt_tokens); + slot.n_past = common_pos.first; + slot.n_kv_tokens = common_pos.second; // reuse chunks from the cached prompt by shifting their KV cache in the new position if (params_base.n_cache_reuse > 0) { diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index 9c4e190e4ab45..ab619406867f3 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -1162,26 +1162,8 @@ struct server_tokens { tokens[pos] = id; } - // if end_pos == -1, we count all positions - size_t n_kv_tokens(llama_pos end_pos = -1) const { - if (end_pos == -1) { - return n_kv; - } else { - size_t res = 0; - for (llama_pos i = 0; i < end_pos;) { - auto & t = tokens[i]; - if (t == LLAMA_TOKEN_NULL) { - auto & chunk = find_chunk(i); - auto img_tokens = mtmd_input_chunk_get_tokens_image(chunk.get()); - res += mtmd_image_tokens_get_n_tokens(img_tokens); - i += mtmd_image_tokens_get_n_pos(img_tokens); - } else { - res++; - i++; - } - } - return res; - } + size_t n_kv_tokens() const { + return n_kv; } llama_pos n_pos() const { @@ -1239,9 +1221,10 @@ struct server_tokens { return common_detokenize(ctx, text_tokens, special); } - // returns the position of the first token that is different - size_t get_common_prefix(const server_tokens & b) const { + // returns pair of + std::pair get_common_prefix(const server_tokens & b) const { size_t max_idx = std::min(tokens.size(), b.tokens.size()); + size_t n_tok = 0; for (size_t i = 0; i < max_idx; ++i) { auto & ai = tokens[i]; auto & bi = b.tokens[i]; @@ -1260,17 +1243,19 @@ struct server_tokens { if (ai_id == bi_id && a_pos == b_pos) { GGML_ASSERT(a_pos > 0 && "Invalid image token"); // should never happen i += a_pos - 1; // will be +1 by the for loop + n_tok += mtmd_image_tokens_get_n_tokens(a_img); continue; } else { - return i; + return {i, n_tok}; } } else if (ai == bi) { + n_tok++; continue; } else { - return i; + return {i, n_tok}; } } - return max_idx; // all tokens are equal + return {max_idx, n_tok}; // all tokens are equal } // make sure all text tokens are within the vocab range From f9cc9f23893d6a0491752f6df90f3790399bfae2 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 16 May 2025 16:36:45 +0200 Subject: [PATCH 3/6] fix stats report --- tools/server/server.cpp | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index e93e15c1f096e..ac8d3392879df 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1308,11 +1308,12 @@ struct server_slot { common_chat_format chat_format = COMMON_CHAT_FORMAT_CONTENT_ONLY; // stats - size_t n_sent_text = 0; // number of sent text character + size_t n_sent_text = 0; // number of sent text character int64_t t_start_process_prompt; int64_t t_start_generation; + size_t n_prompt_processing = 0; // number of decoded prompt tokens (may be less than prompt_tokens.n_kv_tokens(), in case we are using cache) double t_prompt_processing; // ms double t_token_generation; // ms @@ -1334,6 +1335,7 @@ struct server_slot { stopping_word = ""; n_past = 0; n_sent_text = 0; + n_prompt_processing = 0; task_type = SERVER_TASK_TYPE_COMPLETION; generated_tokens.clear(); @@ -1402,10 +1404,10 @@ struct server_slot { result_timings get_timings() const { result_timings timings; - timings.prompt_n = prompt_tokens.n_kv_tokens(); + timings.prompt_n = n_prompt_processing; timings.prompt_ms = t_prompt_processing; - timings.prompt_per_token_ms = t_prompt_processing / prompt_tokens.n_kv_tokens(); - timings.prompt_per_second = 1e3 / t_prompt_processing * prompt_tokens.n_kv_tokens(); + timings.prompt_per_token_ms = t_prompt_processing / n_prompt_processing; + timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_processing; timings.predicted_n = n_decoded; timings.predicted_ms = t_token_generation; @@ -3212,8 +3214,9 @@ struct server_context { slot.cache_tokens.push_back(chunk.get()); // copy } - slot.n_past += n_pos; - slot.n_kv_tokens += n_tok; + slot.n_past += n_pos; + slot.n_kv_tokens += n_tok; + slot.n_prompt_processing += n_tok; // for stats only } // add prompt tokens for processing in the current batch @@ -3233,6 +3236,7 @@ struct server_context { slot.n_kv_tokens++; slot.n_past++; + slot.n_prompt_processing++; // for stats only } // SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str()); From de8956adf482e3878b91b290c5eefc940195f061 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 17 May 2025 16:18:47 +0200 Subject: [PATCH 4/6] rm notion of n_kv_tokens --- tools/server/server.cpp | 152 +++++++++++++++++++--------------------- tools/server/utils.hpp | 66 +++++++++-------- 2 files changed, 109 insertions(+), 109 deletions(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index ac8d3392879df..2a56570b8495b 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1267,18 +1267,13 @@ struct server_slot { int64_t t_last_used = -1; // generation props - int32_t n_ctx = 0; // context size per slot - int32_t n_past = 0; - int32_t n_decoded = 0; + int32_t n_ctx = 0; // context size per slot + int32_t n_past = 0; // current position (note: it is not affected by context shift) + int32_t n_decoded = 0; // number of tokens generated int32_t n_remaining = -1; int32_t i_batch = -1; int32_t n_predict = -1; // TODO: disambiguate from params.n_predict - // this reflects the number of tokens in KV cache - // not to be confused with n_past which reflects the positions - // models using m-rope may have multiple tokens in KV cache sharing the same position - int32_t n_kv_tokens = 0; - // input prompt tokens server_tokens prompt_tokens; @@ -1326,7 +1321,6 @@ struct server_slot { void reset() { SLT_DBG(*this, "%s", "\n"); - n_kv_tokens = 0; last_nl_pos = 0; generated_text = ""; has_new_line = false; @@ -1335,8 +1329,8 @@ struct server_slot { stopping_word = ""; n_past = 0; n_sent_text = 0; - n_prompt_processing = 0; task_type = SERVER_TASK_TYPE_COMPLETION; + n_prompt_processing = 0; generated_tokens.clear(); generated_token_probs.clear(); @@ -1388,7 +1382,16 @@ struct server_slot { } int32_t n_prompt_tokens() const { - return prompt_tokens.n_kv_tokens(); + return prompt_tokens.n_tokens(); + } + + int32_t n_cache_tokens() const { + return cache_tokens.n_tokens(); + } + + // different from n_past if context is shifted + llama_pos curr_pos() const { + return cache_tokens.n_pos(); } void release() { @@ -1453,9 +1456,8 @@ struct server_slot { } void print_timings() const { - const int32_t n_prompt_tokens_processed = prompt_tokens.n_kv_tokens(); - const double t_prompt = t_prompt_processing / n_prompt_tokens_processed; - const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; + const double t_prompt = t_prompt_processing / n_prompt_processing; + const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_processing; const double t_gen = t_token_generation / n_decoded; const double n_gen_second = 1e3 / t_token_generation * n_decoded; @@ -1465,9 +1467,9 @@ struct server_slot { "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n" " eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n" " total time = %10.2f ms / %5d tokens\n", - t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second, + t_prompt_processing, (int)n_prompt_processing, t_prompt, n_prompt_second, t_token_generation, n_decoded, t_gen, n_gen_second, - t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded); + t_prompt_processing + t_token_generation, (int)n_prompt_processing + n_decoded); if (n_draft_total > 0) { const float draft_ratio = (float) n_draft_accepted / n_draft_total; @@ -1524,8 +1526,8 @@ struct server_metrics { } void on_prompt_eval(const server_slot & slot) { - n_prompt_tokens_processed_total += slot.prompt_tokens.n_kv_tokens(); - n_prompt_tokens_processed += slot.prompt_tokens.n_kv_tokens(); + n_prompt_tokens_processed_total += slot.n_prompt_tokens(); + n_prompt_tokens_processed += slot.n_prompt_processing; t_prompt_processing += slot.t_prompt_processing; t_prompt_processing_total += slot.t_prompt_processing; } @@ -2101,8 +2103,7 @@ struct server_context { } // length of the Longest Common Subsequence between the current slot's prompt and the input prompt - auto common_pos = slot.cache_tokens.get_common_prefix(task.prompt_tokens); - int cur_lcs_len = common_pos.first; // position, not tokens + int cur_lcs_len = slot.cache_tokens.get_common_prefix(task.prompt_tokens); // fraction of the common subsequence length compared to the current slot's prompt length float cur_similarity = static_cast(cur_lcs_len) / static_cast(slot.cache_tokens.n_pos()); @@ -2152,7 +2153,6 @@ struct server_context { slot.params = std::move(task.params); slot.prompt_tokens = std::move(task.prompt_tokens); slot.n_past = 0; - slot.n_kv_tokens = 0; if (!are_lora_equal(slot.params.lora, slot.lora)) { // if lora is changed, we cannot reuse cached tokens @@ -2320,13 +2320,13 @@ struct server_context { } // if context shift is disabled, we stop when it reaches the context limit - if (!params_base.ctx_shift && slot.n_kv_tokens >= slot.n_ctx) { + if (!params_base.ctx_shift && slot.n_cache_tokens() >= slot.n_ctx) { slot.truncated = true; slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; - SLT_DBG(slot, "stopped due to running out of context capacity, n_kv_tokens = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n", - slot.n_decoded, slot.n_prompt_tokens(), slot.n_past, slot.n_ctx); + SLT_DBG(slot, "stopped due to running out of context capacity, n_cache_tokens = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n", + slot.n_cache_tokens(), slot.n_prompt_tokens(), slot.n_decoded, slot.n_ctx); } if (llama_vocab_is_eog(vocab, result.tok)) { @@ -2797,7 +2797,7 @@ struct server_context { break; } - const size_t token_count = slot->cache_tokens.n_kv_tokens(); + const size_t token_count = slot->n_cache_tokens(); const int64_t t_start = ggml_time_us(); std::string filename = task.slot_action.filename; @@ -2883,7 +2883,7 @@ struct server_context { } // Erase token cache - const size_t n_erased = slot->cache_tokens.n_kv_tokens(); + const size_t n_erased = slot->n_cache_tokens(); llama_kv_self_seq_rm(ctx, slot->id, -1, -1); slot->cache_tokens.clear(); @@ -2937,7 +2937,7 @@ struct server_context { // apply context-shift if needed // TODO: simplify and improve for (server_slot & slot : slots) { - if (slot.is_processing() && slot.n_past + 1 >= slot.n_ctx) { + if (slot.is_processing() && slot.n_cache_tokens() + 1 >= slot.n_ctx) { if (!params_base.ctx_shift) { // this check is redundant (for good) // we should never get here, because generation should already stopped in process_token() @@ -2953,14 +2953,15 @@ struct server_context { } // Shift context + const int n_pos_cur = slot.cache_tokens.n_pos(); const int n_keep = slot.params.n_keep + add_bos_token; - const int n_left = slot.n_past - n_keep; + const int n_left = n_pos_cur - n_keep; const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2); SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard); llama_kv_self_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard); - llama_kv_self_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard); + llama_kv_self_seq_add(ctx, slot.id, n_keep + n_discard, n_pos_cur, -n_discard); // add generated tokens to cache { @@ -2969,13 +2970,11 @@ struct server_context { new_tokens[i - n_discard] = new_tokens[i]; } - new_tokens.resize(slot.cache_tokens.n_kv_tokens() - n_discard); + new_tokens.resize(slot.n_cache_tokens() - n_discard); slot.cache_tokens.clear(); slot.cache_tokens.insert(new_tokens); } - slot.n_past -= n_discard; - slot.truncated = true; } } @@ -3005,14 +3004,13 @@ struct server_context { slot.i_batch = batch.n_tokens; - common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true); + common_batch_add(batch, slot.sampled, slot.curr_pos(), { slot.id }, true); - slot.n_past += 1; - slot.n_kv_tokens += 1; + slot.n_past += 1; slot.cache_tokens.push_back(slot.sampled); SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_pos = %d, n_cache_tokens = %d, truncated = %d\n", - slot.n_ctx, slot.n_past, slot.cache_tokens.n_pos(), (int) slot.cache_tokens.n_kv_tokens(), slot.truncated); + slot.n_ctx, slot.n_past, slot.cache_tokens.n_pos(), (int) slot.n_cache_tokens(), slot.truncated); } // process in chunks of params.n_batch @@ -3022,6 +3020,8 @@ struct server_context { // next, batch any pending prompts without exceeding n_batch if (params_base.cont_batching || batch.n_tokens == 0) { for (auto & slot : slots) { + auto & prompt_tokens = slot.prompt_tokens; + // check if we can batch this slot with the previous one if (slot.is_processing()) { if (!slot_batched) { @@ -3033,9 +3033,7 @@ struct server_context { // this slot still has a prompt to be processed if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) { - auto & prompt_tokens = slot.prompt_tokens; - // TODO: maybe move branch to outside of this loop in the future if (slot.state == SLOT_STATE_STARTED) { slot.t_start_process_prompt = ggml_time_us(); slot.t_start_generation = 0; @@ -3097,9 +3095,8 @@ struct server_context { if (slot.params.cache_prompt) { // reuse any previously computed tokens that are common with the new prompt - auto common_pos = slot.cache_tokens.get_common_prefix(prompt_tokens); - slot.n_past = common_pos.first; - slot.n_kv_tokens = common_pos.second; + slot.n_past = slot.cache_tokens.get_common_prefix(prompt_tokens); + slot.cache_tokens.keep_first(slot.n_past); // reuse chunks from the cached prompt by shifting their KV cache in the new position if (params_base.n_cache_reuse > 0) { @@ -3113,12 +3110,12 @@ struct server_context { SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past); - while (head_c < slot.cache_tokens.n_kv_tokens() && - head_p < prompt_tokens.n_kv_tokens()) { + while (head_c < (size_t)slot.n_cache_tokens() && + head_p < prompt_tokens.n_tokens()) { size_t n_match = 0; - while (head_c + n_match < slot.cache_tokens.n_kv_tokens() && - head_p + n_match < prompt_tokens.n_kv_tokens() && + while (head_c + n_match < (size_t)slot.n_cache_tokens() && + head_p + n_match < prompt_tokens.n_tokens() && slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) { n_match++; @@ -3148,14 +3145,12 @@ struct server_context { } SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past); - // because we're using this logic on text-only, n_past always == n_kv_tokens - slot.n_kv_tokens = slot.n_past; + slot.cache_tokens.keep_first(slot.n_past); } } else { // if we don't cache the prompt, we have to remove the entire KV cache llama_kv_self_seq_rm(ctx, slot.id, 0, -1); slot.n_past = 0; - slot.n_kv_tokens = 0; slot.cache_tokens.clear(); } } @@ -3164,36 +3159,36 @@ struct server_context { // we have to evaluate at least 1 token to generate logits. SLT_WRN(slot, "need to evaluate at least 1 token to generate logits, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens()); + slot.cache_tokens.rm_last(1); slot.n_past--; - slot.n_kv_tokens--; } - } - // non-causal tasks require to fit the entire prompt in the physical batch - if (slot.is_non_causal()) { - // cannot fit the prompt in the current batch - will try next iter - if (batch.n_tokens + slot.n_prompt_tokens() > n_batch) { - continue; + // non-causal tasks require to fit the entire prompt in the physical batch + if (slot.is_non_causal()) { + // cannot fit the prompt in the current batch - will try next iter + if (batch.n_tokens + slot.n_prompt_tokens() > n_batch) { + continue; + } } - } - // keep only the common part - if (!llama_kv_self_seq_rm(ctx, slot.id, slot.n_past, -1)) { - // could not partially delete (likely using a non-Transformer model) - llama_kv_self_seq_rm(ctx, slot.id, -1, -1); + // keep only the common part + if (!llama_kv_self_seq_rm(ctx, slot.id, slot.n_past, -1)) { + // could not partially delete (likely using a non-Transformer model) + llama_kv_self_seq_rm(ctx, slot.id, -1, -1); - // there is no common part left - slot.n_past = 0; - slot.n_kv_tokens = 0; - } + // there is no common part left + slot.n_past = 0; + slot.cache_tokens.clear(); + } - SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past); + SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past); - // remove the non-common part from the cache - slot.cache_tokens.keep_first(slot.n_past); + // remove the non-common part from the cache + slot.cache_tokens.keep_first(slot.n_past); + } // check if we should process the image - if (slot.n_kv_tokens < slot.n_prompt_tokens() + if (slot.n_past < prompt_tokens.n_pos() && slot.prompt_tokens[slot.n_past] == LLAMA_TOKEN_NULL) { // process the image int32_t new_n_past; @@ -3215,12 +3210,11 @@ struct server_context { } slot.n_past += n_pos; - slot.n_kv_tokens += n_tok; slot.n_prompt_processing += n_tok; // for stats only } // add prompt tokens for processing in the current batch - while (slot.n_kv_tokens < slot.n_prompt_tokens() && batch.n_tokens < n_batch) { + while (slot.n_past < prompt_tokens.n_pos() && batch.n_tokens < n_batch) { // get next token to process llama_token cur_tok = slot.prompt_tokens[slot.n_past]; if (cur_tok == LLAMA_TOKEN_NULL) { @@ -3231,24 +3225,22 @@ struct server_context { const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE; - common_batch_add(batch, cur_tok, slot.n_past, { slot.id }, need_embd); + common_batch_add(batch, cur_tok, slot.curr_pos(), { slot.id }, need_embd); slot.cache_tokens.push_back(cur_tok); - slot.n_kv_tokens++; slot.n_past++; slot.n_prompt_processing++; // for stats only } // SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str()); - SLT_INF(slot, "prompt processing progress, n_past = %d, batch.n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_kv_tokens / slot.n_prompt_tokens()); + SLT_INF(slot, "prompt processing progress, n_past = %d, batch.n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_past / prompt_tokens.n_pos()); // entire prompt has been processed - if (slot.n_kv_tokens == slot.n_prompt_tokens()) { + if (slot.n_past == prompt_tokens.n_pos()) { slot.state = SLOT_STATE_DONE_PROMPT; GGML_ASSERT(batch.n_tokens > 0); - GGML_ASSERT((size_t) slot.n_prompt_tokens() == slot.prompt_tokens.n_kv_tokens()); common_sampler_reset(slot.smpl); @@ -3418,9 +3410,9 @@ struct server_context { // determine the max draft that fits the current slot state int n_draft_max = slot.params.speculative.n_max; - // note: n_past is not yet increased for the `id` token sampled above + // note: slot.curr_pos() is not yet increased for the `id` token sampled above // also, need to leave space for 1 extra token to allow context shifts - n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2); + n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.curr_pos() - 2); if (slot.n_remaining > 0) { n_draft_max = std::min(n_draft_max, slot.n_remaining - 1); @@ -3456,10 +3448,10 @@ struct server_context { // construct the speculation batch common_batch_clear(slot.batch_spec); - common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true); + common_batch_add (slot.batch_spec, id, slot.curr_pos(), { slot.id }, true); for (size_t i = 0; i < draft.size(); ++i) { - common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true); + common_batch_add(slot.batch_spec, draft[i], slot.curr_pos() + 1 + i, { slot.id }, true); } SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens); @@ -3478,7 +3470,7 @@ struct server_context { slot.cache_tokens.push_back(id); slot.cache_tokens.insert({ids.begin(), ids.end() - 1}); - llama_kv_self_seq_rm(ctx, slot.id, slot.n_past, -1); + llama_kv_self_seq_rm(ctx, slot.id, slot.curr_pos(), -1); for (size_t i = 0; i < ids.size(); ++i) { completion_token_output result; diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index ab619406867f3..c72b93eff5fa1 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -1053,10 +1053,10 @@ struct server_tokens { // pos 0 1 2 3 4 5 6 7 8 9 // map_pos_to_image will contain: {5, img0}, {8, img1} - // number of tokens in KV cache - // it is named this way to avoid confusion between the notion of "tokens" and "positions" - // for example, models using m-rope can have multiple tokens in the KV cache but they all share one position - size_t n_kv = 0; + // number of tokens contained in this object + // note that the number of tokens can be larger than the number of positions + // for example, models using m-rope can have multiple tokens that share a position + size_t n_tok = 0; public: server_tokens() = default; @@ -1080,7 +1080,7 @@ struct server_tokens { } } - server_tokens(llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens), n_kv(tokens.size()) {} + server_tokens(llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens), n_tok(tokens.size()) {} // for debugging std::string str() const { @@ -1114,7 +1114,7 @@ struct server_tokens { if (tok == LLAMA_TOKEN_NULL) { throw std::runtime_error("Invalid token"); } - n_kv++; + n_tok++; tokens.emplace_back(tok); } @@ -1129,7 +1129,7 @@ struct server_tokens { for (int i = 0; i < n_pos; ++i) { tokens.emplace_back(LLAMA_TOKEN_NULL); } - n_kv += mtmd_image_tokens_get_n_tokens(img_tokens); + n_tok += mtmd_image_tokens_get_n_tokens(img_tokens); mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk)); map_pos_to_image[start_pos] = std::move(new_chunk); } else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) { @@ -1146,7 +1146,7 @@ struct server_tokens { // for compatibility with context shift and prompt truncation void insert(const llama_tokens & inp_tokens) { GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled - n_kv += inp_tokens.size(); + n_tok += inp_tokens.size(); tokens.insert(tokens.end(), inp_tokens.begin(), inp_tokens.end()); } @@ -1162,8 +1162,8 @@ struct server_tokens { tokens[pos] = id; } - size_t n_kv_tokens() const { - return n_kv; + size_t n_tokens() const { + return n_tok; } llama_pos n_pos() const { @@ -1175,12 +1175,17 @@ struct server_tokens { } void clear() { - n_kv = 0; + n_tok = 0; tokens.clear(); } - void keep_first(size_t n) { - GGML_ASSERT(n <= tokens.size()); + void keep_first(size_t n_pos) { + GGML_ASSERT(n_pos <= tokens.size()); + size_t n_pos_rm = tokens.size() - n_pos; + // num of tokens to remove = n_tok_text + n_tok_img + // = (n_pos_rm - n_pos_img) + n_tok_img + size_t n_pos_img = 0; + size_t n_tok_img = 0; if (has_mtmd) { // we throw an error if we try to remove a token in the middle of an image // for ex. with input of 5 text tokens and 2 images: @@ -1188,26 +1193,32 @@ struct server_tokens { // n 1 2 3 4 5 6 7 8 9 10 // allowed to resize ^ ^ // disallowed to resize ^ ^ ^ - if (n > 0) { - llama_token last_token = tokens[n - 1]; + if (n_pos > 0) { + llama_token last_token = tokens[n_pos - 1]; // make sure we never remove tokens in the middle of an image if (last_token == LLAMA_TOKEN_NULL) { - find_chunk(n - 1); // will throw an error if the token is not begin-of-chunk + find_chunk(n_pos - 1); // will throw an error if the token is not begin-of-chunk } } // remove all image chunks that are not used anymore for (auto it = map_pos_to_image.begin(); it != map_pos_to_image.end(); ) { llama_pos pos = it->first; - if (pos >= (llama_pos)n) { + if (pos >= (llama_pos)n_pos) { auto img_tokens = mtmd_input_chunk_get_tokens_image(it->second.get()); - n_kv -= mtmd_image_tokens_get_n_tokens(img_tokens); + n_pos_img += mtmd_image_tokens_get_n_pos(img_tokens); + n_tok_img += mtmd_image_tokens_get_n_tokens(img_tokens); it = map_pos_to_image.erase(it); } else { ++it; } } } - tokens.resize(n); + n_tok -= (n_pos_rm - n_pos_img) + n_tok_img; + tokens.resize(n_pos); + } + + void rm_last(size_t n_pos) { + keep_first(tokens.size() - n_pos); } std::string detokenize(const llama_context * ctx, bool special) const { @@ -1221,10 +1232,9 @@ struct server_tokens { return common_detokenize(ctx, text_tokens, special); } - // returns pair of - std::pair get_common_prefix(const server_tokens & b) const { + // returns the first position where the tokens differ + llama_pos get_common_prefix(const server_tokens & b) const { size_t max_idx = std::min(tokens.size(), b.tokens.size()); - size_t n_tok = 0; for (size_t i = 0; i < max_idx; ++i) { auto & ai = tokens[i]; auto & bi = b.tokens[i]; @@ -1243,19 +1253,17 @@ struct server_tokens { if (ai_id == bi_id && a_pos == b_pos) { GGML_ASSERT(a_pos > 0 && "Invalid image token"); // should never happen i += a_pos - 1; // will be +1 by the for loop - n_tok += mtmd_image_tokens_get_n_tokens(a_img); continue; } else { - return {i, n_tok}; + return i; } } else if (ai == bi) { - n_tok++; continue; } else { - return {i, n_tok}; + return i; } } - return {max_idx, n_tok}; // all tokens are equal + return max_idx; // all tokens are equal } // make sure all text tokens are within the vocab range @@ -1289,7 +1297,7 @@ struct server_tokens { llama_pos n_past, int32_t seq_id, llama_pos & n_pos_out, - size_t & n_kv_tokens_out) { + size_t & n_tokens_out) { auto it = map_pos_to_image.find(n_past); if (it == map_pos_to_image.end()) { throw std::runtime_error("Chunk not found"); @@ -1312,7 +1320,7 @@ struct server_tokens { return result; } auto img_tokens = mtmd_input_chunk_get_tokens_image(it->second.get()); - n_kv_tokens_out = mtmd_image_tokens_get_n_tokens(img_tokens); + n_tokens_out = mtmd_image_tokens_get_n_tokens(img_tokens); n_pos_out = new_n_past; return 0; } From 3012326f0f8cb9f86a653c0ad9771b44210a71f3 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 17 May 2025 16:25:30 +0200 Subject: [PATCH 5/6] rm double check for out-of-context --- tools/server/server.cpp | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 34dbc1af2d553..4d7d3d6fed853 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -2262,12 +2262,14 @@ struct server_context { slot.has_next_token = true; } - // if context shifting is disabled, make sure that we don't run out of context - if (!params_base.ctx_shift && slot.n_past + 1 >= slot.n_ctx) { + // if context shift is disabled, we stop when it reaches the context limit + if (!params_base.ctx_shift && slot.n_cache_tokens() + 1 >= slot.n_ctx) { + slot.truncated = true; slot.stop = STOP_TYPE_LIMIT; slot.has_next_token = false; - SLT_DBG(slot, "stopped due to running out of context, n_past = %d, n_ctx = %d\n", slot.n_past, slot.n_ctx); + SLT_DBG(slot, "stopped due to running out of context capacity, n_cache_tokens = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n", + slot.n_cache_tokens(), slot.n_prompt_tokens(), slot.n_decoded, slot.n_ctx); } // check the limits @@ -2327,16 +2329,6 @@ struct server_context { } } - // if context shift is disabled, we stop when it reaches the context limit - if (!params_base.ctx_shift && slot.n_cache_tokens() >= slot.n_ctx) { - slot.truncated = true; - slot.stop = STOP_TYPE_LIMIT; - slot.has_next_token = false; - - SLT_DBG(slot, "stopped due to running out of context capacity, n_cache_tokens = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n", - slot.n_cache_tokens(), slot.n_prompt_tokens(), slot.n_decoded, slot.n_ctx); - } - if (llama_vocab_is_eog(vocab, result.tok)) { slot.stop = STOP_TYPE_EOS; slot.has_next_token = false; From 987955f06b83c2cab579c7043816a888d1a61a04 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Mon, 19 May 2025 15:32:25 +0200 Subject: [PATCH 6/6] Apply suggestions from code review Co-authored-by: Georgi Gerganov --- tools/server/server.cpp | 4 ++-- tools/server/utils.hpp | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 4d7d3d6fed853..47f2b2026121c 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1308,7 +1308,7 @@ struct server_slot { int64_t t_start_process_prompt; int64_t t_start_generation; - size_t n_prompt_processing = 0; // number of decoded prompt tokens (may be less than prompt_tokens.n_kv_tokens(), in case we are using cache) + size_t n_prompt_processing = 0; // number of decoded prompt tokens (may be less than prompt_tokens.n_tokens(), in case we are using cache) double t_prompt_processing; // ms double t_token_generation; // ms @@ -2476,7 +2476,7 @@ struct server_context { res->truncated = slot.truncated; res->n_decoded = slot.n_decoded; res->n_prompt_tokens = slot.n_prompt_tokens(); - res->n_tokens_cached = slot.n_past; + res->n_tokens_cached = slot.n_cache_tokens(); res->has_new_line = slot.has_new_line; res->stopping_word = slot.stopping_word; res->stop = slot.stop; diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index c72b93eff5fa1..f3919f3421eb3 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -1045,7 +1045,6 @@ struct server_tokens { // it can include LLAMA_TOKEN_NULL, which is used to indicate a token that is not a text token // a mtmd_input_chunk can occupy multiple tokens, one llama_token per **position** // important: for models using mrope, an image can contain multiple tokens but will use only one **position** - // in otherwords, tokens.size() == n_past llama_tokens tokens; // for ex. with input of 5 text tokens and 2 images: