diff --git a/tools/server/README.md b/tools/server/README.md index e29511cb1b457..186ab2727c3e3 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -428,6 +428,8 @@ By default, this value is set to `0`, meaning no tokens are kept. Use `-1` to re `stream`: Allows receiving each predicted token in real-time instead of waiting for the completion to finish (uses a different response format). To enable this, set to `true`. +`include_prompt_progress`: When `stream` is enabled, this option allows receiving prompt processing progress information before the text generation begins. The progress responses contain a `prompt_processing` field with details about the number of tokens processed and overall progress. This is useful for long prompts where users want to see evaluation progress instead of waiting silently. Default: `false` (only applies when `stream` is `true`). + `stop`: Specify a JSON array of stopping strings. These words will not be included in the completion, so make sure to add them to the prompt for the next iteration. Default: `[]` diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 0afe213af1e47..ea99a10589714 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -109,9 +109,10 @@ static bool server_task_type_need_logits(server_task_type task_type) { } struct slot_params { - bool stream = true; - bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt - bool return_tokens = false; + bool stream = true; + bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt + bool return_tokens = false; + bool include_prompt_progress = false; // include prompt processing progress in streaming responses int32_t n_keep = 0; // number of tokens to keep from initial prompt int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half @@ -258,9 +259,10 @@ struct server_task { params.verbose = params_base.verbosity > 9; params.timings_per_token = json_value(data, "timings_per_token", false); - params.stream = json_value(data, "stream", false); - params.cache_prompt = json_value(data, "cache_prompt", true); - params.return_tokens = json_value(data, "return_tokens", false); + params.stream = json_value(data, "stream", false); + params.cache_prompt = json_value(data, "cache_prompt", true); + params.return_tokens = json_value(data, "return_tokens", false); + params.include_prompt_progress = json_value(data, "include_prompt_progress", false); params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict)); params.n_indent = json_value(data, "n_indent", defaults.n_indent); params.n_keep = json_value(data, "n_keep", defaults.n_keep); @@ -898,6 +900,12 @@ struct server_task_result_cmpl_partial : server_task_result { completion_token_output prob_output; result_timings timings; + // Progress fields (only populated when is_progress_response is true) + bool is_progress_response = false; + int32_t n_past = 0; + int32_t n_prompt_tokens_processed = 0; + float progress = 0.0f; + // OAI-compat fields bool verbose = false; oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; @@ -944,6 +952,15 @@ struct server_task_result_cmpl_partial : server_task_result { if (!prob_output.probs.empty()) { res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs); } + // include prompt processing progress if this is a progress response + if (is_progress_response) { + res["prompt_processing"] = json { + {"n_past", n_past}, + {"n_prompt_tokens", n_prompt_tokens}, + {"n_prompt_tokens_processed", n_prompt_tokens_processed}, + {"progress", progress}, + }; + } return res; } @@ -2515,6 +2532,64 @@ struct server_context { queue_results.send(std::move(res)); } + void send_progress_response(server_slot & slot) { + // Only send progress if explicitly requested and streaming is enabled + if (!slot.params.include_prompt_progress || !slot.params.stream) { + return; + } + + // Calculate current progress percentage + float current_progress = slot.n_prompt_tokens > 0 ? + (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens : 0.0f; + + // Send progress updates at regular intervals (every 10% or significant changes) + static float last_progress = -1.0f; + static int last_slot_id = -1; + + // Reset for new slot + if (slot.id_task != last_slot_id) { + last_progress = -1.0f; + last_slot_id = slot.id_task; + } + + // Send progress if: + // 1. This is the first progress update (last_progress == -1) + // 2. Progress increased by at least 1% or processed at least 10 tokens + // 3. We've completed processing (current_progress >= 1.0) + bool should_send = (last_progress < 0.0f) || + (current_progress - last_progress >= 0.01f) || + (current_progress >= 1.0f && last_progress < 1.0f); + + if (!should_send) { + return; + } + + last_progress = current_progress; + + auto res = std::make_unique(); + + res->id = slot.id_task; + res->index = slot.index; + res->content = ""; // empty content for progress responses + res->tokens = {}; // empty tokens for progress responses + + res->n_decoded = 0; // no tokens decoded yet during prompt processing + res->n_prompt_tokens = slot.n_prompt_tokens; + + // Progress-specific fields + res->is_progress_response = true; + res->n_past = slot.n_past; + res->n_prompt_tokens_processed = slot.n_prompt_tokens_processed; + res->progress = current_progress; + + res->verbose = slot.params.verbose; + res->oaicompat = slot.params.oaicompat; + res->oaicompat_model = slot.params.oaicompat_model; + res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; + + queue_results.send(std::move(res)); + } + void send_final_response(server_slot & slot) { auto res = std::make_unique(); res->id = slot.id_task; @@ -3334,12 +3409,18 @@ struct server_context { slot.n_prompt_tokens_processed++; slot.n_past++; + + // Send incremental progress updates during token processing + send_progress_response(slot); } // SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str()); SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); + // Send progress response if requested + send_progress_response(slot); + // entire prompt has been processed if (slot.n_past == slot.n_prompt_tokens) { slot.state = SLOT_STATE_DONE_PROMPT;