From 1da67a395cd683469e0397d1496618bcb2725cc2 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 29 Sep 2024 01:08:16 +0100 Subject: [PATCH 01/19] `server`: support cancelling non-streamed requests --- examples/server/server.cpp | 336 +++++++++++++++++++++---------------- 1 file changed, 188 insertions(+), 148 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index f343cc252f89a..1ce4d7e26beaa 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -33,6 +33,7 @@ #include #include +#include #include #include #include @@ -104,6 +105,7 @@ struct server_task { json data; server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL; + std::function is_alive; // utility function static std::unordered_set get_list_id(const std::vector & tasks) { @@ -173,7 +175,7 @@ struct server_slot { std::vector generated_token_probs; server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL; - + std::function is_alive; bool has_next_token = true; bool truncated = false; bool stopped_eos = false; @@ -876,6 +878,7 @@ struct server_context { // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them) auto default_sparams = params.sparams; const auto & data = task.data; + slot.is_alive = task.is_alive; if (data.count("__oaicompat") != 0) { slot.oaicompat = true; @@ -1117,6 +1120,13 @@ struct server_context { } bool process_token(completion_token_output & result, server_slot & slot) { + if (!slot.is_alive()) { + slot.truncated = false; + slot.has_next_token = false; + + SLT_DBG(slot, "stopped by client disconnection, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict); + return slot.has_next_token; + } // remember which tokens were sampled - used for repetition penalties during sampling const std::string token_str = llama_token_to_piece(ctx, result.tok, params.special); slot.sampled = result.tok; @@ -1461,13 +1471,14 @@ struct server_context { // Functions to create new task(s) and receive result(s) // - std::vector create_tasks_cmpl(json data, server_task_cmpl_type cmpl_type) { + std::vector create_tasks_cmpl(json data, server_task_cmpl_type cmpl_type, const std::function & is_alive) { std::vector tasks; auto create_task = [&](json & task_data, bool replace_prompt, json prompt) { server_task task; task.id = queue_tasks.get_new_id(); task.cmpl_type = cmpl_type; task.type = SERVER_TASK_TYPE_COMPLETION; + task.is_alive = is_alive; if (replace_prompt) { task.data = task_data; task.data["prompt"] = std::move(prompt); @@ -2412,6 +2423,60 @@ inline void signal_handler(int signal) { shutdown_handler(signal); } +static void handle_tasks( + bool stream, + httplib::Response & res, + server_context & ctx_server, + const std::function(const std::function &)> & create_tasks, + const std::function &, httplib::DataSink & sink, const std::function &)> & payload) +{ + struct State { + std::unordered_set task_ids; + }; + auto state = std::make_shared(); + httplib::ContentProviderResourceReleaser resource_releaser = [state, &ctx_server](bool success) { + if (!success && state) { + ctx_server.cancel_tasks(state->task_ids); + } + }; + if (!stream) { + res.set_content_provider(MIMETYPE_JSON, [create_tasks, payload, state, &ctx_server](size_t, httplib::DataSink & sink) { + auto is_alive = [&sink]() { return sink.is_writable(); }; + state->task_ids = create_tasks(is_alive); + payload(state->task_ids, sink, is_alive); + ctx_server.queue_results.remove_waiting_task_ids(state->task_ids); + return false; + }, resource_releaser); + } else { + res.set_chunked_content_provider("text/event-stream", [create_tasks, payload, state, &ctx_server](size_t, httplib::DataSink & sink) { + auto is_alive = [&sink]() { return sink.is_writable(); }; + state->task_ids = create_tasks(is_alive); + payload(state->task_ids, sink, is_alive); + ctx_server.queue_results.remove_waiting_task_ids(state->task_ids); + return false; + }, resource_releaser); + } +} + +static void respond(httplib::Response & res, httplib::DataSink * sink, int status, const json & response) { + res.status = status; + if (sink) { + res.set_header("Content-Type", MIMETYPE_JSON); + auto out = response.dump(-1, ' ', false, json::error_handler_t::replace); + sink->write(out.c_str(), out.size()); + } else { + res.set_content(response.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON); + } +} + +static void res_error(httplib::Response & res, httplib::DataSink * sink, const json & error_data) { + respond(res, sink, 200, {{"error", error_data}}); +} + +static void res_ok(httplib::Response & res, httplib::DataSink * sink, const json & data) { + respond(res, sink, 200, data); +} + int main(int argc, char ** argv) { // own arguments required by this example gpt_params params; @@ -2479,18 +2544,7 @@ int main(int argc, char ** argv) { svr->set_logger(log_server_request); - auto res_error = [](httplib::Response & res, const json & error_data) { - json final_response {{"error", error_data}}; - res.set_content(final_response.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON); - res.status = json_value(error_data, "code", 500); - }; - - auto res_ok = [](httplib::Response & res, const json & data) { - res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON); - res.status = 200; - }; - - svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) { + svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) { std::string message; try { std::rethrow_exception(ep); @@ -2502,12 +2556,12 @@ int main(int argc, char ** argv) { json formatted_error = format_error_response(message, ERROR_TYPE_SERVER); LOG_WRN("got exception: %s\n", formatted_error.dump().c_str()); - res_error(res, formatted_error); + res_error(res, /* sink= */ nullptr, formatted_error); }); - svr->set_error_handler([&res_error](const httplib::Request &, httplib::Response & res) { + svr->set_error_handler([](const httplib::Request &, httplib::Response & res) { if (res.status == 404) { - res_error(res, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND)); + res_error(res, /* sink= */ nullptr, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND)); } // for other error codes, we skip processing here because it's already done by res_error() }); @@ -2535,7 +2589,7 @@ int main(int argc, char ** argv) { // Middlewares // - auto middleware_validate_api_key = [¶ms, &res_error](const httplib::Request & req, httplib::Response & res) { + auto middleware_validate_api_key = [¶ms](const httplib::Request & req, httplib::Response & res) { // TODO: should we apply API key to all endpoints, including "/health" and "/models"? static const std::unordered_set protected_endpoints = { "/props", @@ -2574,14 +2628,14 @@ int main(int argc, char ** argv) { } // API key is invalid or not provided - res_error(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION)); + res_error(res, /* sink= */ nullptr, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION)); LOG_WRN("Unauthorized: Invalid API Key\n"); return false; }; - auto middleware_server_state = [&res_error, &state](const httplib::Request & req, httplib::Response & res) { + auto middleware_server_state = [&state](const httplib::Request & req, httplib::Response & res) { server_state current_state = state.load(); if (current_state == SERVER_STATE_LOADING_MODEL) { auto tmp = string_split(req.path, '.'); @@ -2589,7 +2643,7 @@ int main(int argc, char ** argv) { res.set_content(reinterpret_cast(loading_html), loading_html_len, "text/html; charset=utf-8"); res.status = 503; } else { - res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE)); + res_error(res, /* sink= */ nullptr, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE)); } return false; } @@ -2615,12 +2669,12 @@ int main(int argc, char ** argv) { const auto handle_health = [&](const httplib::Request &, httplib::Response & res) { // error and loading states are handled by middleware json health = {{"status", "ok"}}; - res_ok(res, health); + res_ok(res, /* sink= */ nullptr, health); }; const auto handle_slots = [&](const httplib::Request & req, httplib::Response & res) { if (!params.endpoint_slots) { - res_error(res, format_error_response("This server does not support slots endpoint. Start it without `--no-slots`", ERROR_TYPE_NOT_SUPPORTED)); + res_error(res, /* sink= */ nullptr, format_error_response("This server does not support slots endpoint. Start it without `--no-slots`", ERROR_TYPE_NOT_SUPPORTED)); return; } @@ -2640,17 +2694,17 @@ int main(int argc, char ** argv) { const int n_idle_slots = result.data.at("idle"); if (req.has_param("fail_on_no_slot")) { if (n_idle_slots == 0) { - res_error(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE)); + res_error(res, /* sink= */ nullptr, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE)); return; } } - res_ok(res, result.data.at("slots")); + res_ok(res, /* sink= */ nullptr, result.data.at("slots")); }; const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) { if (!params.endpoint_metrics) { - res_error(res, format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED)); + res_error(res, /* sink= */ nullptr, format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED)); return; } @@ -2759,11 +2813,11 @@ int main(int argc, char ** argv) { res.status = 200; // HTTP OK }; - const auto handle_slots_save = [&ctx_server, &res_error, &res_ok, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { + const auto handle_slots_save = [&ctx_server, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { json request_data = json::parse(req.body); std::string filename = request_data.at("filename"); if (!fs_validate_filename(filename)) { - res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); + res_error(res, /* sink= */ nullptr, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); return; } std::string filepath = params.slot_save_path + filename; @@ -2783,17 +2837,17 @@ int main(int argc, char ** argv) { ctx_server.queue_results.remove_waiting_task_id(id_task); if (result.error) { - res_error(res, result.data); + res_error(res, /* sink= */ nullptr, result.data); } else { - res_ok(res, result.data); + res_ok(res, /* sink= */ nullptr, result.data); } }; - const auto handle_slots_restore = [&ctx_server, &res_error, &res_ok, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { + const auto handle_slots_restore = [&ctx_server, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { json request_data = json::parse(req.body); std::string filename = request_data.at("filename"); if (!fs_validate_filename(filename)) { - res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); + res_error(res, /* sink= */ nullptr, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); return; } std::string filepath = params.slot_save_path + filename; @@ -2813,13 +2867,13 @@ int main(int argc, char ** argv) { ctx_server.queue_results.remove_waiting_task_id(id_task); if (result.error) { - res_error(res, result.data); + res_error(res, /* sink= */ nullptr, result.data); } else { - res_ok(res, result.data); + res_ok(res, /* sink= */ nullptr, result.data); } }; - const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) { + const auto handle_slots_erase = [&ctx_server](const httplib::Request & /* req */, httplib::Response & res, int id_slot) { server_task task; task.type = SERVER_TASK_TYPE_SLOT_ERASE; task.data = { @@ -2833,15 +2887,15 @@ int main(int argc, char ** argv) { ctx_server.queue_results.remove_waiting_task_id(id_task); if (result.error) { - res_error(res, result.data); + res_error(res, /* sink= */ nullptr, result.data); } else { - res_ok(res, result.data); + res_ok(res, /* sink= */ nullptr, result.data); } }; - const auto handle_slots_action = [¶ms, &res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) { + const auto handle_slots_action = [¶ms, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) { if (params.slot_save_path.empty()) { - res_error(res, format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED)); + res_error(res, /* sink= */ nullptr, format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED)); return; } @@ -2851,7 +2905,7 @@ int main(int argc, char ** argv) { try { id_slot = std::stoi(id_slot_str); } catch (const std::exception &) { - res_error(res, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST)); + res_error(res, /* sink= */ nullptr, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST)); return; } @@ -2864,11 +2918,11 @@ int main(int argc, char ** argv) { } else if (action == "erase") { handle_slots_erase(req, res, id_slot); } else { - res_error(res, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST)); + res_error(res, /* sink= */ nullptr, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST)); } }; - const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { + const auto handle_props = [&ctx_server](const httplib::Request &, httplib::Response & res) { std::string template_key = "tokenizer.chat_template", curr_tmpl; int32_t tlen = llama_model_meta_val_str(ctx_server.model, template_key.c_str(), nullptr, 0); if (tlen > 0) { @@ -2884,57 +2938,49 @@ int main(int argc, char ** argv) { { "chat_template", curr_tmpl.c_str() }, }; - res_ok(res, data); + res_ok(res, /* sink= */ nullptr, data); }; - const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) { + const auto handle_completions_generic = [&ctx_server](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) { if (ctx_server.params.embedding || ctx_server.params.reranking) { - res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); + res_error(res, /* sink= */ nullptr, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); return; } - std::vector tasks = ctx_server.create_tasks_cmpl(data, cmpl_type); - ctx_server.queue_results.add_waiting_tasks(tasks); - ctx_server.queue_tasks.post(tasks); - bool stream = json_value(data, "stream", false); - const auto task_ids = server_task::get_list_id(tasks); - if (!stream) { - ctx_server.receive_cmpl_results(task_ids, [&](std::vector & results) { - if (results.size() == 1) { - // single result - res_ok(res, results[0].data); - } else { - // multiple results (multitask) - json arr = json::array(); - for (const auto & res : results) { - arr.push_back(res.data); - } - res_ok(res, arr); - } - }, [&](const json & error_data) { - res_error(res, error_data); - }); + handle_tasks(stream, res, ctx_server, [data, cmpl_type, &ctx_server](const std::function & is_alive) { + std::vector tasks = ctx_server.create_tasks_cmpl(data, cmpl_type, is_alive); + ctx_server.queue_results.add_waiting_tasks(tasks); + ctx_server.queue_tasks.post(tasks); - ctx_server.queue_results.remove_waiting_task_ids(task_ids); - } else { - const auto chunked_content_provider = [task_ids, &ctx_server](size_t, httplib::DataSink & sink) { - ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool { + return server_task::get_list_id(tasks); + }, [stream, &res, &ctx_server](const std::unordered_set & task_ids, httplib::DataSink & sink, const std::function &) { + if (!stream) { + ctx_server.receive_cmpl_results(task_ids, [&res, &sink](std::vector & results) { + if (results.size() == 1) { + // single result + res_ok(res, &sink, results[0].data); + } else { + // multiple results (multitask) + json arr = json::array(); + for (const auto & res : results) { + arr.push_back(res.data); + } + res_ok(res, &sink, arr); + } + }, [&res, &sink](json error_data) { + res_error(res, &sink, error_data); + }); + } else { + ctx_server.receive_cmpl_results_stream(task_ids, [&sink](server_task_result result) -> bool { return server_sent_event(sink, "data", result.data); - }, [&](const json & error_data) { + }, [&sink](const json & error_data) { server_sent_event(sink, "error", error_data); }); sink.done(); - return false; - }; - - auto on_complete = [task_ids, &ctx_server] (bool) { - ctx_server.queue_results.remove_waiting_task_ids(task_ids); - }; - - res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); - } + } + }); }; const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) { @@ -2948,35 +2994,37 @@ int main(int argc, char ** argv) { }; // TODO: maybe merge this function with "handle_completions_generic" - const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) { + const auto handle_chat_completions = [&ctx_server, ¶ms, verbose](const httplib::Request & req, httplib::Response & res) { if (ctx_server.params.embedding || ctx_server.params.reranking) { - res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); + res_error(res, /* sink= */ nullptr, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); return; } json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); - std::vector tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL); - ctx_server.queue_results.add_waiting_tasks(tasks); - ctx_server.queue_tasks.post(tasks); - bool stream = json_value(data, "stream", false); - const auto task_ids = server_task::get_list_id(tasks); - const auto completion_id = gen_chatcmplid(); - - if (!stream) { - ctx_server.receive_cmpl_results(task_ids, [&](const std::vector & results) { - // multitask is never support in chat completion, there is only one result - json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose); - res_ok(res, result_oai); - }, [&](const json & error_data) { - res_error(res, error_data); - }); + + handle_tasks(stream, res, ctx_server, [data, &ctx_server](const std::function & is_alive) { + std::vector tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL, is_alive); + ctx_server.queue_results.add_waiting_tasks(tasks); + ctx_server.queue_tasks.post(tasks); - ctx_server.queue_results.remove_waiting_task_ids(task_ids); - } else { - const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) { - ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool { + return server_task::get_list_id(tasks); + }, [data, verbose, stream, &res, &ctx_server](const std::unordered_set & task_ids, httplib::DataSink & sink, const std::function & is_alive) { + const auto completion_id = gen_chatcmplid(); + if (!stream) { + ctx_server.receive_cmpl_results(task_ids, [completion_id, data, verbose, &sink, &res](std::vector & results) { + // multitask is never support in chat completion, there is only one result + json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose); + res_ok(res, &sink, result_oai); + }, [&res, &sink](json error_data) { + res_error(res, &sink, error_data); + }); + } else { + ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result result) -> bool { + if (!is_alive()) { + return false; // connection is closed + } std::vector result_array = format_partial_response_oaicompat(result.data, completion_id); for (auto & event_data : result_array) { if (event_data.empty()) { @@ -2993,15 +3041,8 @@ int main(int argc, char ** argv) { static const std::string ev_done = "data: [DONE]\n\n"; sink.write(ev_done.data(), ev_done.size()); sink.done(); - return true; - }; - - auto on_complete = [task_ids, &ctx_server] (bool) { - ctx_server.queue_results.remove_waiting_task_ids(task_ids); - }; - - res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); - } + } + }); }; const auto handle_models = [¶ms, &ctx_server](const httplib::Request &, httplib::Response & res) { @@ -3021,7 +3062,7 @@ int main(int argc, char ** argv) { res.set_content(models.dump(), MIMETYPE_JSON); }; - const auto handle_tokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) { + const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { const json body = json::parse(req.body); json tokens_response = json::array(); @@ -3057,10 +3098,10 @@ int main(int argc, char ** argv) { } const json data = format_tokenizer_response(tokens_response); - res_ok(res, data); + res_ok(res, /* sink= */ nullptr, data); }; - const auto handle_detokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) { + const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { const json body = json::parse(req.body); std::string content; @@ -3070,13 +3111,13 @@ int main(int argc, char ** argv) { } const json data = format_detokenized_response(content); - res_ok(res, data); + res_ok(res, /* sink= */ nullptr, data); }; - const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { + const auto handle_embeddings = [&ctx_server](const httplib::Request & req, httplib::Response & res) { // TODO: somehow clean up this checks in the future if (!ctx_server.params.embedding || ctx_server.params.reranking) { - res_error(res, format_error_response("This server does not support embeddings. Start it with `--embeddings` and without `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); + res_error(res, /* sink= */ nullptr, format_error_response("This server does not support embeddings. Start it with `--embeddings` and without `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); return; } const json body = json::parse(req.body); @@ -3091,47 +3132,46 @@ int main(int argc, char ** argv) { // with "content", we only support single prompt prompt = std::vector{body.at("content")}; } else { - res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); + res_error(res, /* sink= */ nullptr, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); return; } - // create and queue the task - json responses = json::array(); - bool error = false; - { - std::vector tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING); + + handle_tasks(false, res, ctx_server, [prompt, &ctx_server](const std::function & is_alive) { + std::vector tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING, is_alive); ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_tasks.post(tasks); - // get the result - std::unordered_set task_ids = server_task::get_list_id(tasks); + return server_task::get_list_id(tasks); + }, [is_openai, &ctx_server, &res, body](const std::unordered_set & task_ids, httplib::DataSink & sink, const std::function &) { + bool error = false; + json responses = json::array(); - ctx_server.receive_cmpl_results(task_ids, [&](std::vector & results) { + ctx_server.receive_cmpl_results(task_ids, [&responses](std::vector & results) { for (const auto & res : results) { responses.push_back(res.data); } - }, [&](const json & error_data) { - res_error(res, error_data); + }, [&res, &error](json error_data) { + res_error(res, /* sink= */ nullptr, error_data); error = true; }); - ctx_server.queue_results.remove_waiting_task_ids(task_ids); - } - - if (error) { - return; - } + if (error) { + return; + } - // write JSON response - json root = is_openai - ? format_embeddings_response_oaicompat(body, responses) - : responses[0]; - res_ok(res, root); + // write JSON response + json root = is_openai + ? format_embeddings_response_oaicompat(body, responses) + : responses[0]; + + res_ok(res, &sink, root); + }); }; - const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { + const auto handle_rerank = [&ctx_server](const httplib::Request & req, httplib::Response & res) { if (!ctx_server.params.reranking) { - res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); + res_error(res, /* sink= */ nullptr, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); return; } const json body = json::parse(req.body); @@ -3149,17 +3189,17 @@ int main(int argc, char ** argv) { if (body.count("query") == 1) { query = body.at("query"); if (!query.is_string()) { - res_error(res, format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST)); + res_error(res, /* sink= */ nullptr, format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST)); return; } } else { - res_error(res, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST)); + res_error(res, /* sink= */ nullptr, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST)); return; } std::vector documents = json_value(body, "documents", std::vector()); if (documents.empty()) { - res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST)); + res_error(res, /* sink= */ nullptr, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST)); return; } @@ -3176,7 +3216,7 @@ int main(int argc, char ** argv) { json responses = json::array(); bool error = false; { - std::vector tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK); + std::vector tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK, []() { return true; }); ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_tasks.post(tasks); @@ -3188,7 +3228,7 @@ int main(int argc, char ** argv) { responses.push_back(res.data); } }, [&](const json & error_data) { - res_error(res, error_data); + res_error(res, /* sink= */ nullptr, error_data); error = true; }); } @@ -3199,7 +3239,7 @@ int main(int argc, char ** argv) { // write JSON response json root = format_response_rerank(body, responses); - res_ok(res, root); + res_ok(res, /* sink= */ nullptr, root); }; const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) { @@ -3212,7 +3252,7 @@ int main(int argc, char ** argv) { {"scale", lora.scale}, }); } - res_ok(res, result); + res_ok(res, /* sink= */ nullptr, result); res.status = 200; // HTTP OK }; @@ -3244,7 +3284,7 @@ int main(int argc, char ** argv) { server_task_result result = ctx_server.queue_results.recv(id_task); ctx_server.queue_results.remove_waiting_task_id(id_task); - res_ok(res, result.data); + res_ok(res, /* sink= */ nullptr, result.data); res.status = 200; // HTTP OK }; From 4dcb3ea9439a36dcdf71db295d0c8b4fcbffc678 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 29 Sep 2024 01:09:41 +0100 Subject: [PATCH 02/19] `tests`: allow artificial slowdown of sampling for tests --- common/arg.cpp | 7 +++++++ common/common.h | 2 ++ examples/server/server.cpp | 3 +++ examples/server/tests/features/steps/steps.py | 6 ++++++ 4 files changed, 18 insertions(+) diff --git a/common/arg.cpp b/common/arg.cpp index 8266a16c261c5..1ae55b22c32f7 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1879,6 +1879,13 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex, params.slot_prompt_similarity = std::stof(value); } ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(llama_arg( + {"--testing-sampler-delay-millis"}, "N", + format("for tests: delay in milliseconds to add to each sampling (default: %d)", params.testing_sampler_delay_millis), + [](gpt_params & params, int value) { + params.testing_sampler_delay_millis = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); add_opt(llama_arg( {"--lora-init-without-apply"}, format("load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: %s)", params.lora_init_without_apply ? "enabled" : "disabled"), diff --git a/common/common.h b/common/common.h index 8b84cf9ad45ee..154d59846be62 100644 --- a/common/common.h +++ b/common/common.h @@ -299,6 +299,8 @@ struct gpt_params { float slot_prompt_similarity = 0.5f; + int testing_sampler_delay_millis = 0; + // batched-bench params bool is_pp_shared = false; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 1ce4d7e26beaa..c308e23ca7985 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2348,6 +2348,9 @@ struct server_context { } completion_token_output result; + if (params.testing_sampler_delay_millis > 0) { + std::this_thread::sleep_for(std::chrono::milliseconds(params.testing_sampler_delay_millis)); + } const llama_token id = gpt_sampler_sample(slot.smpl, ctx, slot.i_batch - i); gpt_sampler_accept(slot.smpl, id, true); diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index 2611614ba3633..31bfb0b2b152a 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -78,6 +78,7 @@ def step_server_config(context, server_fqdn: str, server_port: str): context.response_format = None context.temperature = None context.lora_file = None + context.testing_sampler_delay_millis = None context.disable_ctx_shift = False context.tasks_result = [] @@ -455,6 +456,9 @@ def step_impl(context, n_ga): def step_impl(context, n_ga_w): context.n_ga_w = n_ga_w +@step('{testing_sampler_delay_millis:d} milliseconds delay in sampler for testing') +def step_testing_sampler_delay_millis(context, testing_sampler_delay_millis): + context.testing_sampler_delay_millis = testing_sampler_delay_millis @step('a passkey prompt template') def step_prompt_passkey(context): @@ -1436,6 +1440,8 @@ def start_server_background(context): server_args.append('--verbose') if context.lora_file: server_args.extend(['--lora', context.lora_file]) + if context.testing_sampler_delay_millis: + server_args.extend(['--testing-sampler-delay-millis', context.testing_sampler_delay_millis]) if context.disable_ctx_shift: server_args.extend(['--no-context-shift']) From 5f00747a9042b68c1591ca11efd0942acbecfac9 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 29 Sep 2024 01:10:18 +0100 Subject: [PATCH 03/19] `server`: test request cancellation (WIP) --- examples/server/tests/features/cancel.feature | 43 +++++++++++++++++++ examples/server/tests/features/steps/steps.py | 28 +++++++++++- 2 files changed, 69 insertions(+), 2 deletions(-) create mode 100644 examples/server/tests/features/cancel.feature diff --git a/examples/server/tests/features/cancel.feature b/examples/server/tests/features/cancel.feature new file mode 100644 index 0000000000000..54ded24c67c19 --- /dev/null +++ b/examples/server/tests/features/cancel.feature @@ -0,0 +1,43 @@ +@llama.cpp +@server +Feature: Cancellation of llama.cpp server requests + + Background: Server startup + Given a server listening on localhost:8080 + And 500 milliseconds delay in sampler for testing + And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models + And a model file test-model.gguf + And a model alias tinyllama-2 + And BOS token is 1 + And 42 as server seed + # KV Cache corresponds to the total amount of tokens + # that can be stored across all independent sequences: #4130 + # see --ctx-size and #5568 + And 256 KV cache size + And 32 as batch size + And 1 slots + And 64 server max tokens to predict + Then the server is starting + Then the server is healthy + + # Scenario: Health + # Then the server is ready + # And all slots are idle + + @wip + Scenario Outline: Cancelling completion request frees up slot + Given a prompt: + """ + Once upon + """ + And 256 max tokens to predict + And 256 server max tokens to predict + And streaming is + And a completion request cancelled after 100 milliseconds + # And wait for 50 milliseconds + Then all slots are idle + + Examples: Prompts + | enable_streaming | + | disabled | + | enabled | diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index 31bfb0b2b152a..5bc4b06316351 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -291,6 +291,25 @@ async def step_request_completion(context, api_error: Literal['raised'] | str): api_error_code = int(api_error) assert completion == api_error_code, f"completion must be an {api_error_code} status code: {completion}" +@step('wait for {millis:d} milliseconds') +@async_run_until_complete +async def step_request_completion(context, millis: int): + await asyncio.sleep(millis / 1000.0) + +@step('a completion request cancelled after {disconnect_after_millis:d} milliseconds') +@async_run_until_complete +async def step_request_completion(context, disconnect_after_millis: int): + seeds = await completions_seed(context, num_seeds=1) + await request_completion(context.prompts.pop(), + seeds[0] if seeds is not None else seeds, + context.base_url, + debug=context.debug, + n_predict=context.n_predict, + cache_prompt=context.cache_prompt, + id_slot=context.id_slot, + disconnect_after_millis=disconnect_after_millis, + user_api_key=context.user_api_key, + temperature=context.temperature) @step('{predicted_n:d} tokens are predicted matching {re_content}') def step_n_tokens_predicted_with_content(context, predicted_n, re_content): @@ -982,9 +1001,10 @@ async def request_completion(prompt, id_slot=None, expect_api_error=None, user_api_key=None, + disconnect_after_millis=None, temperature=None) -> int | dict[str, Any]: if debug: - print(f"Sending completion request: {prompt}") + print(f"Sending completion request: {prompt} with n_predict={n_predict}") origin = "my.super.domain" headers = { 'Origin': origin @@ -1008,6 +1028,10 @@ async def request_completion(prompt, "n_probs": 2, }, headers=headers) as response: + if disconnect_after_millis is not None: + await asyncio.sleep(disconnect_after_millis / 1000) + return 0 + if expect_api_error is None or not expect_api_error: assert response.status == 200 assert response.headers['Access-Control-Allow-Origin'] == origin @@ -1352,7 +1376,7 @@ async def request_slots_status(context, expected_slots): def assert_slots_status(slots, expected_slots): - assert len(slots) == len(expected_slots) + assert len(slots) == len(expected_slots), f'invalid number of slots: {len(slots)} (actual) != {len(expected_slots)} (expected)' for slot_id, (expected, slot) in enumerate(zip(expected_slots, slots)): for key in expected: assert expected[key] == slot[key], (f"invalid slot {slot_id}" From 419e9952c9310e620454b5f666a73bff8fbfffcb Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 29 Sep 2024 17:11:53 +0100 Subject: [PATCH 04/19] `server`: rm superfluous is_alive check in streamed code --- examples/server/server.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index c308e23ca7985..3142470cfdafa 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3025,9 +3025,6 @@ int main(int argc, char ** argv) { }); } else { ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result result) -> bool { - if (!is_alive()) { - return false; // connection is closed - } std::vector result_array = format_partial_response_oaicompat(result.data, completion_id); for (auto & event_data : result_array) { if (event_data.empty()) { From 88c9b5497a601068932b73845019910623688237 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 29 Sep 2024 19:01:48 +0100 Subject: [PATCH 05/19] `server`: simplify handle_tasks signature --- examples/server/server.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 3142470cfdafa..66f6c49800842 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2431,7 +2431,7 @@ static void handle_tasks( httplib::Response & res, server_context & ctx_server, const std::function(const std::function &)> & create_tasks, - const std::function &, httplib::DataSink & sink, const std::function &)> & payload) + const std::function &, httplib::DataSink & sink)> & payload) { struct State { std::unordered_set task_ids; @@ -2446,7 +2446,7 @@ static void handle_tasks( res.set_content_provider(MIMETYPE_JSON, [create_tasks, payload, state, &ctx_server](size_t, httplib::DataSink & sink) { auto is_alive = [&sink]() { return sink.is_writable(); }; state->task_ids = create_tasks(is_alive); - payload(state->task_ids, sink, is_alive); + payload(state->task_ids, sink); ctx_server.queue_results.remove_waiting_task_ids(state->task_ids); return false; }, resource_releaser); @@ -2454,7 +2454,7 @@ static void handle_tasks( res.set_chunked_content_provider("text/event-stream", [create_tasks, payload, state, &ctx_server](size_t, httplib::DataSink & sink) { auto is_alive = [&sink]() { return sink.is_writable(); }; state->task_ids = create_tasks(is_alive); - payload(state->task_ids, sink, is_alive); + payload(state->task_ids, sink); ctx_server.queue_results.remove_waiting_task_ids(state->task_ids); return false; }, resource_releaser); @@ -2958,7 +2958,7 @@ int main(int argc, char ** argv) { ctx_server.queue_tasks.post(tasks); return server_task::get_list_id(tasks); - }, [stream, &res, &ctx_server](const std::unordered_set & task_ids, httplib::DataSink & sink, const std::function &) { + }, [stream, &res, &ctx_server](const std::unordered_set & task_ids, httplib::DataSink & sink) { if (!stream) { ctx_server.receive_cmpl_results(task_ids, [&res, &sink](std::vector & results) { if (results.size() == 1) { @@ -3013,7 +3013,7 @@ int main(int argc, char ** argv) { ctx_server.queue_tasks.post(tasks); return server_task::get_list_id(tasks); - }, [data, verbose, stream, &res, &ctx_server](const std::unordered_set & task_ids, httplib::DataSink & sink, const std::function & is_alive) { + }, [data, verbose, stream, &res, &ctx_server](const std::unordered_set & task_ids, httplib::DataSink & sink) { const auto completion_id = gen_chatcmplid(); if (!stream) { ctx_server.receive_cmpl_results(task_ids, [completion_id, data, verbose, &sink, &res](std::vector & results) { @@ -3143,7 +3143,7 @@ int main(int argc, char ** argv) { ctx_server.queue_tasks.post(tasks); return server_task::get_list_id(tasks); - }, [is_openai, &ctx_server, &res, body](const std::unordered_set & task_ids, httplib::DataSink & sink, const std::function &) { + }, [is_openai, &ctx_server, &res, body](const std::unordered_set & task_ids, httplib::DataSink & sink) { bool error = false; json responses = json::array(); From 3f96ab04a6da3ee3b766d3a7d957fee2696910bd Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 29 Sep 2024 19:12:59 +0100 Subject: [PATCH 06/19] `server`: fix cancel tests --- examples/server/server.cpp | 7 +-- examples/server/tests/features/cancel.feature | 44 +++++++++++++------ examples/server/tests/features/steps/steps.py | 34 +++++++------- 3 files changed, 53 insertions(+), 32 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 66f6c49800842..0869e4623a24f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2349,6 +2349,7 @@ struct server_context { completion_token_output result; if (params.testing_sampler_delay_millis > 0) { + LOG_DBG("sleeping for %dms before sampling (for tests!)\n", params.testing_sampler_delay_millis); std::this_thread::sleep_for(std::chrono::milliseconds(params.testing_sampler_delay_millis)); } const llama_token id = gpt_sampler_sample(slot.smpl, ctx, slot.i_batch - i); @@ -3006,7 +3007,7 @@ int main(int argc, char ** argv) { json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); bool stream = json_value(data, "stream", false); - + handle_tasks(stream, res, ctx_server, [data, &ctx_server](const std::function & is_alive) { std::vector tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL, is_alive); ctx_server.queue_results.add_waiting_tasks(tasks); @@ -3136,7 +3137,7 @@ int main(int argc, char ** argv) { return; } - + handle_tasks(false, res, ctx_server, [prompt, &ctx_server](const std::function & is_alive) { std::vector tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING, is_alive); ctx_server.queue_results.add_waiting_tasks(tasks); @@ -3164,7 +3165,7 @@ int main(int argc, char ** argv) { json root = is_openai ? format_embeddings_response_oaicompat(body, responses) : responses[0]; - + res_ok(res, &sink, root); }); }; diff --git a/examples/server/tests/features/cancel.feature b/examples/server/tests/features/cancel.feature index 54ded24c67c19..241507024eeed 100644 --- a/examples/server/tests/features/cancel.feature +++ b/examples/server/tests/features/cancel.feature @@ -4,7 +4,6 @@ Feature: Cancellation of llama.cpp server requests Background: Server startup Given a server listening on localhost:8080 - And 500 milliseconds delay in sampler for testing And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models And a model file test-model.gguf And a model alias tinyllama-2 @@ -13,28 +12,45 @@ Feature: Cancellation of llama.cpp server requests # KV Cache corresponds to the total amount of tokens # that can be stored across all independent sequences: #4130 # see --ctx-size and #5568 - And 256 KV cache size + And 512 KV cache size And 32 as batch size - And 1 slots + And 2 slots And 64 server max tokens to predict + And prometheus compatible metrics exposed + And 300 milliseconds delay in sampler for testing + And no warmup Then the server is starting Then the server is healthy + # Then the server is healthy with timeout 10 seconds - # Scenario: Health - # Then the server is ready - # And all slots are idle - @wip - Scenario Outline: Cancelling completion request frees up slot - Given a prompt: - """ - Once upon - """ + Scenario Outline: Cancelling an OAI chat completion request frees up slot (streaming ) + Given a model llama-2 + And a user prompt Once upon a time + And a system prompt You tell lengthy stories And 256 max tokens to predict And 256 server max tokens to predict And streaming is - And a completion request cancelled after 100 milliseconds - # And wait for 50 milliseconds + And disconnect after 100 milliseconds + Given concurrent OAI completions requests + And wait for 700 milliseconds + Then all slots are idle + + Examples: Prompts + | enable_streaming | + | disabled | + | enabled | + + + Scenario Outline: Cancelling a completion request frees up slot (streaming ) + Given a model llama-2 + Given a prompt Once upon a time + And 256 max tokens to predict + And 256 server max tokens to predict + And streaming is + And disconnect after 100 milliseconds + Given a completion request with no api error + And wait for 700 milliseconds Then all slots are idle Examples: Prompts diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index 5bc4b06316351..561fc03ffd173 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -80,6 +80,7 @@ def step_server_config(context, server_fqdn: str, server_port: str): context.lora_file = None context.testing_sampler_delay_millis = None context.disable_ctx_shift = False + context.disconnect_after_millis = None context.tasks_result = [] context.concurrent_tasks = [] @@ -279,6 +280,7 @@ async def step_request_completion(context, api_error: Literal['raised'] | str): n_predict=context.n_predict, cache_prompt=context.cache_prompt, id_slot=context.id_slot, + disconnect_after_millis=context.disconnect_after_millis, expect_api_error=expect_api_error, user_api_key=context.user_api_key, temperature=context.temperature) @@ -296,20 +298,12 @@ async def step_request_completion(context, api_error: Literal['raised'] | str): async def step_request_completion(context, millis: int): await asyncio.sleep(millis / 1000.0) -@step('a completion request cancelled after {disconnect_after_millis:d} milliseconds') + +@step('disconnect after {disconnect_after_millis:d} milliseconds') @async_run_until_complete -async def step_request_completion(context, disconnect_after_millis: int): - seeds = await completions_seed(context, num_seeds=1) - await request_completion(context.prompts.pop(), - seeds[0] if seeds is not None else seeds, - context.base_url, - debug=context.debug, - n_predict=context.n_predict, - cache_prompt=context.cache_prompt, - id_slot=context.id_slot, - disconnect_after_millis=disconnect_after_millis, - user_api_key=context.user_api_key, - temperature=context.temperature) +async def step_disconnect_after(context, disconnect_after_millis: int): + context.disconnect_after_millis = disconnect_after_millis + @step('{predicted_n:d} tokens are predicted matching {re_content}') def step_n_tokens_predicted_with_content(context, predicted_n, re_content): @@ -519,6 +513,7 @@ async def step_oai_chat_completions(context, api_error): print(f"Submitting OAI compatible completions request...") expect_api_error = api_error == 'raised' seeds = await completions_seed(context, num_seeds=1), + seeds = await completions_seed(context, num_seeds=1) completion = await oai_chat_completions(context.prompts.pop(), seeds[0] if seeds is not None else seeds, context.system_prompt, @@ -539,6 +534,8 @@ async def step_oai_chat_completions(context, api_error): user_api_key=context.user_api_key if hasattr(context, 'user_api_key') else None, + disconnect_after_millis=context.disconnect_after_millis, + expect_api_error=expect_api_error) context.tasks_result.append(completion) if context.debug: @@ -606,6 +603,7 @@ async def step_oai_chat_completions(context): if hasattr(context, 'enable_streaming') else None, response_format=context.response_format if hasattr(context, 'response_format') else None, + disconnect_after_millis=context.disconnect_after_millis, user_api_key=context.user_api_key if hasattr(context, 'user_api_key') else None) @@ -1029,9 +1027,9 @@ async def request_completion(prompt, }, headers=headers) as response: if disconnect_after_millis is not None: - await asyncio.sleep(disconnect_after_millis / 1000) + await asyncio.sleep(disconnect_after_millis / 1000.0) return 0 - + if expect_api_error is None or not expect_api_error: assert response.status == 200 assert response.headers['Access-Control-Allow-Origin'] == origin @@ -1050,6 +1048,7 @@ async def oai_chat_completions(user_prompt, temperature=None, model=None, n_predict=None, + disconnect_after_millis=None, enable_streaming=None, response_format=None, user_api_key=None, @@ -1093,6 +1092,10 @@ async def oai_chat_completions(user_prompt, async with session.post(f'{base_url}{base_path}', json=payload, headers=headers) as response: + if disconnect_after_millis is not None: + await asyncio.sleep(disconnect_after_millis / 1000.0) + return 0 + if enable_streaming: assert response.status == 200 assert response.headers['Access-Control-Allow-Origin'] == origin @@ -1133,6 +1136,7 @@ async def oai_chat_completions(user_prompt, else: return response.status else: + assert disconnect_after_millis is None, "disconnect_after_millis is not supported with sync client" try: openai.api_key = user_api_key openai.base_url = f'{base_url}{base_path.removesuffix("chat")}' From 231a5e4914d5dd4e2b2b304778834870f12bba3f Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 29 Sep 2024 19:13:27 +0100 Subject: [PATCH 07/19] `server`: fix seed in tests (comma creates a tuple) --- examples/server/tests/features/steps/steps.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index 561fc03ffd173..bfba17198583e 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -512,7 +512,6 @@ async def step_oai_chat_completions(context, api_error): if context.debug: print(f"Submitting OAI compatible completions request...") expect_api_error = api_error == 'raised' - seeds = await completions_seed(context, num_seeds=1), seeds = await completions_seed(context, num_seeds=1) completion = await oai_chat_completions(context.prompts.pop(), seeds[0] if seeds is not None else seeds, From c5a0d57ee5013a194ea6c4ebced14895cac8bb11 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 29 Sep 2024 19:37:23 +0100 Subject: [PATCH 08/19] Update cancel.feature --- examples/server/tests/features/cancel.feature | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/server/tests/features/cancel.feature b/examples/server/tests/features/cancel.feature index 241507024eeed..e7753b5dd06ef 100644 --- a/examples/server/tests/features/cancel.feature +++ b/examples/server/tests/features/cancel.feature @@ -18,10 +18,8 @@ Feature: Cancellation of llama.cpp server requests And 64 server max tokens to predict And prometheus compatible metrics exposed And 300 milliseconds delay in sampler for testing - And no warmup Then the server is starting Then the server is healthy - # Then the server is healthy with timeout 10 seconds Scenario Outline: Cancelling an OAI chat completion request frees up slot (streaming ) From 0e9c4bf5af7e4d04fda2f4c8ac6a0b6b3d49c57d Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 29 Sep 2024 21:02:18 +0100 Subject: [PATCH 09/19] `server`: update log --- examples/server/server.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 0869e4623a24f..a766f17f684e4 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2349,7 +2349,7 @@ struct server_context { completion_token_output result; if (params.testing_sampler_delay_millis > 0) { - LOG_DBG("sleeping for %dms before sampling (for tests!)\n", params.testing_sampler_delay_millis); + SRV_DBG("sleeping for %dms before sampling (for tests!)\n", params.testing_sampler_delay_millis); std::this_thread::sleep_for(std::chrono::milliseconds(params.testing_sampler_delay_millis)); } const llama_token id = gpt_sampler_sample(slot.smpl, ctx, slot.i_batch - i); From d07387ca9ceb3818db53e5e1fe897286a7952d80 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 29 Sep 2024 21:04:27 +0100 Subject: [PATCH 10/19] `server`: speed up cancel test setup --- examples/server/tests/features/cancel.feature | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/server/tests/features/cancel.feature b/examples/server/tests/features/cancel.feature index e7753b5dd06ef..7112367808451 100644 --- a/examples/server/tests/features/cancel.feature +++ b/examples/server/tests/features/cancel.feature @@ -12,7 +12,7 @@ Feature: Cancellation of llama.cpp server requests # KV Cache corresponds to the total amount of tokens # that can be stored across all independent sequences: #4130 # see --ctx-size and #5568 - And 512 KV cache size + And 256 KV cache size And 32 as batch size And 2 slots And 64 server max tokens to predict From 43e306e08fc0eb10c1bc2ce7b8ee22bca7846e51 Mon Sep 17 00:00:00 2001 From: ochafik Date: Fri, 4 Oct 2024 04:26:11 +0100 Subject: [PATCH 11/19] server: fix error status --- examples/server/server.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index a766f17f684e4..7972e556574b9 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2474,7 +2474,7 @@ static void respond(httplib::Response & res, httplib::DataSink * sink, int statu } static void res_error(httplib::Response & res, httplib::DataSink * sink, const json & error_data) { - respond(res, sink, 200, {{"error", error_data}}); + respond(res, sink, 500, {{"error", error_data}}); } static void res_ok(httplib::Response & res, httplib::DataSink * sink, const json & data) { From 42f546500fe4eba7c68d272f1abfa875cc708c79 Mon Sep 17 00:00:00 2001 From: ochafik Date: Fri, 4 Oct 2024 05:18:39 +0100 Subject: [PATCH 12/19] `server`: introduce supposedly lighterweight is_alive in httplib (https://github.com/yhirose/cpp-httplib/pull/1956) --- examples/server/httplib.h | 18 ++++++++++++++++++ examples/server/server.cpp | 4 ++-- tests/.DS_Store | Bin 0 -> 8196 bytes 3 files changed, 20 insertions(+), 2 deletions(-) create mode 100644 tests/.DS_Store diff --git a/examples/server/httplib.h b/examples/server/httplib.h index f360bd93ea098..025946180eaac 100644 --- a/examples/server/httplib.h +++ b/examples/server/httplib.h @@ -457,6 +457,7 @@ class DataSink { std::function write; std::function is_writable; + std::function is_alive; std::function done; std::function done_with_trailer; std::ostream os; @@ -639,6 +640,7 @@ class Stream { virtual bool is_readable() const = 0; virtual bool is_writable() const = 0; + virtual bool is_alive() const = 0; virtual ssize_t read(char *ptr, size_t size) = 0; virtual ssize_t write(const char *ptr, size_t size) = 0; @@ -2135,6 +2137,7 @@ class BufferStream final : public Stream { bool is_readable() const override; bool is_writable() const override; + bool is_alive() const override; ssize_t read(char *ptr, size_t size) override; ssize_t write(const char *ptr, size_t size) override; void get_remote_ip_and_port(std::string &ip, int &port) const override; @@ -2945,6 +2948,7 @@ class SocketStream final : public Stream { bool is_readable() const override; bool is_writable() const override; + bool is_alive() const override; ssize_t read(char *ptr, size_t size) override; ssize_t write(const char *ptr, size_t size) override; void get_remote_ip_and_port(std::string &ip, int &port) const override; @@ -2975,6 +2979,7 @@ class SSLSocketStream final : public Stream { bool is_readable() const override; bool is_writable() const override; + bool is_alive() const override; ssize_t read(char *ptr, size_t size) override; ssize_t write(const char *ptr, size_t size) override; void get_remote_ip_and_port(std::string &ip, int &port) const override; @@ -4088,6 +4093,7 @@ inline bool write_content(Stream &strm, const ContentProvider &content_provider, }; data_sink.is_writable = [&]() -> bool { return strm.is_writable(); }; + data_sink.is_alive = [&]() -> bool { return strm.is_alive(); }; while (offset < end_offset && !is_shutting_down()) { if (!strm.is_writable()) { @@ -4134,6 +4140,7 @@ write_content_without_length(Stream &strm, }; data_sink.is_writable = [&]() -> bool { return strm.is_writable(); }; + data_sink.is_alive = [&]() -> bool { return strm.is_alive(); }; data_sink.done = [&](void) { data_available = false; }; @@ -4186,6 +4193,7 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider, }; data_sink.is_writable = [&]() -> bool { return strm.is_writable(); }; + data_sink.is_alive = [&]() -> bool { return strm.is_alive(); }; auto done_with_trailer = [&](const Headers *trailer) { if (!ok) { return; } @@ -5484,6 +5492,10 @@ inline bool SocketStream::is_writable() const { is_socket_alive(sock_); } +inline bool SocketStream::is_alive() const { + return is_socket_alive(sock_); +} + inline ssize_t SocketStream::read(char *ptr, size_t size) { #ifdef _WIN32 size = @@ -5558,6 +5570,8 @@ inline bool BufferStream::is_readable() const { return true; } inline bool BufferStream::is_writable() const { return true; } +inline bool BufferStream::is_alive() const { return true; } + inline ssize_t BufferStream::read(char *ptr, size_t size) { #if defined(_MSC_VER) && _MSC_VER < 1910 auto len_read = buffer._Copy_s(ptr, size, size, position); @@ -8348,6 +8362,10 @@ inline bool SSLSocketStream::is_writable() const { is_socket_alive(sock_); } +inline bool SSLSocketStream::is_alive() const { + return is_socket_alive(sock_); +} + inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { if (SSL_pending(ssl_) > 0) { return SSL_read(ssl_, ptr, static_cast(size)); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 7972e556574b9..8f8a91ad5bf5f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2445,7 +2445,7 @@ static void handle_tasks( }; if (!stream) { res.set_content_provider(MIMETYPE_JSON, [create_tasks, payload, state, &ctx_server](size_t, httplib::DataSink & sink) { - auto is_alive = [&sink]() { return sink.is_writable(); }; + auto is_alive = [&sink]() { return sink.is_alive(); }; state->task_ids = create_tasks(is_alive); payload(state->task_ids, sink); ctx_server.queue_results.remove_waiting_task_ids(state->task_ids); @@ -2453,7 +2453,7 @@ static void handle_tasks( }, resource_releaser); } else { res.set_chunked_content_provider("text/event-stream", [create_tasks, payload, state, &ctx_server](size_t, httplib::DataSink & sink) { - auto is_alive = [&sink]() { return sink.is_writable(); }; + auto is_alive = [&sink]() { return sink.is_alive(); }; state->task_ids = create_tasks(is_alive); payload(state->task_ids, sink); ctx_server.queue_results.remove_waiting_task_ids(state->task_ids); diff --git a/tests/.DS_Store b/tests/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..9baef261ea77e8aec20a8434c24c60399e395110 GIT binary patch literal 8196 zcmeHMO>fgc5S>j^lBhz-fr?6Su!Oj^kVX|B;(!U^$N`K%2vD%&sC99?A$G_i3i1V^ zaO1{L;K+>wcYZW=z8+=I z3a|qIr2_PRu&@!@22+jd)qzeO0TA;TR)+lwbr3Csw!u^*j-Ux$il|G4IbsN1j&aN6 z*#=XMx*UW#dl1^C$7sVaf7JcQw4$F#*~Nm zg26tc*sb;{2KOS8g|-gq=u^ zH3MPZ0&x5{Mz6AT{J0f*xI9YU;`q$zMWKvVF-GxuNT3Zd`XyALk0_yskr11U@iv?T z9~`(3H^bFsrzJaGtOD`*^^5&gcVweszKs-X-@bqNboJx4{m$SYa)$0&mcctBQqJ)2 zKmZY}Bk!r4VSM`O?iE>{@tfxxUv4R$yWFQEe{tJ;RhAEoZQ^rk8y|_!j-KTbHi^7^ z?6nXh^c0`UQ~Wsl^58RN`Om(Uo6q@x>&ohK_}Q|0<;q`4uRj_6J}hj6w!u^*g(%MZ zKMBFc3LHTNX7rjD==^_U=kNbVkQpzT6<`JSs({EhoJJk9)vi9JigtMm+hc5Gl$UB$ kCFr*OA%NFg|1iY9g{f3y8%#A~1dTrgC>hwX0)JJ3pI46gjsO4v literal 0 HcmV?d00001 From 03efb92fdedf3bf604aed0c78d3fbf8e5368c020 Mon Sep 17 00:00:00 2001 From: ochafik Date: Fri, 4 Oct 2024 13:22:57 +0100 Subject: [PATCH 13/19] `server`: support cancellation of prompt processing --- examples/server/server.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 8f8a91ad5bf5f..c4dde5347edd3 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1877,6 +1877,13 @@ struct server_context { system_prompt_update(); } + for (auto & slot : slots) { + if (slot.is_processing() && slot.is_alive && !slot.is_alive()) { + SLT_WRN(slot, "%s", "slot connection died\n"); + slot.release(); + } + } + // check if all slots are idle { bool all_idle = true; From 16ff5022145428a60159c7f58843e792e5b77aa3 Mon Sep 17 00:00:00 2001 From: ochafik Date: Fri, 4 Oct 2024 13:24:15 +0100 Subject: [PATCH 14/19] `server`: mime nit --- examples/server/server.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index c4dde5347edd3..bee140b2142f4 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -12,6 +12,7 @@ #include "json.hpp" // mime type for sending response #define MIMETYPE_JSON "application/json; charset=utf-8" +#define MIMETYPE_EVENT_STREAM "text/event-stream" // auto generated files (update with ./deps.sh) #include "colorthemes.css.hpp" @@ -2459,7 +2460,7 @@ static void handle_tasks( return false; }, resource_releaser); } else { - res.set_chunked_content_provider("text/event-stream", [create_tasks, payload, state, &ctx_server](size_t, httplib::DataSink & sink) { + res.set_chunked_content_provider(MIMETYPE_EVENT_STREAM, [create_tasks, payload, state, &ctx_server](size_t, httplib::DataSink & sink) { auto is_alive = [&sink]() { return sink.is_alive(); }; state->task_ids = create_tasks(is_alive); payload(state->task_ids, sink); From 6297d9fec1653599acdc8bfdc5a52cee89fb1d75 Mon Sep 17 00:00:00 2001 From: ochafik Date: Fri, 4 Oct 2024 18:45:09 +0100 Subject: [PATCH 15/19] =?UTF-8?q?`server`:=20avoid=20calling=20sink.is=5Fa?= =?UTF-8?q?live()=20after=20it=20died=20=F0=9F=A7=9F=E2=80=8D=E2=99=82?= =?UTF-8?q?=EF=B8=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- examples/server/server.cpp | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index bee140b2142f4..93c6c43a8f535 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1121,7 +1121,7 @@ struct server_context { } bool process_token(completion_token_output & result, server_slot & slot) { - if (!slot.is_alive()) { + if (slot.is_alive && !slot.is_alive()) { slot.truncated = false; slot.has_next_token = false; @@ -2444,16 +2444,20 @@ static void handle_tasks( { struct State { std::unordered_set task_ids; + bool is_sink_valid = true; }; auto state = std::make_shared(); httplib::ContentProviderResourceReleaser resource_releaser = [state, &ctx_server](bool success) { + state->is_sink_valid = false; if (!success && state) { ctx_server.cancel_tasks(state->task_ids); } }; if (!stream) { res.set_content_provider(MIMETYPE_JSON, [create_tasks, payload, state, &ctx_server](size_t, httplib::DataSink & sink) { - auto is_alive = [&sink]() { return sink.is_alive(); }; + auto is_alive = [state, &sink]() { + return state->is_sink_valid && sink.is_alive(); + }; state->task_ids = create_tasks(is_alive); payload(state->task_ids, sink); ctx_server.queue_results.remove_waiting_task_ids(state->task_ids); @@ -2461,7 +2465,9 @@ static void handle_tasks( }, resource_releaser); } else { res.set_chunked_content_provider(MIMETYPE_EVENT_STREAM, [create_tasks, payload, state, &ctx_server](size_t, httplib::DataSink & sink) { - auto is_alive = [&sink]() { return sink.is_alive(); }; + auto is_alive = [state, &sink]() { + return state->is_sink_valid && sink.is_alive(); + }; state->task_ids = create_tasks(is_alive); payload(state->task_ids, sink); ctx_server.queue_results.remove_waiting_task_ids(state->task_ids); From 2dc708c72ad8f34c3b3c87265062feafcf4e1321 Mon Sep 17 00:00:00 2001 From: ochafik Date: Fri, 4 Oct 2024 19:32:46 +0100 Subject: [PATCH 16/19] Delete tests/.DS_Store --- tests/.DS_Store | Bin 8196 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 tests/.DS_Store diff --git a/tests/.DS_Store b/tests/.DS_Store deleted file mode 100644 index 9baef261ea77e8aec20a8434c24c60399e395110..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8196 zcmeHMO>fgc5S>j^lBhz-fr?6Su!Oj^kVX|B;(!U^$N`K%2vD%&sC99?A$G_i3i1V^ zaO1{L;K+>wcYZW=z8+=I z3a|qIr2_PRu&@!@22+jd)qzeO0TA;TR)+lwbr3Csw!u^*j-Ux$il|G4IbsN1j&aN6 z*#=XMx*UW#dl1^C$7sVaf7JcQw4$F#*~Nm zg26tc*sb;{2KOS8g|-gq=u^ zH3MPZ0&x5{Mz6AT{J0f*xI9YU;`q$zMWKvVF-GxuNT3Zd`XyALk0_yskr11U@iv?T z9~`(3H^bFsrzJaGtOD`*^^5&gcVweszKs-X-@bqNboJx4{m$SYa)$0&mcctBQqJ)2 zKmZY}Bk!r4VSM`O?iE>{@tfxxUv4R$yWFQEe{tJ;RhAEoZQ^rk8y|_!j-KTbHi^7^ z?6nXh^c0`UQ~Wsl^58RN`Om(Uo6q@x>&ohK_}Q|0<;q`4uRj_6J}hj6w!u^*g(%MZ zKMBFc3LHTNX7rjD==^_U=kNbVkQpzT6<`JSs({EhoJJk9)vi9JigtMm+hc5Gl$UB$ kCFr*OA%NFg|1iY9g{f3y8%#A~1dTrgC>hwX0)JJ3pI46gjsO4v From 6f693f14b0ee9cc525c62dc3cca85af1f5da36a5 Mon Sep 17 00:00:00 2001 From: ochafik Date: Fri, 4 Oct 2024 23:51:58 +0100 Subject: [PATCH 17/19] `server`: use (new) Request::is_alive as set_content_provider called after status / headers sent --- examples/server/httplib.h | 7 +- examples/server/server.cpp | 326 +++++++++++++++++-------------------- 2 files changed, 149 insertions(+), 184 deletions(-) diff --git a/examples/server/httplib.h b/examples/server/httplib.h index 025946180eaac..05ee81a088ed7 100644 --- a/examples/server/httplib.h +++ b/examples/server/httplib.h @@ -457,7 +457,6 @@ class DataSink { std::function write; std::function is_writable; - std::function is_alive; std::function done; std::function done_with_trailer; std::ostream os; @@ -591,6 +590,7 @@ struct Response { Headers headers; std::string body; std::string location; // Redirect location + std::function is_alive; bool has_header(const std::string &key) const; std::string get_header_value(const std::string &key, size_t id = 0) const; @@ -4093,7 +4093,6 @@ inline bool write_content(Stream &strm, const ContentProvider &content_provider, }; data_sink.is_writable = [&]() -> bool { return strm.is_writable(); }; - data_sink.is_alive = [&]() -> bool { return strm.is_alive(); }; while (offset < end_offset && !is_shutting_down()) { if (!strm.is_writable()) { @@ -4140,7 +4139,6 @@ write_content_without_length(Stream &strm, }; data_sink.is_writable = [&]() -> bool { return strm.is_writable(); }; - data_sink.is_alive = [&]() -> bool { return strm.is_alive(); }; data_sink.done = [&](void) { data_available = false; }; @@ -4193,7 +4191,6 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider, }; data_sink.is_writable = [&]() -> bool { return strm.is_writable(); }; - data_sink.is_alive = [&]() -> bool { return strm.is_alive(); }; auto done_with_trailer = [&](const Headers *trailer) { if (!ok) { return; } @@ -4287,6 +4284,7 @@ inline bool redirect(T &cli, Request &req, Response &res, } Response new_res; + new_res.is_alive = res.is_alive; auto ret = cli.send(new_req, new_res, error); if (ret) { @@ -6648,6 +6646,7 @@ Server::process_request(Stream &strm, bool close_connection, Request req; Response res; + res.is_alive = [&strm]() { return strm.is_alive(); }; res.version = "HTTP/1.1"; res.headers = default_headers_; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 93c6c43a8f535..880862aac3305 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -12,7 +12,6 @@ #include "json.hpp" // mime type for sending response #define MIMETYPE_JSON "application/json; charset=utf-8" -#define MIMETYPE_EVENT_STREAM "text/event-stream" // auto generated files (update with ./deps.sh) #include "colorthemes.css.hpp" @@ -34,7 +33,6 @@ #include #include -#include #include #include #include @@ -2435,66 +2433,6 @@ inline void signal_handler(int signal) { shutdown_handler(signal); } -static void handle_tasks( - bool stream, - httplib::Response & res, - server_context & ctx_server, - const std::function(const std::function &)> & create_tasks, - const std::function &, httplib::DataSink & sink)> & payload) -{ - struct State { - std::unordered_set task_ids; - bool is_sink_valid = true; - }; - auto state = std::make_shared(); - httplib::ContentProviderResourceReleaser resource_releaser = [state, &ctx_server](bool success) { - state->is_sink_valid = false; - if (!success && state) { - ctx_server.cancel_tasks(state->task_ids); - } - }; - if (!stream) { - res.set_content_provider(MIMETYPE_JSON, [create_tasks, payload, state, &ctx_server](size_t, httplib::DataSink & sink) { - auto is_alive = [state, &sink]() { - return state->is_sink_valid && sink.is_alive(); - }; - state->task_ids = create_tasks(is_alive); - payload(state->task_ids, sink); - ctx_server.queue_results.remove_waiting_task_ids(state->task_ids); - return false; - }, resource_releaser); - } else { - res.set_chunked_content_provider(MIMETYPE_EVENT_STREAM, [create_tasks, payload, state, &ctx_server](size_t, httplib::DataSink & sink) { - auto is_alive = [state, &sink]() { - return state->is_sink_valid && sink.is_alive(); - }; - state->task_ids = create_tasks(is_alive); - payload(state->task_ids, sink); - ctx_server.queue_results.remove_waiting_task_ids(state->task_ids); - return false; - }, resource_releaser); - } -} - -static void respond(httplib::Response & res, httplib::DataSink * sink, int status, const json & response) { - res.status = status; - if (sink) { - res.set_header("Content-Type", MIMETYPE_JSON); - auto out = response.dump(-1, ' ', false, json::error_handler_t::replace); - sink->write(out.c_str(), out.size()); - } else { - res.set_content(response.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON); - } -} - -static void res_error(httplib::Response & res, httplib::DataSink * sink, const json & error_data) { - respond(res, sink, 500, {{"error", error_data}}); -} - -static void res_ok(httplib::Response & res, httplib::DataSink * sink, const json & data) { - respond(res, sink, 200, data); -} - int main(int argc, char ** argv) { // own arguments required by this example gpt_params params; @@ -2562,7 +2500,18 @@ int main(int argc, char ** argv) { svr->set_logger(log_server_request); - svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) { + auto res_error = [](httplib::Response & res, const json & error_data) { + json final_response {{"error", error_data}}; + res.set_content(final_response.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON); + res.status = json_value(error_data, "code", 500); + }; + + auto res_ok = [](httplib::Response & res, const json & data) { + res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON); + res.status = 200; + }; + + svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) { std::string message; try { std::rethrow_exception(ep); @@ -2574,12 +2523,12 @@ int main(int argc, char ** argv) { json formatted_error = format_error_response(message, ERROR_TYPE_SERVER); LOG_WRN("got exception: %s\n", formatted_error.dump().c_str()); - res_error(res, /* sink= */ nullptr, formatted_error); + res_error(res, formatted_error); }); - svr->set_error_handler([](const httplib::Request &, httplib::Response & res) { + svr->set_error_handler([&res_error](const httplib::Request &, httplib::Response & res) { if (res.status == 404) { - res_error(res, /* sink= */ nullptr, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND)); + res_error(res, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND)); } // for other error codes, we skip processing here because it's already done by res_error() }); @@ -2607,7 +2556,7 @@ int main(int argc, char ** argv) { // Middlewares // - auto middleware_validate_api_key = [¶ms](const httplib::Request & req, httplib::Response & res) { + auto middleware_validate_api_key = [¶ms, &res_error](const httplib::Request & req, httplib::Response & res) { // TODO: should we apply API key to all endpoints, including "/health" and "/models"? static const std::unordered_set protected_endpoints = { "/props", @@ -2646,14 +2595,14 @@ int main(int argc, char ** argv) { } // API key is invalid or not provided - res_error(res, /* sink= */ nullptr, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION)); + res_error(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION)); LOG_WRN("Unauthorized: Invalid API Key\n"); return false; }; - auto middleware_server_state = [&state](const httplib::Request & req, httplib::Response & res) { + auto middleware_server_state = [&res_error, &state](const httplib::Request & req, httplib::Response & res) { server_state current_state = state.load(); if (current_state == SERVER_STATE_LOADING_MODEL) { auto tmp = string_split(req.path, '.'); @@ -2661,7 +2610,7 @@ int main(int argc, char ** argv) { res.set_content(reinterpret_cast(loading_html), loading_html_len, "text/html; charset=utf-8"); res.status = 503; } else { - res_error(res, /* sink= */ nullptr, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE)); + res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE)); } return false; } @@ -2687,12 +2636,12 @@ int main(int argc, char ** argv) { const auto handle_health = [&](const httplib::Request &, httplib::Response & res) { // error and loading states are handled by middleware json health = {{"status", "ok"}}; - res_ok(res, /* sink= */ nullptr, health); + res_ok(res, health); }; const auto handle_slots = [&](const httplib::Request & req, httplib::Response & res) { if (!params.endpoint_slots) { - res_error(res, /* sink= */ nullptr, format_error_response("This server does not support slots endpoint. Start it without `--no-slots`", ERROR_TYPE_NOT_SUPPORTED)); + res_error(res, format_error_response("This server does not support slots endpoint. Start it without `--no-slots`", ERROR_TYPE_NOT_SUPPORTED)); return; } @@ -2712,17 +2661,17 @@ int main(int argc, char ** argv) { const int n_idle_slots = result.data.at("idle"); if (req.has_param("fail_on_no_slot")) { if (n_idle_slots == 0) { - res_error(res, /* sink= */ nullptr, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE)); + res_error(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE)); return; } } - res_ok(res, /* sink= */ nullptr, result.data.at("slots")); + res_ok(res, result.data.at("slots")); }; const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) { if (!params.endpoint_metrics) { - res_error(res, /* sink= */ nullptr, format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED)); + res_error(res, format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED)); return; } @@ -2831,11 +2780,11 @@ int main(int argc, char ** argv) { res.status = 200; // HTTP OK }; - const auto handle_slots_save = [&ctx_server, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { + const auto handle_slots_save = [&ctx_server, &res_error, &res_ok, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { json request_data = json::parse(req.body); std::string filename = request_data.at("filename"); if (!fs_validate_filename(filename)) { - res_error(res, /* sink= */ nullptr, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); + res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); return; } std::string filepath = params.slot_save_path + filename; @@ -2855,17 +2804,17 @@ int main(int argc, char ** argv) { ctx_server.queue_results.remove_waiting_task_id(id_task); if (result.error) { - res_error(res, /* sink= */ nullptr, result.data); + res_error(res, result.data); } else { - res_ok(res, /* sink= */ nullptr, result.data); + res_ok(res, result.data); } }; - const auto handle_slots_restore = [&ctx_server, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { + const auto handle_slots_restore = [&ctx_server, &res_error, &res_ok, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { json request_data = json::parse(req.body); std::string filename = request_data.at("filename"); if (!fs_validate_filename(filename)) { - res_error(res, /* sink= */ nullptr, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); + res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); return; } std::string filepath = params.slot_save_path + filename; @@ -2885,13 +2834,13 @@ int main(int argc, char ** argv) { ctx_server.queue_results.remove_waiting_task_id(id_task); if (result.error) { - res_error(res, /* sink= */ nullptr, result.data); + res_error(res, result.data); } else { - res_ok(res, /* sink= */ nullptr, result.data); + res_ok(res, result.data); } }; - const auto handle_slots_erase = [&ctx_server](const httplib::Request & /* req */, httplib::Response & res, int id_slot) { + const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) { server_task task; task.type = SERVER_TASK_TYPE_SLOT_ERASE; task.data = { @@ -2905,15 +2854,15 @@ int main(int argc, char ** argv) { ctx_server.queue_results.remove_waiting_task_id(id_task); if (result.error) { - res_error(res, /* sink= */ nullptr, result.data); + res_error(res, result.data); } else { - res_ok(res, /* sink= */ nullptr, result.data); + res_ok(res, result.data); } }; - const auto handle_slots_action = [¶ms, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) { + const auto handle_slots_action = [¶ms, &res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) { if (params.slot_save_path.empty()) { - res_error(res, /* sink= */ nullptr, format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED)); + res_error(res, format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED)); return; } @@ -2923,7 +2872,7 @@ int main(int argc, char ** argv) { try { id_slot = std::stoi(id_slot_str); } catch (const std::exception &) { - res_error(res, /* sink= */ nullptr, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST)); + res_error(res, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST)); return; } @@ -2936,11 +2885,11 @@ int main(int argc, char ** argv) { } else if (action == "erase") { handle_slots_erase(req, res, id_slot); } else { - res_error(res, /* sink= */ nullptr, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST)); + res_error(res, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST)); } }; - const auto handle_props = [&ctx_server](const httplib::Request &, httplib::Response & res) { + const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { std::string template_key = "tokenizer.chat_template", curr_tmpl; int32_t tlen = llama_model_meta_val_str(ctx_server.model, template_key.c_str(), nullptr, 0); if (tlen > 0) { @@ -2956,49 +2905,57 @@ int main(int argc, char ** argv) { { "chat_template", curr_tmpl.c_str() }, }; - res_ok(res, /* sink= */ nullptr, data); + res_ok(res, data); }; - const auto handle_completions_generic = [&ctx_server](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) { + const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) { if (ctx_server.params.embedding || ctx_server.params.reranking) { - res_error(res, /* sink= */ nullptr, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); + res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); return; } - bool stream = json_value(data, "stream", false); + std::vector tasks = ctx_server.create_tasks_cmpl(data, cmpl_type, res.is_alive); + ctx_server.queue_results.add_waiting_tasks(tasks); + ctx_server.queue_tasks.post(tasks); - handle_tasks(stream, res, ctx_server, [data, cmpl_type, &ctx_server](const std::function & is_alive) { - std::vector tasks = ctx_server.create_tasks_cmpl(data, cmpl_type, is_alive); - ctx_server.queue_results.add_waiting_tasks(tasks); - ctx_server.queue_tasks.post(tasks); + bool stream = json_value(data, "stream", false); + const auto task_ids = server_task::get_list_id(tasks); - return server_task::get_list_id(tasks); - }, [stream, &res, &ctx_server](const std::unordered_set & task_ids, httplib::DataSink & sink) { - if (!stream) { - ctx_server.receive_cmpl_results(task_ids, [&res, &sink](std::vector & results) { - if (results.size() == 1) { - // single result - res_ok(res, &sink, results[0].data); - } else { - // multiple results (multitask) - json arr = json::array(); - for (const auto & res : results) { - arr.push_back(res.data); - } - res_ok(res, &sink, arr); + if (!stream) { + ctx_server.receive_cmpl_results(task_ids, [&](std::vector & results) { + if (results.size() == 1) { + // single result + res_ok(res, results[0].data); + } else { + // multiple results (multitask) + json arr = json::array(); + for (const auto & res : results) { + arr.push_back(res.data); } - }, [&res, &sink](json error_data) { - res_error(res, &sink, error_data); - }); - } else { - ctx_server.receive_cmpl_results_stream(task_ids, [&sink](server_task_result result) -> bool { + res_ok(res, arr); + } + }, [&](const json & error_data) { + res_error(res, error_data); + }); + + ctx_server.queue_results.remove_waiting_task_ids(task_ids); + } else { + const auto chunked_content_provider = [task_ids, &ctx_server](size_t, httplib::DataSink & sink) { + ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool { return server_sent_event(sink, "data", result.data); - }, [&sink](const json & error_data) { + }, [&](const json & error_data) { server_sent_event(sink, "error", error_data); }); sink.done(); - } - }); + return false; + }; + + auto on_complete = [task_ids, &ctx_server] (bool) { + ctx_server.queue_results.remove_waiting_task_ids(task_ids); + }; + + res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); + } }; const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) { @@ -3012,34 +2969,35 @@ int main(int argc, char ** argv) { }; // TODO: maybe merge this function with "handle_completions_generic" - const auto handle_chat_completions = [&ctx_server, ¶ms, verbose](const httplib::Request & req, httplib::Response & res) { + const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) { if (ctx_server.params.embedding || ctx_server.params.reranking) { - res_error(res, /* sink= */ nullptr, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); + res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); return; } json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); - bool stream = json_value(data, "stream", false); + std::vector tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL, res.is_alive); + ctx_server.queue_results.add_waiting_tasks(tasks); + ctx_server.queue_tasks.post(tasks); - handle_tasks(stream, res, ctx_server, [data, &ctx_server](const std::function & is_alive) { - std::vector tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL, is_alive); - ctx_server.queue_results.add_waiting_tasks(tasks); - ctx_server.queue_tasks.post(tasks); + bool stream = json_value(data, "stream", false); + const auto task_ids = server_task::get_list_id(tasks); + const auto completion_id = gen_chatcmplid(); + + if (!stream) { + ctx_server.receive_cmpl_results(task_ids, [&](const std::vector & results) { + // multitask is never support in chat completion, there is only one result + json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose); + res_ok(res, result_oai); + }, [&](const json & error_data) { + res_error(res, error_data); + }); - return server_task::get_list_id(tasks); - }, [data, verbose, stream, &res, &ctx_server](const std::unordered_set & task_ids, httplib::DataSink & sink) { - const auto completion_id = gen_chatcmplid(); - if (!stream) { - ctx_server.receive_cmpl_results(task_ids, [completion_id, data, verbose, &sink, &res](std::vector & results) { - // multitask is never support in chat completion, there is only one result - json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose); - res_ok(res, &sink, result_oai); - }, [&res, &sink](json error_data) { - res_error(res, &sink, error_data); - }); - } else { - ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result result) -> bool { + ctx_server.queue_results.remove_waiting_task_ids(task_ids); + } else { + const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) { + ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool { std::vector result_array = format_partial_response_oaicompat(result.data, completion_id); for (auto & event_data : result_array) { if (event_data.empty()) { @@ -3056,8 +3014,15 @@ int main(int argc, char ** argv) { static const std::string ev_done = "data: [DONE]\n\n"; sink.write(ev_done.data(), ev_done.size()); sink.done(); - } - }); + return true; + }; + + auto on_complete = [task_ids, &ctx_server] (bool) { + ctx_server.queue_results.remove_waiting_task_ids(task_ids); + }; + + res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); + } }; const auto handle_models = [¶ms, &ctx_server](const httplib::Request &, httplib::Response & res) { @@ -3077,7 +3042,7 @@ int main(int argc, char ** argv) { res.set_content(models.dump(), MIMETYPE_JSON); }; - const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { + const auto handle_tokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) { const json body = json::parse(req.body); json tokens_response = json::array(); @@ -3113,10 +3078,10 @@ int main(int argc, char ** argv) { } const json data = format_tokenizer_response(tokens_response); - res_ok(res, /* sink= */ nullptr, data); + res_ok(res, data); }; - const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { + const auto handle_detokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) { const json body = json::parse(req.body); std::string content; @@ -3126,13 +3091,13 @@ int main(int argc, char ** argv) { } const json data = format_detokenized_response(content); - res_ok(res, /* sink= */ nullptr, data); + res_ok(res, data); }; - const auto handle_embeddings = [&ctx_server](const httplib::Request & req, httplib::Response & res) { + const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { // TODO: somehow clean up this checks in the future if (!ctx_server.params.embedding || ctx_server.params.reranking) { - res_error(res, /* sink= */ nullptr, format_error_response("This server does not support embeddings. Start it with `--embeddings` and without `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); + res_error(res, format_error_response("This server does not support embeddings. Start it with `--embeddings` and without `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); return; } const json body = json::parse(req.body); @@ -3147,46 +3112,47 @@ int main(int argc, char ** argv) { // with "content", we only support single prompt prompt = std::vector{body.at("content")}; } else { - res_error(res, /* sink= */ nullptr, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); + res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); return; } - - handle_tasks(false, res, ctx_server, [prompt, &ctx_server](const std::function & is_alive) { - std::vector tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING, is_alive); + // create and queue the task + json responses = json::array(); + bool error = false; + { + std::vector tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING, res.is_alive); ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_tasks.post(tasks); - return server_task::get_list_id(tasks); - }, [is_openai, &ctx_server, &res, body](const std::unordered_set & task_ids, httplib::DataSink & sink) { - bool error = false; - json responses = json::array(); + // get the result + std::unordered_set task_ids = server_task::get_list_id(tasks); - ctx_server.receive_cmpl_results(task_ids, [&responses](std::vector & results) { + ctx_server.receive_cmpl_results(task_ids, [&](std::vector & results) { for (const auto & res : results) { responses.push_back(res.data); } - }, [&res, &error](json error_data) { - res_error(res, /* sink= */ nullptr, error_data); + }, [&](const json & error_data) { + res_error(res, error_data); error = true; }); - if (error) { - return; - } + ctx_server.queue_results.remove_waiting_task_ids(task_ids); + } - // write JSON response - json root = is_openai - ? format_embeddings_response_oaicompat(body, responses) - : responses[0]; + if (error) { + return; + } - res_ok(res, &sink, root); - }); + // write JSON response + json root = is_openai + ? format_embeddings_response_oaicompat(body, responses) + : responses[0]; + res_ok(res, root); }; - const auto handle_rerank = [&ctx_server](const httplib::Request & req, httplib::Response & res) { + const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { if (!ctx_server.params.reranking) { - res_error(res, /* sink= */ nullptr, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); + res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); return; } const json body = json::parse(req.body); @@ -3204,17 +3170,17 @@ int main(int argc, char ** argv) { if (body.count("query") == 1) { query = body.at("query"); if (!query.is_string()) { - res_error(res, /* sink= */ nullptr, format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST)); + res_error(res, format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST)); return; } } else { - res_error(res, /* sink= */ nullptr, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST)); + res_error(res, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST)); return; } std::vector documents = json_value(body, "documents", std::vector()); if (documents.empty()) { - res_error(res, /* sink= */ nullptr, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST)); + res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST)); return; } @@ -3231,7 +3197,7 @@ int main(int argc, char ** argv) { json responses = json::array(); bool error = false; { - std::vector tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK, []() { return true; }); + std::vector tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK, res.is_alive); ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_tasks.post(tasks); @@ -3243,7 +3209,7 @@ int main(int argc, char ** argv) { responses.push_back(res.data); } }, [&](const json & error_data) { - res_error(res, /* sink= */ nullptr, error_data); + res_error(res, error_data); error = true; }); } @@ -3254,7 +3220,7 @@ int main(int argc, char ** argv) { // write JSON response json root = format_response_rerank(body, responses); - res_ok(res, /* sink= */ nullptr, root); + res_ok(res, root); }; const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) { @@ -3267,7 +3233,7 @@ int main(int argc, char ** argv) { {"scale", lora.scale}, }); } - res_ok(res, /* sink= */ nullptr, result); + res_ok(res, result); res.status = 200; // HTTP OK }; @@ -3299,7 +3265,7 @@ int main(int argc, char ** argv) { server_task_result result = ctx_server.queue_results.recv(id_task); ctx_server.queue_results.remove_waiting_task_id(id_task); - res_ok(res, /* sink= */ nullptr, result.data); + res_ok(res, result.data); res.status = 200; // HTTP OK }; @@ -3454,4 +3420,4 @@ int main(int argc, char ** argv) { t.join(); return 0; -} +} \ No newline at end of file From 52c5a6244f76607abee2447a2f2e498ef92b1101 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sat, 5 Oct 2024 00:44:13 +0100 Subject: [PATCH 18/19] `server`: fix disconnection logic in test (before post response headers) --- examples/server/tests/features/steps/steps.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index bfba17198583e..cc3107b2b773f 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -1012,6 +1012,9 @@ async def request_completion(prompt, headers['Authorization'] = f'Bearer {user_api_key}' async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: + if disconnect_after_millis is not None: + await asyncio.sleep(disconnect_after_millis / 1000.0) + return 0 async with session.post(f'{base_url}/completion', json={ "input_prefix": prompt_prefix, @@ -1025,10 +1028,6 @@ async def request_completion(prompt, "n_probs": 2, }, headers=headers) as response: - if disconnect_after_millis is not None: - await asyncio.sleep(disconnect_after_millis / 1000.0) - return 0 - if expect_api_error is None or not expect_api_error: assert response.status == 200 assert response.headers['Access-Control-Allow-Origin'] == origin @@ -1088,13 +1087,12 @@ async def oai_chat_completions(user_prompt, origin = 'llama.cpp' headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin} async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: + if disconnect_after_millis is not None: + await asyncio.sleep(disconnect_after_millis / 1000.0) + return 0 async with session.post(f'{base_url}{base_path}', json=payload, headers=headers) as response: - if disconnect_after_millis is not None: - await asyncio.sleep(disconnect_after_millis / 1000.0) - return 0 - if enable_streaming: assert response.status == 200 assert response.headers['Access-Control-Allow-Origin'] == origin From e41a5403a0d1a305480d961abce467d03bb5013a Mon Sep 17 00:00:00 2001 From: ochafik Date: Sat, 5 Oct 2024 01:15:48 +0100 Subject: [PATCH 19/19] reinstate trailing \n --- examples/server/server.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 880862aac3305..01998eabe61a0 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3420,4 +3420,4 @@ int main(int argc, char ** argv) { t.join(); return 0; -} \ No newline at end of file +}