diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 25bc296396772..f6a01a7dd302a 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2282,6 +2282,7 @@ struct server_context { {"size", llama_model_size (model)}, }; } + }; static void server_print_usage(const char * argv0, const gpt_params & params, const server_params & sparams) { @@ -3009,42 +3010,58 @@ int main(int argc, char ** argv) { log_data["api_key"] = "api_key: " + std::to_string(sparams.api_keys.size()) + " keys loaded"; } - // load the model - if (!ctx_server.load_model(params)) { - state.store(SERVER_STATE_ERROR); - return 1; - } else { - ctx_server.init(); - state.store(SERVER_STATE_READY); - } + const auto handle_health = [&](const httplib::Request & req, httplib::Response & res) { + server_state current_state = state.load(); + switch (current_state) { + case SERVER_STATE_READY: + { + // request slots data using task queue + server_task task; + task.id = ctx_server.queue_tasks.get_new_id(); + task.type = SERVER_TASK_TYPE_METRICS; + task.id_target = -1; - LOG_INFO("model loaded", {}); + ctx_server.queue_results.add_waiting_task_id(task.id); + ctx_server.queue_tasks.post(task); - const auto model_meta = ctx_server.model_meta(); + // get the result + server_task_result result = ctx_server.queue_results.recv(task.id); + ctx_server.queue_results.remove_waiting_task_id(task.id); - // if a custom chat template is not supplied, we will use the one that comes with the model (if any) - if (sparams.chat_template.empty()) { - if (!ctx_server.validate_model_chat_template()) { - LOG_ERROR("The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {}); - sparams.chat_template = "chatml"; - } - } + const int n_idle_slots = result.data["idle"]; + const int n_processing_slots = result.data["processing"]; - // print sample chat example to make it clear which template is used - { - json chat; - chat.push_back({{"role", "system"}, {"content", "You are a helpful assistant"}}); - chat.push_back({{"role", "user"}, {"content", "Hello"}}); - chat.push_back({{"role", "assistant"}, {"content", "Hi there"}}); - chat.push_back({{"role", "user"}, {"content", "How are you?"}}); + json health = { + {"status", "ok"}, + {"slots_idle", n_idle_slots}, + {"slots_processing", n_processing_slots} + }; - const std::string chat_example = format_chat(ctx_server.model, sparams.chat_template, chat); + res.status = 200; // HTTP OK + if (sparams.slots_endpoint && req.has_param("include_slots")) { + health["slots"] = result.data["slots"]; + } - LOG_INFO("chat template", { - {"chat_example", chat_example}, - {"built_in", sparams.chat_template.empty()}, - }); - } + if (n_idle_slots == 0) { + health["status"] = "no slot available"; + if (req.has_param("fail_on_no_slot")) { + res.status = 503; // HTTP Service Unavailable + } + } + + res.set_content(health.dump(), "application/json"); + break; + } + case SERVER_STATE_LOADING_MODEL: + { + res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE)); + } break; + case SERVER_STATE_ERROR: + { + res_error(res, format_error_response("Model failed to load", ERROR_TYPE_SERVER)); + } break; + } + }; // // Middlewares @@ -3098,71 +3115,43 @@ int main(int argc, char ** argv) { return false; }; - // register server middlewares - svr->set_pre_routing_handler([&middleware_validate_api_key](const httplib::Request & req, httplib::Response & res) { - if (!middleware_validate_api_key(req, res)) { - return httplib::Server::HandlerResponse::Handled; - } - return httplib::Server::HandlerResponse::Unhandled; - }); + auto middleware_model_loading = [&sparams, &res_error](const httplib::Request & req, httplib::Response & res, server_state current_state) { - // - // Route handlers (or controllers) - // + // If path is not an health check skip validation + if (req.path == "/health" || req.path == "/v1/health") { + return true; + } - const auto handle_health = [&](const httplib::Request & req, httplib::Response & res) { - server_state current_state = state.load(); switch (current_state) { - case SERVER_STATE_READY: - { - // request slots data using task queue - server_task task; - task.id = ctx_server.queue_tasks.get_new_id(); - task.type = SERVER_TASK_TYPE_METRICS; - task.id_target = -1; - - ctx_server.queue_results.add_waiting_task_id(task.id); - ctx_server.queue_tasks.post(task); - - // get the result - server_task_result result = ctx_server.queue_results.recv(task.id); - ctx_server.queue_results.remove_waiting_task_id(task.id); - - const int n_idle_slots = result.data["idle"]; - const int n_processing_slots = result.data["processing"]; - - json health = { - {"status", "ok"}, - {"slots_idle", n_idle_slots}, - {"slots_processing", n_processing_slots} - }; - - res.status = 200; // HTTP OK - if (sparams.slots_endpoint && req.has_param("include_slots")) { - health["slots"] = result.data["slots"]; - } - - if (n_idle_slots == 0) { - health["status"] = "no slot available"; - if (req.has_param("fail_on_no_slot")) { - res.status = 503; // HTTP Service Unavailable - } - } - - res.set_content(health.dump(), "application/json"); - break; - } case SERVER_STATE_LOADING_MODEL: { res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE)); + return false; } break; case SERVER_STATE_ERROR: { res_error(res, format_error_response("Model failed to load", ERROR_TYPE_SERVER)); + return false; } break; + default: + break; } + return true; }; + // register server middlewares + svr->set_pre_routing_handler([&middleware_validate_api_key, &state, &middleware_model_loading](const httplib::Request & req, httplib::Response & res) { + server_state current_state = state.load(); + if (!middleware_model_loading(req, res, current_state) || !middleware_validate_api_key(req, res)) { + return httplib::Server::HandlerResponse::Handled; + } + return httplib::Server::HandlerResponse::Unhandled; + }); + + // + // Route handlers (or controllers) + // + const auto handle_slots = [&](const httplib::Request &, httplib::Response & res) { if (!sparams.slots_endpoint) { res_error(res, format_error_response("This server does not support slots endpoint.", ERROR_TYPE_NOT_SUPPORTED)); @@ -3482,25 +3471,6 @@ int main(int argc, char ** argv) { } }; - const auto handle_models = [¶ms, &model_meta](const httplib::Request & req, httplib::Response & res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - - json models = { - {"object", "list"}, - {"data", { - { - {"id", params.model_alias}, - {"object", "model"}, - {"created", std::time(0)}, - {"owned_by", "llamacpp"}, - {"meta", model_meta} - }, - }} - }; - - res.set_content(models.dump(), "application/json; charset=utf-8"); - }; - const auto handle_chat_completions = [&ctx_server, &sparams, &res_error](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), sparams.chat_template); @@ -3714,33 +3684,31 @@ int main(int argc, char ** argv) { : responses[0]; return res.set_content(root.dump(), "application/json; charset=utf-8"); }; + + const auto handle_models = [¶ms, &ctx_server](const httplib::Request & req, httplib::Response & res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) { - return [content, len, mime_type](const httplib::Request &, httplib::Response & res) { - res.set_content(reinterpret_cast(content), len, mime_type); - return false; - }; - }; - - // - // Router - // + json model_meta = ctx_server.model_meta(); - // register static assets routes - if (!sparams.public_path.empty()) { - // Set the base directory for serving static files - svr->set_base_dir(sparams.public_path); - } + json models = { + {"object", "list"}, + {"data", { + { + {"id", params.model_alias}, + {"object", "model"}, + {"created", std::time(0)}, + {"owned_by", "llamacpp"}, + {"meta", model_meta} + }, + }} + }; - // using embedded static files - svr->Get("/", handle_static_file(index_html, index_html_len, "text/html; charset=utf-8")); - svr->Get("/index.js", handle_static_file(index_js, index_js_len, "text/javascript; charset=utf-8")); - svr->Get("/completion.js", handle_static_file(completion_js, completion_js_len, "text/javascript; charset=utf-8")); - svr->Get("/json-schema-to-grammar.mjs", handle_static_file( - json_schema_to_grammar_mjs, json_schema_to_grammar_mjs_len, "text/javascript; charset=utf-8")); + res.set_content(models.dump(), "application/json; charset=utf-8"); + }; // register API routes svr->Get ("/health", handle_health); + svr->Get ("/v1/health", handle_health); svr->Get ("/slots", handle_slots); svr->Get ("/metrics", handle_metrics); svr->Get ("/props", handle_props); @@ -3756,11 +3724,31 @@ int main(int argc, char ** argv) { svr->Post("/v1/embeddings", handle_embeddings); svr->Post("/tokenize", handle_tokenize); svr->Post("/detokenize", handle_detokenize); + if (!sparams.slot_save_path.empty()) { // only enable slot endpoints if slot_save_path is set svr->Post("/slots/:id_slot", handle_slots_action); } + auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) { + return [content, len, mime_type](const httplib::Request &, httplib::Response & res) { + res.set_content(reinterpret_cast(content), len, mime_type); + return false; + }; + }; + + if (!sparams.public_path.empty()) { + // Set the base directory for serving static files + svr->set_base_dir(sparams.public_path); + } + + // using embedded static files + svr->Get("/", handle_static_file(index_html, index_html_len, "text/html; charset=utf-8")); + svr->Get("/index.js", handle_static_file(index_js, index_js_len, "text/javascript; charset=utf-8")); + svr->Get("/completion.js", handle_static_file(completion_js, completion_js_len, "text/javascript; charset=utf-8")); + svr->Get("/json-schema-to-grammar.mjs", handle_static_file( + json_schema_to_grammar_mjs, json_schema_to_grammar_mjs_len, "text/javascript; charset=utf-8")); + // // Start the server // @@ -3782,6 +3770,41 @@ int main(int argc, char ** argv) { return 0; }); + + // load the model + if (!ctx_server.load_model(params)) { + state.store(SERVER_STATE_ERROR); + return 1; + } else { + ctx_server.init(); + state.store(SERVER_STATE_READY); + } + + LOG_INFO("model loaded", {}); + + // if a custom chat template is not supplied, we will use the one that comes with the model (if any) + if (sparams.chat_template.empty()) { + if (!ctx_server.validate_model_chat_template()) { + LOG_ERROR("The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {}); + sparams.chat_template = "chatml"; + } + } + + // print sample chat example to make it clear which template is used + { + json chat; + chat.push_back({{"role", "system"}, {"content", "You are a helpful assistant"}}); + chat.push_back({{"role", "user"}, {"content", "Hello"}}); + chat.push_back({{"role", "assistant"}, {"content", "Hi there"}}); + chat.push_back({{"role", "user"}, {"content", "How are you?"}}); + + const std::string chat_example = format_chat(ctx_server.model, sparams.chat_template, chat); + + LOG_INFO("chat template", { + {"chat_example", chat_example}, + {"built_in", sparams.chat_template.empty()}, + }); + } ctx_server.queue_tasks.on_new_task(std::bind( &server_context::process_single_task, &ctx_server, std::placeholders::_1));