Skip to content

feat: Add optional prompt processing progress streaming #14731

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tools/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,8 @@ By default, this value is set to `0`, meaning no tokens are kept. Use `-1` to re

`stream`: Allows receiving each predicted token in real-time instead of waiting for the completion to finish (uses a different response format). To enable this, set to `true`.

`include_prompt_progress`: When `stream` is enabled, this option allows receiving prompt processing progress information before the text generation begins. The progress responses contain a `prompt_processing` field with details about the number of tokens processed and overall progress. This is useful for long prompts where users want to see evaluation progress instead of waiting silently. Default: `false` (only applies when `stream` is `true`).

`stop`: Specify a JSON array of stopping strings.
These words will not be included in the completion, so make sure to add them to the prompt for the next iteration. Default: `[]`

Expand Down
93 changes: 87 additions & 6 deletions tools/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,10 @@ static bool server_task_type_need_logits(server_task_type task_type) {
}

struct slot_params {
bool stream = true;
bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
bool return_tokens = false;
bool stream = true;
bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
bool return_tokens = false;
bool include_prompt_progress = false; // include prompt processing progress in streaming responses

int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
Expand Down Expand Up @@ -258,9 +259,10 @@ struct server_task {
params.verbose = params_base.verbosity > 9;
params.timings_per_token = json_value(data, "timings_per_token", false);

params.stream = json_value(data, "stream", false);
params.cache_prompt = json_value(data, "cache_prompt", true);
params.return_tokens = json_value(data, "return_tokens", false);
params.stream = json_value(data, "stream", false);
params.cache_prompt = json_value(data, "cache_prompt", true);
params.return_tokens = json_value(data, "return_tokens", false);
params.include_prompt_progress = json_value(data, "include_prompt_progress", false);
params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
params.n_indent = json_value(data, "n_indent", defaults.n_indent);
params.n_keep = json_value(data, "n_keep", defaults.n_keep);
Expand Down Expand Up @@ -898,6 +900,12 @@ struct server_task_result_cmpl_partial : server_task_result {
completion_token_output prob_output;
result_timings timings;

// Progress fields (only populated when is_progress_response is true)
bool is_progress_response = false;
int32_t n_past = 0;
int32_t n_prompt_tokens_processed = 0;
float progress = 0.0f;

// OAI-compat fields
bool verbose = false;
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
Expand Down Expand Up @@ -944,6 +952,15 @@ struct server_task_result_cmpl_partial : server_task_result {
if (!prob_output.probs.empty()) {
res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs);
}
// include prompt processing progress if this is a progress response
if (is_progress_response) {
res["prompt_processing"] = json {
{"n_past", n_past},
{"n_prompt_tokens", n_prompt_tokens},
{"n_prompt_tokens_processed", n_prompt_tokens_processed},
{"progress", progress},
};
}
return res;
}

Expand Down Expand Up @@ -2515,6 +2532,64 @@ struct server_context {
queue_results.send(std::move(res));
}

void send_progress_response(server_slot & slot) {
// Only send progress if explicitly requested and streaming is enabled
if (!slot.params.include_prompt_progress || !slot.params.stream) {
return;
}

// Calculate current progress percentage
float current_progress = slot.n_prompt_tokens > 0 ?
(float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens : 0.0f;

// Send progress updates at regular intervals (every 10% or significant changes)
static float last_progress = -1.0f;
static int last_slot_id = -1;

// Reset for new slot
if (slot.id_task != last_slot_id) {
last_progress = -1.0f;
last_slot_id = slot.id_task;
}

// Send progress if:
// 1. This is the first progress update (last_progress == -1)
// 2. Progress increased by at least 1% or processed at least 10 tokens
// 3. We've completed processing (current_progress >= 1.0)
bool should_send = (last_progress < 0.0f) ||
(current_progress - last_progress >= 0.01f) ||
(current_progress >= 1.0f && last_progress < 1.0f);

if (!should_send) {
return;
}

last_progress = current_progress;

auto res = std::make_unique<server_task_result_cmpl_partial>();

res->id = slot.id_task;
res->index = slot.index;
res->content = ""; // empty content for progress responses
res->tokens = {}; // empty tokens for progress responses

res->n_decoded = 0; // no tokens decoded yet during prompt processing
res->n_prompt_tokens = slot.n_prompt_tokens;

// Progress-specific fields
res->is_progress_response = true;
res->n_past = slot.n_past;
res->n_prompt_tokens_processed = slot.n_prompt_tokens_processed;
res->progress = current_progress;

res->verbose = slot.params.verbose;
res->oaicompat = slot.params.oaicompat;
res->oaicompat_model = slot.params.oaicompat_model;
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;

queue_results.send(std::move(res));
}

void send_final_response(server_slot & slot) {
auto res = std::make_unique<server_task_result_cmpl_final>();
res->id = slot.id_task;
Expand Down Expand Up @@ -3334,12 +3409,18 @@ struct server_context {

slot.n_prompt_tokens_processed++;
slot.n_past++;

// Send incremental progress updates during token processing
send_progress_response(slot);
}

// SLT_INF(slot, "new cache_tokens: %s\n", slot.cache_tokens.str().c_str());

SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);

// Send progress response if requested
send_progress_response(slot);

// entire prompt has been processed
if (slot.n_past == slot.n_prompt_tokens) {
slot.state = SLOT_STATE_DONE_PROMPT;
Expand Down