From 712f1d5795bfc3a115bb93250d7f1c57650307bb Mon Sep 17 00:00:00 2001 From: irfnfnkemed Date: Fri, 14 Mar 2025 12:22:34 +0800 Subject: [PATCH 01/17] [Grammar] Upgrade xgrammar to latest version - upgrade xgrammar calling to latest API --- cpp/serve/engine.cc | 14 ++++++++------ cpp/serve/request_state.cc | 4 ++-- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 2f09219392..3f54878b6b 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -463,9 +463,11 @@ class EngineImpl : public Engine { ModelWorkspace{model->AllocEmbeddingTensor(), model->AllocHiddenStatesTensor()}); } // - Initialize tokenizer and grammar + n->tokenizer_ = Tokenizer::FromPath(engine_config->model, GetTokenizerInfo(model_configs[0])); n->token_table_ = n->tokenizer_->PostProcessedTokenTable(); - n->cached_grammar_compiler_ = xgrammar::CachedGrammarCompiler(n->token_table_); + // TODO: check 'vocab_size' of TokenizerInfo + n->grammar_compiler_ = xgrammar::GrammarCompiler(xgrammar::TokenizerInfo(n->token_table_)); // - Create the logit processor and sampler, and // the DraftTokenWorkspaceManager for speculative decoding. int max_num_tokens = engine_config->max_num_sequence; @@ -975,13 +977,13 @@ class EngineImpl : public Engine { * is not JSON, return std::nullopt. */ std::optional GetGrammarFromResponseFormat( const ResponseFormat& response_format) { + // TODO: add other grammar type if (response_format.type != "json_object") { return std::nullopt; } else if (!response_format.schema) { - return cached_grammar_compiler_.GetCompiledGrammarForJSON(); + return grammar_compiler_.CompileBuiltinJSONGrammar(); } else { - return cached_grammar_compiler_.GetCompiledGrammarForJSONSchema( - response_format.schema.value()); + return grammar_compiler_.CompileJSONSchema(response_format.schema.value()); } } @@ -992,8 +994,8 @@ class EngineImpl : public Engine { // internal tokenizer Tokenizer tokenizer_; std::vector token_table_; - // Cached grammar compiler for grammar matching. - xgrammar::CachedGrammarCompiler cached_grammar_compiler_; + // Grammar compiler for grammar matching. + xgrammar::GrammarCompiler grammar_compiler_; // Models Array models_; // Device that the models run on. diff --git a/cpp/serve/request_state.cc b/cpp/serve/request_state.cc index 17e02ee85b..4771f14d3b 100644 --- a/cpp/serve/request_state.cc +++ b/cpp/serve/request_state.cc @@ -24,7 +24,7 @@ RequestModelState::RequestModelState( if (compiled_grammar.has_value()) { // TODO(yixin): set rollback limit to a configurable value. n->grammar_matcher = - xgrammar::GrammarMatcher(compiled_grammar.value(), std::nullopt, false, std::nullopt, 10); + xgrammar::GrammarMatcher(compiled_grammar.value(), std::nullopt, false, 10); } n->request = std::move(request); @@ -44,7 +44,7 @@ bool RequestModelStateNode::RequireNextTokenBitmask() { return grammar_matcher.h void RequestModelStateNode::GetNextTokenBitmask(DLTensor* bitmask) { ICHECK(grammar_matcher.has_value()); - grammar_matcher->GetNextTokenBitmask(bitmask); + grammar_matcher->FillNextTokenBitmask(bitmask); } void RequestModelStateNode::CommitToken(SampleResult sampled_token) { From 6e37a0b6e2bbc346b3eae68adbc9a5bd48db053d Mon Sep 17 00:00:00 2001 From: irfnfnkemed Date: Thu, 20 Mar 2025 10:37:07 +0800 Subject: [PATCH 02/17] Updated submodule-xgrammar references --- 3rdparty/xgrammar | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/xgrammar b/3rdparty/xgrammar index d4f57c440f..f4badadfe7 160000 --- a/3rdparty/xgrammar +++ b/3rdparty/xgrammar @@ -1 +1 @@ -Subproject commit d4f57c440f3da8e7330a1e5d50bba9c31f9433ea +Subproject commit f4badadfe7363e4e09fafcde3c253a46dd5d6e97 From f70a37a4870c49d16ec328a85dfad01c83396d50 Mon Sep 17 00:00:00 2001 From: irfnfnkemed Date: Thu, 20 Mar 2025 11:05:46 +0800 Subject: [PATCH 03/17] [fix] modify cmake since xgrammar use nanobind to replace pybind --- 3rdparty/tvm | 2 +- CMakeLists.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 9c894f78fd..7752c9221c 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 9c894f78fdef156263ced19eed67e79203ca4a11 +Subproject commit 7752c9221c768617af01711f8ad155e0a1cd409e diff --git a/CMakeLists.txt b/CMakeLists.txt index a010a05192..99926be832 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -71,7 +71,7 @@ add_subdirectory(${TOKENZIER_CPP_PATH} tokenizers EXCLUDE_FROM_ALL) set(XGRAMMAR_PATH 3rdparty/xgrammar) tvm_file_glob(GLOB_RECURSE MLC_LLM_SRCS cpp/*.cc) tvm_file_glob(GLOB_RECURSE XGRAMMAR_SRCS ${XGRAMMAR_PATH}/cpp/*.cc) -list(FILTER XGRAMMAR_SRCS EXCLUDE REGEX "${XGRAMMAR_PATH}/cpp/pybind/.*\\.cc") +list(FILTER XGRAMMAR_SRCS EXCLUDE REGEX "${XGRAMMAR_PATH}/cpp/nanobind/.*\\.cc") list(APPEND MLC_LLM_SRCS ${XGRAMMAR_SRCS}) add_library(mlc_llm_objs OBJECT ${MLC_LLM_SRCS}) From 6e7426e8275270c30f138750e88dc0cdcdd3151f Mon Sep 17 00:00:00 2001 From: irfnfnkemed Date: Thu, 20 Mar 2025 11:43:25 +0800 Subject: [PATCH 04/17] [feature] Support tool function calling with Structural Tag with xgrammar - ensure the tool function will be called in expected format using xgrammar - modify RequestResponseFormat: add structural tag according to the tools when building response format - the tool function calling is now constrained by format: parameters - tools call list will be parsed according to the calling format when processing the response - also expose the Structural Tag api of xgrammar to RequestResponseFormat --- cpp/serve/config.cc | 55 ++++++++++++++++- cpp/serve/config.h | 3 + cpp/serve/engine.cc | 23 +++++-- cpp/serve/logit_processor.cc | 2 + .../mlc_llm/protocol/openai_api_protocol.py | 25 +++++++- python/mlc_llm/serve/engine.py | 14 +++++ python/mlc_llm/serve/engine_base.py | 60 +++++++++++++------ 7 files changed, 156 insertions(+), 26 deletions(-) diff --git a/cpp/serve/config.cc b/cpp/serve/config.cc index f7e71e72c9..22c431ba3e 100644 --- a/cpp/serve/config.cc +++ b/cpp/serve/config.cc @@ -10,11 +10,14 @@ #include #include +#include +#include #include "../json_ffi/openai_api_protocol.h" #include "../support/json_parser.h" #include "../support/utils.h" #include "data.h" +#include "tvm/runtime/container/array.h" namespace mlc { namespace llm { @@ -42,13 +45,43 @@ Result ResponseFormat::FromJSON(const picojson::object& config) ResponseFormat res; res.type = json::LookupOrDefault(config, "type", "text"); + if (res.type != "text" && res.type != "function" && res.type != "json_object" && + res.type != "structural_tag") { + return TResult::Error("Uknonwn response_format type " + res.type); + } + std::optional schema = json::LookupOptional(config, "schema"); if (schema.has_value()) { res.schema = schema.value(); } - if (res.type != "text" && res.type != "function" && res.type != "json_object") { - return TResult::Error("Uknonwn response_format type " + res.type); + if (auto tags_obj = json::LookupOptional(config, "tags")) { + auto tags = Array>(); + for (auto tag_obj : tags_obj.value()) { + Array tag = Array(); + std::optional begin = + json::LookupOptional(tag_obj.get(), "begin"); + std::optional schema = + json::LookupOptional(tag_obj.get(), "schema"); + std::optional end = + json::LookupOptional(tag_obj.get(), "end"); + if (!(begin.has_value() && schema.has_value() && end.has_value())) { + return TResult::Error("Miss tag attribute."); + } + tag.push_back(begin.value()); + tag.push_back(schema.value()); + tag.push_back(end.value()); + tags.push_back(std::move(tag)); + } + res.tags = tags; + } + + if (auto triggers_obj = json::LookupOptional(config, "triggers")) { + auto triggers = Array(); + for (auto trigger : triggers_obj.value()) { + triggers.push_back(trigger.get()); + } + res.triggers = triggers; } return TResult::Ok(res); @@ -60,6 +93,24 @@ picojson::object ResponseFormat::AsJSON() const { if (schema.defined()) { config["schema"] = picojson::value(schema.value().operator std::string()); } + if (tags.defined()) { + picojson::array tags_obj = picojson::array(); + for (auto tag : tags.value()) { + picojson::array tag_obj = picojson::array(); + tag_obj.emplace_back(tag[0]); + tag_obj.emplace_back(tag[1]); + tag_obj.emplace_back(tag[2]); + tags_obj.emplace_back(tag_obj); + } + config["tags"] = picojson::value(tags_obj); + } + if (triggers.defined()) { + picojson::array trigger_obj = picojson::array(); + for (std::string trigger : triggers.value()) { + trigger_obj.emplace_back(trigger); + } + config["triggers"] = picojson::value(trigger_obj); + } return config; } diff --git a/cpp/serve/config.h b/cpp/serve/config.h index 9da3ba2517..b48d981cd7 100644 --- a/cpp/serve/config.h +++ b/cpp/serve/config.h @@ -14,6 +14,7 @@ #include "../metadata/model.h" #include "../support/result.h" +#include "tvm/runtime/container/optional.h" namespace mlc { namespace llm { @@ -28,6 +29,8 @@ using namespace tvm::runtime; struct ResponseFormat { String type = "text"; Optional schema = NullOpt; + Optional>> tags = NullOpt; + Optional> triggers = NullOpt; /*! * \brief Create debug config from JSON. * \param config_json The json string for generation config diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 3f54878b6b..08a2ff2963 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -19,8 +19,10 @@ #include #include #include +#include #include #include +#include #include "../support/json_parser.h" #include "../support/result.h" @@ -35,6 +37,7 @@ #include "request.h" #include "request_state.h" #include "sampler/sampler.h" +#include "xgrammar/grammar.h" namespace mlc { namespace llm { @@ -978,12 +981,24 @@ class EngineImpl : public Engine { std::optional GetGrammarFromResponseFormat( const ResponseFormat& response_format) { // TODO: add other grammar type - if (response_format.type != "json_object") { + if (response_format.type == "text") { return std::nullopt; - } else if (!response_format.schema) { - return grammar_compiler_.CompileBuiltinJSONGrammar(); + } else if (response_format.type == "json_object") { + if (!response_format.schema) { + return grammar_compiler_.CompileBuiltinJSONGrammar(); + } else { + return grammar_compiler_.CompileJSONSchema(response_format.schema.value()); + } } else { - return grammar_compiler_.CompileJSONSchema(response_format.schema.value()); + std::vector tags; + std::vector triggers; + for (auto tag : response_format.tags.value()) { + tags.emplace_back(xgrammar::StructuralTagItem{tag[0], tag[1], tag[2]}); + } + for (auto trigger : response_format.triggers.value()) { + triggers.emplace_back(trigger); + } + return grammar_compiler_.CompileStructuralTag(std::move(tags), std::move(triggers)); } } diff --git a/cpp/serve/logit_processor.cc b/cpp/serve/logit_processor.cc index 3c9bf88717..aa3d41751a 100644 --- a/cpp/serve/logit_processor.cc +++ b/cpp/serve/logit_processor.cc @@ -12,6 +12,8 @@ #include #include +#include + namespace mlc { namespace llm { namespace serve { diff --git a/python/mlc_llm/protocol/openai_api_protocol.py b/python/mlc_llm/protocol/openai_api_protocol.py index cb2e1f2852..88147f12d3 100644 --- a/python/mlc_llm/protocol/openai_api_protocol.py +++ b/python/mlc_llm/protocol/openai_api_protocol.py @@ -86,12 +86,33 @@ class ModelResponse(BaseModel): class RequestResponseFormat(BaseModel): - type: Literal["text", "json_object"] = "text" - json_schema: Optional[str] = Field(default=None, alias="schema") + type: Literal["text", "json_object", "structural_tag"] = "text" """This field is named json_schema instead of schema because BaseModel defines a method called schema. During construction of RequestResponseFormat, key "schema" still should be used: `RequestResponseFormat(type="json_object", schema="{}")` """ + json_schema: Optional[str] = Field(default=None, alias="schema") + + """These field are only used for type="structural_tag".""" + tags: Optional[List[Dict[str, str]]] = Field(default=None, alias="tags") + triggers: Optional[List[str]] = Field(default=None, alias="triggers") + + @model_validator(mode="after") + def check_request_response_format(self) -> "RequestResponseFormat": + """Check if the RequestResponseFormat is valid.""" + if self.type == "structural_tag": + if self.tags is None or self.triggers is None: + raise ValueError("structural_tag type must contain keys 'tags' and 'triggers'.") + for tag in self.tags: + if set(tag.keys()) != {"begin", "schema", "end"}: + raise ValueError( + f"Each tag must contain exactly 'begin', 'schema' and 'end' keys. Got keys: {list(tag.keys())}." + ) + elif self.tags is not None or self.triggers is not None: + raise Warning( + "'tags' and 'triggers' attributes should be used when type='structural_tag'" + ) + return self class CompletionRequest(BaseModel): diff --git a/python/mlc_llm/serve/engine.py b/python/mlc_llm/serve/engine.py index 3d9d181b1f..0250db3211 100644 --- a/python/mlc_llm/serve/engine.py +++ b/python/mlc_llm/serve/engine.py @@ -976,6 +976,12 @@ async def _chat_completion( # pylint: disable=too-many-arguments,too-many-local if request_id is None: request_id = f"chatcmpl-{engine_utils.random_uuid()}" + tools = ( + [openai_api_protocol.ChatTool.model_validate(tool) for tool in tools] + if tools is not None + else None + ) + chatcmpl_generator = self._handle_chat_completion( openai_api_protocol.ChatCompletionRequest( messages=[ @@ -1207,6 +1213,10 @@ async def _handle_chat_completion( e : BadRequestError BadRequestError is raised when the request is invalid. """ + request.response_format = engine_base.set_structural_tag_from_tools( + request.tools, request.response_format + ) + ( prompts, generation_cfg, @@ -1764,6 +1774,10 @@ def _handle_chat_completion( e : BadRequestError BadRequestError is raised when the request is invalid. """ + request.response_format = engine_base.set_structural_tag_from_tools( + request.tools, request.response_format + ) + ( prompts, generation_cfg, diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index 1d5303e412..2b7178eb4c 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -7,6 +7,7 @@ import json import numbers import queue +import re import sys import threading from dataclasses import dataclass @@ -1146,29 +1147,52 @@ def create_completion_suffix_response( return response -def convert_function_str_to_json(stringified_calls: str) -> List[Union[Dict, None]]: +def convert_function_str_to_json(stringified_calls: str): """Convert a (possibly list) of function call string to a list of json objects. Return None for invalid function call string.""" + function_calls_json = [] + for call in re.finditer(r"(.*?)", stringified_calls, re.DOTALL): + function_name = call.group(1) + params_str = call.group(2).strip() + params = ast.literal_eval(params_str) + function_calls_json.append({"name": function_name, "arguments": params}) - def parse_function_call(call_str: str): - node = ast.parse(call_str, mode="eval") - call_node = node.body - if isinstance(call_node, ast.Call) and isinstance(call_node.func, ast.Name): - name = call_node.func.id - arguments = {} - for keyword in call_node.keywords: - arguments[keyword.arg] = ast.literal_eval(keyword.value) - return {"name": name, "arguments": arguments} - return None + return function_calls_json - if ( - stringified_calls[0] == "[" and stringified_calls[-1] == "]" - ): # hacky way to check if string list - calls = ast.literal_eval(stringified_calls) + +def set_structural_tag_from_tools( + tools: Optional[List[openai_api_protocol.ChatTool]], + response_format: Optional[openai_api_protocol.RequestResponseFormat], +): + """Add the corresponding structural tag to the response format according to the tools to ensure valid function calling. + Return the updated response format. + """ + if tools is None: + return response_format else: - calls = [stringified_calls] - function_calls_json = [parse_function_call(call_str) for call_str in calls] - return function_calls_json + if response_format is None or response_format.type == "text": + response_format = openai_api_protocol.RequestResponseFormat.model_validate( + {"type": "structural_tag", "tags": [], "triggers": []} + ) + elif response_format.type == "json_object": + response_format.tags = [] + response_format.triggers = [] + + response_format.triggers.append("", + "schema": json.dumps(schema), + "end": "", + } + ) + return response_format def process_function_call_output( From 4f980f3be8a5839e8501aa77feb98aaee36cfcaf Mon Sep 17 00:00:00 2001 From: irfnfnkemed Date: Mon, 24 Mar 2025 14:11:58 +0800 Subject: [PATCH 05/17] [feat] Add Structural-Tag api to RequestResponseFormat - Expose Structural-Tag api, which can be used to standarlize function calling format - Add test script for Structural-Tag (passed on Llama-2-7b-chat-hf-q0f16-MLC and Llama-3-8B-Instruct-q4f16_1-MLC) --- 3rdparty/xgrammar | 2 +- cpp/serve/config.cc | 5 +- cpp/serve/config.h | 3 +- cpp/serve/engine.cc | 3 - cpp/serve/logit_processor.cc | 2 - .../mlc_llm/protocol/openai_api_protocol.py | 1 + python/mlc_llm/serve/engine.py | 14 - python/mlc_llm/serve/engine_base.py | 60 +-- .../server/test_server_structural_tag.py | 429 ++++++++++++++++++ 9 files changed, 451 insertions(+), 68 deletions(-) create mode 100644 tests/python/serve/server/test_server_structural_tag.py diff --git a/3rdparty/xgrammar b/3rdparty/xgrammar index f4badadfe7..dbf200ecde 160000 --- a/3rdparty/xgrammar +++ b/3rdparty/xgrammar @@ -1 +1 @@ -Subproject commit f4badadfe7363e4e09fafcde3c253a46dd5d6e97 +Subproject commit dbf200ecde5dd5467c8320076ee60b1e248b23e0 diff --git a/cpp/serve/config.cc b/cpp/serve/config.cc index 22c431ba3e..aa14a6611d 100644 --- a/cpp/serve/config.cc +++ b/cpp/serve/config.cc @@ -10,14 +10,11 @@ #include #include -#include -#include #include "../json_ffi/openai_api_protocol.h" #include "../support/json_parser.h" #include "../support/utils.h" #include "data.h" -#include "tvm/runtime/container/array.h" namespace mlc { namespace llm { @@ -1124,4 +1121,4 @@ Result ModelsUseKVCache(const std::vector& model_configs } // namespace serve } // namespace llm -} // namespace mlc +} // namespace mlc \ No newline at end of file diff --git a/cpp/serve/config.h b/cpp/serve/config.h index b48d981cd7..f39e1911ba 100644 --- a/cpp/serve/config.h +++ b/cpp/serve/config.h @@ -14,7 +14,6 @@ #include "../metadata/model.h" #include "../support/result.h" -#include "tvm/runtime/container/optional.h" namespace mlc { namespace llm { @@ -451,4 +450,4 @@ inline PrefillMode PrefillModeFromString(const std::string& prefill_mode) { } // namespace llm } // namespace mlc -#endif // MLC_LLM_SERVE_CONFIG_H_ +#endif // MLC_LLM_SERVE_CONFIG_H_ \ No newline at end of file diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 08a2ff2963..0db12299a1 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -19,10 +19,8 @@ #include #include #include -#include #include #include -#include #include "../support/json_parser.h" #include "../support/result.h" @@ -37,7 +35,6 @@ #include "request.h" #include "request_state.h" #include "sampler/sampler.h" -#include "xgrammar/grammar.h" namespace mlc { namespace llm { diff --git a/cpp/serve/logit_processor.cc b/cpp/serve/logit_processor.cc index aa3d41751a..3c9bf88717 100644 --- a/cpp/serve/logit_processor.cc +++ b/cpp/serve/logit_processor.cc @@ -12,8 +12,6 @@ #include #include -#include - namespace mlc { namespace llm { namespace serve { diff --git a/python/mlc_llm/protocol/openai_api_protocol.py b/python/mlc_llm/protocol/openai_api_protocol.py index 88147f12d3..5b617810df 100644 --- a/python/mlc_llm/protocol/openai_api_protocol.py +++ b/python/mlc_llm/protocol/openai_api_protocol.py @@ -112,6 +112,7 @@ def check_request_response_format(self) -> "RequestResponseFormat": raise Warning( "'tags' and 'triggers' attributes should be used when type='structural_tag'" ) + return self diff --git a/python/mlc_llm/serve/engine.py b/python/mlc_llm/serve/engine.py index 0250db3211..3d9d181b1f 100644 --- a/python/mlc_llm/serve/engine.py +++ b/python/mlc_llm/serve/engine.py @@ -976,12 +976,6 @@ async def _chat_completion( # pylint: disable=too-many-arguments,too-many-local if request_id is None: request_id = f"chatcmpl-{engine_utils.random_uuid()}" - tools = ( - [openai_api_protocol.ChatTool.model_validate(tool) for tool in tools] - if tools is not None - else None - ) - chatcmpl_generator = self._handle_chat_completion( openai_api_protocol.ChatCompletionRequest( messages=[ @@ -1213,10 +1207,6 @@ async def _handle_chat_completion( e : BadRequestError BadRequestError is raised when the request is invalid. """ - request.response_format = engine_base.set_structural_tag_from_tools( - request.tools, request.response_format - ) - ( prompts, generation_cfg, @@ -1774,10 +1764,6 @@ def _handle_chat_completion( e : BadRequestError BadRequestError is raised when the request is invalid. """ - request.response_format = engine_base.set_structural_tag_from_tools( - request.tools, request.response_format - ) - ( prompts, generation_cfg, diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index 2b7178eb4c..1d5303e412 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -7,7 +7,6 @@ import json import numbers import queue -import re import sys import threading from dataclasses import dataclass @@ -1147,52 +1146,29 @@ def create_completion_suffix_response( return response -def convert_function_str_to_json(stringified_calls: str): +def convert_function_str_to_json(stringified_calls: str) -> List[Union[Dict, None]]: """Convert a (possibly list) of function call string to a list of json objects. Return None for invalid function call string.""" - function_calls_json = [] - for call in re.finditer(r"(.*?)", stringified_calls, re.DOTALL): - function_name = call.group(1) - params_str = call.group(2).strip() - params = ast.literal_eval(params_str) - function_calls_json.append({"name": function_name, "arguments": params}) - - return function_calls_json + def parse_function_call(call_str: str): + node = ast.parse(call_str, mode="eval") + call_node = node.body + if isinstance(call_node, ast.Call) and isinstance(call_node.func, ast.Name): + name = call_node.func.id + arguments = {} + for keyword in call_node.keywords: + arguments[keyword.arg] = ast.literal_eval(keyword.value) + return {"name": name, "arguments": arguments} + return None -def set_structural_tag_from_tools( - tools: Optional[List[openai_api_protocol.ChatTool]], - response_format: Optional[openai_api_protocol.RequestResponseFormat], -): - """Add the corresponding structural tag to the response format according to the tools to ensure valid function calling. - Return the updated response format. - """ - if tools is None: - return response_format + if ( + stringified_calls[0] == "[" and stringified_calls[-1] == "]" + ): # hacky way to check if string list + calls = ast.literal_eval(stringified_calls) else: - if response_format is None or response_format.type == "text": - response_format = openai_api_protocol.RequestResponseFormat.model_validate( - {"type": "structural_tag", "tags": [], "triggers": []} - ) - elif response_format.type == "json_object": - response_format.tags = [] - response_format.triggers = [] - - response_format.triggers.append("", - "schema": json.dumps(schema), - "end": "", - } - ) - return response_format + calls = [stringified_calls] + function_calls_json = [parse_function_call(call_str) for call_str in calls] + return function_calls_json def process_function_call_output( diff --git a/tests/python/serve/server/test_server_structural_tag.py b/tests/python/serve/server/test_server_structural_tag.py new file mode 100644 index 0000000000..5a7c93e4f5 --- /dev/null +++ b/tests/python/serve/server/test_server_structural_tag.py @@ -0,0 +1,429 @@ +# pylint: disable=line-too-long +""" +Test script for structural tag in chat completion. To run this script, use the following command: +- start a new shell session, run + mlc_llm serve --model "YOUR_MODEL" (e.g. ./dist/Llama-2-7b-chat-hf-q0f16-MLC) +- start another shell session, run this file + MLC_SERVE_MODEL="YOUR_MODEL" python tests/python/serve/server/test_server_structural_tag.py +""" + +# pylint: disable=missing-function-docstring,too-many-arguments,too-many-locals,too-many-branches +import json +import os +import re +from typing import Dict, List, Optional, Tuple + +import pytest +import requests + +OPENAI_V1_CHAT_COMPLETION_URL = "http://127.0.0.1:8000/v1/chat/completions" + + +def check_openai_nonstream_response( + response: Dict, + *, + model: str, + object_str: str, + num_choices: int, + finish_reason: List[str], + completion_tokens: Optional[int] = None, +): + assert response["model"] == model + assert response["object"] == object_str + + choices = response["choices"] + assert isinstance(choices, list) + assert len(choices) == num_choices + for idx, choice in enumerate(choices): + assert choice["index"] == idx + assert choice["finish_reason"] in finish_reason + + find_format_start = set() + beg_tag_start = set() + message = choice["message"]["content"] + print("Outputs:\n-----------") + print(message, flush=True) + pattern1 = r"(.*?)\|(.*?)\|End<---(.*?)>" + pattern2 = r"(.*?)\|(.*?)\|End<---(.*?)>" + # check format + for match in re.finditer(pattern1, message): + find_format_start.add(match.start()) + check_format(match.group(1), match.group(3), "CALL", match.group(2)) + for match in re.finditer(pattern2, message): + find_format_start.add(match.start()) + check_format(match.group(1), match.group(3), "call", match.group(2)) + for match in re.finditer(r"", message): + beg_tag_start.add(match.start()) + for match in re.finditer(r"", message): + beg_tag_start.add(match.start()) + assert find_format_start == beg_tag_start + + usage = response["usage"] + assert isinstance(usage, dict) + assert usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"] + assert usage["prompt_tokens"] > 0 + + if completion_tokens is not None: + assert usage["completion_tokens"] == completion_tokens + + +def check_openai_stream_response( + responses: List[Dict], + *, + model: str, + object_str: str, + num_choices: int, + finish_reason: str, + echo_prompt: Optional[str] = None, + suffix: Optional[str] = None, + stop: Optional[List[str]] = None, + require_substr: Optional[List[str]] = None, +): + assert len(responses) > 0 + + finished = [False for _ in range(num_choices)] + outputs = ["" for _ in range(num_choices)] + for response in responses: + assert response["model"] == model + assert response["object"] == object_str + + choices = response["choices"] + assert isinstance(choices, list) + assert len(choices) == num_choices + for idx, choice in enumerate(choices): + assert choice["index"] == idx + + delta = choice["delta"] + assert delta["role"] == "assistant" + assert isinstance(delta["content"], str) + outputs[idx] += delta["content"] + + if finished[idx]: + assert choice["finish_reason"] == finish_reason + elif choice["finish_reason"] is not None: + assert choice["finish_reason"] == finish_reason + finished[idx] = True + + for output in outputs: + if echo_prompt is not None: + assert output.startswith(echo_prompt) + if suffix is not None: + assert output.endswith(suffix) + if stop is not None: + for stop_str in stop: + assert stop_str not in output + if require_substr is not None: + for substr in require_substr: + assert substr in output + find_format_start = set() + beg_tag_start = set() + print("Outputs:\n-----------") + print(output, flush=True) + pattern1 = r"(.*?)\|(.*?)\|End<---(.*?)>" + pattern2 = r"(.*?)\|(.*?)\|End<---(.*?)>" + # check format + for match in re.finditer(pattern1, output): + find_format_start.add(match.start()) + check_format(match.group(1), match.group(3), "CALL", match.group(2)) + for match in re.finditer(pattern2, output): + find_format_start.add(match.start()) + check_format(match.group(1), match.group(3), "call", match.group(2)) + for match in re.finditer(r"", output): + beg_tag_start.add(match.start()) + for match in re.finditer(r"", output): + beg_tag_start.add(match.start()) + assert find_format_start == beg_tag_start + + +def check_format(name_beg: str, name_end: str, beg_tag: str, schema: str): + schema = json.loads(schema) + assert "hash_code" in schema + hash_code = schema["hash_code"] + assert hash_code in CHECK_INFO + info = CHECK_INFO[hash_code] + assert name_beg == info["name"] + assert name_end == info["name"] + assert beg_tag == info["beg_tag"] + for key in info["required"]: + assert key in schema + + +# NOTE: the end-tag format and the hash_code number is been hidden in the SYSTEM_PROMPT. +# By checking whether the end tag and hash code can be generated correctly without any prompts, the correctness of the structural tag can be verified. + +SYSTEM_PROMPT = { + "role": "system", + "content": """ +# Tool Instructions +- Always execute python code in messages that you share. +- When looking for real time information use relevant functions if available else fallback to brave_search +You have access to the following functions: +Use the function 'get_current_weather' to: Get the current weather in a given location +{ + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city to find the weather for, e.g. 'San Francisco'", + }, + "state": { + "type": "string", + "description": "the two-letter abbreviation for the state that the city is" + " in, e.g. 'CA' which would mean 'California'", + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"], + }, + "hash_code": { + "type": "string", + }, + }, + "required": ["city", "state", "unit", "hash_code"], + }, +} +Use the function 'get_current_date' to: Get the current date and time for a given timezone +{ + "name": "get_current_date", + "description": "Get the current date and time for a given timezone", + "parameters": { + "type": "object", + "properties": { + "timezone": { + "type": "string", + "description": "The timezone to fetch the current date and time for, e.g. 'America/New_York'", + }, + "hash_code": { + "type": "string", + }, + }, + "required": ["timezone", "hash_code"], + }, +} +If a you choose to call a function ONLY reply in the following format: +<{start_tag}--->{function_name}|{parameters}|{end_tag}<---{function_name}> +where +start_tag => ` a JSON dict with the function argument name as key and function argument value as value. +Here is an example, +example_function_name|{"example_name": "example_value"}... +or +example_function_name|{"example_name": "example_value"}... +Reminder: +- Function calls MUST follow the specified format +- Required parameters MUST be specified +You are a helpful assistant.""", +} + + +STRUCTURAL_TAGS = { + "triggers": ["", ""], + "tags": [ + { + "begin": "get_current_weather|", + "schema": json.dumps( + { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city to find the weather for, e.g. 'San Francisco'", + }, + "state": { + "type": "string", + "description": "the two-letter abbreviation for the state that the city is" + " in, e.g. 'CA' which would mean 'California'", + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"], + }, + "hash_code": {"const": 1234}, + }, + "required": ["city", "state", "unit", "hash_code"], + } + ), + "end": "|End<---get_current_weather>", + }, + { + "begin": "get_current_date|", + "schema": json.dumps( + { + "type": "object", + "properties": { + "timezone": { + "type": "string", + "description": "The timezone to fetch the current date and time for, e.g. 'America/New_York'", + }, + "hash_code": {"const": 2345}, + }, + "required": ["timezone", "hash_code"], + } + ), + "end": "|End<---get_current_date>", + }, + { + "begin": "get_current_weather|", + "schema": json.dumps( + { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city to find the weather for, e.g. 'San Francisco'", + }, + "state": { + "type": "string", + "description": "the two-letter abbreviation for the state that the city is" + " in, e.g. 'CA' which would mean 'California'", + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"], + }, + "hash_code": {"const": 3456}, + }, + "required": ["city", "state", "unit", "hash_code"], + } + ), + "end": "|End<---get_current_weather>", + }, + { + "begin": "get_current_date|", + "schema": json.dumps( + { + "type": "object", + "properties": { + "timezone": { + "type": "string", + "description": "The timezone to fetch the current date and time for, e.g. 'America/New_York'", + }, + "hash_code": {"const": 4567}, + }, + "required": ["timezone", "hash_code"], + } + ), + "end": "|End<---get_current_date>", + }, + ], +} + +CHECK_INFO = { + 1234: { + "name": "get_current_weather", + "beg_tag": "CALL", + "required": ["city", "state", "unit", "hash_code"], + }, + 2345: { + "name": "get_current_date", + "beg_tag": "CALL", + "required": ["timezone", "hash_code"], + }, + 3456: { + "name": "get_current_weather", + "beg_tag": "call", + "required": ["city", "state", "unit", "hash_code"], + }, + 4567: { + "name": "get_current_date", + "beg_tag": "call", + "required": ["timezone", "hash_code"], + }, +} + + +CHAT_COMPLETION_MESSAGES = [ + # messages #0 + [ + SYSTEM_PROMPT, + { + "role": "user", + "content": "You are in New York. Please get the current date and time.", + }, + ], + # messages #1 + [ + SYSTEM_PROMPT, + { + "role": "user", + "content": "You are in New York. Please get the current weather.", + }, + ], + # messages #2 + [ + SYSTEM_PROMPT, + { + "role": "user", + "content": "You are in New York. Please get the current date and time, and the weather.", + }, + ], +] + + +@pytest.mark.parametrize("stream", [False, True]) +@pytest.mark.parametrize("messages", CHAT_COMPLETION_MESSAGES) +def test_openai_v1_chat_completion_structural_tag( + served_model: str, + launch_server, # pylint: disable=unused-argument + stream: bool, + messages: List[Dict[str, str]], +): + # `served_model` and `launch_server` are pytest fixtures + # defined in conftest.py. + + payload = { + "model": served_model, + "messages": messages, + "stream": stream, + "response_format": { + "type": "structural_tag", + "tags": STRUCTURAL_TAGS["tags"], + "triggers": STRUCTURAL_TAGS["triggers"], + }, + "max_tokens": 1024, + } + + response = requests.post(OPENAI_V1_CHAT_COMPLETION_URL, json=payload, timeout=60) + if not stream: + check_openai_nonstream_response( + response.json(), + model=served_model, + object_str="chat.completion", + num_choices=1, + finish_reason=["stop"], + ) + else: + responses = [] + for chunk in response.iter_lines(chunk_size=512): + if not chunk or chunk == b"data: [DONE]": + continue + responses.append(json.loads(chunk.decode("utf-8")[6:])) + check_openai_stream_response( + responses, + model=served_model, + object_str="chat.completion.chunk", + num_choices=1, + finish_reason="stop", + ) + + print(f"-----------\nCheck for stream={stream} is passed!\n") + + +if __name__ == "__main__": + MODEL = os.environ.get("MLC_SERVE_MODEL") + if MODEL is None: + raise ValueError( + 'Environment variable "MLC_SERVE_MODEL" not found. ' + "Please set it to model compiled by MLC LLM " + "(e.g., `./dist/Llama-2-7b-chat-hf-q0f16-MLC`) " + ) + + for msg in CHAT_COMPLETION_MESSAGES: + test_openai_v1_chat_completion_structural_tag(MODEL, None, stream=False, messages=msg) + test_openai_v1_chat_completion_structural_tag(MODEL, None, stream=True, messages=msg) From 8b70dd7565c7ae0bdc8f05e090bd5244acf2bbfc Mon Sep 17 00:00:00 2001 From: irfnfnkemed Date: Mon, 24 Mar 2025 15:56:33 +0800 Subject: [PATCH 06/17] [fix] type annotation in test scripts --- python/mlc_llm/protocol/openai_api_protocol.py | 3 ++- .../serve/server/test_server_structural_tag.py | 15 +++++++++------ 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/python/mlc_llm/protocol/openai_api_protocol.py b/python/mlc_llm/protocol/openai_api_protocol.py index 5b617810df..eb6d6c7e7e 100644 --- a/python/mlc_llm/protocol/openai_api_protocol.py +++ b/python/mlc_llm/protocol/openai_api_protocol.py @@ -106,7 +106,8 @@ def check_request_response_format(self) -> "RequestResponseFormat": for tag in self.tags: if set(tag.keys()) != {"begin", "schema", "end"}: raise ValueError( - f"Each tag must contain exactly 'begin', 'schema' and 'end' keys. Got keys: {list(tag.keys())}." + "Each tag must contain exactly 'begin', 'schema' and 'end' keys." + f"Got keys: {list(tag.keys())}." ) elif self.tags is not None or self.triggers is not None: raise Warning( diff --git a/tests/python/serve/server/test_server_structural_tag.py b/tests/python/serve/server/test_server_structural_tag.py index 5a7c93e4f5..f34df96d31 100644 --- a/tests/python/serve/server/test_server_structural_tag.py +++ b/tests/python/serve/server/test_server_structural_tag.py @@ -11,7 +11,7 @@ import json import os import re -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional import pytest import requests @@ -136,16 +136,21 @@ def check_openai_stream_response( def check_format(name_beg: str, name_end: str, beg_tag: str, schema: str): - schema = json.loads(schema) + try: + paras: Dict[str, Any] = json.loads(schema) + except json.JSONDecodeError as e: + print(f"Invalid JSON format: {e}") + assert False + assert "hash_code" in paras assert "hash_code" in schema - hash_code = schema["hash_code"] + hash_code = paras["hash_code"] assert hash_code in CHECK_INFO info = CHECK_INFO[hash_code] assert name_beg == info["name"] assert name_end == info["name"] assert beg_tag == info["beg_tag"] for key in info["required"]: - assert key in schema + assert key in paras # NOTE: the end-tag format and the hash_code number is been hidden in the SYSTEM_PROMPT. @@ -219,7 +224,6 @@ def check_format(name_beg: str, name_end: str, beg_tag: str, schema: str): You are a helpful assistant.""", } - STRUCTURAL_TAGS = { "triggers": ["", ""], "tags": [ @@ -337,7 +341,6 @@ def check_format(name_beg: str, name_end: str, beg_tag: str, schema: str): }, } - CHAT_COMPLETION_MESSAGES = [ # messages #0 [ From 7df3f34b6c5f88499ca85083bd8a58ae21bc3e9a Mon Sep 17 00:00:00 2001 From: irfnfnkemed Date: Wed, 26 Mar 2025 20:41:01 +0800 Subject: [PATCH 07/17] [feat] Support tool function calls under strict format constraints - Add "tool_call_format" attribute in EngineConfig, which determines the tool calls format - Add "strict" attribute in ChatFunction, which is aligned to OpenAI format - Set system prompt according to tool_call_format - Set structural tag to ensure strict func calls - Parse output to json-style func calls - TODO: Now only supports format {PARA} --- .../mlc_llm/protocol/conversation_protocol.py | 23 ++++ .../mlc_llm/protocol/openai_api_protocol.py | 4 +- python/mlc_llm/serve/config.py | 8 ++ python/mlc_llm/serve/engine.py | 16 ++- python/mlc_llm/serve/engine_base.py | 114 ++++++++++++++---- .../serve/entrypoints/openai_entrypoints.py | 2 +- 6 files changed, 140 insertions(+), 27 deletions(-) diff --git a/python/mlc_llm/protocol/conversation_protocol.py b/python/mlc_llm/protocol/conversation_protocol.py index 71738efeef..4963f17a1f 100644 --- a/python/mlc_llm/protocol/conversation_protocol.py +++ b/python/mlc_llm/protocol/conversation_protocol.py @@ -81,6 +81,8 @@ class Conversation(BaseModel): function_string: str = "" # whether using function calling or not, helps check for output message format in API call use_function_calling: bool = False + # Tool function call format mode + tool_call_format: str = "default" def __init__(self, role_templates: Optional[Dict[str, str]] = None, **kwargs): # Defaults templates which would be overridden by model specific templates @@ -124,6 +126,7 @@ def as_prompt(self, config=None) -> List[Any]: from ..serve import data # pylint: disable=import-outside-toplevel # - Get the system message. + self.set_tool_call_format_in_system_message() system_msg = self.system_template.replace( MessagePlaceholders.SYSTEM.value, self.system_message ) @@ -195,6 +198,26 @@ def as_prompt(self, config=None) -> List[Any]: return prompt + def set_tool_call_format_in_system_message(self): + """Add tool function information and call format to the system message.""" + if self.tool_call_format == "default": + tool_call_instruct = ( + "Tool Instructions:" + f"You have access to the following tool functions: {MessagePlaceholders.FUNCTION.value}" + "If a you choose to call a function, you should ONLY reply in the following format:" + "`{parameters(JSON dict)}`" + "Here is an example," + '` {"location": "Pittsburgh"} `' + "Reminder:" + "- Function calls MUST follow the specified format" + "- Required parameters MUST be specified" + ) + self.system_message += tool_call_instruct + elif self.tool_call_format == "python": + raise ValueError("TODO: Not supported yet.") + else: + raise ValueError("Unknown tool calling format.") + def _get_url_from_item(item: Dict) -> str: image_url: str diff --git a/python/mlc_llm/protocol/openai_api_protocol.py b/python/mlc_llm/protocol/openai_api_protocol.py index eb6d6c7e7e..3a7607ad76 100644 --- a/python/mlc_llm/protocol/openai_api_protocol.py +++ b/python/mlc_llm/protocol/openai_api_protocol.py @@ -106,8 +106,7 @@ def check_request_response_format(self) -> "RequestResponseFormat": for tag in self.tags: if set(tag.keys()) != {"begin", "schema", "end"}: raise ValueError( - "Each tag must contain exactly 'begin', 'schema' and 'end' keys." - f"Got keys: {list(tag.keys())}." + f"Each tag must contain exactly 'begin', 'schema' and 'end' keys. Got keys: {list(tag.keys())}." ) elif self.tags is not None or self.triggers is not None: raise Warning( @@ -204,6 +203,7 @@ class ChatFunction(BaseModel): description: Optional[str] = None name: str parameters: Dict + strict: bool = True class ChatTool(BaseModel): diff --git a/python/mlc_llm/serve/config.py b/python/mlc_llm/serve/config.py index 9b82de8350..1438f85f0a 100644 --- a/python/mlc_llm/serve/config.py +++ b/python/mlc_llm/serve/config.py @@ -132,6 +132,13 @@ class EngineConfig: # pylint: disable=too-many-instance-attributes verbose : bool A boolean indicating whether to print logging info in engine. + + tool_call_format: Literal["default", "python"] + The tool function call foramt. + "default" means model will call tool function in format '{parameters(JSON dict)}', + e.g. ' {"location": "Pittsburgh"} '. + "python" means model will call tool function in python-style format, + e.g. 'wolfram_alpha.call(query="solve x^3 - 4x^2 + 6x - 24 = 0")'. """ model: Optional[str] = None @@ -158,6 +165,7 @@ class EngineConfig: # pylint: disable=too-many-instance-attributes prefix_cache_max_num_recycling_seqs: Optional[int] = None prefill_mode: Literal["chunked", "hybrid"] = "hybrid" verbose: bool = True + tool_call_format: Literal["default", "python"] = "default" def asjson(self) -> str: """Return the config in string of JSON format.""" diff --git a/python/mlc_llm/serve/engine.py b/python/mlc_llm/serve/engine.py index 3d9d181b1f..57841dd48f 100644 --- a/python/mlc_llm/serve/engine.py +++ b/python/mlc_llm/serve/engine.py @@ -1056,7 +1056,7 @@ async def _chat_completion( # pylint: disable=too-many-arguments,too-many-local assert all(finish_reason is not None for finish_reason in finish_reasons) use_function_calling, tool_calls_list = engine_base.process_function_call_output( - output_texts, finish_reasons + output_texts, finish_reasons, self.engine_config.tool_call_format ) return engine_base.wrap_chat_completion_response( request_id=request_id, @@ -1207,6 +1207,12 @@ async def _handle_chat_completion( e : BadRequestError BadRequestError is raised when the request is invalid. """ + request.response_format = engine_base.set_structural_tag_from_tools( + request.tools, + request.response_format, + request.tool_choice, + self.engine_config.tool_call_format, + ) ( prompts, generation_cfg, @@ -1617,7 +1623,7 @@ def _chat_completion( # pylint: disable=too-many-arguments,too-many-locals assert all(finish_reason is not None for finish_reason in finish_reasons) use_function_calling, tool_calls_list = engine_base.process_function_call_output( - output_texts, finish_reasons + output_texts, finish_reasons, self.engine_config.tool_call_format ) return engine_base.wrap_chat_completion_response( request_id=request_id, @@ -1764,6 +1770,12 @@ def _handle_chat_completion( e : BadRequestError BadRequestError is raised when the request is invalid. """ + request.response_format = engine_base.set_structural_tag_from_tools( + request.tools, + request.response_format, + request.tool_choice, + self.engine_config.tool_call_format, + ) ( prompts, generation_cfg, diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index 1d5303e412..0ba14ae0f2 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -9,6 +9,7 @@ import queue import sys import threading +import re from dataclasses import dataclass from pathlib import Path from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union @@ -130,6 +131,7 @@ def _convert_model_info(model: ModelInfo) -> Tuple[str, str]: if conversation is None: conversation = mlc_chat_config.conv_template + conversation.tool_call_format = engine_config.tool_call_format if model.model_lib is not None: # do model lib search if the model lib is provided @@ -1146,36 +1148,104 @@ def create_completion_suffix_response( return response -def convert_function_str_to_json(stringified_calls: str) -> List[Union[Dict, None]]: +def set_structural_tag_from_tools( + tools: Optional[List[openai_api_protocol.ChatTool]], + response_format: Optional[openai_api_protocol.RequestResponseFormat], + tool_choice: Optional[Union[Literal["none", "auto"], Dict]], + tool_call_format: str, +): + """Add the corresponding structural tag to the response format according to + the tools to ensure valid function calling. Only set in strict mode of the tool. + Return the updated response format. + """ + if tools is None or (isinstance(tool_choice, str) and tool_choice == "none"): + return response_format + + if response_format is None or response_format.type == "text": + response_format = openai_api_protocol.RequestResponseFormat.model_validate( + {"type": "structural_tag", "tags": [], "triggers": []} + ) + elif response_format.type == "json_object": + response_format.tags = [] + response_format.triggers = [] + + if tool_call_format == "default": + for tool in tools: + if tool.function.strict and ( + tool_choice is None + or (isinstance(tool_choice, str) and tool_choice == "auto") + or ( + isinstance(tool_choice, dict) + and tool.function.name == tool_choice["function"]["name"] + ) + ): + schema = { + "properties": tool.function.parameters["properties"], + "required": tool.function.parameters["required"], + "type": tool.function.parameters["type"], + } + response_format.tags.append( + { + "begin": f"", + "schema": json.dumps(schema), + "end": "", + } + ) + response_format.triggers.append(" List[Union[Dict, None]]: """Convert a (possibly list) of function call string to a list of json objects. Return None for invalid function call string.""" - def parse_function_call(call_str: str): - node = ast.parse(call_str, mode="eval") - call_node = node.body - if isinstance(call_node, ast.Call) and isinstance(call_node.func, ast.Name): - name = call_node.func.id - arguments = {} - for keyword in call_node.keywords: - arguments[keyword.arg] = ast.literal_eval(keyword.value) - return {"name": name, "arguments": arguments} - return None + if tool_call_format == "default": + function_calls_json = [] + pattern = r"(.+?)" + for match in re.finditer(pattern, stringified_calls): + args: Dict = json.loads(match.group(2)) + function_calls_json.append({"name": match.group(1), "arguments": args}) + return function_calls_json + if tool_call_format == "python": + + def parse_function_call(call_str: str): + node = ast.parse(call_str, mode="eval") + call_node = node.body + if isinstance(call_node, ast.Call) and isinstance(call_node.func, ast.Name): + name = call_node.func.id + arguments = {} + for keyword in call_node.keywords: + arguments[keyword.arg] = ast.literal_eval(keyword.value) + return {"name": name, "arguments": arguments} + return None - if ( - stringified_calls[0] == "[" and stringified_calls[-1] == "]" - ): # hacky way to check if string list - calls = ast.literal_eval(stringified_calls) - else: - calls = [stringified_calls] - function_calls_json = [parse_function_call(call_str) for call_str in calls] - return function_calls_json + if ( + stringified_calls[0] == "[" and stringified_calls[-1] == "]" + ): # hacky way to check if string list + calls = ast.literal_eval(stringified_calls) + else: + calls = [stringified_calls] + function_calls_json = [parse_function_call(call_str) for call_str in calls] + return function_calls_json + raise ValueError("Unknown tool calling format.") def process_function_call_output( - output_texts: List[str], finish_reasons: List[str] + output_texts: List[str], finish_reasons: List[str], tool_call_format: str ) -> Tuple[bool, List[List[openai_api_protocol.ChatToolCall]]]: """Process the potential function call results outputted by model, - according to the finish reasons. + according to the finish reasons and the tool calling format. Return whether the output has function call, and the list of tool calls. """ n = len(output_texts) @@ -1184,7 +1254,7 @@ def process_function_call_output( if use_function_calling: for i, output_text in enumerate(output_texts): try: - fn_json_list = convert_function_str_to_json(output_text) + fn_json_list = convert_function_str_to_json(output_text, tool_call_format) except (SyntaxError, ValueError): output_text = "Got an invalid function call output from model" finish_reasons[i] = "error" diff --git a/python/mlc_llm/serve/entrypoints/openai_entrypoints.py b/python/mlc_llm/serve/entrypoints/openai_entrypoints.py index 18a415e413..212fc1273c 100644 --- a/python/mlc_llm/serve/entrypoints/openai_entrypoints.py +++ b/python/mlc_llm/serve/entrypoints/openai_entrypoints.py @@ -222,7 +222,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: assert all(finish_reason is not None for finish_reason in finish_reasons) use_function_calling, tool_calls_list = engine_base.process_function_call_output( - output_texts, finish_reasons + output_texts, finish_reasons, async_engine.engine_config.tool_call_format ) return engine_base.wrap_chat_completion_response( From 379ce428074d83741b977b24ca51009571945d31 Mon Sep 17 00:00:00 2001 From: irfnfnkemed Date: Wed, 26 Mar 2025 21:13:51 +0800 Subject: [PATCH 08/17] [fix] Trigger CI on branch --- .github/workflows/documentation.yaml | 1 + .github/workflows/windows-build.yaml | 1 + 2 files changed, 2 insertions(+) diff --git a/.github/workflows/documentation.yaml b/.github/workflows/documentation.yaml index 6ec3492e2f..1e58dc0da7 100644 --- a/.github/workflows/documentation.yaml +++ b/.github/workflows/documentation.yaml @@ -4,6 +4,7 @@ on: push: branches: - main + - tool_call jobs: test_linux: diff --git a/.github/workflows/windows-build.yaml b/.github/workflows/windows-build.yaml index 560d2f275c..b048890c98 100644 --- a/.github/workflows/windows-build.yaml +++ b/.github/workflows/windows-build.yaml @@ -7,6 +7,7 @@ on: push: branches: - main + - tool_call pull_request: branches: - main From a5143a230cb32a7e19ca1c8f1a411040bd1e09c0 Mon Sep 17 00:00:00 2001 From: irfnfnkemed Date: Wed, 26 Mar 2025 22:17:05 +0800 Subject: [PATCH 09/17] [format] adjust the format --- python/mlc_llm/serve/engine_base.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index 0ba14ae0f2..de46771692 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -7,9 +7,9 @@ import json import numbers import queue +import re import sys import threading -import re from dataclasses import dataclass from pathlib import Path from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union @@ -1148,7 +1148,7 @@ def create_completion_suffix_response( return response -def set_structural_tag_from_tools( +def set_structural_tag_from_tools( # pylint: disable=too-many-boolean-expressions tools: Optional[List[openai_api_protocol.ChatTool]], response_format: Optional[openai_api_protocol.RequestResponseFormat], tool_choice: Optional[Union[Literal["none", "auto"], Dict]], @@ -1193,7 +1193,9 @@ def set_structural_tag_from_tools( ) response_format.triggers.append("{PARA}` pattern = r"(.+?)" for match in re.finditer(pattern, stringified_calls): args: Dict = json.loads(match.group(2)) function_calls_json.append({"name": match.group(1), "arguments": args}) - return function_calls_json - if tool_call_format == "python": - + elif tool_call_format == "python": + # tool calling in python grammar def parse_function_call(call_str: str): node = ast.parse(call_str, mode="eval") call_node = node.body @@ -1237,8 +1239,9 @@ def parse_function_call(call_str: str): else: calls = [stringified_calls] function_calls_json = [parse_function_call(call_str) for call_str in calls] - return function_calls_json - raise ValueError("Unknown tool calling format.") + else: + raise ValueError("Unknown tool calling format.") + return function_calls_json def process_function_call_output( From 56050450aeee20172a62fbb932696d4705114086 Mon Sep 17 00:00:00 2001 From: irfnfnkemed Date: Wed, 26 Mar 2025 23:50:29 +0800 Subject: [PATCH 10/17] [fix] fix some typos --- python/mlc_llm/protocol/conversation_protocol.py | 3 ++- python/mlc_llm/protocol/openai_api_protocol.py | 3 ++- python/mlc_llm/serve/config.py | 7 ++++--- python/mlc_llm/serve/engine_base.py | 2 +- 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/python/mlc_llm/protocol/conversation_protocol.py b/python/mlc_llm/protocol/conversation_protocol.py index 4963f17a1f..25adeac07f 100644 --- a/python/mlc_llm/protocol/conversation_protocol.py +++ b/python/mlc_llm/protocol/conversation_protocol.py @@ -203,7 +203,8 @@ def set_tool_call_format_in_system_message(self): if self.tool_call_format == "default": tool_call_instruct = ( "Tool Instructions:" - f"You have access to the following tool functions: {MessagePlaceholders.FUNCTION.value}" + "You have access to the following tool functions:" + f"{MessagePlaceholders.FUNCTION.value}" "If a you choose to call a function, you should ONLY reply in the following format:" "`{parameters(JSON dict)}`" "Here is an example," diff --git a/python/mlc_llm/protocol/openai_api_protocol.py b/python/mlc_llm/protocol/openai_api_protocol.py index 3a7607ad76..da5f607a2c 100644 --- a/python/mlc_llm/protocol/openai_api_protocol.py +++ b/python/mlc_llm/protocol/openai_api_protocol.py @@ -106,7 +106,8 @@ def check_request_response_format(self) -> "RequestResponseFormat": for tag in self.tags: if set(tag.keys()) != {"begin", "schema", "end"}: raise ValueError( - f"Each tag must contain exactly 'begin', 'schema' and 'end' keys. Got keys: {list(tag.keys())}." + "Each tag must contain exactly 'begin', 'schema' and 'end' keys." + f"Got keys: {list(tag.keys())}." ) elif self.tags is not None or self.triggers is not None: raise Warning( diff --git a/python/mlc_llm/serve/config.py b/python/mlc_llm/serve/config.py index 1438f85f0a..7f24821935 100644 --- a/python/mlc_llm/serve/config.py +++ b/python/mlc_llm/serve/config.py @@ -134,9 +134,10 @@ class EngineConfig: # pylint: disable=too-many-instance-attributes A boolean indicating whether to print logging info in engine. tool_call_format: Literal["default", "python"] - The tool function call foramt. - "default" means model will call tool function in format '{parameters(JSON dict)}', - e.g. ' {"location": "Pittsburgh"} '. + The tool function call format. + "default" means model will call tool function + in format '{parameters(JSON dict)}', + e.g. ' {"location": "Pittsburgh"} '. "python" means model will call tool function in python-style format, e.g. 'wolfram_alpha.call(query="solve x^3 - 4x^2 + 6x - 24 = 0")'. """ diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index de46771692..d337bee8fb 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -1148,7 +1148,7 @@ def create_completion_suffix_response( return response -def set_structural_tag_from_tools( # pylint: disable=too-many-boolean-expressions +def set_structural_tag_from_tools( # pylint: disable=too-many-boolean-expressions tools: Optional[List[openai_api_protocol.ChatTool]], response_format: Optional[openai_api_protocol.RequestResponseFormat], tool_choice: Optional[Union[Literal["none", "auto"], Dict]], From 3f277e661ae5aad78cafd875b0baf4a60f8aa5f4 Mon Sep 17 00:00:00 2001 From: irfnfnkemed Date: Fri, 28 Mar 2025 12:56:18 +0800 Subject: [PATCH 11/17] [feat] Add 'json' format mode for tool calls - rename the 'default' format mode to 'xml' --- .../mlc_llm/protocol/conversation_protocol.py | 30 +++++-- .../mlc_llm/protocol/openai_api_protocol.py | 3 +- python/mlc_llm/serve/config.py | 21 ++--- python/mlc_llm/serve/engine_base.py | 79 ++++++++++++++++--- 4 files changed, 104 insertions(+), 29 deletions(-) diff --git a/python/mlc_llm/protocol/conversation_protocol.py b/python/mlc_llm/protocol/conversation_protocol.py index 25adeac07f..a59aa8c386 100644 --- a/python/mlc_llm/protocol/conversation_protocol.py +++ b/python/mlc_llm/protocol/conversation_protocol.py @@ -195,27 +195,43 @@ def as_prompt(self, config=None) -> List[Any]: ) # Replace with remaining function string placeholders with empty string prompt[0] = prompt[0].replace(MessagePlaceholders.FUNCTION.value, "") - return prompt def set_tool_call_format_in_system_message(self): """Add tool function information and call format to the system message.""" - if self.tool_call_format == "default": + if self.tool_call_format == "xml": tool_call_instruct = ( "Tool Instructions:" - "You have access to the following tool functions:" - f"{MessagePlaceholders.FUNCTION.value}" + f"You have access to the following tool functions: {MessagePlaceholders.FUNCTION.value}" "If a you choose to call a function, you should ONLY reply in the following format:" - "`{parameters(JSON dict)}`" + "`\n{parameters(JSON dict)}\n`" "Here is an example," - '` {"location": "Pittsburgh"} `' + '`\n{"location": "Pittsburgh"}\n`' + "Reminder:" + "- Function calls MUST follow the specified format" + "- Required parameters MUST be specified" + ) + self.system_message += tool_call_instruct + elif self.tool_call_format == "json": + tool_call_instruct = ( + "Tool Instructions:" + f"You have access to the following tool functions: {MessagePlaceholders.FUNCTION.value}" + "If a you choose to call a function, you should ONLY reply in the following format:" + '`{"name": func_name, "parameters": parameters(JSON dict)}`' + "Here is an example," + '`{"name": "get_time", "parameters": {"location": "Pittsburgh"} }`' "Reminder:" "- Function calls MUST follow the specified format" "- Required parameters MUST be specified" ) self.system_message += tool_call_instruct elif self.tool_call_format == "python": - raise ValueError("TODO: Not supported yet.") + tool_call_instruct = ( + "Tool Instructions:" + f"- You have access to the following tool functions: {MessagePlaceholders.FUNCTION.value}" + "- Required parameters MUST be specified" + ) + self.system_message += tool_call_instruct else: raise ValueError("Unknown tool calling format.") diff --git a/python/mlc_llm/protocol/openai_api_protocol.py b/python/mlc_llm/protocol/openai_api_protocol.py index da5f607a2c..3a7607ad76 100644 --- a/python/mlc_llm/protocol/openai_api_protocol.py +++ b/python/mlc_llm/protocol/openai_api_protocol.py @@ -106,8 +106,7 @@ def check_request_response_format(self) -> "RequestResponseFormat": for tag in self.tags: if set(tag.keys()) != {"begin", "schema", "end"}: raise ValueError( - "Each tag must contain exactly 'begin', 'schema' and 'end' keys." - f"Got keys: {list(tag.keys())}." + f"Each tag must contain exactly 'begin', 'schema' and 'end' keys. Got keys: {list(tag.keys())}." ) elif self.tags is not None or self.triggers is not None: raise Warning( diff --git a/python/mlc_llm/serve/config.py b/python/mlc_llm/serve/config.py index 7f24821935..60ad50ab4f 100644 --- a/python/mlc_llm/serve/config.py +++ b/python/mlc_llm/serve/config.py @@ -130,16 +130,19 @@ class EngineConfig: # pylint: disable=too-many-instance-attributes "hybrid" means the hybrid prefill or split-fuse, so that decode step will be converted into prefill. - verbose : bool - A boolean indicating whether to print logging info in engine. - - tool_call_format: Literal["default", "python"] - The tool function call format. - "default" means model will call tool function - in format '{parameters(JSON dict)}', - e.g. ' {"location": "Pittsburgh"} '. + tool_call_format : Literal["xml", "json", "python"] + The tool function call foramt. + "xml" means model will call tool function in xml style format + '\n{parameters(JSON dict)}\n', + e.g. '\n{"location": "Pittsburgh"}\n'. + "json" means model will call tool function in json style format + '{"name": func_name, "parameters": parameters(JSON dict)}', + e.g. '{"name": "get_time", "parameters": {"location": "Pittsburgh"}}'. "python" means model will call tool function in python-style format, e.g. 'wolfram_alpha.call(query="solve x^3 - 4x^2 + 6x - 24 = 0")'. + + verbose : bool + A boolean indicating whether to print logging info in engine. """ model: Optional[str] = None @@ -165,8 +168,8 @@ class EngineConfig: # pylint: disable=too-many-instance-attributes prefix_cache_mode: Literal["disable", "radix"] = "radix" prefix_cache_max_num_recycling_seqs: Optional[int] = None prefill_mode: Literal["chunked", "hybrid"] = "hybrid" + tool_call_format: Literal["xml", "json", "python"] = "xml" verbose: bool = True - tool_call_format: Literal["default", "python"] = "default" def asjson(self) -> str: """Return the config in string of JSON format.""" diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index d337bee8fb..b618feff3c 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -646,6 +646,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals engine_config.mode = mode self._ffi["reload"](engine_config.asjson()) self.engine_config = EngineConfig.from_json(self._ffi["get_complete_engine_config"]()) + self.engine_config.tool_call_format = engine_config.tool_call_format self.max_input_sequence_length = min( self.engine_config.max_single_sequence_length, self.engine_config.max_total_sequence_length, @@ -1148,7 +1149,7 @@ def create_completion_suffix_response( return response -def set_structural_tag_from_tools( # pylint: disable=too-many-boolean-expressions +def set_structural_tag_from_tools( tools: Optional[List[openai_api_protocol.ChatTool]], response_format: Optional[openai_api_protocol.RequestResponseFormat], tool_choice: Optional[Union[Literal["none", "auto"], Dict]], @@ -1169,7 +1170,9 @@ def set_structural_tag_from_tools( # pylint: disable=too-many-boolean-expressio response_format.tags = [] response_format.triggers = [] - if tool_call_format == "default": + if tool_call_format == "xml": + begin_format = "\n" + end = "\n" for tool in tools: if tool.function.strict and ( tool_choice is None @@ -1186,12 +1189,37 @@ def set_structural_tag_from_tools( # pylint: disable=too-many-boolean-expressio } response_format.tags.append( { - "begin": f"", + "begin": begin_format.format(func_name=tool.function.name), "schema": json.dumps(schema), - "end": "", + "end": end, } ) response_format.triggers.append(" List[Union[Dict, None]]: """Convert a (possibly list) of function call string to a list of json objects. Return None for invalid function call string.""" - function_calls_json = [] - if tool_call_format == "default": - # tool calling in format `{PARA}` - pattern = r"(.+?)" - for match in re.finditer(pattern, stringified_calls): - args: Dict = json.loads(match.group(2)) - function_calls_json.append({"name": match.group(1), "arguments": args}) + if tool_call_format == "xml": + # tool calling in format `\n{PARA}\n` + pattern = r"\n(.*?)\n" + matches = re.findall(pattern, stringified_calls, re.DOTALL) + for func_name, args_str in matches: + args: Dict = json.loads(args_str) + function_calls_json.append({"name": func_name, "arguments": args}) + elif tool_call_format == "json": + # tool calling in format `{"name": func_name, "parameters": parameters(JSON dict)}` + starts = [-1] + while True: + index = stringified_calls.find('{"name":', starts[-1] + 1) + if index == -1: + break + else: + starts.append(index) + starts.append(len(stringified_calls)) + for i in range(1, len(starts) - 1): + cnt = 1 + quote = False + for j in range(starts[i] + 1, starts[i + 1]): + if stringified_calls[j] == '"': + quote = not quote + elif not quote: + if stringified_calls[j] == "{": + cnt += 1 + elif stringified_calls[j] == "}": + cnt -= 1 + if cnt == 0: + func_call: Dict = json.loads(stringified_calls[starts[i] : j + 1]) + assert "name" in func_call + assert "parameters" in func_call + function_calls_json.append( + {"name": func_call["name"], "arguments": func_call["parameters"]} + ) + break elif tool_call_format == "python": # tool calling in python grammar def parse_function_call(call_str: str): From 5c998316fdf691dc3ef522d4f80b565897d8fc1a Mon Sep 17 00:00:00 2001 From: irfnfnkemed Date: Tue, 1 Apr 2025 01:24:04 +0800 Subject: [PATCH 12/17] [feat] initial version of the stag-eval script --- .github/workflows/documentation.yaml | 1 + .github/workflows/windows-build.yaml | 1 + cpp/serve/engine.cc | 18 +- cpp/tokenizers/streamer.cc | 5 +- cpp/tokenizers/tokenizers.cc | 5 + cpp/tokenizers/tokenizers.h | 3 + eval/__init__.py | 1 + eval/__main__.py | 474 +++++++++++ eval/api_endpoint.py | 215 +++++ eval/dataset.py | 500 ++++++++++++ eval/request_processor.py | 738 ++++++++++++++++++ eval/request_record.py | 209 +++++ .../mlc_llm/protocol/conversation_protocol.py | 7 +- .../mlc_llm/protocol/openai_api_protocol.py | 2 - python/mlc_llm/serve/config.py | 8 +- python/mlc_llm/serve/engine_base.py | 14 +- .../serve/entrypoints/openai_entrypoints.py | 1 - 17 files changed, 2181 insertions(+), 21 deletions(-) create mode 100644 eval/__init__.py create mode 100644 eval/__main__.py create mode 100644 eval/api_endpoint.py create mode 100644 eval/dataset.py create mode 100644 eval/request_processor.py create mode 100644 eval/request_record.py diff --git a/.github/workflows/documentation.yaml b/.github/workflows/documentation.yaml index 1e58dc0da7..9b0fc4eaee 100644 --- a/.github/workflows/documentation.yaml +++ b/.github/workflows/documentation.yaml @@ -5,6 +5,7 @@ on: branches: - main - tool_call + - eval jobs: test_linux: diff --git a/.github/workflows/windows-build.yaml b/.github/workflows/windows-build.yaml index b048890c98..a9b10039e2 100644 --- a/.github/workflows/windows-build.yaml +++ b/.github/workflows/windows-build.yaml @@ -8,6 +8,7 @@ on: branches: - main - tool_call + - eval pull_request: branches: - main diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 0db12299a1..27765fb65f 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -35,6 +35,7 @@ #include "request.h" #include "request_state.h" #include "sampler/sampler.h" +#include "xgrammar/tokenizer_info.h" namespace mlc { namespace llm { @@ -64,6 +65,9 @@ inline std::optional GetTokenizerInfo(const picojson::object& mod if (tokenizer_info_obj.count("strip_space_in_decode")) { info->strip_space_in_decode = tokenizer_info_obj.at("strip_space_in_decode").get(); } + if (model_config.count("vocab_size")) { + info->vocab_size = model_config.at("vocab_size").get(); + } return TokenizerInfo(info); } @@ -464,10 +468,17 @@ class EngineImpl : public Engine { } // - Initialize tokenizer and grammar - n->tokenizer_ = Tokenizer::FromPath(engine_config->model, GetTokenizerInfo(model_configs[0])); + std::optional info = GetTokenizerInfo(model_configs[0]); + n->tokenizer_ = Tokenizer::FromPath(engine_config->model, info); n->token_table_ = n->tokenizer_->PostProcessedTokenTable(); - // TODO: check 'vocab_size' of TokenizerInfo - n->grammar_compiler_ = xgrammar::GrammarCompiler(xgrammar::TokenizerInfo(n->token_table_)); + int64_t vocab_size = n->tokenizer_->GetVocabSize(); + if (info.has_value() && info.value()->vocab_size != 0) { + vocab_size = info.value()->vocab_size; + } + n->grammar_compiler_ = xgrammar::GrammarCompiler(xgrammar::TokenizerInfo(n->token_table_, xgrammar::VocabType::RAW, vocab_size)); + + + // - Create the logit processor and sampler, and // the DraftTokenWorkspaceManager for speculative decoding. int max_num_tokens = engine_config->max_num_sequence; @@ -977,7 +988,6 @@ class EngineImpl : public Engine { * is not JSON, return std::nullopt. */ std::optional GetGrammarFromResponseFormat( const ResponseFormat& response_format) { - // TODO: add other grammar type if (response_format.type == "text") { return std::nullopt; } else if (response_format.type == "json_object") { diff --git a/cpp/tokenizers/streamer.cc b/cpp/tokenizers/streamer.cc index 2901834f3b..7aaacd6b59 100644 --- a/cpp/tokenizers/streamer.cc +++ b/cpp/tokenizers/streamer.cc @@ -193,7 +193,10 @@ void StopStrHandlerObj::Put(int32_t token_id, std::vector* return_token } CHECK(!stop_triggered_) << "Cannot put new token when already stopped."; - + // TODO: find better solution + if (token_id >= static_cast(token_table_.size())){ + token_id = 0; + } ICHECK_LT(token_id, static_cast(token_table_.size())); const std::string& token = token_table_[token_id]; pending_token_ids_.push_back(token_id); diff --git a/cpp/tokenizers/tokenizers.cc b/cpp/tokenizers/tokenizers.cc index 13ae547d72..2c0d1f4e96 100644 --- a/cpp/tokenizers/tokenizers.cc +++ b/cpp/tokenizers/tokenizers.cc @@ -30,6 +30,7 @@ String TokenizerInfoNode::AsJSONString() const { obj["token_postproc_method"] = picojson::value(token_postproc_method); obj["prepend_space_in_encode"] = picojson::value(prepend_space_in_encode); obj["strip_space_in_decode"] = picojson::value(strip_space_in_decode); + obj["vocab_size"] = picojson::value(vocab_size); return picojson::value(obj).serialize(false); } @@ -54,6 +55,10 @@ TokenizerInfo TokenizerInfo::FromJSONString(String json_string) { ICHECK(obj.at("strip_space_in_decode").is()); n->strip_space_in_decode = obj.at("strip_space_in_decode").get(); } + if (obj.count("vocab_size")) { + ICHECK(obj.at("vocab_size").is()); + n->vocab_size = obj.at("vocab_size").get(); + } return TokenizerInfo(n); } diff --git a/cpp/tokenizers/tokenizers.h b/cpp/tokenizers/tokenizers.h index 2b1847f524..bca2a0e50c 100644 --- a/cpp/tokenizers/tokenizers.h +++ b/cpp/tokenizers/tokenizers.h @@ -43,6 +43,9 @@ class TokenizerInfoNode : public Object { bool prepend_space_in_encode = false; /*! \brief Whether to strip the first space during decoding. */ bool strip_space_in_decode = false; + /*! \brief The vocab_size in config.json (length of logits).This may be bigger than the vocabulary + * size. The value will be 0 if not set.*/ + int64_t vocab_size = 0; String AsJSONString() const; diff --git a/eval/__init__.py b/eval/__init__.py new file mode 100644 index 0000000000..f8fc6a6220 --- /dev/null +++ b/eval/__init__.py @@ -0,0 +1 @@ +"""Subdirectory of bench.""" diff --git a/eval/__main__.py b/eval/__main__.py new file mode 100644 index 0000000000..6b045cb9d4 --- /dev/null +++ b/eval/__main__.py @@ -0,0 +1,474 @@ +"""MLC LLM benchmark main entrance""" + +import functools +import json +import random +from typing import Any, Dict, List, Optional, Tuple + +from mlc_llm.protocol.openai_api_protocol import ChatToolCall +import numpy as np +import requests +from transformers import AutoTokenizer # pylint: disable=import-error + +import mlc_llm +from api_endpoint import SUPPORTED_BACKENDS, create_api_endpoint +from dataset import SUPPORTED_DATASET, Dataset, GorillaDataset, create_dataset +from request_processor import ( + MetricAnalyzer, + RequestProcessor, + create_pipelines, +) +from request_record import ( + RequestRecord, + convert_reports_to_df, + generate_metrics_summary, + pretty_print_report, +) +from mlc_llm.cli.serve import EngineConfigOverride +from mlc_llm.serve import EngineConfig +from mlc_llm.support import argparse, logging + +logging.enable_logging() +logger = logging.getLogger(__name__) + + +def _parse_num_concurrent_requests(num_str: Optional[str]) -> Optional[List[int]]: + if num_str is None: + return None + numbers = num_str.split(",") + if any(not number.isdigit() for number in numbers): + raise ValueError(f"Unrecognized num_concurrent_requests list: {numbers}") + return list(int(number) for number in numbers) + + +def _parse_request_rate(request_rate_str: Optional[str]) -> Optional[List[np.float32]]: + if request_rate_str is None: + return None + request_rates = request_rate_str.split(",") + results = [] + for rate_str in request_rates: + request_rate = float(rate_str) + if request_rate <= 0: + raise ValueError(f"Invalid request rate {request_rate}") + results.append(np.float32(request_rate)) + return results + + +def _parse_mlc_engine_config(config_str: Optional[str]) -> EngineConfig: + if config_str is None: + return None + engine_config_override = EngineConfigOverride.from_str(config_str) + return EngineConfig( + tensor_parallel_shards=engine_config_override.tensor_parallel_shards, + max_num_sequence=engine_config_override.max_num_sequence, + max_total_sequence_length=engine_config_override.max_total_seq_length, + prefill_chunk_size=engine_config_override.prefill_chunk_size, + sliding_window_size=engine_config_override.sliding_window_size, + attention_sink_size=engine_config_override.attention_sink_size, + max_history_size=engine_config_override.max_history_size, + gpu_memory_utilization=engine_config_override.gpu_memory_utilization, + spec_draft_length=engine_config_override.spec_draft_length, + prefill_mode=engine_config_override.prefill_mode, + prefix_cache_max_num_recycling_seqs=engine_config_override.prefix_cache_max_num_recycling_seqs, # pylint: disable=line-too-long + prefix_cache_mode=engine_config_override.prefix_cache_mode, + ) + + +def _launch_mlc_server(args: argparse.argparse.Namespace): + return mlc_llm.serve.PopenServer( + model=args.tokenizer, + mode="server", + model_lib=args.mlc_model_lib, + enable_tracing=False, + host=args.host, + port=args.port, + engine_config=args.mlc_engine_config, + ) + + +def run_pipeline( + pipeline: RequestProcessor, + dataset: Dataset, + args: argparse.argparse.Namespace, +) -> Tuple[Dict[str, Any], List[RequestRecord]]: + """Run the pipeline with the given dataset and args. Return the benchmark report dict.""" + random.seed(args.seed) + np.random.seed(args.seed) + request_records = dataset.generate_request_records( + args.input_len, + args.output_len, + args.input_len_std, + args.output_len_std, + ) + request_records = pipeline(request_records) + num_total_requests = ( + args.num_requests if not args.per_gpu_workload else args.num_requests * args.num_gpus + ) + assert len(request_records) == num_total_requests + sorted_requests: List[RequestRecord] = [None] * num_total_requests + for request_record in request_records: + assert request_record.request_id is not None + assert sorted_requests[request_record.request_id] is None + sorted_requests[request_record.request_id] = request_record + + report = generate_metrics_summary(request_records, num_total_requests, args.num_gpus) + + return report, sorted_requests + + +def query_mlc_server_metrics(host: str, port: int): + """Try to get the MLC server metrics whenever it exists.""" + try: + r = requests.post(f"http://{host}:{port}/debug/dump_engine_metrics", json={}, timeout=10) + if r.status_code == 200: + print(f"MLC server metrics: {r.json()}") + except Exception: # pylint: disable=broad-exception-caught + pass + +def convert_calls_to_json(calls: List[ChatToolCall])-> List[Dict[str, Any]]: + """Convert the list of ChatToolCall to a list of dict.""" + result = [] + for call in calls: + call_dict = { + "function": {"name": call.function.name, "arguments": call.function.arguments} + } + result.append(call_dict) + return result + + +def check_acc(args: argparse.argparse.Namespace, dataset: GorillaDataset): + request_records = [] + final_output = {"fail_format": [], "fail_call": []} + with open(args.generate_output, "r") as f: + request_records = json.load(f) + count = 0 + for request in request_records: + info = dataset.gorilla_data[request["id"]] + if info["source"] == "BFCL_v3_simple.json": + count += 1 + if "call" not in request: + final_output["fail_format"].append(request["id"]) + final_output["fail_call"].append(request["id"]) + continue + format, call = dataset.check_simple(request["call"][0], info["tool"][0], info["ideal_call"][0]) + if not format: + final_output["fail_format"].append(request["id"]) + if not call: + final_output["fail_call"].append(request["id"]) + correct_format = count - len(final_output["fail_format"]) + correct_call = count - len(final_output["fail_call"]) + final_output["format_accuracy"] = correct_format / count + final_output["call_accuracy"] = correct_call / count + print(f"correct_format: {correct_format}/{count}, correct_call: {correct_call}/{count}") + with open(args.final_output, "w", encoding="utf-8") as file: + json.dump(final_output, file, indent=4) + + + +def main(args: argparse.argparse.Namespace): + """Main benchmark entrance.""" + mlc_server = None + if args.mlc_model_lib: + mlc_server = _launch_mlc_server(args) + if args.num_requests <= 0: + raise ValueError("Number of requests to benchmark must be positive.") + + def _main(): + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) + dataset: GorillaDataset = create_dataset(args, tokenizer, args.use_stag) + f_create_api_endpoint = functools.partial(create_api_endpoint, args) + pipelines = create_pipelines(args, f_create_api_endpoint, dataset) + reports = [] + alltime_records = {} + store_record = [] + for i, pipeline in enumerate(pipelines): + report, request_records = run_pipeline(pipeline, dataset, args) + + for request in request_records: + info = dataset.gorilla_data[request.request_id] + if info["source"] == "BFCL_v3_simple.json": + store_record.append({"id": request.request_id}) + if len(request.chat_cmpl.messages) == 2: + store_record[-1]["output"] = request.chat_cmpl.messages[1].content + if len(request.chat_cmpl.messages) == 2 and request.chat_cmpl.messages[1].tool_calls is not None: + store_record[-1]["call"] = convert_calls_to_json(request.chat_cmpl.messages[1].tool_calls) + + with open(args.generate_output, "w") as f: + json.dump(store_record, f, indent=4) + + exec_feature = ( + json.dumps(report["exec_feature"]) + if report["exec_feature"] is not None + else f"pipeline{i}" + ) + alltime_records[exec_feature] = [ + request_record.model_dump() for request_record in request_records + ] + reports.append(report) + pretty_print_report(report) + query_mlc_server_metrics(args.host, args.port) + + # Construct data frame + df = convert_reports_to_df(reports) + print(df) + df.to_csv(args.bench_output, index=False) + logger.info("Benchmark results dumped to file %s", args.bench_output) + if args.debug_dump: + debug_dump_filepath = ( + args.bench_output[:-4] if args.bench_output.endswith(".csv") else args.bench_output + ) + "_debug_dump.log" + with open(debug_dump_filepath, "w", encoding="utf-8") as file: + json.dump(alltime_records, file, indent=4) + logger.info("Debug log dumped to file %s", debug_dump_filepath) + + check_acc(args, dataset) + + if mlc_server is not None: + with mlc_server: + _main() + else: + _main() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("MLC LLM benchmark") + + parser.add_argument( + "--dataset", + type=str, + choices=SUPPORTED_DATASET, + help=f"The benchmark dataset kind. Supporting {SUPPORTED_DATASET}", + ) + parser.add_argument( + "--dataset-path", + type=str, + help="The dataset file path.", + ) + parser.add_argument( + "--api-endpoint", + type=str, + choices=SUPPORTED_BACKENDS, + default="openai", + help="The API endpoint API for benchmarking.", + ) + parser.add_argument( + "--tokenizer", + type=str, + required=True, + help="The path of the tokenizer directory.", + ) + parser.add_argument( + "--num-gpus", + type=int, + required=True, + help="The number of GPUs used by the server. " + "We need this to better analyze the throughput per GPU.", + ) + parser.add_argument( + "--num-requests", + type=int, + required=True, + help="The number of requests for benchmark.", + ) + parser.add_argument( + "--num-warmup-requests", + type=int, + help="The number of requests for warmup. " + "It is optional when fixing the number of concurrent requests, and is required otherwise.", + ) + parser.add_argument( + "--per-gpu-workload", + default=False, + action="store_true", + help='When set to True, the specified "num_concurrent_requests"/"request_rate" ' + "denote the workload **per GPU**, which means that the real values of " + '"num_concurrent_requests"/"request_rate" used in benchmark' + 'will be multiplied by "num_gpus".', + ) + parser.add_argument( + "--num-concurrent-requests", + type=_parse_num_concurrent_requests, + help="The number(s) of concurrent requests to benchmark. " + 'It can be either one integer or a list of integer separated by commas(","). ' + "When specified, for each integer, the benchmark keeps these many consistent " + "number of concurrently running requests.", + ) + parser.add_argument( + "--request-rate", + type=_parse_request_rate, + help="The request rate(s) denoting the number of new requests each second. " + 'It can be either one float number (or "inf") or a list of numbers separated ' + 'by commas(","). ' + "When specified, the benchmark sends these many new requests each second. " + 'If it is "inf", all requests will be sent together at once.', + ) + parser.add_argument( + "--replay-timestamp-scale", + type=float, + help="The timestamp scale when replaying the timestamps in a dataset. " + 'The dataset replay mode is enabled when neither "--num-concurrent-requests" and ' + '"--request-rate" is specified. ' + "The scale is 1 by default in the replay mode.", + ) + parser.add_argument( + "--input-len", + type=int, + help="The benchmark request average input length. Default to None, " + "which means the request input length depends on the dataset being used.", + ) + parser.add_argument( + "--input-len-std", + type=float, + default=0, + help="The benchmark request input length standard deviation. Default to 0.", + ) + parser.add_argument( + "--output-len", + type=int, + help="The benchmark request average output length. Default to None, " + "which means the request output length depends on the dataset being used.", + ) + parser.add_argument( + "--output-len-std", + type=float, + default=0, + help="The benchmark request output length standard deviation. Default to 0.", + ) + parser.add_argument( + "--stream", + action="store_true", + default=False, + help="Whether to benchmark stream responses. " + "When not enabled, metrics such as time-to-first-token (TTFT) will not be available. " + "Default to False.", + ) + parser.add_argument( + # NOTE: The current implementation of server metrics still has some issues that need fixes, + # which makes it not work to include server metrics. + "--include-server-metrics", + action="store_true", + help="Whether to also benchmark the server side request metrics. " + "This option is only available when benchmarking MLC server.", + ) + parser.add_argument( + "--host", + type=str, + required=True, + help="The host address of the backend API.", + ) + parser.add_argument( + "--port", + type=int, + required=True, + help="The port of the backend API.", + ) + parser.add_argument( + "--timeout", + type=float, + default=3 * 60 * 60, + help="The timeout limit of each request.", + ) + parser.add_argument( + "--seed", + type=int, + default=0, + help="The random number seed. Default to 0.", + ) + parser.add_argument( + "--temperature", + type=float, + default=1.0, + help="The temperature value for logit adjustment. Default to 1.", + ) + parser.add_argument( + "--top-p", + type=float, + default=1.0, + help="The top-p value for sampling. Default to 1.", + ) + parser.add_argument( + "--ignore-eos", + default=False, + action="store_true", + help='Whether to set the "ignore_eos" field.', + ) + parser.add_argument( + "--apply-chat-template", + default=False, + action="store_true", + help="Whether to apply chat template to the request input text. " + 'It is not supported when "--input-len" is specified.', + ) + parser.add_argument( + "--num-process-workers", + type=int, + help="The number of parallel process workers to send the requests.", + ) + parser.add_argument( + "--disable-tqdm", + action="store_true", + help="Whether to disable showing progress bar with tqdm during benchmarking.", + ) + parser.add_argument( + "--max-schedule-gap", + type=float, + default=0.5, + help="The maximum allowed delay between the scheduled time in seconds.", + ) + parser.add_argument( + "--mlc-model-lib", + type=str, + help="The model lib path when benchmarking MLC serve. " + "When specified, the server is automatic launched and no external server launch is needed.", + ) + parser.add_argument( + "--mlc-engine-config", + type=_parse_mlc_engine_config, + help="The engine config used when launch MLC server.", + ) + parser.add_argument( + "--cuda-profile", + default=False, + action="store_true", + help="Whether to enable cuda profile on server. " + "The --mlc-model-lib path should be provided when enabling this option.", + ) + parser.add_argument( + "--debug-dump", + default=False, + action="store_true", + help="Whether to dump all request record raw data to file.", + ) + parser.add_argument( + "--multi-round", + default=False, + action="store_true", + help="Whether to chat like multi round conversion with history log each request. " + "Only enabled when benchmarked with fixed concurrent request mode." + "The --num-concurrent-requests should be provided when enabling this option.", + ) + parser.add_argument( + "--bench-output", + "-o", + type=str, + required=True, + help="The path of the output file where to dump the benchmark results.", + ) + parser.add_argument( + "--generate-output", + type=str, + required=True, + help="The path of the generated output file where to dump the output results.", + ) + parser.add_argument( + "--final-output", + type=str, + required=True, + help="The path of the final output file where to dump the final accuracy results.", + ) + parser.add_argument( + "--use-stag", + action="store_true", + help="Whether to set stag.", + ) + main(parser.parse_args()) diff --git a/eval/api_endpoint.py b/eval/api_endpoint.py new file mode 100644 index 0000000000..198fd47a08 --- /dev/null +++ b/eval/api_endpoint.py @@ -0,0 +1,215 @@ +"""MLC LLM bench backends""" + +import argparse +import json +import os +import time +import traceback +from typing import Optional + +from mlc_llm.protocol.openai_api_protocol import ChatCompletionMessage +from typing_extensions import Self + +from request_record import Metrics, RequestRecord, ServerMetrics +from mlc_llm.support import logging + +logger = logging.getLogger(__name__) + + +class APIEndPoint: + """Manages the sending of requests to a specified API endpoint and gathers + inference statistics. + """ + + def __init__(self, include_server_metrics: bool = False) -> None: + self.include_server_metrics = include_server_metrics + + async def __aenter__(self) -> Self: + return self + + async def __aexit__(self, exc_type, exc_value, tb) -> None: + pass + + async def __call__(self, request: RequestRecord) -> RequestRecord: + raise NotImplementedError() + + +class OpenAIChatEndPoint(APIEndPoint): + """The backend of sending HTTP requests in OpenAI API through "v1/chat/completions".""" + + def __init__( # pylint: disable=too-many-arguments + self, + host: str, + port: int, + timeout: Optional[float] = None, + include_server_metrics: bool = False, + ) -> None: + super().__init__(include_server_metrics=include_server_metrics) + + import aiohttp # pylint: disable=import-outside-toplevel,import-error + + self.timeout = timeout + self.client: aiohttp.ClientSession = None + self.url = f"http://{host}:{port}/v1/chat/completions" + self.headers = {"Content-Type": "application/json"} + if os.getenv("MLC_LLM_API_KEY"): + self.headers["Authorization"] = f"Bearer {os.getenv('MLC_LLM_API_KEY')}" + + async def __aenter__(self) -> Self: + import aiohttp # pylint: disable=import-outside-toplevel,import-error + + self.client = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(self.timeout)) + return self + + async def __aexit__(self, exc_type, exc_value, tb) -> None: + await self.client.close() + + async def __call__( # pylint: disable=too-many-branches,too-many-statements,too-many-locals + self, request_record: RequestRecord + ) -> RequestRecord: + payload = request_record.chat_cmpl.model_dump() + if self.timeout is not None and "timeout" not in payload: + payload["timeout"] = self.timeout + if self.include_server_metrics: + if "stream_options" not in payload or payload["stream_options"] is None: + payload["stream_options"] = {"include_usage": True} + else: + payload["stream_options"]["include_usage"] = True + if ( + request_record.chat_cmpl.debug_config is not None + and request_record.chat_cmpl.debug_config.ignore_eos + ): + payload["ignore_eos"] = True + + generated_text = "" + first_chunk_output_str = "" + time_to_first_token_s = None + start_time = time.monotonic() + server_metrics = None + + try: + async with self.client.post(self.url, json=payload, headers=self.headers) as response: + assert response.status == 200, await response.text() + if payload["stream"]: + async for chunk in response.content: + chunk = chunk.strip() + if not chunk or chunk == b"\n": + continue + # Get rid of the prefix "data: " and suffix "\n" + raw_data = chunk[6:].strip() + if raw_data == b"[DONE]": + continue + data = json.loads(raw_data) + if not data["choices"]: + continue + delta = data["choices"][0]["delta"] + content = delta.get("content", None) + if content is not None and not time_to_first_token_s: + time_to_first_token_s = time.monotonic() - start_time + first_chunk_output_str = content + if self.include_server_metrics and data["usage"] is not None: + # fmt: off + # pylint: disable=line-too-long + server_metrics = ServerMetrics( + input_tokens=data["usage"]["extra"]["prompt_tokens"], + prefill_tokens=data["usage"]["extra"]["prefill_tokens"], + output_tokens=data["usage"]["extra"]["completion_tokens"], + end_to_end_latency_s=data["usage"]["extra"]["end_to_end_latency_s"], + prefill_tokens_per_s=data["usage"]["extra"]["prefill_tokens_per_s"], + inter_token_latency_s=data["usage"]["extra"]["inter_token_latency_s"], + time_per_output_token_s=1 / data["usage"]["extra"]["decode_tokens_per_s"], + time_to_first_token_s=data["usage"]["extra"]["ttft_s"], + ) + # pylint: enable=line-too-long + # fmt: on + + if content is not None: + generated_text += content + else: + data = await response.json() + generated_text = data["choices"][0]["message"]["content"] + if self.include_server_metrics and data["usage"] is not None: + # fmt: off + # pylint: disable=line-too-long + server_metrics = ServerMetrics( + input_tokens=data["usage"]["extra"]["prompt_tokens"], + prefill_tokens=data["usage"]["extra"]["prefill_tokens"], + output_tokens=data["usage"]["extra"]["completion_tokens"], + end_to_end_latency_s=data["usage"]["extra"]["end_to_end_latency_s"], + prefill_tokens_per_s=data["usage"]["extra"]["prefill_tokens_per_s"], + inter_token_latency_s=data["usage"]["extra"]["inter_token_latency_s"], + time_per_output_token_s=1 / data["usage"]["extra"]["decode_tokens_per_s"], + time_to_first_token_s=data["usage"]["extra"]["ttft_s"], + ) + # pylint: enable=line-too-long + # fmt: on + except Exception: # pylint: disable=broad-except + error_msg = "API endpoint errored when sending request: " + traceback.format_exc() + logger.info(error_msg) + finish_time = time.monotonic() + request_record.output_str = generated_text + request_record.first_chunk_output_str = first_chunk_output_str + request_record.metrics = Metrics( + success=False, + start_time=start_time, + finish_time=finish_time, + end_to_end_latency_s=finish_time - start_time, + input_tokens=request_record.metrics.input_tokens, + time_to_first_token_s=time_to_first_token_s, + server_metrics=server_metrics, + exec_feature=request_record.metrics.exec_feature, + ) + request_record.error_msg = error_msg + return request_record + + finish_time = time.monotonic() + request_record.output_str = generated_text + request_record.first_chunk_output_str = first_chunk_output_str + success = True + error_msg = None + if generated_text is None: + if data["choices"][0]["finish_reason"] == "tool_calls": + if data["choices"][0]["message"]["tool_calls"] is None or len(data["choices"][0]["message"]["tool_calls"]) == 0: + success = False + error_msg = "Invalid tool call." + else: + success = True + else: + success = False + error_msg = "Invalid response." + else: + if len(generated_text) == 0: + success = False + error_msg = "Empty generated text." + + message = ChatCompletionMessage( + role=data["choices"][0]["message"]["role"], + content=generated_text, + function_call=data["choices"][0]["message"].get("function_call", None), + tool_calls=data["choices"][0]["message"].get("tool_calls", None), + tool_call_id=data["choices"][0]["message"].get("tool_call_id", None), + ) + request_record.chat_cmpl.messages.append(message) + request_record.metrics = Metrics( + success=success, + start_time=start_time, + finish_time=finish_time, + end_to_end_latency_s=finish_time - start_time, + input_tokens=request_record.metrics.input_tokens, + time_to_first_token_s=time_to_first_token_s, + server_metrics=server_metrics, + exec_feature=request_record.metrics.exec_feature, + ) + request_record.error_msg = error_msg + return request_record + +SUPPORTED_BACKENDS = [ + "openai-chat", +] + + +def create_api_endpoint(args: argparse.Namespace) -> APIEndPoint: + """Create an API endpoint instance with regard to the specified endpoint kind.""" + if args.api_endpoint == "openai-chat": + return OpenAIChatEndPoint(args.host, args.port, args.timeout, args.include_server_metrics) + raise ValueError(f'Unrecognized endpoint "{args.api_endpoint}"') diff --git a/eval/dataset.py b/eval/dataset.py new file mode 100644 index 0000000000..7807383435 --- /dev/null +++ b/eval/dataset.py @@ -0,0 +1,500 @@ +"""MLC LLM benchmark dataset classes""" + +import argparse +import json +import os +import requests +import random +from datetime import datetime +import re +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import pandas as pd # pylint: disable=import-error +from datasets import load_dataset # pylint: disable=import-error +from transformers import AutoTokenizer # pylint: disable=import-error + +from request_record import GroupedRequestRecord, Metrics, RequestRecord +from mlc_llm.protocol.openai_api_protocol import ( + ChatCompletionMessage, + ChatCompletionRequest, + ChatToolCall, + DebugConfig, +) + + +class Dataset: # pylint: disable=too-few-public-methods + """The dataset base class.""" + + # We set a truncation limit of 100k. + truncate_length = int(1e5) + # For some that datasets (e.g., dataset that has shared common prefix), + # we need fake warmup requests to avoid prefilling common prefixes to the engine. + require_fake_warmup: bool = False + # Whether the dataset contains timestamps already. + # If the dataset comes with timestamps, the benchmark can just replay + # the requests according to their timestamps. + timestamp_available: bool = False + + def generate_request_records( + self, + input_len: Optional[int], + output_len: Optional[int], + input_len_std: float = 0.0, + output_len_std: float = 0.0, + ) -> List[RequestRecord]: + """Get the raw unprocessed request records of the dataset.""" + raise NotImplementedError() + +GORILLA_TO_OPENAPI = { + "integer": "integer", + "number": "number", + "float": "number", + "string": "string", + "boolean": "boolean", + "bool": "boolean", + "array": "array", + "list": "array", + "dict": "object", + "object": "object", + "tuple": "array", + "any": "string", + "byte": "integer", + "short": "integer", + "long": "integer", + "double": "number", + "char": "string", + "ArrayList": "array", + "Array": "array", + "HashMap": "object", + "Hashtable": "object", + "Queue": "array", + "Stack": "array", + "Any": "string", + "String": "string", + "Bigint": "integer", +} + +class GorillaDataset(Dataset): # pylint: disable=too-few-public-methods + """The dataset class for Gorilla dataset. + Reference: https://github.com/ShishirPatil/gorilla + """ + + def __init__(self, dataset_path: str, tokenizer: AutoTokenizer, use_stag: bool) -> None: + self.tokenizer = tokenizer + self.require_fake_warmup = True + self.gorilla_data = [] + file_patterns = [ + "BFCL_v3_simple.json", + ] + base_url = "https://raw.githubusercontent.com/ShishirPatil/gorilla/main/berkeley-function-call-leaderboard/data" + + for filename in file_patterns: + id = 0 + dataset_file = f"{dataset_path}/{filename}" + if os.path.exists(dataset_file): + with open(dataset_file, mode="r", encoding="utf-8") as file: + self.gorilla_data = json.load(file) + else: + function_url = f"{base_url}/{filename}" + answer_url = f"{base_url}/possible_answer/{filename}" + print(f"Downloading {filename} from GitHub...") + functions_data = [] + answers_data = [] + try: + function_response = requests.get(function_url) + function_response.raise_for_status() + function_text = function_response.text + for line in function_text.strip().split("\n"): + if line.strip(): + try: + functions_data.append(json.loads(line)) + except json.JSONDecodeError as e: + print(f"Error parsing function line in {filename}: {e}") + answer_response = requests.get(answer_url) + answer_response.raise_for_status() + answer_text = answer_response.text + for line in answer_text.strip().split("\n"): + if line.strip(): + try: + answers_data.append(json.loads(line)) + except json.JSONDecodeError as e: + print(f"Error parsing answer line in {filename}: {e}") + print( + f"Successfully downloaded {filename}: {len(functions_data)} functions, {len(answers_data)} answers" + ) + except requests.RequestException as e: + print(f"Error downloading {filename}: {e}") + functions_data = [] + answers_data = [] + if not functions_data or not answers_data: + print(f"Skipping {filename} - failed to download data") + continue + print(f"Processing {filename}...") + answers_by_id = {item["id"]: item for item in answers_data} + for item in functions_data: + item_id = item["id"] + question = item["question"][0] + if item_id not in answers_by_id: + print(f"Warning: No answer found for item {item_id}") + continue + if "function" not in item or not item["function"]: + print(f"Warning: No function definition for item {item_id}") + continue + tool = [{"type": "function", "function": func} for func in item["function"]] + self.map_type_values(tool) + answer = answers_by_id[item_id] + if "ground_truth" not in answer or not answer["ground_truth"]: + print(f"Warning: No ground truth for item {item_id}") + continue + ideal_call = [] + for ground_truth in answer["ground_truth"]: + function_name = list(ground_truth.keys())[0] + params = ground_truth[function_name] + ideal_call.append({"name": function_name, "arguments": params}) + self.gorilla_data.append( + { + "id": id, + "question": question, + "tool": tool, + "ideal_call": ideal_call, + "source": filename, + } + ) + id += 1 + with open(dataset_file, mode="w", encoding="utf-8") as file: + json.dump(self.gorilla_data, file, ensure_ascii=False, indent=4) + if self.tokenizer is not None: + for item in self.gorilla_data: + num_tokens = 0 + for message in item["question"]: + num_tokens += len( + tokenizer.encode(message["content"], add_special_tokens=False) + ) + item["num_tokens"] = num_tokens + if not use_stag: + for item in self.gorilla_data: + for tool in item["tool"]: + tool["function"]["strict"] = False + + def generate_request_records( + self, + input_len: Optional[int], + output_len: Optional[int], + input_len_std: float = 0.0, + output_len_std: float = 0.0, + ) -> List[RequestRecord]: + + request_records = [] + for entry in self.gorilla_data: + # If the request does not have enough length, discard it. + # if input_len is not None and entry["num_tokens"] < input_len + 4 * input_len_std: + # continue + + if output_len is not None: + output_length = max( + round(np.random.normal(loc=output_len, scale=output_len_std)), 1 + ) + else: + output_length = 256 + request_records.append( + RequestRecord( + request_id=entry["id"], + chat_cmpl=ChatCompletionRequest( + messages=[ + ChatCompletionMessage(content=message["content"], role=message["role"]) + for message in entry["question"] + ], + model="", + max_tokens=output_length, + tools=entry["tool"], + ), + metrics=Metrics( + success=False, + start_time=0, + finish_time=0, + end_to_end_latency_s=0, + input_tokens=entry["num_tokens"], + ), + ) + ) + return request_records + + # Modified by https://github.com/ShishirPatil/gorilla/blob/main/berkeley-function-call-leaderboard/bfcl/eval_checker/ast_eval/ast_checker.py + def check_simple(self, tool_call: Dict[str, Any], + tool: Dict[str, Any], ideal: Dict[str, Any]) -> Tuple[bool, bool]: + # check func name + if ideal["name"] != tool_call["function"]["name"]: + return True, False + func = tool["function"] + # check func args + for arg in func["parameters"]["required"]: + if arg not in tool_call["function"]["arguments"]: + return True, False + for arg in tool_call["function"]["arguments"].keys(): + ideal_arg: List = ideal["arguments"][arg] if arg in ideal["arguments"] else None + real_arg = tool_call["function"]["arguments"][arg] + if arg not in func["parameters"]["properties"]: + return True, False + info_arg = func["parameters"]["properties"][arg] + if info_arg["type"] == "integer": + if not self.check_integer(real_arg, ideal_arg): + return True, False + elif info_arg["type"] == "number": + if not self.check_number(real_arg, ideal_arg): + return True, False + elif info_arg["type"] == "boolean": + if not self.check_boolean(real_arg, ideal_arg): + return True, False + elif info_arg["type"] == "string": + enum = info_arg["enum"] if "enum" in info_arg else None + if not self.check_string(real_arg, ideal_arg, enum): + return True, False + elif info_arg["type"] == "array": + if not self.check_list(real_arg, ideal_arg, info_arg["items"]): + return True, False + elif info_arg["type"] == "dict": + if not self.check_dict(real_arg, ideal_arg, info_arg["properties"]): + return True, False + return True, True + + + + def check_integer(self, real_arg: Any, ideal_arg: Optional[List[Any]]) -> bool: + try: + if type(real_arg) != int: + return False + if ideal_arg is None: + return True + match = False + for ideal in ideal_arg: + if real_arg == ideal: + match = True + break + return match + except: + return False + + def check_number(self, real_arg: Any, ideal_arg: Optional[List[Any]]) -> bool: + if type(real_arg) != float and type(real_arg) != int: + return False + if ideal_arg is None: + return True + match = False + for ideal in ideal_arg: + if real_arg == ideal: + match = True + break + return match + + def check_string(self, real_arg: Any, ideal_arg: Optional[List[Any]], enum: Optional[List[str]]) -> bool: + + def standardize_string(string: Any) -> str: + if not isinstance(string, str): + return "Error><><><><><>" + regex_string = r"[ \,\.\/\-\_\*\^]" + return re.sub(regex_string, "", string).lower().replace("'", '"') + + if type(real_arg) != str: + return False + match = False + real_arg = standardize_string(real_arg) + if ideal_arg is None: + if enum is None: + return True + else: + for ideal in enum: + if real_arg == standardize_string(ideal): + match = True + break + else: + for ideal in ideal_arg: + if real_arg == standardize_string(ideal): + match = True + break + return match + + def check_boolean(self, real_arg: bool, ideal_arg: Optional[List[bool]]) -> bool: + if type(real_arg) != bool: + return False + if ideal_arg is None: + return True + match = False + for ideal in ideal_arg: + if real_arg == ideal: + match = True + break + return match + + def check_list(self, real_arg: List, ideal_arg: Optional[List[List]], item: Dict[str, Any]) -> bool: + if type(real_arg) != list: + return False + item_type = item["type"] + if ideal_arg is None: + if item_type == "integer": + for i, integer in enumerate(real_arg): + if not self.check_integer(integer, None): + return False + elif item_type == "number": + for i, integer in enumerate(real_arg): + if not self.check_number(integer, None): + return False + elif item_type == "boolean": + for i, boolean in enumerate(real_arg): + if not self.check_boolean(boolean, None): + return False + elif item_type == "string": + for i, string in enumerate(real_arg): + enum = item["enum"] if "enum" in item else None + if not self.check_string(string, None, enum): + return False + elif item_type == "array": + for i, array in enumerate(real_arg): + if not self.check_list(array, None, item["items"]): + return False + elif item_type == "dict": + for i, dictionary in enumerate(real_arg): + if not self.check_dict(dictionary, None, item["properties"]): + return False + return True + else: + for ideal in ideal_arg: + if len(ideal) != len(real_arg): + continue + match = True + if item_type == "integer": + for i, integer in enumerate(real_arg): + if not self.check_integer(integer, [ideal[i]]): + match = False + break + elif item_type == "number": + for i, integer in enumerate(real_arg): + if not self.check_number(integer, [ideal[i]]): + match = False + break + elif item_type == "boolean": + for i, boolean in enumerate(real_arg): + if not self.check_boolean(boolean, [ideal[i]]): + match = False + break + elif item_type == "string": + for i, string in enumerate(real_arg): + enum = item["enum"] if "enum" in item else None + if not self.check_string(string, [ideal[i]], enum): + match = False + break + elif item_type == "array": + for i, array in enumerate(real_arg): + if not self.check_list(array, [ideal[i]], item["items"]): + match = False + break + elif item_type == "dict": + for i, dictionary in enumerate(real_arg): + if not self.check_dict(dictionary, [ideal[i]], item["properties"]): + match = False + break + if match: + return True + return False + + def check_dict(self, real_arg: Dict[str, Any], ideal_arg: Optional[Dict[str, Any]], properties: Dict[str, Any]) -> bool: + if type(real_arg) != dict: + return False + if ideal_arg is None: + for key in properties.keys(): + if key not in real_arg: + return False + item_type = properties[key]["type"] + if item_type == "integer": + if not self.check_integer(real_arg[key], None): + return False + elif item_type == "number": + if not self.check_number(real_arg[key], None): + return False + elif item_type == "boolean": + if not self.check_boolean(real_arg[key], None): + return False + elif item_type == "string": + enum = properties[key]["enum"] if "enum" in properties[key] else None + if not self.check_string(real_arg[key], None, enum): + return False + elif item_type == "array": + if not self.check_list(real_arg[key], None, properties[key]["items"]): + return False + elif item_type == "dict": + if not self.check_dict(real_arg[key], None, properties[key]["properties"]): + return False + return True + else: + for ideal in ideal_arg: + match = True + for key in properties.keys(): + if key not in real_arg: + match = False + break + item_type = properties[key]["type"] + if item_type == "integer": + if not self.check_integer(real_arg[key], [ideal[key]]): + match = False + break + elif item_type == "number": + if not self.check_number(real_arg[key], [ideal[key]]): + match = False + break + elif item_type == "boolean": + if not self.check_boolean(real_arg[key], [ideal[key]]): + match = False + break + elif item_type == "string": + enum = properties[key]["enum"] if "enum" in properties[key] else None + if not self.check_string(real_arg[key], [ideal[key]], enum): + match = False + break + elif item_type == "array": + if not self.check_list(real_arg[key], [ideal[key]], properties[key]["items"]): + match = False + break + elif item_type == "dict": + if not self.check_dict(real_arg[key], [ideal[key]], properties[key]["properties"]): + match = False + break + if match: + return True + return False + + def map_type_values(self, data): + if isinstance(data, dict): + for key, value in data.items(): + if isinstance(value, (dict, list)): + self.map_type_values(value) + elif key == "type" and value in GORILLA_TO_OPENAPI: + data[key] = GORILLA_TO_OPENAPI[value] + elif isinstance(data, list): + for item in data: + if isinstance(item, (dict, list)): + self.map_type_values(item) + + + +SUPPORTED_DATASET = [ + "gorilla" +] + + +def create_dataset( # pylint: disable=too-many-return-statements,too-many-branches + args: argparse.Namespace, tokenizer: AutoTokenizer +) -> Dataset: + """Create a dataset instance with regard to the specified dataset kind and file path.""" + if args.dataset_path is not None and not isinstance(args.dataset_path, str): + raise TypeError(f"Invalid dataset path {args.dataset_path}. Please use a string.") + if args.dataset == "gorilla": + if args.dataset_path is None: + raise ValueError( + "Gorilla dataset requires dataset path. " + 'Please specify it with "--dataset-path".' + ) + assert ( + args.apply_chat_template is False + ), "Gorilla dataset does not support applying chat template" + return GorillaDataset(args.dataset_path, tokenizer) + raise ValueError(f"Unrecognized dataset {args.dataset}") diff --git a/eval/request_processor.py b/eval/request_processor.py new file mode 100644 index 0000000000..99de2d7293 --- /dev/null +++ b/eval/request_processor.py @@ -0,0 +1,738 @@ +"""MLC LLM Bench Request""" + +import argparse +import asyncio +import concurrent.futures +import copy +import os +import random +import time +from typing import Any, Callable, Dict, List, Optional, Tuple + +import numpy as np +import requests +from tqdm import tqdm +from transformers import AutoTokenizer # pylint: disable=import-error + +from api_endpoint import APIEndPoint +from dataset import Dataset +from request_record import GroupedRequestRecord, RequestRecord +from mlc_llm.protocol.openai_api_protocol import ( + ChatCompletionMessage, + ChatCompletionRequest, + DebugConfig, +) +from mlc_llm.support import logging + +logger = logging.getLogger(__name__) + + +class RequestProcessor: # pylint: disable=too-few-public-methods + """The request processor base class. + Each processor can take a list of RequestRecord, applying the process, + and returning the processed RequestRecord in the end. + """ + + def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]: + raise NotImplementedError() + + +class LogMessage(RequestProcessor): # pylint: disable=too-few-public-methods + """The processor that prints the logger message.""" + + def __init__(self, message: str) -> None: + self.message = message + + def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]: + logger.info(self.message) + return request_records + + +class SampleRequests(RequestProcessor): # pylint: disable=too-few-public-methods + """The processor that samples requests out from the given request list.""" + + def __init__(self, num_requests: int, take_first_x_requests: bool = True) -> None: + self.num_requests = num_requests + # If `take_first_x_requests` is True, the first `num_requests` requests + # are returned and sampling will not happen. + self.take_first_x_requests = take_first_x_requests + + def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]: + assert len(request_records) > 0, "Empty input request record." + + # We expect the input request records to be all grouped or all plain. + if isinstance(request_records[0], GroupedRequestRecord): + assert all(isinstance(record, GroupedRequestRecord) for record in request_records) + return self._sample_from_grouped_request_records(request_records) + + assert all(not isinstance(record, GroupedRequestRecord) for record in request_records) + return self._sample_from_plain_request_records(request_records) + + def _sample_from_plain_request_records( + self, request_records: List[RequestRecord] + ) -> List[RequestRecord]: + samples: List[RequestRecord] = [] + if self.take_first_x_requests: + if len(request_records) < self.num_requests: + raise ValueError( + f"Insufficient requests. Requiring {self.num_requests} requests " + f"but only {len(request_records)} are available." + ) + samples = copy.deepcopy(list(request_records[: self.num_requests])) + else: + while len(samples) < self.num_requests: + # Create a new list so that the in-place shuffle does not mutate the input list. + records = list(request_records) + random.shuffle(records) + samples += copy.deepcopy(records) + samples = samples[: self.num_requests] + for i, record in enumerate(samples): + record.request_id = i + return samples + + def _sample_from_grouped_request_records( + self, grouped_request_records: List[GroupedRequestRecord] + ) -> List[RequestRecord]: + num_total_available_requests = sum( + len(record.records) for record in grouped_request_records + ) + if self.num_requests > num_total_available_requests: + raise ValueError( + "Due to the existence of shared common prefixes, we do not allow " + "benchmarking with requests more than the available requests in the dataset. " + f"The required number of requests {self.num_requests} exceeds the " + f"number of total available requests {num_total_available_requests}." + ) + + # Create a new list so that the in-place shuffle does not mutate the input list. + records = list(grouped_request_records) + if not self.take_first_x_requests: + random.shuffle(records) + remaining = self.num_requests + samples: List[RequestRecord] = [] + for grouped_request_record in grouped_request_records: + num_used_requests = min(len(grouped_request_record.records), remaining) + samples += grouped_request_record.records[:num_used_requests] + remaining -= num_used_requests + if remaining == 0: + break + for i, record in enumerate(samples): + record.request_id = i + return samples + + +class AttachModelName(RequestProcessor): # pylint: disable=too-few-public-methods + """The processor that attaches model name to requests.""" + + def __init__(self, model: str) -> None: + self.model = model + + def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]: + for request_record in request_records: + request_record.chat_cmpl.model = self.model + return request_records + + +class AttachRequestRateTimestamp(RequestProcessor): # pylint: disable=too-few-public-methods + """The processor that applies timestamps to the requests.""" + + def __init__(self, request_rate: np.float32) -> None: + self.request_rate = request_rate + + def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]: + timestamp = 0.0 + for request_record in request_records: + assert request_record.timestamp is None, "The request record already has a timestamp" + request_record.timestamp = timestamp + timestamp += float(np.random.exponential(1.0 / self.request_rate)) + return request_records + + +class AttachExecutionFeature(RequestProcessor): # pylint: disable=too-few-public-methods + """The processor that attaches execution features to all requests""" + + def __init__(self, exec_feature: Dict[str, Any]) -> None: + self.exec_feature = exec_feature + + def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]: + for request_record in request_records: + assert request_record.metrics is not None + request_record.metrics.exec_feature = self.exec_feature + return request_records + + +class AttachStreamFlag(RequestProcessor): # pylint: disable=too-few-public-methods + """The processor that attaches the stream flag to the requests.""" + + def __init__(self, stream: Optional[bool]) -> None: + self.stream = stream + + def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]: + if self.stream is None: + return request_records + for request_record in request_records: + request_record.chat_cmpl.stream = self.stream + return request_records + + +class AttachSamplingOptions(RequestProcessor): # pylint: disable=too-few-public-methods + """The processor that attaches the stream flag to the requests.""" + + def __init__(self, temperature: float, top_p: float, ignore_eos: bool) -> None: + self.temperature = temperature + self.top_p = top_p + self.ignore_eos = ignore_eos + + def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]: + for request_record in request_records: + request_record.chat_cmpl.temperature = self.temperature + request_record.chat_cmpl.top_p = self.top_p + request_record.chat_cmpl.frequency_penalty = 0.0 + request_record.chat_cmpl.presence_penalty = 0.0 + # request_record.chat_cmpl.tool_choice = "none" + if self.ignore_eos: + request_record.chat_cmpl.debug_config = DebugConfig(ignore_eos=True) + return request_records + + +class ScaleTimestamp(RequestProcessor): # pylint: disable=too-few-public-methods + """Scale the timestamp of requests by the given scale factor.""" + + def __init__(self, timestamp_scale: float): + self.timestamp_scale = timestamp_scale + + def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]: + for request_record in request_records: + if request_record.timestamp is None: + raise ValueError( + f"The timestamp of request {request_record} has not been initialized." + ) + request_record.timestamp *= self.timestamp_scale + return request_records + + +class MetricAnalyzer(RequestProcessor): # pylint: disable=too-few-public-methods + """The processor that analyzes the raw benchmark results and computes more detailed metrics.""" + + def __init__(self, tokenizer: AutoTokenizer) -> None: + self.tokenizer = tokenizer + + def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]: + updated_records = [] + for request_record in request_records: + metrics = request_record.metrics + if not metrics.success: + assert request_record.error_msg is not None + continue + + metrics.output_tokens = len( + self.tokenizer.encode(request_record.output_str, add_special_tokens=False) + ) + first_chunk_output_tokens = len( + self.tokenizer.encode( + request_record.first_chunk_output_str, add_special_tokens=False + ) + ) + if metrics.output_tokens <= first_chunk_output_tokens: + metrics.success = False + request_record.error_msg = ( + f"Total output token num ({metrics.output_tokens}) equals " + f'the first chunk output token. Output text "{request_record.output_str}", ' + f'first chunk output text "{request_record.first_chunk_output_str}"' + ) + continue + assert metrics.input_tokens > 0, "Invalid prompt tokens" + metrics.inter_token_latency_s = metrics.end_to_end_latency_s / metrics.output_tokens + if metrics.time_to_first_token_s is None: + metrics.time_to_first_token_s = 0 + metrics.time_per_output_token_s = ( + metrics.end_to_end_latency_s - metrics.time_to_first_token_s + ) / (metrics.output_tokens - first_chunk_output_tokens) + updated_records.append(request_record) + return updated_records + + +class WarmupAndRun(RequestProcessor): # pylint: disable=too-few-public-methods,line-too-long + """The processor that runs warmup first and then runs the benchmark with the given pipeline.""" + + def __init__( # pylint: disable=too-many-arguments + self, + num_warmup_requests: int, + num_benchmark_requests: int, + pipeline: RequestProcessor, + cuda_profile_url: Optional[str], + fake_warmup: bool = False, + ) -> None: + self.num_warmup_requests = num_warmup_requests + self.num_benchmark_requests = num_benchmark_requests + self.pipeline = pipeline + self.cuda_profile_url = cuda_profile_url + self.fake_warmup = fake_warmup + + def generate_fake_warmup_requests( # pylint: disable=missing-function-docstring + self, num_warmup_requests: int, example_request: RequestRecord + ) -> List[RequestRecord]: + records = [] + for _ in range(num_warmup_requests): + record = copy.deepcopy(example_request) + record.chat_cmpl = ChatCompletionRequest( + messages=[ + { + "role": "user", + "content": "Please output arbitrary coherent sentences. Do not output eos token.", # pylint: disable=line-too-long + } + ], + model="", + max_tokens=128, + ) + records.append(record) + return records + + def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]: + # Warmup + if self.fake_warmup: + assert len(request_records) == self.num_benchmark_requests + benchmark_requests = request_records + example_request = benchmark_requests[0] + warmup_requests = self.generate_fake_warmup_requests( + self.num_warmup_requests, example_request=example_request + ) + else: + assert len(request_records) == self.num_warmup_requests + self.num_benchmark_requests + benchmark_requests = request_records[: -self.num_warmup_requests] + warmup_requests = request_records[-self.num_warmup_requests :] + for request_record in warmup_requests: + request_record.timestamp = 0 if request_record.timestamp is not None else None + warmup_requests = self._process_warmup_requests(warmup_requests) + logger.info("Warmup with %d request(s)...", self.num_warmup_requests) + self.pipeline(warmup_requests) + + # Then run benchmark + if self.cuda_profile_url is not None: + cuda_profiler_start_url = self.cuda_profile_url + "/debug/cuda_profiler_start" + cuda_profiler_start_response = requests.post(cuda_profiler_start_url, timeout=60) + assert cuda_profiler_start_response.status_code == 200 + logger.info("Warmup finished. Start benchmarking...") + updated_request_records = self.pipeline(benchmark_requests) + if self.cuda_profile_url is not None: + cuda_profiler_stop_url = self.cuda_profile_url + "/debug/cuda_profiler_stop" + cuda_profiler_stop_response = requests.post(cuda_profiler_stop_url, timeout=60) + assert cuda_profiler_stop_response.status_code == 200 + + return updated_request_records + + def _process_warmup_requests(self, warmup_requests: List[RequestRecord]) -> List[RequestRecord]: + if len(warmup_requests) == 0: + return warmup_requests + # NOTE: to warm up the server for as more different batch sizes as possible, + # we usese 128 output tokens for the first request and use two more tokens + # for every followup request. + # Setting a high temperature and top-p to avoid early stop as much as possible. + warmup_requests[0].chat_cmpl.max_tokens = 128 + for i in range(1, len(warmup_requests)): + warmup_requests[i].chat_cmpl.max_tokens = ( + warmup_requests[i - 1].chat_cmpl.max_tokens + 1 + ) + warmup_requests[i].chat_cmpl.temperature = 2.0 + warmup_requests[i].chat_cmpl.top_p = 1.0 + return warmup_requests + + +class SequentialProcessor(RequestProcessor): # pylint: disable=too-few-public-methods + """The processor that sequentially applies a list of processors in order.""" + + processors: List[RequestProcessor] + + def __init__(self, *processors: RequestProcessor) -> None: + self.processors = list(processors) + + def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]: + for processor in self.processors: + request_records = processor(request_records) + return request_records + + +class Executor(RequestProcessor): # pylint: disable=too-few-public-methods + """The executor base class, denoting the kind of benchmark mode.""" + + def __init__( + self, + f_create_api_endpoint: Callable[[], APIEndPoint], + num_processes: int, + disable_tqdm: bool, + ) -> None: + self.f_create_api_endpoint = f_create_api_endpoint + self.disable_tqdm = disable_tqdm + self.num_processes = num_processes + + def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]: + raise NotImplementedError() + + +class FixedConcurrentRequestExecutor(Executor): # pylint: disable=too-few-public-methods + """The benchmark executor of fixing the number of concurrent requests.""" + + def __init__( # pylint: disable=too-many-arguments + self, + f_create_api_endpoint: Callable[[], APIEndPoint], + num_processes: Optional[int], + disable_tqdm: bool, + num_concurrent_requests: int, + multi_round: bool, + ) -> None: + if num_processes is None: + # We assign each process at most 32 concurrent requests to send + # so that the asyncio pressure will not be too much. + num_processes = min((num_concurrent_requests + 31) // 32, 10) + super().__init__(f_create_api_endpoint, num_processes, disable_tqdm) + self.num_concurrent_requests = num_concurrent_requests + self.multi_round = multi_round + + def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]: + partitions: List[List[RequestRecord]] = [ + request_records[slice(i, len(request_records), self.num_processes)] + for i in range(self.num_processes) + ] + # Package "tokenizers" reports warnings with multiprocessing. + # We disable "TOKENIZERS_PARALLELISM" to depress the warnings. + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + pbar = None if self.disable_tqdm else tqdm(total=len(request_records)) + with concurrent.futures.ProcessPoolExecutor(max_workers=self.num_processes) as pool: + futures = [ + pool.submit( + FixedConcurrentRequestExecutor._process_task, + self.f_create_api_endpoint, + partition, + self.num_concurrent_requests // self.num_processes + + int(i < self.num_concurrent_requests % self.num_processes), + self.multi_round, + ) + for i, partition in enumerate(partitions) + ] + results: List[RequestRecord] = [] + for i, future in enumerate(concurrent.futures.as_completed(futures)): + results.extend(future.result()) + if pbar is not None: + pbar.update(len(partitions[i])) + + return results + + @staticmethod + def _process_task( + f_create_api_endpoint: Callable[[], APIEndPoint], + request_records: List[RequestRecord], + num_concurrent_requests: int, + multi_round: bool, + ) -> List[RequestRecord]: + if len(request_records) == 0: + return [] + chat_history: List[List[ChatCompletionMessage]] = [ + [] for _ in range(num_concurrent_requests) + ] + + async def process_task_impl( + f_create_api_endpoint: Callable[[], APIEndPoint], + request_records: List[RequestRecord], + num_concurrent_requests: int, + multi_round: bool, + ) -> List[RequestRecord]: + api_endpoint = f_create_api_endpoint() + updated_request_records: List[RequestRecord] = [None for _ in request_records] + async with api_endpoint: + num_sent_request = 0 + + async def _task(i: int) -> None: + nonlocal num_sent_request + while True: + if num_sent_request == len(request_records): + break + idx = num_sent_request + num_sent_request += 1 + request = request_records[idx] + + if multi_round: + request.chat_cmpl.messages = ( + chat_history[i] + request.chat_cmpl.messages + ) + + updated_request_records[idx] = await api_endpoint(request) + + if multi_round: + chat_history[i] = updated_request_records[idx].chat_cmpl.messages + [ + ChatCompletionMessage( + content=updated_request_records[idx].output_str, + role="assistant", + ) + ] + + tasks = [asyncio.create_task(_task(i)) for i in range(num_concurrent_requests)] + await asyncio.gather(*tasks) + + return updated_request_records + + return asyncio.run( + process_task_impl( + f_create_api_endpoint, + request_records, + num_concurrent_requests, + multi_round, + ) + ) + + +class FixTimestampExecutor(Executor): # pylint: disable=too-few-public-methods + """The benchmark executor of fixing the timestamps of sending requests.""" + + def __init__( # pylint: disable=too-many-arguments + self, + f_create_api_endpoint: Callable[[], APIEndPoint], + num_processes: Optional[int], + disable_tqdm: bool, + max_schedule_gap: float, + num_requests: int, + ) -> None: + if num_processes is None: + # We assign each process at most 32 requests to send + # so that the asyncio pressure will not be too much. + num_processes = min((num_requests + 31) // 32, 10) + super().__init__(f_create_api_endpoint, num_processes, disable_tqdm) + self.max_schedule_gap = max_schedule_gap + self.num_requests = num_requests + + def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]: + assert len(request_records) > 0 + assert all(request_record.timestamp is not None for request_record in request_records) + # Sort the request records in timestamp ascending order before partitioning. + request_records.sort(key=lambda request_record: request_record.timestamp) + base_timestamp = request_records[0].timestamp + partitions: List[List[RequestRecord]] = [ + request_records[slice(i, len(request_records), self.num_processes)] + for i in range(self.num_processes) + ] + base_sys_time = time.time() + # Package "tokenizers" reports warnings with multiprocessing. + # We disable "TOKENIZERS_PARALLELISM" to depress the warnings. + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + pbar = None if self.disable_tqdm else tqdm(total=len(request_records)) + with concurrent.futures.ProcessPoolExecutor(max_workers=self.num_processes) as pool: + futures = [ + pool.submit( + FixTimestampExecutor._process_task, + self.f_create_api_endpoint, + partition, + base_timestamp, + base_sys_time, + self.max_schedule_gap, + ) + for partition in partitions + ] + results: List[RequestRecord] = [] + for i, future in enumerate(concurrent.futures.as_completed(futures)): + results.extend(future.result()) + if pbar is not None: + pbar.update(len(partitions[i])) + + return results + + @staticmethod + def _process_task( + f_create_api_endpoint: Callable[[], APIEndPoint], + request_records: List[RequestRecord], + base_timestamp: float, + base_sys_time: float, + max_schedule_gap: float, + ) -> List[RequestRecord]: + if len(request_records) == 0: + return [] + + async def process_task_impl( + f_create_api_endpoint: Callable[[], APIEndPoint], + request_records: List[RequestRecord], + base_timestamp: float, + base_sys_time: float, + max_schedule_gap: float, + ) -> List[RequestRecord]: + api_endpoint = f_create_api_endpoint() + loop = asyncio.get_running_loop() + # Get the delta time to convert system time to the loop time. + # We must use the system time `time.time()` which is consistent across processes. + loop_sys_delta_time = loop.time() - time.time() + updated_request_records: List[RequestRecord] = [] + async with api_endpoint: + + async def _task(request_record: RequestRecord) -> None: + updated_request_records.append(await api_endpoint(request_record)) + + tasks = [] + for request_record in request_records: + launch_time = ( + (request_record.timestamp - base_timestamp) + + (base_sys_time + max_schedule_gap) + + loop_sys_delta_time + ) + loop.call_at( + launch_time, + lambda record: tasks.append(asyncio.create_task(_task(record))), + request_record, + ) + # Sleep to allow runs of other scheduled tasks if any. + await asyncio.sleep(max(launch_time - loop.time() - max_schedule_gap, 0)) + + # Sleep until all the tasks are launched. + await asyncio.sleep(launch_time - loop.time() + max_schedule_gap) + # Wait for all tasks to be scheduled + assert len(tasks) == len(request_records) + await asyncio.gather(*tasks) + + assert len(updated_request_records) == len(request_records) + return updated_request_records + + return asyncio.run( + process_task_impl( + f_create_api_endpoint, + request_records, + base_timestamp, + base_sys_time, + max_schedule_gap, + ) + ) + + +def create_pipelines( # pylint: disable=too-many-branches + args: argparse.Namespace, f_create_api_endpoint: Callable[[], APIEndPoint], dataset: Dataset +) -> List[RequestProcessor]: + """Creating request processing pipelines with regard to the specified args.""" + cuda_profile_url = f"http://{args.host}:{args.port}" if args.cuda_profile else None + pipelines: List[RequestProcessor] = [] + if args.num_concurrent_requests is not None: + if args.request_rate is not None: + raise ValueError( + 'Both "num_concurrent_requests" and "request_rate" are specified. ' + "Please specify only one of them." + ) + if args.replay_timestamp_scale is not None: + raise ValueError( + "Dataset replay is unsupported when fixing number of concurrent requests." + ) + for num_concurrent_requests in args.num_concurrent_requests: + num_warmup_requests = ( + args.num_warmup_requests + if args.num_warmup_requests is not None + else num_concurrent_requests + ) + pipelines.append( + SequentialProcessor( + LogMessage(f"Fixing number of concurrent requests: {num_concurrent_requests}"), + SampleRequests(args.num_requests + num_warmup_requests), + AttachModelName(args.tokenizer), + AttachStreamFlag(args.stream), + AttachSamplingOptions(args.temperature, args.top_p, args.ignore_eos), + AttachExecutionFeature({"num_concurrent_requests": num_concurrent_requests}), + WarmupAndRun( + num_warmup_requests=num_warmup_requests, + num_benchmark_requests=args.num_requests, + pipeline=FixedConcurrentRequestExecutor( + f_create_api_endpoint, + args.num_process_workers, + args.disable_tqdm, + num_concurrent_requests, + args.multi_round, + ), + cuda_profile_url=cuda_profile_url, + fake_warmup=dataset.require_fake_warmup, + ), + ) + ) + return pipelines + if args.request_rate is not None: + if args.num_warmup_requests is None: + raise ValueError( + "Please specify the number of warmup requests via " + '"--num-warmup-requests" when fixing request rate.' + ) + if args.replay_timestamp_scale is not None: + raise ValueError("Dataset replay is unsupported when fixing request rates.") + num_total_requests = int( + args.num_requests if not args.per_gpu_workload else args.num_requests * args.num_gpus + ) + if dataset.require_fake_warmup: + num_samples = num_total_requests + else: + num_samples = num_total_requests + args.num_warmup_requests + return [ + SequentialProcessor( + LogMessage(f"Fixing request rate: {request_rate}"), + SampleRequests(num_samples), + AttachModelName(args.tokenizer), + AttachRequestRateTimestamp( + request_rate if not args.per_gpu_workload else request_rate * args.num_gpus + ), + AttachStreamFlag(args.stream), + AttachSamplingOptions(args.temperature, args.top_p, args.ignore_eos), + AttachExecutionFeature({"request_rate": float(request_rate)}), + WarmupAndRun( + num_warmup_requests=args.num_warmup_requests, + num_benchmark_requests=num_total_requests, + pipeline=FixTimestampExecutor( + f_create_api_endpoint, + args.num_process_workers, + args.disable_tqdm, + args.max_schedule_gap, + args.num_requests, + ), + cuda_profile_url=cuda_profile_url, + fake_warmup=dataset.require_fake_warmup, + ), + ) + for request_rate in args.request_rate + ] + + # Default: dataset replay mode + # The dataset must come with timestamps. + if not dataset.timestamp_available: + raise ValueError( + "The dataset does not have timestamps, so dataset replay is unsupported. " + 'Please specify one of "num_concurrent_requests" ' + 'and "request_rate".' + ) + if args.per_gpu_workload: + raise ValueError("Fixing per-GPU workload is not compatible with dataset replay.") + if args.num_warmup_requests is None: + raise ValueError( + "Please specify the number of warmup requests via " + '"--num-warmup-requests" for dataset replay.' + ) + timestamp_scale = args.replay_timestamp_scale or 1.0 + if dataset.require_fake_warmup: + num_samples = args.num_requests + else: + num_samples = args.num_requests + args.num_warmup_requests + return [ + SequentialProcessor( + LogMessage(f"Dataset replay with time scaling of {timestamp_scale}"), + SampleRequests(num_samples, take_first_x_requests=True), + AttachModelName(args.tokenizer), + ScaleTimestamp(timestamp_scale), + AttachStreamFlag(args.stream), + AttachSamplingOptions(args.temperature, args.top_p, args.ignore_eos), + AttachExecutionFeature({"timestamp_scale": timestamp_scale}), + WarmupAndRun( + num_warmup_requests=args.num_warmup_requests, + num_benchmark_requests=args.num_requests, + pipeline=FixTimestampExecutor( + f_create_api_endpoint, + args.num_process_workers, + args.disable_tqdm, + args.max_schedule_gap, + args.num_requests, + ), + cuda_profile_url=cuda_profile_url, + fake_warmup=dataset.require_fake_warmup, + ), + ) + ] + + + \ No newline at end of file diff --git a/eval/request_record.py b/eval/request_record.py new file mode 100644 index 0000000000..774519b7d3 --- /dev/null +++ b/eval/request_record.py @@ -0,0 +1,209 @@ +"""MLC LLM Bench Request""" + +from typing import Any, Dict, List, Optional, Tuple, Union + +import pandas as pd # pylint: disable=import-error +from pydantic import BaseModel + +from mlc_llm.protocol.openai_api_protocol import ChatCompletionRequest +from mlc_llm.support import logging + +logger = logging.getLogger(__name__) + + +class ServerMetrics(BaseModel): + """The metrics from the server side.""" + + input_tokens: int + prefill_tokens: int + output_tokens: int + end_to_end_latency_s: float + prefill_tokens_per_s: float + inter_token_latency_s: float + time_per_output_token_s: float + time_to_first_token_s: Optional[float] = None + + +class Metrics(BaseModel): + """The list of metric keys""" + + success: bool + start_time: float + finish_time: float + end_to_end_latency_s: float + + input_tokens: Optional[int] = None + output_tokens: Optional[int] = None + inter_token_latency_s: Optional[float] = None + time_per_output_token_s: Optional[float] = None + time_to_first_token_s: Optional[float] = None + server_metrics: Optional[ServerMetrics] = None + + exec_feature: Optional[Dict[str, Any]] = None + + +class RequestRecord(BaseModel): + """The request records collected from LLM inference requests.""" + + request_id: Optional[int] = None + chat_cmpl: ChatCompletionRequest + output_str: Optional[str] = None + first_chunk_output_str: str = "" + timestamp: Optional[float] = None + metrics: Optional[Metrics] = None + error_msg: Optional[str] = None + + +class GroupedRequestRecord(RequestRecord): + """The data structure for request record groups. + For datasets that have common prefix sharing, the request records + that share a same common prefix will be wrapped in a GroupedRequestRecord + at the beginning. + """ + + records: List[RequestRecord] + + +def generate_metrics_summary( + request_records: List[RequestRecord], + num_total_requests: int, + num_gpus: int, +) -> Dict[str, Any]: + """Computes summary statistics across all metrics collected. + Return a dictionary as the report. + """ + num_completed_requests = len(request_records) + assert num_completed_requests <= num_total_requests + request_metrics = [record.metrics for record in request_records] + duration = ( + max(metrics.finish_time for metrics in request_metrics) + - min(metrics.start_time for metrics in request_metrics) + if num_completed_requests > 0 + else 1e-5 + ) + + report = _compute_metrics_statistics(request_metrics) + report["num_gpus"] = num_gpus + report["duration"] = duration + report["num_total_requests"] = num_total_requests + report["num_completed_requests"] = num_completed_requests + report["request_throughput"] = num_completed_requests / duration + + # Generate the server metrics statistics + server_metrics = [metric.server_metrics for metric in request_metrics if metric.server_metrics] + server_report = _compute_metrics_statistics(server_metrics) + if server_report is not None and len(server_report) > 0: + report["server_metrics"] = server_report + + report = { + "exec_feature": ( + request_records[0].metrics.exec_feature if num_completed_requests > 0 else None + ), + **report, + } + return report + + +def _compute_metrics_statistics(metrics: List[Union[Metrics, ServerMetrics]]) -> Dict[str, Any]: + """ + Compute the statistics of the metrics. + + Parameters + ---------- + metrics : List[Union[Metrics, ServerMetrics]] + The list of metrics to get the statistics. + + Returns + ------- + report : Dict + The statistics of the metrics. + """ + if not metrics: + return {} + + report: Dict = {} + df = pd.DataFrame([metric.model_dump() for metric in metrics]) + for key, _ in metrics[0].model_fields.items(): + if key in ["success", "start_time", "finish_time", "server_metrics", "exec_feature"]: + continue + if key in ["end_to_end_latency_s", "input_tokens"]: + if key in df.columns: + series = df[key].dropna() + report[key] = { + "quantiles": { + f"p{int(q * 100)}": v + for q, v in series.quantile([0.25, 0.5, 0.75, 0.9, 0.95, 0.99]).items() + }, + "mean": series.mean(), + "min": series.min(), + "max": series.max(), + "stddev": series.std(), + } + return report + + +def convert_reports_to_df(reports: List[Dict[str, Any]]) -> pd.DataFrame: + """Convert benchmark reports to pandas DataFrame.""" + + def _flatten_dict(d: Dict[str, Any], parent_key: str = "") -> Dict[str, Any]: + items: List[Tuple[str, Any]] = [] + for key, value in d.items(): + new_key = f"{parent_key}.{key}" if parent_key != "" else key + if isinstance(value, dict): + items.extend(_flatten_dict(value, new_key).items()) + else: + items.append((new_key, value)) + return dict(items) + + return pd.DataFrame([_flatten_dict(report) for report in reports]) + + +def pretty_print_report(report: Dict[str, Any]) -> None: # pylint: disable=too-many-statements + """Pretty print the metrics report.""" + + def _print(report: Dict[str, Any], server_metrics: bool): # pylint: disable=too-many-statements + # pylint: disable=line-too-long + # fmt: off + title = "Benchmark Result" + if server_metrics: + title += " (server side)" + print(f" {title} ".center(50, "=")) + if not server_metrics: + print(f"{'Total requests:':<40} {report['num_total_requests']:<10}") + print(f"{'Completed requests:':<40} {report['num_completed_requests']:<10}") + print(f"{'Duration (s):':<40} {report['duration']:<10.2f}") + print(f"{'Num GPUs:':<40} {report['num_gpus']:<10}") + if report["num_completed_requests"] == 0: + return + + + e2e_latency = report["end_to_end_latency_s"] + print(" End-to-End Latency (ms) ".center(50, "-")) + print(f"{'Mean:':<40} {e2e_latency['mean'] * 1000:<10.2f}") + print(f"{'Stddev:':<40} {e2e_latency['stddev'] * 1000:<10.2f}") + print(f"{'P25:':<40} {e2e_latency['quantiles']['p25'] * 1000:<10.2f}") + print(f"{'P50:':<40} {e2e_latency['quantiles']['p50'] * 1000:<10.2f}") + print(f"{'P75:':<40} {e2e_latency['quantiles']['p75'] * 1000:<10.2f}") + print(f"{'P90:':<40} {e2e_latency['quantiles']['p90'] * 1000:<10.2f}") + print(f"{'P95:':<40} {e2e_latency['quantiles']['p95'] * 1000:<10.2f}") + print(f"{'P99:':<40} {e2e_latency['quantiles']['p99'] * 1000:<10.2f}") + print(f"{'Min:':<40} {e2e_latency['min'] * 1000:<10.2f}") + print(f"{'Max:':<40} {e2e_latency['max'] * 1000:<10.2f}") + + input_tokens = report["input_tokens"] + print(" Input Tokens ".center(50, "-")) + print(f"{'Mean:':<40} {input_tokens['mean']:<1}") + print(f"{'Stddev:':<40} {input_tokens['stddev']:<1}") + print(f"{'P25:':<40} {input_tokens['quantiles']['p25']:<1}") + print(f"{'P50:':<40} {input_tokens['quantiles']['p50']:<1}") + print(f"{'P95:':<40} {input_tokens['quantiles']['p95']:<1}") + print(f"{'Min:':<40} {input_tokens['min']:<1}") + print(f"{'Max:':<40} {input_tokens['max']:<1}") + + print("=" * 50) + + # fmt: on + # pylint: enable=line-too-long + _print(report, server_metrics=False) + if "server_metrics" in report: + _print(report["server_metrics"], server_metrics=True) diff --git a/python/mlc_llm/protocol/conversation_protocol.py b/python/mlc_llm/protocol/conversation_protocol.py index a59aa8c386..0bb9ebb251 100644 --- a/python/mlc_llm/protocol/conversation_protocol.py +++ b/python/mlc_llm/protocol/conversation_protocol.py @@ -82,7 +82,7 @@ class Conversation(BaseModel): # whether using function calling or not, helps check for output message format in API call use_function_calling: bool = False # Tool function call format mode - tool_call_format: str = "default" + tool_call_format: str = "json" def __init__(self, role_templates: Optional[Dict[str, str]] = None, **kwargs): # Defaults templates which would be overridden by model specific templates @@ -195,6 +195,7 @@ def as_prompt(self, config=None) -> List[Any]: ) # Replace with remaining function string placeholders with empty string prompt[0] = prompt[0].replace(MessagePlaceholders.FUNCTION.value, "") + return prompt def set_tool_call_format_in_system_message(self): @@ -217,9 +218,9 @@ def set_tool_call_format_in_system_message(self): "Tool Instructions:" f"You have access to the following tool functions: {MessagePlaceholders.FUNCTION.value}" "If a you choose to call a function, you should ONLY reply in the following format:" - '`{"name": func_name, "parameters": parameters(JSON dict)}`' + '`{"name": func_name, "parameters": parameters(JSON dict)}\n`' "Here is an example," - '`{"name": "get_time", "parameters": {"location": "Pittsburgh"} }`' + '`{"name": "get_time", "parameters": {"location": "Pittsburgh"} }\n`' "Reminder:" "- Function calls MUST follow the specified format" "- Required parameters MUST be specified" diff --git a/python/mlc_llm/protocol/openai_api_protocol.py b/python/mlc_llm/protocol/openai_api_protocol.py index 3a7607ad76..8d1e9c7863 100644 --- a/python/mlc_llm/protocol/openai_api_protocol.py +++ b/python/mlc_llm/protocol/openai_api_protocol.py @@ -341,12 +341,10 @@ def check_function_call_usage(self, conv_template: Conversation) -> None: """Check if function calling is used and update the conversation template. Return error message if invalid request format for function calling. """ - # return if no tools are provided or tool_choice is set to none if self.tools is None or (isinstance(self.tool_choice, str) and self.tool_choice == "none"): conv_template.use_function_calling = False return - # select the tool based on the tool_choice if specified if isinstance(self.tool_choice, dict): if self.tool_choice["type"] != "function": # pylint: disable=unsubscriptable-object diff --git a/python/mlc_llm/serve/config.py b/python/mlc_llm/serve/config.py index 60ad50ab4f..058952ca5c 100644 --- a/python/mlc_llm/serve/config.py +++ b/python/mlc_llm/serve/config.py @@ -129,13 +129,13 @@ class EngineConfig: # pylint: disable=too-many-instance-attributes "chunked" means the basic prefill with chunked input enabled. "hybrid" means the hybrid prefill or split-fuse, so that decode step will be converted into prefill. - + tool_call_format : Literal["xml", "json", "python"] The tool function call foramt. - "xml" means model will call tool function in xml style format + "xml" means model will call tool function in xml style format '\n{parameters(JSON dict)}\n', e.g. '\n{"location": "Pittsburgh"}\n'. - "json" means model will call tool function in json style format + "json" means model will call tool function in json style format '{"name": func_name, "parameters": parameters(JSON dict)}', e.g. '{"name": "get_time", "parameters": {"location": "Pittsburgh"}}'. "python" means model will call tool function in python-style format, @@ -168,7 +168,7 @@ class EngineConfig: # pylint: disable=too-many-instance-attributes prefix_cache_mode: Literal["disable", "radix"] = "radix" prefix_cache_max_num_recycling_seqs: Optional[int] = None prefill_mode: Literal["chunked", "hybrid"] = "hybrid" - tool_call_format: Literal["xml", "json", "python"] = "xml" + tool_call_format: Literal["xml", "json", "python"] = "json" verbose: bool = True def asjson(self) -> str: diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index b618feff3c..dc504c20ac 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -7,9 +7,9 @@ import json import numbers import queue -import re import sys import threading +import re from dataclasses import dataclass from pathlib import Path from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union @@ -1196,8 +1196,8 @@ def set_structural_tag_from_tools( ) response_format.triggers.append(" AsyncGenerator[str, None]: if choice.logprobs is not None: assert logprob_results is not None logprob_results[choice.index] += choice.logprobs.content - assert all(finish_reason is not None for finish_reason in finish_reasons) use_function_calling, tool_calls_list = engine_base.process_function_call_output( output_texts, finish_reasons, async_engine.engine_config.tool_call_format From b8c331a92ee85ee70ccd3f6594044d3132175283 Mon Sep 17 00:00:00 2001 From: irfnfnkemed Date: Wed, 16 Apr 2025 00:36:54 +0800 Subject: [PATCH 13/17] [feat] using structural tag to support func-call with the latest version --- cpp/serve/config.h | 2 +- cpp/serve/engine.cc | 15 +- cpp/tokenizers/streamer.cc | 5 +- eval/__init__.py | 1 - eval/__main__.py | 474 ----------- eval/api_endpoint.py | 215 ----- eval/dataset.py | 500 ------------ eval/request_processor.py | 738 ------------------ eval/request_record.py | 209 ----- .../mlc_llm/protocol/conversation_protocol.py | 29 +- .../mlc_llm/protocol/openai_api_protocol.py | 36 +- python/mlc_llm/serve/config.py | 14 +- python/mlc_llm/serve/engine_base.py | 107 +-- 13 files changed, 115 insertions(+), 2230 deletions(-) delete mode 100644 eval/__init__.py delete mode 100644 eval/__main__.py delete mode 100644 eval/api_endpoint.py delete mode 100644 eval/dataset.py delete mode 100644 eval/request_processor.py delete mode 100644 eval/request_record.py diff --git a/cpp/serve/config.h b/cpp/serve/config.h index f39e1911ba..406ee28307 100644 --- a/cpp/serve/config.h +++ b/cpp/serve/config.h @@ -249,7 +249,7 @@ class EngineConfigNode : public Object { * significantly smaller than this number. Under mode "server", the actual * memory usage may be slightly larger than this number. */ - float gpu_memory_utilization = 0.85; + float gpu_memory_utilization = 0.55; /*! \brief The number of consecutive tokens handled in each page in paged KV cache. */ int kv_cache_page_size = 16; /*! diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 27765fb65f..7895e6c860 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -475,10 +475,9 @@ class EngineImpl : public Engine { if (info.has_value() && info.value()->vocab_size != 0) { vocab_size = info.value()->vocab_size; } - n->grammar_compiler_ = xgrammar::GrammarCompiler(xgrammar::TokenizerInfo(n->token_table_, xgrammar::VocabType::RAW, vocab_size)); - - - + n->grammar_compiler_ = xgrammar::GrammarCompiler( + xgrammar::TokenizerInfo(n->token_table_, xgrammar::VocabType::RAW, vocab_size)); + // - Create the logit processor and sampler, and // the DraftTokenWorkspaceManager for speculative decoding. int max_num_tokens = engine_config->max_num_sequence; @@ -991,11 +990,9 @@ class EngineImpl : public Engine { if (response_format.type == "text") { return std::nullopt; } else if (response_format.type == "json_object") { - if (!response_format.schema) { - return grammar_compiler_.CompileBuiltinJSONGrammar(); - } else { - return grammar_compiler_.CompileJSONSchema(response_format.schema.value()); - } + return grammar_compiler_.CompileBuiltinJSONGrammar(); + } else if (response_format.type == "json_schema") { + return grammar_compiler_.CompileJSONSchema(response_format.schema.value()); } else { std::vector tags; std::vector triggers; diff --git a/cpp/tokenizers/streamer.cc b/cpp/tokenizers/streamer.cc index 7aaacd6b59..2901834f3b 100644 --- a/cpp/tokenizers/streamer.cc +++ b/cpp/tokenizers/streamer.cc @@ -193,10 +193,7 @@ void StopStrHandlerObj::Put(int32_t token_id, std::vector* return_token } CHECK(!stop_triggered_) << "Cannot put new token when already stopped."; - // TODO: find better solution - if (token_id >= static_cast(token_table_.size())){ - token_id = 0; - } + ICHECK_LT(token_id, static_cast(token_table_.size())); const std::string& token = token_table_[token_id]; pending_token_ids_.push_back(token_id); diff --git a/eval/__init__.py b/eval/__init__.py deleted file mode 100644 index f8fc6a6220..0000000000 --- a/eval/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Subdirectory of bench.""" diff --git a/eval/__main__.py b/eval/__main__.py deleted file mode 100644 index 6b045cb9d4..0000000000 --- a/eval/__main__.py +++ /dev/null @@ -1,474 +0,0 @@ -"""MLC LLM benchmark main entrance""" - -import functools -import json -import random -from typing import Any, Dict, List, Optional, Tuple - -from mlc_llm.protocol.openai_api_protocol import ChatToolCall -import numpy as np -import requests -from transformers import AutoTokenizer # pylint: disable=import-error - -import mlc_llm -from api_endpoint import SUPPORTED_BACKENDS, create_api_endpoint -from dataset import SUPPORTED_DATASET, Dataset, GorillaDataset, create_dataset -from request_processor import ( - MetricAnalyzer, - RequestProcessor, - create_pipelines, -) -from request_record import ( - RequestRecord, - convert_reports_to_df, - generate_metrics_summary, - pretty_print_report, -) -from mlc_llm.cli.serve import EngineConfigOverride -from mlc_llm.serve import EngineConfig -from mlc_llm.support import argparse, logging - -logging.enable_logging() -logger = logging.getLogger(__name__) - - -def _parse_num_concurrent_requests(num_str: Optional[str]) -> Optional[List[int]]: - if num_str is None: - return None - numbers = num_str.split(",") - if any(not number.isdigit() for number in numbers): - raise ValueError(f"Unrecognized num_concurrent_requests list: {numbers}") - return list(int(number) for number in numbers) - - -def _parse_request_rate(request_rate_str: Optional[str]) -> Optional[List[np.float32]]: - if request_rate_str is None: - return None - request_rates = request_rate_str.split(",") - results = [] - for rate_str in request_rates: - request_rate = float(rate_str) - if request_rate <= 0: - raise ValueError(f"Invalid request rate {request_rate}") - results.append(np.float32(request_rate)) - return results - - -def _parse_mlc_engine_config(config_str: Optional[str]) -> EngineConfig: - if config_str is None: - return None - engine_config_override = EngineConfigOverride.from_str(config_str) - return EngineConfig( - tensor_parallel_shards=engine_config_override.tensor_parallel_shards, - max_num_sequence=engine_config_override.max_num_sequence, - max_total_sequence_length=engine_config_override.max_total_seq_length, - prefill_chunk_size=engine_config_override.prefill_chunk_size, - sliding_window_size=engine_config_override.sliding_window_size, - attention_sink_size=engine_config_override.attention_sink_size, - max_history_size=engine_config_override.max_history_size, - gpu_memory_utilization=engine_config_override.gpu_memory_utilization, - spec_draft_length=engine_config_override.spec_draft_length, - prefill_mode=engine_config_override.prefill_mode, - prefix_cache_max_num_recycling_seqs=engine_config_override.prefix_cache_max_num_recycling_seqs, # pylint: disable=line-too-long - prefix_cache_mode=engine_config_override.prefix_cache_mode, - ) - - -def _launch_mlc_server(args: argparse.argparse.Namespace): - return mlc_llm.serve.PopenServer( - model=args.tokenizer, - mode="server", - model_lib=args.mlc_model_lib, - enable_tracing=False, - host=args.host, - port=args.port, - engine_config=args.mlc_engine_config, - ) - - -def run_pipeline( - pipeline: RequestProcessor, - dataset: Dataset, - args: argparse.argparse.Namespace, -) -> Tuple[Dict[str, Any], List[RequestRecord]]: - """Run the pipeline with the given dataset and args. Return the benchmark report dict.""" - random.seed(args.seed) - np.random.seed(args.seed) - request_records = dataset.generate_request_records( - args.input_len, - args.output_len, - args.input_len_std, - args.output_len_std, - ) - request_records = pipeline(request_records) - num_total_requests = ( - args.num_requests if not args.per_gpu_workload else args.num_requests * args.num_gpus - ) - assert len(request_records) == num_total_requests - sorted_requests: List[RequestRecord] = [None] * num_total_requests - for request_record in request_records: - assert request_record.request_id is not None - assert sorted_requests[request_record.request_id] is None - sorted_requests[request_record.request_id] = request_record - - report = generate_metrics_summary(request_records, num_total_requests, args.num_gpus) - - return report, sorted_requests - - -def query_mlc_server_metrics(host: str, port: int): - """Try to get the MLC server metrics whenever it exists.""" - try: - r = requests.post(f"http://{host}:{port}/debug/dump_engine_metrics", json={}, timeout=10) - if r.status_code == 200: - print(f"MLC server metrics: {r.json()}") - except Exception: # pylint: disable=broad-exception-caught - pass - -def convert_calls_to_json(calls: List[ChatToolCall])-> List[Dict[str, Any]]: - """Convert the list of ChatToolCall to a list of dict.""" - result = [] - for call in calls: - call_dict = { - "function": {"name": call.function.name, "arguments": call.function.arguments} - } - result.append(call_dict) - return result - - -def check_acc(args: argparse.argparse.Namespace, dataset: GorillaDataset): - request_records = [] - final_output = {"fail_format": [], "fail_call": []} - with open(args.generate_output, "r") as f: - request_records = json.load(f) - count = 0 - for request in request_records: - info = dataset.gorilla_data[request["id"]] - if info["source"] == "BFCL_v3_simple.json": - count += 1 - if "call" not in request: - final_output["fail_format"].append(request["id"]) - final_output["fail_call"].append(request["id"]) - continue - format, call = dataset.check_simple(request["call"][0], info["tool"][0], info["ideal_call"][0]) - if not format: - final_output["fail_format"].append(request["id"]) - if not call: - final_output["fail_call"].append(request["id"]) - correct_format = count - len(final_output["fail_format"]) - correct_call = count - len(final_output["fail_call"]) - final_output["format_accuracy"] = correct_format / count - final_output["call_accuracy"] = correct_call / count - print(f"correct_format: {correct_format}/{count}, correct_call: {correct_call}/{count}") - with open(args.final_output, "w", encoding="utf-8") as file: - json.dump(final_output, file, indent=4) - - - -def main(args: argparse.argparse.Namespace): - """Main benchmark entrance.""" - mlc_server = None - if args.mlc_model_lib: - mlc_server = _launch_mlc_server(args) - if args.num_requests <= 0: - raise ValueError("Number of requests to benchmark must be positive.") - - def _main(): - tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) - dataset: GorillaDataset = create_dataset(args, tokenizer, args.use_stag) - f_create_api_endpoint = functools.partial(create_api_endpoint, args) - pipelines = create_pipelines(args, f_create_api_endpoint, dataset) - reports = [] - alltime_records = {} - store_record = [] - for i, pipeline in enumerate(pipelines): - report, request_records = run_pipeline(pipeline, dataset, args) - - for request in request_records: - info = dataset.gorilla_data[request.request_id] - if info["source"] == "BFCL_v3_simple.json": - store_record.append({"id": request.request_id}) - if len(request.chat_cmpl.messages) == 2: - store_record[-1]["output"] = request.chat_cmpl.messages[1].content - if len(request.chat_cmpl.messages) == 2 and request.chat_cmpl.messages[1].tool_calls is not None: - store_record[-1]["call"] = convert_calls_to_json(request.chat_cmpl.messages[1].tool_calls) - - with open(args.generate_output, "w") as f: - json.dump(store_record, f, indent=4) - - exec_feature = ( - json.dumps(report["exec_feature"]) - if report["exec_feature"] is not None - else f"pipeline{i}" - ) - alltime_records[exec_feature] = [ - request_record.model_dump() for request_record in request_records - ] - reports.append(report) - pretty_print_report(report) - query_mlc_server_metrics(args.host, args.port) - - # Construct data frame - df = convert_reports_to_df(reports) - print(df) - df.to_csv(args.bench_output, index=False) - logger.info("Benchmark results dumped to file %s", args.bench_output) - if args.debug_dump: - debug_dump_filepath = ( - args.bench_output[:-4] if args.bench_output.endswith(".csv") else args.bench_output - ) + "_debug_dump.log" - with open(debug_dump_filepath, "w", encoding="utf-8") as file: - json.dump(alltime_records, file, indent=4) - logger.info("Debug log dumped to file %s", debug_dump_filepath) - - check_acc(args, dataset) - - if mlc_server is not None: - with mlc_server: - _main() - else: - _main() - - -if __name__ == "__main__": - parser = argparse.ArgumentParser("MLC LLM benchmark") - - parser.add_argument( - "--dataset", - type=str, - choices=SUPPORTED_DATASET, - help=f"The benchmark dataset kind. Supporting {SUPPORTED_DATASET}", - ) - parser.add_argument( - "--dataset-path", - type=str, - help="The dataset file path.", - ) - parser.add_argument( - "--api-endpoint", - type=str, - choices=SUPPORTED_BACKENDS, - default="openai", - help="The API endpoint API for benchmarking.", - ) - parser.add_argument( - "--tokenizer", - type=str, - required=True, - help="The path of the tokenizer directory.", - ) - parser.add_argument( - "--num-gpus", - type=int, - required=True, - help="The number of GPUs used by the server. " - "We need this to better analyze the throughput per GPU.", - ) - parser.add_argument( - "--num-requests", - type=int, - required=True, - help="The number of requests for benchmark.", - ) - parser.add_argument( - "--num-warmup-requests", - type=int, - help="The number of requests for warmup. " - "It is optional when fixing the number of concurrent requests, and is required otherwise.", - ) - parser.add_argument( - "--per-gpu-workload", - default=False, - action="store_true", - help='When set to True, the specified "num_concurrent_requests"/"request_rate" ' - "denote the workload **per GPU**, which means that the real values of " - '"num_concurrent_requests"/"request_rate" used in benchmark' - 'will be multiplied by "num_gpus".', - ) - parser.add_argument( - "--num-concurrent-requests", - type=_parse_num_concurrent_requests, - help="The number(s) of concurrent requests to benchmark. " - 'It can be either one integer or a list of integer separated by commas(","). ' - "When specified, for each integer, the benchmark keeps these many consistent " - "number of concurrently running requests.", - ) - parser.add_argument( - "--request-rate", - type=_parse_request_rate, - help="The request rate(s) denoting the number of new requests each second. " - 'It can be either one float number (or "inf") or a list of numbers separated ' - 'by commas(","). ' - "When specified, the benchmark sends these many new requests each second. " - 'If it is "inf", all requests will be sent together at once.', - ) - parser.add_argument( - "--replay-timestamp-scale", - type=float, - help="The timestamp scale when replaying the timestamps in a dataset. " - 'The dataset replay mode is enabled when neither "--num-concurrent-requests" and ' - '"--request-rate" is specified. ' - "The scale is 1 by default in the replay mode.", - ) - parser.add_argument( - "--input-len", - type=int, - help="The benchmark request average input length. Default to None, " - "which means the request input length depends on the dataset being used.", - ) - parser.add_argument( - "--input-len-std", - type=float, - default=0, - help="The benchmark request input length standard deviation. Default to 0.", - ) - parser.add_argument( - "--output-len", - type=int, - help="The benchmark request average output length. Default to None, " - "which means the request output length depends on the dataset being used.", - ) - parser.add_argument( - "--output-len-std", - type=float, - default=0, - help="The benchmark request output length standard deviation. Default to 0.", - ) - parser.add_argument( - "--stream", - action="store_true", - default=False, - help="Whether to benchmark stream responses. " - "When not enabled, metrics such as time-to-first-token (TTFT) will not be available. " - "Default to False.", - ) - parser.add_argument( - # NOTE: The current implementation of server metrics still has some issues that need fixes, - # which makes it not work to include server metrics. - "--include-server-metrics", - action="store_true", - help="Whether to also benchmark the server side request metrics. " - "This option is only available when benchmarking MLC server.", - ) - parser.add_argument( - "--host", - type=str, - required=True, - help="The host address of the backend API.", - ) - parser.add_argument( - "--port", - type=int, - required=True, - help="The port of the backend API.", - ) - parser.add_argument( - "--timeout", - type=float, - default=3 * 60 * 60, - help="The timeout limit of each request.", - ) - parser.add_argument( - "--seed", - type=int, - default=0, - help="The random number seed. Default to 0.", - ) - parser.add_argument( - "--temperature", - type=float, - default=1.0, - help="The temperature value for logit adjustment. Default to 1.", - ) - parser.add_argument( - "--top-p", - type=float, - default=1.0, - help="The top-p value for sampling. Default to 1.", - ) - parser.add_argument( - "--ignore-eos", - default=False, - action="store_true", - help='Whether to set the "ignore_eos" field.', - ) - parser.add_argument( - "--apply-chat-template", - default=False, - action="store_true", - help="Whether to apply chat template to the request input text. " - 'It is not supported when "--input-len" is specified.', - ) - parser.add_argument( - "--num-process-workers", - type=int, - help="The number of parallel process workers to send the requests.", - ) - parser.add_argument( - "--disable-tqdm", - action="store_true", - help="Whether to disable showing progress bar with tqdm during benchmarking.", - ) - parser.add_argument( - "--max-schedule-gap", - type=float, - default=0.5, - help="The maximum allowed delay between the scheduled time in seconds.", - ) - parser.add_argument( - "--mlc-model-lib", - type=str, - help="The model lib path when benchmarking MLC serve. " - "When specified, the server is automatic launched and no external server launch is needed.", - ) - parser.add_argument( - "--mlc-engine-config", - type=_parse_mlc_engine_config, - help="The engine config used when launch MLC server.", - ) - parser.add_argument( - "--cuda-profile", - default=False, - action="store_true", - help="Whether to enable cuda profile on server. " - "The --mlc-model-lib path should be provided when enabling this option.", - ) - parser.add_argument( - "--debug-dump", - default=False, - action="store_true", - help="Whether to dump all request record raw data to file.", - ) - parser.add_argument( - "--multi-round", - default=False, - action="store_true", - help="Whether to chat like multi round conversion with history log each request. " - "Only enabled when benchmarked with fixed concurrent request mode." - "The --num-concurrent-requests should be provided when enabling this option.", - ) - parser.add_argument( - "--bench-output", - "-o", - type=str, - required=True, - help="The path of the output file where to dump the benchmark results.", - ) - parser.add_argument( - "--generate-output", - type=str, - required=True, - help="The path of the generated output file where to dump the output results.", - ) - parser.add_argument( - "--final-output", - type=str, - required=True, - help="The path of the final output file where to dump the final accuracy results.", - ) - parser.add_argument( - "--use-stag", - action="store_true", - help="Whether to set stag.", - ) - main(parser.parse_args()) diff --git a/eval/api_endpoint.py b/eval/api_endpoint.py deleted file mode 100644 index 198fd47a08..0000000000 --- a/eval/api_endpoint.py +++ /dev/null @@ -1,215 +0,0 @@ -"""MLC LLM bench backends""" - -import argparse -import json -import os -import time -import traceback -from typing import Optional - -from mlc_llm.protocol.openai_api_protocol import ChatCompletionMessage -from typing_extensions import Self - -from request_record import Metrics, RequestRecord, ServerMetrics -from mlc_llm.support import logging - -logger = logging.getLogger(__name__) - - -class APIEndPoint: - """Manages the sending of requests to a specified API endpoint and gathers - inference statistics. - """ - - def __init__(self, include_server_metrics: bool = False) -> None: - self.include_server_metrics = include_server_metrics - - async def __aenter__(self) -> Self: - return self - - async def __aexit__(self, exc_type, exc_value, tb) -> None: - pass - - async def __call__(self, request: RequestRecord) -> RequestRecord: - raise NotImplementedError() - - -class OpenAIChatEndPoint(APIEndPoint): - """The backend of sending HTTP requests in OpenAI API through "v1/chat/completions".""" - - def __init__( # pylint: disable=too-many-arguments - self, - host: str, - port: int, - timeout: Optional[float] = None, - include_server_metrics: bool = False, - ) -> None: - super().__init__(include_server_metrics=include_server_metrics) - - import aiohttp # pylint: disable=import-outside-toplevel,import-error - - self.timeout = timeout - self.client: aiohttp.ClientSession = None - self.url = f"http://{host}:{port}/v1/chat/completions" - self.headers = {"Content-Type": "application/json"} - if os.getenv("MLC_LLM_API_KEY"): - self.headers["Authorization"] = f"Bearer {os.getenv('MLC_LLM_API_KEY')}" - - async def __aenter__(self) -> Self: - import aiohttp # pylint: disable=import-outside-toplevel,import-error - - self.client = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(self.timeout)) - return self - - async def __aexit__(self, exc_type, exc_value, tb) -> None: - await self.client.close() - - async def __call__( # pylint: disable=too-many-branches,too-many-statements,too-many-locals - self, request_record: RequestRecord - ) -> RequestRecord: - payload = request_record.chat_cmpl.model_dump() - if self.timeout is not None and "timeout" not in payload: - payload["timeout"] = self.timeout - if self.include_server_metrics: - if "stream_options" not in payload or payload["stream_options"] is None: - payload["stream_options"] = {"include_usage": True} - else: - payload["stream_options"]["include_usage"] = True - if ( - request_record.chat_cmpl.debug_config is not None - and request_record.chat_cmpl.debug_config.ignore_eos - ): - payload["ignore_eos"] = True - - generated_text = "" - first_chunk_output_str = "" - time_to_first_token_s = None - start_time = time.monotonic() - server_metrics = None - - try: - async with self.client.post(self.url, json=payload, headers=self.headers) as response: - assert response.status == 200, await response.text() - if payload["stream"]: - async for chunk in response.content: - chunk = chunk.strip() - if not chunk or chunk == b"\n": - continue - # Get rid of the prefix "data: " and suffix "\n" - raw_data = chunk[6:].strip() - if raw_data == b"[DONE]": - continue - data = json.loads(raw_data) - if not data["choices"]: - continue - delta = data["choices"][0]["delta"] - content = delta.get("content", None) - if content is not None and not time_to_first_token_s: - time_to_first_token_s = time.monotonic() - start_time - first_chunk_output_str = content - if self.include_server_metrics and data["usage"] is not None: - # fmt: off - # pylint: disable=line-too-long - server_metrics = ServerMetrics( - input_tokens=data["usage"]["extra"]["prompt_tokens"], - prefill_tokens=data["usage"]["extra"]["prefill_tokens"], - output_tokens=data["usage"]["extra"]["completion_tokens"], - end_to_end_latency_s=data["usage"]["extra"]["end_to_end_latency_s"], - prefill_tokens_per_s=data["usage"]["extra"]["prefill_tokens_per_s"], - inter_token_latency_s=data["usage"]["extra"]["inter_token_latency_s"], - time_per_output_token_s=1 / data["usage"]["extra"]["decode_tokens_per_s"], - time_to_first_token_s=data["usage"]["extra"]["ttft_s"], - ) - # pylint: enable=line-too-long - # fmt: on - - if content is not None: - generated_text += content - else: - data = await response.json() - generated_text = data["choices"][0]["message"]["content"] - if self.include_server_metrics and data["usage"] is not None: - # fmt: off - # pylint: disable=line-too-long - server_metrics = ServerMetrics( - input_tokens=data["usage"]["extra"]["prompt_tokens"], - prefill_tokens=data["usage"]["extra"]["prefill_tokens"], - output_tokens=data["usage"]["extra"]["completion_tokens"], - end_to_end_latency_s=data["usage"]["extra"]["end_to_end_latency_s"], - prefill_tokens_per_s=data["usage"]["extra"]["prefill_tokens_per_s"], - inter_token_latency_s=data["usage"]["extra"]["inter_token_latency_s"], - time_per_output_token_s=1 / data["usage"]["extra"]["decode_tokens_per_s"], - time_to_first_token_s=data["usage"]["extra"]["ttft_s"], - ) - # pylint: enable=line-too-long - # fmt: on - except Exception: # pylint: disable=broad-except - error_msg = "API endpoint errored when sending request: " + traceback.format_exc() - logger.info(error_msg) - finish_time = time.monotonic() - request_record.output_str = generated_text - request_record.first_chunk_output_str = first_chunk_output_str - request_record.metrics = Metrics( - success=False, - start_time=start_time, - finish_time=finish_time, - end_to_end_latency_s=finish_time - start_time, - input_tokens=request_record.metrics.input_tokens, - time_to_first_token_s=time_to_first_token_s, - server_metrics=server_metrics, - exec_feature=request_record.metrics.exec_feature, - ) - request_record.error_msg = error_msg - return request_record - - finish_time = time.monotonic() - request_record.output_str = generated_text - request_record.first_chunk_output_str = first_chunk_output_str - success = True - error_msg = None - if generated_text is None: - if data["choices"][0]["finish_reason"] == "tool_calls": - if data["choices"][0]["message"]["tool_calls"] is None or len(data["choices"][0]["message"]["tool_calls"]) == 0: - success = False - error_msg = "Invalid tool call." - else: - success = True - else: - success = False - error_msg = "Invalid response." - else: - if len(generated_text) == 0: - success = False - error_msg = "Empty generated text." - - message = ChatCompletionMessage( - role=data["choices"][0]["message"]["role"], - content=generated_text, - function_call=data["choices"][0]["message"].get("function_call", None), - tool_calls=data["choices"][0]["message"].get("tool_calls", None), - tool_call_id=data["choices"][0]["message"].get("tool_call_id", None), - ) - request_record.chat_cmpl.messages.append(message) - request_record.metrics = Metrics( - success=success, - start_time=start_time, - finish_time=finish_time, - end_to_end_latency_s=finish_time - start_time, - input_tokens=request_record.metrics.input_tokens, - time_to_first_token_s=time_to_first_token_s, - server_metrics=server_metrics, - exec_feature=request_record.metrics.exec_feature, - ) - request_record.error_msg = error_msg - return request_record - -SUPPORTED_BACKENDS = [ - "openai-chat", -] - - -def create_api_endpoint(args: argparse.Namespace) -> APIEndPoint: - """Create an API endpoint instance with regard to the specified endpoint kind.""" - if args.api_endpoint == "openai-chat": - return OpenAIChatEndPoint(args.host, args.port, args.timeout, args.include_server_metrics) - raise ValueError(f'Unrecognized endpoint "{args.api_endpoint}"') diff --git a/eval/dataset.py b/eval/dataset.py deleted file mode 100644 index 7807383435..0000000000 --- a/eval/dataset.py +++ /dev/null @@ -1,500 +0,0 @@ -"""MLC LLM benchmark dataset classes""" - -import argparse -import json -import os -import requests -import random -from datetime import datetime -import re -from typing import Any, Dict, List, Optional, Tuple - -import numpy as np -import pandas as pd # pylint: disable=import-error -from datasets import load_dataset # pylint: disable=import-error -from transformers import AutoTokenizer # pylint: disable=import-error - -from request_record import GroupedRequestRecord, Metrics, RequestRecord -from mlc_llm.protocol.openai_api_protocol import ( - ChatCompletionMessage, - ChatCompletionRequest, - ChatToolCall, - DebugConfig, -) - - -class Dataset: # pylint: disable=too-few-public-methods - """The dataset base class.""" - - # We set a truncation limit of 100k. - truncate_length = int(1e5) - # For some that datasets (e.g., dataset that has shared common prefix), - # we need fake warmup requests to avoid prefilling common prefixes to the engine. - require_fake_warmup: bool = False - # Whether the dataset contains timestamps already. - # If the dataset comes with timestamps, the benchmark can just replay - # the requests according to their timestamps. - timestamp_available: bool = False - - def generate_request_records( - self, - input_len: Optional[int], - output_len: Optional[int], - input_len_std: float = 0.0, - output_len_std: float = 0.0, - ) -> List[RequestRecord]: - """Get the raw unprocessed request records of the dataset.""" - raise NotImplementedError() - -GORILLA_TO_OPENAPI = { - "integer": "integer", - "number": "number", - "float": "number", - "string": "string", - "boolean": "boolean", - "bool": "boolean", - "array": "array", - "list": "array", - "dict": "object", - "object": "object", - "tuple": "array", - "any": "string", - "byte": "integer", - "short": "integer", - "long": "integer", - "double": "number", - "char": "string", - "ArrayList": "array", - "Array": "array", - "HashMap": "object", - "Hashtable": "object", - "Queue": "array", - "Stack": "array", - "Any": "string", - "String": "string", - "Bigint": "integer", -} - -class GorillaDataset(Dataset): # pylint: disable=too-few-public-methods - """The dataset class for Gorilla dataset. - Reference: https://github.com/ShishirPatil/gorilla - """ - - def __init__(self, dataset_path: str, tokenizer: AutoTokenizer, use_stag: bool) -> None: - self.tokenizer = tokenizer - self.require_fake_warmup = True - self.gorilla_data = [] - file_patterns = [ - "BFCL_v3_simple.json", - ] - base_url = "https://raw.githubusercontent.com/ShishirPatil/gorilla/main/berkeley-function-call-leaderboard/data" - - for filename in file_patterns: - id = 0 - dataset_file = f"{dataset_path}/{filename}" - if os.path.exists(dataset_file): - with open(dataset_file, mode="r", encoding="utf-8") as file: - self.gorilla_data = json.load(file) - else: - function_url = f"{base_url}/{filename}" - answer_url = f"{base_url}/possible_answer/{filename}" - print(f"Downloading {filename} from GitHub...") - functions_data = [] - answers_data = [] - try: - function_response = requests.get(function_url) - function_response.raise_for_status() - function_text = function_response.text - for line in function_text.strip().split("\n"): - if line.strip(): - try: - functions_data.append(json.loads(line)) - except json.JSONDecodeError as e: - print(f"Error parsing function line in {filename}: {e}") - answer_response = requests.get(answer_url) - answer_response.raise_for_status() - answer_text = answer_response.text - for line in answer_text.strip().split("\n"): - if line.strip(): - try: - answers_data.append(json.loads(line)) - except json.JSONDecodeError as e: - print(f"Error parsing answer line in {filename}: {e}") - print( - f"Successfully downloaded {filename}: {len(functions_data)} functions, {len(answers_data)} answers" - ) - except requests.RequestException as e: - print(f"Error downloading {filename}: {e}") - functions_data = [] - answers_data = [] - if not functions_data or not answers_data: - print(f"Skipping {filename} - failed to download data") - continue - print(f"Processing {filename}...") - answers_by_id = {item["id"]: item for item in answers_data} - for item in functions_data: - item_id = item["id"] - question = item["question"][0] - if item_id not in answers_by_id: - print(f"Warning: No answer found for item {item_id}") - continue - if "function" not in item or not item["function"]: - print(f"Warning: No function definition for item {item_id}") - continue - tool = [{"type": "function", "function": func} for func in item["function"]] - self.map_type_values(tool) - answer = answers_by_id[item_id] - if "ground_truth" not in answer or not answer["ground_truth"]: - print(f"Warning: No ground truth for item {item_id}") - continue - ideal_call = [] - for ground_truth in answer["ground_truth"]: - function_name = list(ground_truth.keys())[0] - params = ground_truth[function_name] - ideal_call.append({"name": function_name, "arguments": params}) - self.gorilla_data.append( - { - "id": id, - "question": question, - "tool": tool, - "ideal_call": ideal_call, - "source": filename, - } - ) - id += 1 - with open(dataset_file, mode="w", encoding="utf-8") as file: - json.dump(self.gorilla_data, file, ensure_ascii=False, indent=4) - if self.tokenizer is not None: - for item in self.gorilla_data: - num_tokens = 0 - for message in item["question"]: - num_tokens += len( - tokenizer.encode(message["content"], add_special_tokens=False) - ) - item["num_tokens"] = num_tokens - if not use_stag: - for item in self.gorilla_data: - for tool in item["tool"]: - tool["function"]["strict"] = False - - def generate_request_records( - self, - input_len: Optional[int], - output_len: Optional[int], - input_len_std: float = 0.0, - output_len_std: float = 0.0, - ) -> List[RequestRecord]: - - request_records = [] - for entry in self.gorilla_data: - # If the request does not have enough length, discard it. - # if input_len is not None and entry["num_tokens"] < input_len + 4 * input_len_std: - # continue - - if output_len is not None: - output_length = max( - round(np.random.normal(loc=output_len, scale=output_len_std)), 1 - ) - else: - output_length = 256 - request_records.append( - RequestRecord( - request_id=entry["id"], - chat_cmpl=ChatCompletionRequest( - messages=[ - ChatCompletionMessage(content=message["content"], role=message["role"]) - for message in entry["question"] - ], - model="", - max_tokens=output_length, - tools=entry["tool"], - ), - metrics=Metrics( - success=False, - start_time=0, - finish_time=0, - end_to_end_latency_s=0, - input_tokens=entry["num_tokens"], - ), - ) - ) - return request_records - - # Modified by https://github.com/ShishirPatil/gorilla/blob/main/berkeley-function-call-leaderboard/bfcl/eval_checker/ast_eval/ast_checker.py - def check_simple(self, tool_call: Dict[str, Any], - tool: Dict[str, Any], ideal: Dict[str, Any]) -> Tuple[bool, bool]: - # check func name - if ideal["name"] != tool_call["function"]["name"]: - return True, False - func = tool["function"] - # check func args - for arg in func["parameters"]["required"]: - if arg not in tool_call["function"]["arguments"]: - return True, False - for arg in tool_call["function"]["arguments"].keys(): - ideal_arg: List = ideal["arguments"][arg] if arg in ideal["arguments"] else None - real_arg = tool_call["function"]["arguments"][arg] - if arg not in func["parameters"]["properties"]: - return True, False - info_arg = func["parameters"]["properties"][arg] - if info_arg["type"] == "integer": - if not self.check_integer(real_arg, ideal_arg): - return True, False - elif info_arg["type"] == "number": - if not self.check_number(real_arg, ideal_arg): - return True, False - elif info_arg["type"] == "boolean": - if not self.check_boolean(real_arg, ideal_arg): - return True, False - elif info_arg["type"] == "string": - enum = info_arg["enum"] if "enum" in info_arg else None - if not self.check_string(real_arg, ideal_arg, enum): - return True, False - elif info_arg["type"] == "array": - if not self.check_list(real_arg, ideal_arg, info_arg["items"]): - return True, False - elif info_arg["type"] == "dict": - if not self.check_dict(real_arg, ideal_arg, info_arg["properties"]): - return True, False - return True, True - - - - def check_integer(self, real_arg: Any, ideal_arg: Optional[List[Any]]) -> bool: - try: - if type(real_arg) != int: - return False - if ideal_arg is None: - return True - match = False - for ideal in ideal_arg: - if real_arg == ideal: - match = True - break - return match - except: - return False - - def check_number(self, real_arg: Any, ideal_arg: Optional[List[Any]]) -> bool: - if type(real_arg) != float and type(real_arg) != int: - return False - if ideal_arg is None: - return True - match = False - for ideal in ideal_arg: - if real_arg == ideal: - match = True - break - return match - - def check_string(self, real_arg: Any, ideal_arg: Optional[List[Any]], enum: Optional[List[str]]) -> bool: - - def standardize_string(string: Any) -> str: - if not isinstance(string, str): - return "Error><><><><><>" - regex_string = r"[ \,\.\/\-\_\*\^]" - return re.sub(regex_string, "", string).lower().replace("'", '"') - - if type(real_arg) != str: - return False - match = False - real_arg = standardize_string(real_arg) - if ideal_arg is None: - if enum is None: - return True - else: - for ideal in enum: - if real_arg == standardize_string(ideal): - match = True - break - else: - for ideal in ideal_arg: - if real_arg == standardize_string(ideal): - match = True - break - return match - - def check_boolean(self, real_arg: bool, ideal_arg: Optional[List[bool]]) -> bool: - if type(real_arg) != bool: - return False - if ideal_arg is None: - return True - match = False - for ideal in ideal_arg: - if real_arg == ideal: - match = True - break - return match - - def check_list(self, real_arg: List, ideal_arg: Optional[List[List]], item: Dict[str, Any]) -> bool: - if type(real_arg) != list: - return False - item_type = item["type"] - if ideal_arg is None: - if item_type == "integer": - for i, integer in enumerate(real_arg): - if not self.check_integer(integer, None): - return False - elif item_type == "number": - for i, integer in enumerate(real_arg): - if not self.check_number(integer, None): - return False - elif item_type == "boolean": - for i, boolean in enumerate(real_arg): - if not self.check_boolean(boolean, None): - return False - elif item_type == "string": - for i, string in enumerate(real_arg): - enum = item["enum"] if "enum" in item else None - if not self.check_string(string, None, enum): - return False - elif item_type == "array": - for i, array in enumerate(real_arg): - if not self.check_list(array, None, item["items"]): - return False - elif item_type == "dict": - for i, dictionary in enumerate(real_arg): - if not self.check_dict(dictionary, None, item["properties"]): - return False - return True - else: - for ideal in ideal_arg: - if len(ideal) != len(real_arg): - continue - match = True - if item_type == "integer": - for i, integer in enumerate(real_arg): - if not self.check_integer(integer, [ideal[i]]): - match = False - break - elif item_type == "number": - for i, integer in enumerate(real_arg): - if not self.check_number(integer, [ideal[i]]): - match = False - break - elif item_type == "boolean": - for i, boolean in enumerate(real_arg): - if not self.check_boolean(boolean, [ideal[i]]): - match = False - break - elif item_type == "string": - for i, string in enumerate(real_arg): - enum = item["enum"] if "enum" in item else None - if not self.check_string(string, [ideal[i]], enum): - match = False - break - elif item_type == "array": - for i, array in enumerate(real_arg): - if not self.check_list(array, [ideal[i]], item["items"]): - match = False - break - elif item_type == "dict": - for i, dictionary in enumerate(real_arg): - if not self.check_dict(dictionary, [ideal[i]], item["properties"]): - match = False - break - if match: - return True - return False - - def check_dict(self, real_arg: Dict[str, Any], ideal_arg: Optional[Dict[str, Any]], properties: Dict[str, Any]) -> bool: - if type(real_arg) != dict: - return False - if ideal_arg is None: - for key in properties.keys(): - if key not in real_arg: - return False - item_type = properties[key]["type"] - if item_type == "integer": - if not self.check_integer(real_arg[key], None): - return False - elif item_type == "number": - if not self.check_number(real_arg[key], None): - return False - elif item_type == "boolean": - if not self.check_boolean(real_arg[key], None): - return False - elif item_type == "string": - enum = properties[key]["enum"] if "enum" in properties[key] else None - if not self.check_string(real_arg[key], None, enum): - return False - elif item_type == "array": - if not self.check_list(real_arg[key], None, properties[key]["items"]): - return False - elif item_type == "dict": - if not self.check_dict(real_arg[key], None, properties[key]["properties"]): - return False - return True - else: - for ideal in ideal_arg: - match = True - for key in properties.keys(): - if key not in real_arg: - match = False - break - item_type = properties[key]["type"] - if item_type == "integer": - if not self.check_integer(real_arg[key], [ideal[key]]): - match = False - break - elif item_type == "number": - if not self.check_number(real_arg[key], [ideal[key]]): - match = False - break - elif item_type == "boolean": - if not self.check_boolean(real_arg[key], [ideal[key]]): - match = False - break - elif item_type == "string": - enum = properties[key]["enum"] if "enum" in properties[key] else None - if not self.check_string(real_arg[key], [ideal[key]], enum): - match = False - break - elif item_type == "array": - if not self.check_list(real_arg[key], [ideal[key]], properties[key]["items"]): - match = False - break - elif item_type == "dict": - if not self.check_dict(real_arg[key], [ideal[key]], properties[key]["properties"]): - match = False - break - if match: - return True - return False - - def map_type_values(self, data): - if isinstance(data, dict): - for key, value in data.items(): - if isinstance(value, (dict, list)): - self.map_type_values(value) - elif key == "type" and value in GORILLA_TO_OPENAPI: - data[key] = GORILLA_TO_OPENAPI[value] - elif isinstance(data, list): - for item in data: - if isinstance(item, (dict, list)): - self.map_type_values(item) - - - -SUPPORTED_DATASET = [ - "gorilla" -] - - -def create_dataset( # pylint: disable=too-many-return-statements,too-many-branches - args: argparse.Namespace, tokenizer: AutoTokenizer -) -> Dataset: - """Create a dataset instance with regard to the specified dataset kind and file path.""" - if args.dataset_path is not None and not isinstance(args.dataset_path, str): - raise TypeError(f"Invalid dataset path {args.dataset_path}. Please use a string.") - if args.dataset == "gorilla": - if args.dataset_path is None: - raise ValueError( - "Gorilla dataset requires dataset path. " - 'Please specify it with "--dataset-path".' - ) - assert ( - args.apply_chat_template is False - ), "Gorilla dataset does not support applying chat template" - return GorillaDataset(args.dataset_path, tokenizer) - raise ValueError(f"Unrecognized dataset {args.dataset}") diff --git a/eval/request_processor.py b/eval/request_processor.py deleted file mode 100644 index 99de2d7293..0000000000 --- a/eval/request_processor.py +++ /dev/null @@ -1,738 +0,0 @@ -"""MLC LLM Bench Request""" - -import argparse -import asyncio -import concurrent.futures -import copy -import os -import random -import time -from typing import Any, Callable, Dict, List, Optional, Tuple - -import numpy as np -import requests -from tqdm import tqdm -from transformers import AutoTokenizer # pylint: disable=import-error - -from api_endpoint import APIEndPoint -from dataset import Dataset -from request_record import GroupedRequestRecord, RequestRecord -from mlc_llm.protocol.openai_api_protocol import ( - ChatCompletionMessage, - ChatCompletionRequest, - DebugConfig, -) -from mlc_llm.support import logging - -logger = logging.getLogger(__name__) - - -class RequestProcessor: # pylint: disable=too-few-public-methods - """The request processor base class. - Each processor can take a list of RequestRecord, applying the process, - and returning the processed RequestRecord in the end. - """ - - def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]: - raise NotImplementedError() - - -class LogMessage(RequestProcessor): # pylint: disable=too-few-public-methods - """The processor that prints the logger message.""" - - def __init__(self, message: str) -> None: - self.message = message - - def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]: - logger.info(self.message) - return request_records - - -class SampleRequests(RequestProcessor): # pylint: disable=too-few-public-methods - """The processor that samples requests out from the given request list.""" - - def __init__(self, num_requests: int, take_first_x_requests: bool = True) -> None: - self.num_requests = num_requests - # If `take_first_x_requests` is True, the first `num_requests` requests - # are returned and sampling will not happen. - self.take_first_x_requests = take_first_x_requests - - def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]: - assert len(request_records) > 0, "Empty input request record." - - # We expect the input request records to be all grouped or all plain. - if isinstance(request_records[0], GroupedRequestRecord): - assert all(isinstance(record, GroupedRequestRecord) for record in request_records) - return self._sample_from_grouped_request_records(request_records) - - assert all(not isinstance(record, GroupedRequestRecord) for record in request_records) - return self._sample_from_plain_request_records(request_records) - - def _sample_from_plain_request_records( - self, request_records: List[RequestRecord] - ) -> List[RequestRecord]: - samples: List[RequestRecord] = [] - if self.take_first_x_requests: - if len(request_records) < self.num_requests: - raise ValueError( - f"Insufficient requests. Requiring {self.num_requests} requests " - f"but only {len(request_records)} are available." - ) - samples = copy.deepcopy(list(request_records[: self.num_requests])) - else: - while len(samples) < self.num_requests: - # Create a new list so that the in-place shuffle does not mutate the input list. - records = list(request_records) - random.shuffle(records) - samples += copy.deepcopy(records) - samples = samples[: self.num_requests] - for i, record in enumerate(samples): - record.request_id = i - return samples - - def _sample_from_grouped_request_records( - self, grouped_request_records: List[GroupedRequestRecord] - ) -> List[RequestRecord]: - num_total_available_requests = sum( - len(record.records) for record in grouped_request_records - ) - if self.num_requests > num_total_available_requests: - raise ValueError( - "Due to the existence of shared common prefixes, we do not allow " - "benchmarking with requests more than the available requests in the dataset. " - f"The required number of requests {self.num_requests} exceeds the " - f"number of total available requests {num_total_available_requests}." - ) - - # Create a new list so that the in-place shuffle does not mutate the input list. - records = list(grouped_request_records) - if not self.take_first_x_requests: - random.shuffle(records) - remaining = self.num_requests - samples: List[RequestRecord] = [] - for grouped_request_record in grouped_request_records: - num_used_requests = min(len(grouped_request_record.records), remaining) - samples += grouped_request_record.records[:num_used_requests] - remaining -= num_used_requests - if remaining == 0: - break - for i, record in enumerate(samples): - record.request_id = i - return samples - - -class AttachModelName(RequestProcessor): # pylint: disable=too-few-public-methods - """The processor that attaches model name to requests.""" - - def __init__(self, model: str) -> None: - self.model = model - - def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]: - for request_record in request_records: - request_record.chat_cmpl.model = self.model - return request_records - - -class AttachRequestRateTimestamp(RequestProcessor): # pylint: disable=too-few-public-methods - """The processor that applies timestamps to the requests.""" - - def __init__(self, request_rate: np.float32) -> None: - self.request_rate = request_rate - - def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]: - timestamp = 0.0 - for request_record in request_records: - assert request_record.timestamp is None, "The request record already has a timestamp" - request_record.timestamp = timestamp - timestamp += float(np.random.exponential(1.0 / self.request_rate)) - return request_records - - -class AttachExecutionFeature(RequestProcessor): # pylint: disable=too-few-public-methods - """The processor that attaches execution features to all requests""" - - def __init__(self, exec_feature: Dict[str, Any]) -> None: - self.exec_feature = exec_feature - - def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]: - for request_record in request_records: - assert request_record.metrics is not None - request_record.metrics.exec_feature = self.exec_feature - return request_records - - -class AttachStreamFlag(RequestProcessor): # pylint: disable=too-few-public-methods - """The processor that attaches the stream flag to the requests.""" - - def __init__(self, stream: Optional[bool]) -> None: - self.stream = stream - - def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]: - if self.stream is None: - return request_records - for request_record in request_records: - request_record.chat_cmpl.stream = self.stream - return request_records - - -class AttachSamplingOptions(RequestProcessor): # pylint: disable=too-few-public-methods - """The processor that attaches the stream flag to the requests.""" - - def __init__(self, temperature: float, top_p: float, ignore_eos: bool) -> None: - self.temperature = temperature - self.top_p = top_p - self.ignore_eos = ignore_eos - - def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]: - for request_record in request_records: - request_record.chat_cmpl.temperature = self.temperature - request_record.chat_cmpl.top_p = self.top_p - request_record.chat_cmpl.frequency_penalty = 0.0 - request_record.chat_cmpl.presence_penalty = 0.0 - # request_record.chat_cmpl.tool_choice = "none" - if self.ignore_eos: - request_record.chat_cmpl.debug_config = DebugConfig(ignore_eos=True) - return request_records - - -class ScaleTimestamp(RequestProcessor): # pylint: disable=too-few-public-methods - """Scale the timestamp of requests by the given scale factor.""" - - def __init__(self, timestamp_scale: float): - self.timestamp_scale = timestamp_scale - - def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]: - for request_record in request_records: - if request_record.timestamp is None: - raise ValueError( - f"The timestamp of request {request_record} has not been initialized." - ) - request_record.timestamp *= self.timestamp_scale - return request_records - - -class MetricAnalyzer(RequestProcessor): # pylint: disable=too-few-public-methods - """The processor that analyzes the raw benchmark results and computes more detailed metrics.""" - - def __init__(self, tokenizer: AutoTokenizer) -> None: - self.tokenizer = tokenizer - - def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]: - updated_records = [] - for request_record in request_records: - metrics = request_record.metrics - if not metrics.success: - assert request_record.error_msg is not None - continue - - metrics.output_tokens = len( - self.tokenizer.encode(request_record.output_str, add_special_tokens=False) - ) - first_chunk_output_tokens = len( - self.tokenizer.encode( - request_record.first_chunk_output_str, add_special_tokens=False - ) - ) - if metrics.output_tokens <= first_chunk_output_tokens: - metrics.success = False - request_record.error_msg = ( - f"Total output token num ({metrics.output_tokens}) equals " - f'the first chunk output token. Output text "{request_record.output_str}", ' - f'first chunk output text "{request_record.first_chunk_output_str}"' - ) - continue - assert metrics.input_tokens > 0, "Invalid prompt tokens" - metrics.inter_token_latency_s = metrics.end_to_end_latency_s / metrics.output_tokens - if metrics.time_to_first_token_s is None: - metrics.time_to_first_token_s = 0 - metrics.time_per_output_token_s = ( - metrics.end_to_end_latency_s - metrics.time_to_first_token_s - ) / (metrics.output_tokens - first_chunk_output_tokens) - updated_records.append(request_record) - return updated_records - - -class WarmupAndRun(RequestProcessor): # pylint: disable=too-few-public-methods,line-too-long - """The processor that runs warmup first and then runs the benchmark with the given pipeline.""" - - def __init__( # pylint: disable=too-many-arguments - self, - num_warmup_requests: int, - num_benchmark_requests: int, - pipeline: RequestProcessor, - cuda_profile_url: Optional[str], - fake_warmup: bool = False, - ) -> None: - self.num_warmup_requests = num_warmup_requests - self.num_benchmark_requests = num_benchmark_requests - self.pipeline = pipeline - self.cuda_profile_url = cuda_profile_url - self.fake_warmup = fake_warmup - - def generate_fake_warmup_requests( # pylint: disable=missing-function-docstring - self, num_warmup_requests: int, example_request: RequestRecord - ) -> List[RequestRecord]: - records = [] - for _ in range(num_warmup_requests): - record = copy.deepcopy(example_request) - record.chat_cmpl = ChatCompletionRequest( - messages=[ - { - "role": "user", - "content": "Please output arbitrary coherent sentences. Do not output eos token.", # pylint: disable=line-too-long - } - ], - model="", - max_tokens=128, - ) - records.append(record) - return records - - def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]: - # Warmup - if self.fake_warmup: - assert len(request_records) == self.num_benchmark_requests - benchmark_requests = request_records - example_request = benchmark_requests[0] - warmup_requests = self.generate_fake_warmup_requests( - self.num_warmup_requests, example_request=example_request - ) - else: - assert len(request_records) == self.num_warmup_requests + self.num_benchmark_requests - benchmark_requests = request_records[: -self.num_warmup_requests] - warmup_requests = request_records[-self.num_warmup_requests :] - for request_record in warmup_requests: - request_record.timestamp = 0 if request_record.timestamp is not None else None - warmup_requests = self._process_warmup_requests(warmup_requests) - logger.info("Warmup with %d request(s)...", self.num_warmup_requests) - self.pipeline(warmup_requests) - - # Then run benchmark - if self.cuda_profile_url is not None: - cuda_profiler_start_url = self.cuda_profile_url + "/debug/cuda_profiler_start" - cuda_profiler_start_response = requests.post(cuda_profiler_start_url, timeout=60) - assert cuda_profiler_start_response.status_code == 200 - logger.info("Warmup finished. Start benchmarking...") - updated_request_records = self.pipeline(benchmark_requests) - if self.cuda_profile_url is not None: - cuda_profiler_stop_url = self.cuda_profile_url + "/debug/cuda_profiler_stop" - cuda_profiler_stop_response = requests.post(cuda_profiler_stop_url, timeout=60) - assert cuda_profiler_stop_response.status_code == 200 - - return updated_request_records - - def _process_warmup_requests(self, warmup_requests: List[RequestRecord]) -> List[RequestRecord]: - if len(warmup_requests) == 0: - return warmup_requests - # NOTE: to warm up the server for as more different batch sizes as possible, - # we usese 128 output tokens for the first request and use two more tokens - # for every followup request. - # Setting a high temperature and top-p to avoid early stop as much as possible. - warmup_requests[0].chat_cmpl.max_tokens = 128 - for i in range(1, len(warmup_requests)): - warmup_requests[i].chat_cmpl.max_tokens = ( - warmup_requests[i - 1].chat_cmpl.max_tokens + 1 - ) - warmup_requests[i].chat_cmpl.temperature = 2.0 - warmup_requests[i].chat_cmpl.top_p = 1.0 - return warmup_requests - - -class SequentialProcessor(RequestProcessor): # pylint: disable=too-few-public-methods - """The processor that sequentially applies a list of processors in order.""" - - processors: List[RequestProcessor] - - def __init__(self, *processors: RequestProcessor) -> None: - self.processors = list(processors) - - def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]: - for processor in self.processors: - request_records = processor(request_records) - return request_records - - -class Executor(RequestProcessor): # pylint: disable=too-few-public-methods - """The executor base class, denoting the kind of benchmark mode.""" - - def __init__( - self, - f_create_api_endpoint: Callable[[], APIEndPoint], - num_processes: int, - disable_tqdm: bool, - ) -> None: - self.f_create_api_endpoint = f_create_api_endpoint - self.disable_tqdm = disable_tqdm - self.num_processes = num_processes - - def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]: - raise NotImplementedError() - - -class FixedConcurrentRequestExecutor(Executor): # pylint: disable=too-few-public-methods - """The benchmark executor of fixing the number of concurrent requests.""" - - def __init__( # pylint: disable=too-many-arguments - self, - f_create_api_endpoint: Callable[[], APIEndPoint], - num_processes: Optional[int], - disable_tqdm: bool, - num_concurrent_requests: int, - multi_round: bool, - ) -> None: - if num_processes is None: - # We assign each process at most 32 concurrent requests to send - # so that the asyncio pressure will not be too much. - num_processes = min((num_concurrent_requests + 31) // 32, 10) - super().__init__(f_create_api_endpoint, num_processes, disable_tqdm) - self.num_concurrent_requests = num_concurrent_requests - self.multi_round = multi_round - - def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]: - partitions: List[List[RequestRecord]] = [ - request_records[slice(i, len(request_records), self.num_processes)] - for i in range(self.num_processes) - ] - # Package "tokenizers" reports warnings with multiprocessing. - # We disable "TOKENIZERS_PARALLELISM" to depress the warnings. - os.environ["TOKENIZERS_PARALLELISM"] = "false" - - pbar = None if self.disable_tqdm else tqdm(total=len(request_records)) - with concurrent.futures.ProcessPoolExecutor(max_workers=self.num_processes) as pool: - futures = [ - pool.submit( - FixedConcurrentRequestExecutor._process_task, - self.f_create_api_endpoint, - partition, - self.num_concurrent_requests // self.num_processes - + int(i < self.num_concurrent_requests % self.num_processes), - self.multi_round, - ) - for i, partition in enumerate(partitions) - ] - results: List[RequestRecord] = [] - for i, future in enumerate(concurrent.futures.as_completed(futures)): - results.extend(future.result()) - if pbar is not None: - pbar.update(len(partitions[i])) - - return results - - @staticmethod - def _process_task( - f_create_api_endpoint: Callable[[], APIEndPoint], - request_records: List[RequestRecord], - num_concurrent_requests: int, - multi_round: bool, - ) -> List[RequestRecord]: - if len(request_records) == 0: - return [] - chat_history: List[List[ChatCompletionMessage]] = [ - [] for _ in range(num_concurrent_requests) - ] - - async def process_task_impl( - f_create_api_endpoint: Callable[[], APIEndPoint], - request_records: List[RequestRecord], - num_concurrent_requests: int, - multi_round: bool, - ) -> List[RequestRecord]: - api_endpoint = f_create_api_endpoint() - updated_request_records: List[RequestRecord] = [None for _ in request_records] - async with api_endpoint: - num_sent_request = 0 - - async def _task(i: int) -> None: - nonlocal num_sent_request - while True: - if num_sent_request == len(request_records): - break - idx = num_sent_request - num_sent_request += 1 - request = request_records[idx] - - if multi_round: - request.chat_cmpl.messages = ( - chat_history[i] + request.chat_cmpl.messages - ) - - updated_request_records[idx] = await api_endpoint(request) - - if multi_round: - chat_history[i] = updated_request_records[idx].chat_cmpl.messages + [ - ChatCompletionMessage( - content=updated_request_records[idx].output_str, - role="assistant", - ) - ] - - tasks = [asyncio.create_task(_task(i)) for i in range(num_concurrent_requests)] - await asyncio.gather(*tasks) - - return updated_request_records - - return asyncio.run( - process_task_impl( - f_create_api_endpoint, - request_records, - num_concurrent_requests, - multi_round, - ) - ) - - -class FixTimestampExecutor(Executor): # pylint: disable=too-few-public-methods - """The benchmark executor of fixing the timestamps of sending requests.""" - - def __init__( # pylint: disable=too-many-arguments - self, - f_create_api_endpoint: Callable[[], APIEndPoint], - num_processes: Optional[int], - disable_tqdm: bool, - max_schedule_gap: float, - num_requests: int, - ) -> None: - if num_processes is None: - # We assign each process at most 32 requests to send - # so that the asyncio pressure will not be too much. - num_processes = min((num_requests + 31) // 32, 10) - super().__init__(f_create_api_endpoint, num_processes, disable_tqdm) - self.max_schedule_gap = max_schedule_gap - self.num_requests = num_requests - - def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]: - assert len(request_records) > 0 - assert all(request_record.timestamp is not None for request_record in request_records) - # Sort the request records in timestamp ascending order before partitioning. - request_records.sort(key=lambda request_record: request_record.timestamp) - base_timestamp = request_records[0].timestamp - partitions: List[List[RequestRecord]] = [ - request_records[slice(i, len(request_records), self.num_processes)] - for i in range(self.num_processes) - ] - base_sys_time = time.time() - # Package "tokenizers" reports warnings with multiprocessing. - # We disable "TOKENIZERS_PARALLELISM" to depress the warnings. - os.environ["TOKENIZERS_PARALLELISM"] = "false" - - pbar = None if self.disable_tqdm else tqdm(total=len(request_records)) - with concurrent.futures.ProcessPoolExecutor(max_workers=self.num_processes) as pool: - futures = [ - pool.submit( - FixTimestampExecutor._process_task, - self.f_create_api_endpoint, - partition, - base_timestamp, - base_sys_time, - self.max_schedule_gap, - ) - for partition in partitions - ] - results: List[RequestRecord] = [] - for i, future in enumerate(concurrent.futures.as_completed(futures)): - results.extend(future.result()) - if pbar is not None: - pbar.update(len(partitions[i])) - - return results - - @staticmethod - def _process_task( - f_create_api_endpoint: Callable[[], APIEndPoint], - request_records: List[RequestRecord], - base_timestamp: float, - base_sys_time: float, - max_schedule_gap: float, - ) -> List[RequestRecord]: - if len(request_records) == 0: - return [] - - async def process_task_impl( - f_create_api_endpoint: Callable[[], APIEndPoint], - request_records: List[RequestRecord], - base_timestamp: float, - base_sys_time: float, - max_schedule_gap: float, - ) -> List[RequestRecord]: - api_endpoint = f_create_api_endpoint() - loop = asyncio.get_running_loop() - # Get the delta time to convert system time to the loop time. - # We must use the system time `time.time()` which is consistent across processes. - loop_sys_delta_time = loop.time() - time.time() - updated_request_records: List[RequestRecord] = [] - async with api_endpoint: - - async def _task(request_record: RequestRecord) -> None: - updated_request_records.append(await api_endpoint(request_record)) - - tasks = [] - for request_record in request_records: - launch_time = ( - (request_record.timestamp - base_timestamp) - + (base_sys_time + max_schedule_gap) - + loop_sys_delta_time - ) - loop.call_at( - launch_time, - lambda record: tasks.append(asyncio.create_task(_task(record))), - request_record, - ) - # Sleep to allow runs of other scheduled tasks if any. - await asyncio.sleep(max(launch_time - loop.time() - max_schedule_gap, 0)) - - # Sleep until all the tasks are launched. - await asyncio.sleep(launch_time - loop.time() + max_schedule_gap) - # Wait for all tasks to be scheduled - assert len(tasks) == len(request_records) - await asyncio.gather(*tasks) - - assert len(updated_request_records) == len(request_records) - return updated_request_records - - return asyncio.run( - process_task_impl( - f_create_api_endpoint, - request_records, - base_timestamp, - base_sys_time, - max_schedule_gap, - ) - ) - - -def create_pipelines( # pylint: disable=too-many-branches - args: argparse.Namespace, f_create_api_endpoint: Callable[[], APIEndPoint], dataset: Dataset -) -> List[RequestProcessor]: - """Creating request processing pipelines with regard to the specified args.""" - cuda_profile_url = f"http://{args.host}:{args.port}" if args.cuda_profile else None - pipelines: List[RequestProcessor] = [] - if args.num_concurrent_requests is not None: - if args.request_rate is not None: - raise ValueError( - 'Both "num_concurrent_requests" and "request_rate" are specified. ' - "Please specify only one of them." - ) - if args.replay_timestamp_scale is not None: - raise ValueError( - "Dataset replay is unsupported when fixing number of concurrent requests." - ) - for num_concurrent_requests in args.num_concurrent_requests: - num_warmup_requests = ( - args.num_warmup_requests - if args.num_warmup_requests is not None - else num_concurrent_requests - ) - pipelines.append( - SequentialProcessor( - LogMessage(f"Fixing number of concurrent requests: {num_concurrent_requests}"), - SampleRequests(args.num_requests + num_warmup_requests), - AttachModelName(args.tokenizer), - AttachStreamFlag(args.stream), - AttachSamplingOptions(args.temperature, args.top_p, args.ignore_eos), - AttachExecutionFeature({"num_concurrent_requests": num_concurrent_requests}), - WarmupAndRun( - num_warmup_requests=num_warmup_requests, - num_benchmark_requests=args.num_requests, - pipeline=FixedConcurrentRequestExecutor( - f_create_api_endpoint, - args.num_process_workers, - args.disable_tqdm, - num_concurrent_requests, - args.multi_round, - ), - cuda_profile_url=cuda_profile_url, - fake_warmup=dataset.require_fake_warmup, - ), - ) - ) - return pipelines - if args.request_rate is not None: - if args.num_warmup_requests is None: - raise ValueError( - "Please specify the number of warmup requests via " - '"--num-warmup-requests" when fixing request rate.' - ) - if args.replay_timestamp_scale is not None: - raise ValueError("Dataset replay is unsupported when fixing request rates.") - num_total_requests = int( - args.num_requests if not args.per_gpu_workload else args.num_requests * args.num_gpus - ) - if dataset.require_fake_warmup: - num_samples = num_total_requests - else: - num_samples = num_total_requests + args.num_warmup_requests - return [ - SequentialProcessor( - LogMessage(f"Fixing request rate: {request_rate}"), - SampleRequests(num_samples), - AttachModelName(args.tokenizer), - AttachRequestRateTimestamp( - request_rate if not args.per_gpu_workload else request_rate * args.num_gpus - ), - AttachStreamFlag(args.stream), - AttachSamplingOptions(args.temperature, args.top_p, args.ignore_eos), - AttachExecutionFeature({"request_rate": float(request_rate)}), - WarmupAndRun( - num_warmup_requests=args.num_warmup_requests, - num_benchmark_requests=num_total_requests, - pipeline=FixTimestampExecutor( - f_create_api_endpoint, - args.num_process_workers, - args.disable_tqdm, - args.max_schedule_gap, - args.num_requests, - ), - cuda_profile_url=cuda_profile_url, - fake_warmup=dataset.require_fake_warmup, - ), - ) - for request_rate in args.request_rate - ] - - # Default: dataset replay mode - # The dataset must come with timestamps. - if not dataset.timestamp_available: - raise ValueError( - "The dataset does not have timestamps, so dataset replay is unsupported. " - 'Please specify one of "num_concurrent_requests" ' - 'and "request_rate".' - ) - if args.per_gpu_workload: - raise ValueError("Fixing per-GPU workload is not compatible with dataset replay.") - if args.num_warmup_requests is None: - raise ValueError( - "Please specify the number of warmup requests via " - '"--num-warmup-requests" for dataset replay.' - ) - timestamp_scale = args.replay_timestamp_scale or 1.0 - if dataset.require_fake_warmup: - num_samples = args.num_requests - else: - num_samples = args.num_requests + args.num_warmup_requests - return [ - SequentialProcessor( - LogMessage(f"Dataset replay with time scaling of {timestamp_scale}"), - SampleRequests(num_samples, take_first_x_requests=True), - AttachModelName(args.tokenizer), - ScaleTimestamp(timestamp_scale), - AttachStreamFlag(args.stream), - AttachSamplingOptions(args.temperature, args.top_p, args.ignore_eos), - AttachExecutionFeature({"timestamp_scale": timestamp_scale}), - WarmupAndRun( - num_warmup_requests=args.num_warmup_requests, - num_benchmark_requests=args.num_requests, - pipeline=FixTimestampExecutor( - f_create_api_endpoint, - args.num_process_workers, - args.disable_tqdm, - args.max_schedule_gap, - args.num_requests, - ), - cuda_profile_url=cuda_profile_url, - fake_warmup=dataset.require_fake_warmup, - ), - ) - ] - - - \ No newline at end of file diff --git a/eval/request_record.py b/eval/request_record.py deleted file mode 100644 index 774519b7d3..0000000000 --- a/eval/request_record.py +++ /dev/null @@ -1,209 +0,0 @@ -"""MLC LLM Bench Request""" - -from typing import Any, Dict, List, Optional, Tuple, Union - -import pandas as pd # pylint: disable=import-error -from pydantic import BaseModel - -from mlc_llm.protocol.openai_api_protocol import ChatCompletionRequest -from mlc_llm.support import logging - -logger = logging.getLogger(__name__) - - -class ServerMetrics(BaseModel): - """The metrics from the server side.""" - - input_tokens: int - prefill_tokens: int - output_tokens: int - end_to_end_latency_s: float - prefill_tokens_per_s: float - inter_token_latency_s: float - time_per_output_token_s: float - time_to_first_token_s: Optional[float] = None - - -class Metrics(BaseModel): - """The list of metric keys""" - - success: bool - start_time: float - finish_time: float - end_to_end_latency_s: float - - input_tokens: Optional[int] = None - output_tokens: Optional[int] = None - inter_token_latency_s: Optional[float] = None - time_per_output_token_s: Optional[float] = None - time_to_first_token_s: Optional[float] = None - server_metrics: Optional[ServerMetrics] = None - - exec_feature: Optional[Dict[str, Any]] = None - - -class RequestRecord(BaseModel): - """The request records collected from LLM inference requests.""" - - request_id: Optional[int] = None - chat_cmpl: ChatCompletionRequest - output_str: Optional[str] = None - first_chunk_output_str: str = "" - timestamp: Optional[float] = None - metrics: Optional[Metrics] = None - error_msg: Optional[str] = None - - -class GroupedRequestRecord(RequestRecord): - """The data structure for request record groups. - For datasets that have common prefix sharing, the request records - that share a same common prefix will be wrapped in a GroupedRequestRecord - at the beginning. - """ - - records: List[RequestRecord] - - -def generate_metrics_summary( - request_records: List[RequestRecord], - num_total_requests: int, - num_gpus: int, -) -> Dict[str, Any]: - """Computes summary statistics across all metrics collected. - Return a dictionary as the report. - """ - num_completed_requests = len(request_records) - assert num_completed_requests <= num_total_requests - request_metrics = [record.metrics for record in request_records] - duration = ( - max(metrics.finish_time for metrics in request_metrics) - - min(metrics.start_time for metrics in request_metrics) - if num_completed_requests > 0 - else 1e-5 - ) - - report = _compute_metrics_statistics(request_metrics) - report["num_gpus"] = num_gpus - report["duration"] = duration - report["num_total_requests"] = num_total_requests - report["num_completed_requests"] = num_completed_requests - report["request_throughput"] = num_completed_requests / duration - - # Generate the server metrics statistics - server_metrics = [metric.server_metrics for metric in request_metrics if metric.server_metrics] - server_report = _compute_metrics_statistics(server_metrics) - if server_report is not None and len(server_report) > 0: - report["server_metrics"] = server_report - - report = { - "exec_feature": ( - request_records[0].metrics.exec_feature if num_completed_requests > 0 else None - ), - **report, - } - return report - - -def _compute_metrics_statistics(metrics: List[Union[Metrics, ServerMetrics]]) -> Dict[str, Any]: - """ - Compute the statistics of the metrics. - - Parameters - ---------- - metrics : List[Union[Metrics, ServerMetrics]] - The list of metrics to get the statistics. - - Returns - ------- - report : Dict - The statistics of the metrics. - """ - if not metrics: - return {} - - report: Dict = {} - df = pd.DataFrame([metric.model_dump() for metric in metrics]) - for key, _ in metrics[0].model_fields.items(): - if key in ["success", "start_time", "finish_time", "server_metrics", "exec_feature"]: - continue - if key in ["end_to_end_latency_s", "input_tokens"]: - if key in df.columns: - series = df[key].dropna() - report[key] = { - "quantiles": { - f"p{int(q * 100)}": v - for q, v in series.quantile([0.25, 0.5, 0.75, 0.9, 0.95, 0.99]).items() - }, - "mean": series.mean(), - "min": series.min(), - "max": series.max(), - "stddev": series.std(), - } - return report - - -def convert_reports_to_df(reports: List[Dict[str, Any]]) -> pd.DataFrame: - """Convert benchmark reports to pandas DataFrame.""" - - def _flatten_dict(d: Dict[str, Any], parent_key: str = "") -> Dict[str, Any]: - items: List[Tuple[str, Any]] = [] - for key, value in d.items(): - new_key = f"{parent_key}.{key}" if parent_key != "" else key - if isinstance(value, dict): - items.extend(_flatten_dict(value, new_key).items()) - else: - items.append((new_key, value)) - return dict(items) - - return pd.DataFrame([_flatten_dict(report) for report in reports]) - - -def pretty_print_report(report: Dict[str, Any]) -> None: # pylint: disable=too-many-statements - """Pretty print the metrics report.""" - - def _print(report: Dict[str, Any], server_metrics: bool): # pylint: disable=too-many-statements - # pylint: disable=line-too-long - # fmt: off - title = "Benchmark Result" - if server_metrics: - title += " (server side)" - print(f" {title} ".center(50, "=")) - if not server_metrics: - print(f"{'Total requests:':<40} {report['num_total_requests']:<10}") - print(f"{'Completed requests:':<40} {report['num_completed_requests']:<10}") - print(f"{'Duration (s):':<40} {report['duration']:<10.2f}") - print(f"{'Num GPUs:':<40} {report['num_gpus']:<10}") - if report["num_completed_requests"] == 0: - return - - - e2e_latency = report["end_to_end_latency_s"] - print(" End-to-End Latency (ms) ".center(50, "-")) - print(f"{'Mean:':<40} {e2e_latency['mean'] * 1000:<10.2f}") - print(f"{'Stddev:':<40} {e2e_latency['stddev'] * 1000:<10.2f}") - print(f"{'P25:':<40} {e2e_latency['quantiles']['p25'] * 1000:<10.2f}") - print(f"{'P50:':<40} {e2e_latency['quantiles']['p50'] * 1000:<10.2f}") - print(f"{'P75:':<40} {e2e_latency['quantiles']['p75'] * 1000:<10.2f}") - print(f"{'P90:':<40} {e2e_latency['quantiles']['p90'] * 1000:<10.2f}") - print(f"{'P95:':<40} {e2e_latency['quantiles']['p95'] * 1000:<10.2f}") - print(f"{'P99:':<40} {e2e_latency['quantiles']['p99'] * 1000:<10.2f}") - print(f"{'Min:':<40} {e2e_latency['min'] * 1000:<10.2f}") - print(f"{'Max:':<40} {e2e_latency['max'] * 1000:<10.2f}") - - input_tokens = report["input_tokens"] - print(" Input Tokens ".center(50, "-")) - print(f"{'Mean:':<40} {input_tokens['mean']:<1}") - print(f"{'Stddev:':<40} {input_tokens['stddev']:<1}") - print(f"{'P25:':<40} {input_tokens['quantiles']['p25']:<1}") - print(f"{'P50:':<40} {input_tokens['quantiles']['p50']:<1}") - print(f"{'P95:':<40} {input_tokens['quantiles']['p95']:<1}") - print(f"{'Min:':<40} {input_tokens['min']:<1}") - print(f"{'Max:':<40} {input_tokens['max']:<1}") - - print("=" * 50) - - # fmt: on - # pylint: enable=line-too-long - _print(report, server_metrics=False) - if "server_metrics" in report: - _print(report["server_metrics"], server_metrics=True) diff --git a/python/mlc_llm/protocol/conversation_protocol.py b/python/mlc_llm/protocol/conversation_protocol.py index 0bb9ebb251..ad19460ed5 100644 --- a/python/mlc_llm/protocol/conversation_protocol.py +++ b/python/mlc_llm/protocol/conversation_protocol.py @@ -82,7 +82,7 @@ class Conversation(BaseModel): # whether using function calling or not, helps check for output message format in API call use_function_calling: bool = False # Tool function call format mode - tool_call_format: str = "json" + _tool_call_format: str = "json" def __init__(self, role_templates: Optional[Dict[str, str]] = None, **kwargs): # Defaults templates which would be overridden by model specific templates @@ -200,37 +200,44 @@ def as_prompt(self, config=None) -> List[Any]: def set_tool_call_format_in_system_message(self): """Add tool function information and call format to the system message.""" - if self.tool_call_format == "xml": + if self._tool_call_format == "xml": tool_call_instruct = ( "Tool Instructions:" - f"You have access to the following tool functions: {MessagePlaceholders.FUNCTION.value}" + "You have access to the following tool functions:" + f"{MessagePlaceholders.FUNCTION.value}" "If a you choose to call a function, you should ONLY reply in the following format:" - "`\n{parameters(JSON dict)}\n`" + "`{parameters(JSON dict)}`" "Here is an example," - '`\n{"location": "Pittsburgh"}\n`' + '`{"location": "Pittsburgh"}`' "Reminder:" "- Function calls MUST follow the specified format" "- Required parameters MUST be specified" + "- You should not repeat or miss the call" ) self.system_message += tool_call_instruct - elif self.tool_call_format == "json": + elif self._tool_call_format == "json": tool_call_instruct = ( "Tool Instructions:" - f"You have access to the following tool functions: {MessagePlaceholders.FUNCTION.value}" + "You have access to the following tool functions:" + f"{MessagePlaceholders.FUNCTION.value}" "If a you choose to call a function, you should ONLY reply in the following format:" - '`{"name": func_name, "parameters": parameters(JSON dict)}\n`' + '`{"name": func_name, "parameters": parameters(JSON dict)}`' "Here is an example," - '`{"name": "get_time", "parameters": {"location": "Pittsburgh"} }\n`' + '`{"name": "get_time", "parameters": {"location": "Pittsburgh"}}}}`' "Reminder:" "- Function calls MUST follow the specified format" "- Required parameters MUST be specified" + "- You should not repeat or miss the call" + "- You should response with at least one function calling" ) self.system_message += tool_call_instruct - elif self.tool_call_format == "python": + elif self._tool_call_format == "python": tool_call_instruct = ( "Tool Instructions:" - f"- You have access to the following tool functions: {MessagePlaceholders.FUNCTION.value}" + "- You have access to the following tool functions:" + f"{MessagePlaceholders.FUNCTION.value}" "- Required parameters MUST be specified" + "- You should not repeat or miss the call" ) self.system_message += tool_call_instruct else: diff --git a/python/mlc_llm/protocol/openai_api_protocol.py b/python/mlc_llm/protocol/openai_api_protocol.py index 8d1e9c7863..d11319acd2 100644 --- a/python/mlc_llm/protocol/openai_api_protocol.py +++ b/python/mlc_llm/protocol/openai_api_protocol.py @@ -86,33 +86,45 @@ class ModelResponse(BaseModel): class RequestResponseFormat(BaseModel): - type: Literal["text", "json_object", "structural_tag"] = "text" + type: Literal["text", "json_object", "json_schema", "structural_tag"] = "text" """This field is named json_schema instead of schema because BaseModel defines a method called schema. During construction of RequestResponseFormat, key "schema" still should be used: - `RequestResponseFormat(type="json_object", schema="{}")` + `RequestResponseFormat(type="json_schema", schema="{}")` """ json_schema: Optional[str] = Field(default=None, alias="schema") """These field are only used for type="structural_tag".""" - tags: Optional[List[Dict[str, str]]] = Field(default=None, alias="tags") - triggers: Optional[List[str]] = Field(default=None, alias="triggers") + tags: Optional[List[Dict[str, str]]] = None + triggers: Optional[List[str]] = None @model_validator(mode="after") def check_request_response_format(self) -> "RequestResponseFormat": """Check if the RequestResponseFormat is valid.""" - if self.type == "structural_tag": + if self.type in ["text", "json_object"]: + if self.json_schema is not None: + raise Warning("'json_schema' should be set in 'json_schema' type.") + if self.tags is not None or self.triggers is not None: + raise Warning( + "'tags' and 'triggers' attributes should be used when type='structural_tag'" + ) + elif self.type == "json_schema": + if self.json_schema is None: + raise ValueError("'json_schema' should be set in 'json_schema' type.") + if self.tags is not None or self.triggers is not None: + raise Warning( + "'tags' and 'triggers' attributes should be used when type='structural_tag'" + ) + elif self.type == "structural_tag": if self.tags is None or self.triggers is None: raise ValueError("structural_tag type must contain keys 'tags' and 'triggers'.") - for tag in self.tags: + for tag in self.tags: # pylint: disable=not-an-iterable if set(tag.keys()) != {"begin", "schema", "end"}: raise ValueError( - f"Each tag must contain exactly 'begin', 'schema' and 'end' keys. Got keys: {list(tag.keys())}." + "Each tag must contain exactly 'begin', 'schema' and 'end' keys." + f"Got keys: {list(tag.keys())}." ) - elif self.tags is not None or self.triggers is not None: - raise Warning( - "'tags' and 'triggers' attributes should be used when type='structural_tag'" - ) - + if self.json_schema is not None: + raise Warning("'json_schema' should be set in 'json_schema' type.") return self diff --git a/python/mlc_llm/serve/config.py b/python/mlc_llm/serve/config.py index 058952ca5c..ff8a7e69d7 100644 --- a/python/mlc_llm/serve/config.py +++ b/python/mlc_llm/serve/config.py @@ -129,15 +129,15 @@ class EngineConfig: # pylint: disable=too-many-instance-attributes "chunked" means the basic prefill with chunked input enabled. "hybrid" means the hybrid prefill or split-fuse, so that decode step will be converted into prefill. - - tool_call_format : Literal["xml", "json", "python"] + + tool_call_format : Literal["json", "xml", "python"] The tool function call foramt. - "xml" means model will call tool function in xml style format - '\n{parameters(JSON dict)}\n', - e.g. '\n{"location": "Pittsburgh"}\n'. - "json" means model will call tool function in json style format + "json" means model will call tool function in json style format '{"name": func_name, "parameters": parameters(JSON dict)}', e.g. '{"name": "get_time", "parameters": {"location": "Pittsburgh"}}'. + "xml" means model will call tool function in xml style format + '{parameters(JSON dict)}', + e.g. '{"location": "Pittsburgh"}'. "python" means model will call tool function in python-style format, e.g. 'wolfram_alpha.call(query="solve x^3 - 4x^2 + 6x - 24 = 0")'. @@ -168,7 +168,7 @@ class EngineConfig: # pylint: disable=too-many-instance-attributes prefix_cache_mode: Literal["disable", "radix"] = "radix" prefix_cache_max_num_recycling_seqs: Optional[int] = None prefill_mode: Literal["chunked", "hybrid"] = "hybrid" - tool_call_format: Literal["xml", "json", "python"] = "json" + tool_call_format: Literal["json", "xml", "python"] = "json" verbose: bool = True def asjson(self) -> str: diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index dc504c20ac..61118f8283 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -9,7 +9,6 @@ import queue import sys import threading -import re from dataclasses import dataclass from pathlib import Path from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union @@ -131,8 +130,9 @@ def _convert_model_info(model: ModelInfo) -> Tuple[str, str]: if conversation is None: conversation = mlc_chat_config.conv_template - conversation.tool_call_format = engine_config.tool_call_format - + conversation._tool_call_format = ( # pylint: disable=protected-access + engine_config.tool_call_format + ) if model.model_lib is not None: # do model lib search if the model lib is provided # error out if file not found @@ -1149,7 +1149,7 @@ def create_completion_suffix_response( return response -def set_structural_tag_from_tools( +def set_structural_tag_from_tools( # pylint: disable=too-many-branches,too-many-boolean-expressions tools: Optional[List[openai_api_protocol.ChatTool]], response_format: Optional[openai_api_protocol.RequestResponseFormat], tool_choice: Optional[Union[Literal["none", "auto"], Dict]], @@ -1170,9 +1170,9 @@ def set_structural_tag_from_tools( response_format.tags = [] response_format.triggers = [] - if tool_call_format == "xml": - begin_format = "\n" - end = "\n" + if tool_call_format == "json": + begin_format = '{{"name": "{func_name}", "parameters":' + end = "}" for tool in tools: if tool.function.strict and ( tool_choice is None @@ -1194,10 +1194,11 @@ def set_structural_tag_from_tools( "end": end, } ) - response_format.triggers.append(" List[Union[Dict, None]]: """Convert a (possibly list) of function call string to a list of json objects. Return None for invalid function call string.""" function_calls_json = [] - if tool_call_format == "xml": - # tool calling in format `\n{PARA}\n` - pattern = r"\n(.*?)\n" - matches = re.findall(pattern, stringified_calls, re.DOTALL) - for func_name, args_str in matches: - args: Dict = json.loads(args_str) - function_calls_json.append({"name": func_name, "arguments": args}) - elif tool_call_format == "json": + + if tool_call_format == "json": # tool calling in format `{"name": func_name, "parameters": parameters(JSON dict)}` - starts = [-1] + start = 0 while True: - index = stringified_calls.find('{"name":', starts[-1] + 1) + index = stringified_calls.find('{"name":', start) if index == -1: break - else: - starts.append(index) - starts.append(len(stringified_calls)) - for i in range(1, len(starts) - 1): - cnt = 1 - quote = False - for j in range(starts[i] + 1, starts[i + 1]): - if stringified_calls[j] == '"': - quote = not quote - elif not quote: - if stringified_calls[j] == "{": - cnt += 1 - elif stringified_calls[j] == "}": - cnt -= 1 - if cnt == 0: - func_call: Dict = json.loads(stringified_calls[starts[i] : j + 1]) - if "name" not in func_call or "parameters" not in func_call: - raise ValueError("Invalid function call output.") - if not isinstance(func_call["name"], str) or not isinstance(func_call["parameters"], dict): - raise ValueError("Invalid function call output type.") - function_calls_json.append( - {"name": func_call["name"], "arguments": func_call["parameters"]} - ) - break + try: + decoder = json.JSONDecoder() + result, end_index = decoder.raw_decode(stringified_calls, index) + except: # pylint: disable=bare-except + start = index + 1 + continue + start = end_index + if not isinstance(result, dict) or "name" not in result or "parameters" not in result: + continue + function_calls_json.append({"name": result["name"], "arguments": result["parameters"]}) + + elif tool_call_format == "xml": + # tool calling in format `{PARA}` + start = 0 + while True: + begin_start = stringified_calls.find("", begin_start) + if begin_end == -1: + break + end_start = stringified_calls.find("", begin_end) + if end_start == -1: + break + start = end_start + len("") + + func_name = stringified_calls[begin_start + len(" Date: Wed, 16 Apr 2025 00:38:06 +0800 Subject: [PATCH 14/17] [fix] remove CI check for other branch --- .github/workflows/documentation.yaml | 2 -- .github/workflows/windows-build.yaml | 2 -- 2 files changed, 4 deletions(-) diff --git a/.github/workflows/documentation.yaml b/.github/workflows/documentation.yaml index 9b0fc4eaee..6ec3492e2f 100644 --- a/.github/workflows/documentation.yaml +++ b/.github/workflows/documentation.yaml @@ -4,8 +4,6 @@ on: push: branches: - main - - tool_call - - eval jobs: test_linux: diff --git a/.github/workflows/windows-build.yaml b/.github/workflows/windows-build.yaml index a9b10039e2..560d2f275c 100644 --- a/.github/workflows/windows-build.yaml +++ b/.github/workflows/windows-build.yaml @@ -7,8 +7,6 @@ on: push: branches: - main - - tool_call - - eval pull_request: branches: - main From 7e4da8f1e6c7a21053302c55ec449b60b72ecb00 Mon Sep 17 00:00:00 2001 From: irfnfnkemed Date: Wed, 16 Apr 2025 01:14:05 +0800 Subject: [PATCH 15/17] [fix] missing type check for RequestFormat type --- cpp/serve/config.cc | 2 +- cpp/serve/engine.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/serve/config.cc b/cpp/serve/config.cc index aa14a6611d..3be0b815bd 100644 --- a/cpp/serve/config.cc +++ b/cpp/serve/config.cc @@ -43,7 +43,7 @@ Result ResponseFormat::FromJSON(const picojson::object& config) res.type = json::LookupOrDefault(config, "type", "text"); if (res.type != "text" && res.type != "function" && res.type != "json_object" && - res.type != "structural_tag") { + res.type != "json_schema" && res.type != "structural_tag") { return TResult::Error("Uknonwn response_format type " + res.type); } diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 7895e6c860..c2b967ecf3 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -993,7 +993,7 @@ class EngineImpl : public Engine { return grammar_compiler_.CompileBuiltinJSONGrammar(); } else if (response_format.type == "json_schema") { return grammar_compiler_.CompileJSONSchema(response_format.schema.value()); - } else { + } else if (response_format.type == "structural_tag") { std::vector tags; std::vector triggers; for (auto tag : response_format.tags.value()) { From 3ddacf06943f2db4bac16c52798b65bfe67c1008 Mon Sep 17 00:00:00 2001 From: Irfnfnkemed <119502078+Irfnfnkemed@users.noreply.github.com> Date: Wed, 16 Apr 2025 02:13:32 +0800 Subject: [PATCH 16/17] Update config.h --- cpp/serve/config.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/serve/config.h b/cpp/serve/config.h index 406ee28307..c73b7d0b50 100644 --- a/cpp/serve/config.h +++ b/cpp/serve/config.h @@ -249,7 +249,7 @@ class EngineConfigNode : public Object { * significantly smaller than this number. Under mode "server", the actual * memory usage may be slightly larger than this number. */ - float gpu_memory_utilization = 0.55; + float gpu_memory_utilization = 0.85; /*! \brief The number of consecutive tokens handled in each page in paged KV cache. */ int kv_cache_page_size = 16; /*! @@ -450,4 +450,4 @@ inline PrefillMode PrefillModeFromString(const std::string& prefill_mode) { } // namespace llm } // namespace mlc -#endif // MLC_LLM_SERVE_CONFIG_H_ \ No newline at end of file +#endif // MLC_LLM_SERVE_CONFIG_H_ From 9dace3a09ca444c56eaa6468e0043d04b3eb8385 Mon Sep 17 00:00:00 2001 From: irfnfnkemed Date: Fri, 25 Apr 2025 00:26:21 +0800 Subject: [PATCH 17/17] [fix] remove unexpected prompt when don't use func-call --- .../mlc_llm/protocol/conversation_protocol.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/python/mlc_llm/protocol/conversation_protocol.py b/python/mlc_llm/protocol/conversation_protocol.py index ad19460ed5..14adcb8dd1 100644 --- a/python/mlc_llm/protocol/conversation_protocol.py +++ b/python/mlc_llm/protocol/conversation_protocol.py @@ -126,7 +126,8 @@ def as_prompt(self, config=None) -> List[Any]: from ..serve import data # pylint: disable=import-outside-toplevel # - Get the system message. - self.set_tool_call_format_in_system_message() + if self.use_function_calling: + self.set_tool_call_format_in_system_message() system_msg = self.system_template.replace( MessagePlaceholders.SYSTEM.value, self.system_message ) @@ -200,35 +201,35 @@ def as_prompt(self, config=None) -> List[Any]: def set_tool_call_format_in_system_message(self): """Add tool function information and call format to the system message.""" - if self._tool_call_format == "xml": + if self._tool_call_format == "json": tool_call_instruct = ( "Tool Instructions:" "You have access to the following tool functions:" f"{MessagePlaceholders.FUNCTION.value}" "If a you choose to call a function, you should ONLY reply in the following format:" - "`{parameters(JSON dict)}`" + '`{"name": func_name, "parameters": parameters(JSON dict)}`' "Here is an example," - '`{"location": "Pittsburgh"}`' + '`{"name": "get_time", "parameters": {"location": "Pittsburgh"}}}}`' "Reminder:" "- Function calls MUST follow the specified format" "- Required parameters MUST be specified" "- You should not repeat or miss the call" + "- You should response with at least one function calling" ) self.system_message += tool_call_instruct - elif self._tool_call_format == "json": + elif self._tool_call_format == "xml": tool_call_instruct = ( "Tool Instructions:" "You have access to the following tool functions:" f"{MessagePlaceholders.FUNCTION.value}" "If a you choose to call a function, you should ONLY reply in the following format:" - '`{"name": func_name, "parameters": parameters(JSON dict)}`' + "`{parameters(JSON dict)}`" "Here is an example," - '`{"name": "get_time", "parameters": {"location": "Pittsburgh"}}}}`' + '`{"location": "Pittsburgh"}`' "Reminder:" "- Function calls MUST follow the specified format" "- Required parameters MUST be specified" "- You should not repeat or miss the call" - "- You should response with at least one function calling" ) self.system_message += tool_call_instruct elif self._tool_call_format == "python":