diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 0afe213af1e47..d83e4df4a865d 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1034,6 +1034,112 @@ struct server_task_result_cmpl_partial : server_task_result { } }; +struct server_task_result_cmpl_progress : server_task_result { + int index = 0; + + int32_t n_past; + int32_t n_prompt_tokens; + int32_t n_prompt_tokens_processed; + float progress; + + // OAI-compat fields + bool verbose = false; + oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE; + std::string oaicompat_model; + std::string oaicompat_cmpl_id; + + virtual int get_index() override { + return index; + } + + virtual bool is_stop() override { + return false; // progress responses are not considered stop + } + + virtual json to_json() override { + switch (oaicompat) { + case OAICOMPAT_TYPE_NONE: + return to_json_non_oaicompat(); + case OAICOMPAT_TYPE_COMPLETION: + return to_json_oaicompat(); + case OAICOMPAT_TYPE_CHAT: + return to_json_oaicompat_chat(); + default: + GGML_ASSERT(false && "Invalid oaicompat_type"); + } + } + + json to_json_non_oaicompat() { + return json { + {"index", index}, + {"stop", false}, + {"id_slot", id_slot}, + {"prompt_processing", json { + {"n_past", n_past}, + {"n_prompt_tokens", n_prompt_tokens}, + {"n_prompt_tokens_processed", n_prompt_tokens_processed}, + {"progress", progress}, + }}, + }; + } + + json to_json_oaicompat() { + std::time_t t = std::time(0); + json res = json { + {"choices", json::array({ + json{ + {"text", ""}, + {"index", index}, + {"logprobs", nullptr}, + {"finish_reason", nullptr}, + } + })}, + {"created", t}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "text_completion"}, + {"id", oaicompat_cmpl_id}, + {"prompt_processing", json { + {"n_past", n_past}, + {"n_prompt_tokens", n_prompt_tokens}, + {"n_prompt_tokens_processed", n_prompt_tokens_processed}, + {"progress", progress}, + }}, + }; + + // extra fields for debugging purposes + if (verbose) { + res["__verbose"] = to_json_non_oaicompat(); + } + + return res; + } + + json to_json_oaicompat_chat() { + std::time_t t = std::time(0); + return json { + {"choices", json::array({ + json { + {"finish_reason", nullptr}, + {"index", 0}, + {"delta", json::object()}, + }, + })}, + {"created", t}, + {"id", oaicompat_cmpl_id}, + {"model", oaicompat_model}, + {"system_fingerprint", build_info}, + {"object", "chat.completion.chunk"}, + {"prompt_processing", json { + {"n_past", n_past}, + {"n_prompt_tokens", n_prompt_tokens}, + {"n_prompt_tokens_processed", n_prompt_tokens_processed}, + {"progress", progress}, + }}, + }; + } +}; + struct server_task_result_embd : server_task_result { int index = 0; std::vector> embedding; @@ -2515,6 +2621,31 @@ struct server_context { queue_results.send(std::move(res)); } + void send_progress_response(server_slot & slot) { + // Only send progress updates for streaming requests + if (!slot.params.stream) { + return; + } + + auto res = std::make_unique(); + + res->id = slot.id_task; + res->id_slot = slot.id; + res->index = slot.index; + + res->n_past = slot.n_past; + res->n_prompt_tokens = slot.n_prompt_tokens; + res->n_prompt_tokens_processed = slot.n_prompt_tokens_processed; + res->progress = (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens; + + res->verbose = slot.params.verbose; + res->oaicompat = slot.params.oaicompat; + res->oaicompat_model = slot.params.oaicompat_model; + res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id; + + queue_results.send(std::move(res)); + } + void send_final_response(server_slot & slot) { auto res = std::make_unique(); res->id = slot.id_task; @@ -2725,6 +2856,7 @@ struct server_context { GGML_ASSERT( dynamic_cast(result.get()) != nullptr || dynamic_cast(result.get()) != nullptr + || dynamic_cast(result.get()) != nullptr ); if (!result_handler(result)) { cancel_tasks(id_tasks); @@ -3340,6 +3472,9 @@ struct server_context { SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); + // send progress update to client + send_progress_response(slot); + // entire prompt has been processed if (slot.n_past == slot.n_prompt_tokens) { slot.state = SLOT_STATE_DONE_PROMPT;