From 712f1d5795bfc3a115bb93250d7f1c57650307bb Mon Sep 17 00:00:00 2001 From: irfnfnkemed Date: Fri, 14 Mar 2025 12:22:34 +0800 Subject: [PATCH 1/6] [Grammar] Upgrade xgrammar to latest version - upgrade xgrammar calling to latest API --- cpp/serve/engine.cc | 14 ++++++++------ cpp/serve/request_state.cc | 4 ++-- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 2f09219392..3f54878b6b 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -463,9 +463,11 @@ class EngineImpl : public Engine { ModelWorkspace{model->AllocEmbeddingTensor(), model->AllocHiddenStatesTensor()}); } // - Initialize tokenizer and grammar + n->tokenizer_ = Tokenizer::FromPath(engine_config->model, GetTokenizerInfo(model_configs[0])); n->token_table_ = n->tokenizer_->PostProcessedTokenTable(); - n->cached_grammar_compiler_ = xgrammar::CachedGrammarCompiler(n->token_table_); + // TODO: check 'vocab_size' of TokenizerInfo + n->grammar_compiler_ = xgrammar::GrammarCompiler(xgrammar::TokenizerInfo(n->token_table_)); // - Create the logit processor and sampler, and // the DraftTokenWorkspaceManager for speculative decoding. int max_num_tokens = engine_config->max_num_sequence; @@ -975,13 +977,13 @@ class EngineImpl : public Engine { * is not JSON, return std::nullopt. */ std::optional GetGrammarFromResponseFormat( const ResponseFormat& response_format) { + // TODO: add other grammar type if (response_format.type != "json_object") { return std::nullopt; } else if (!response_format.schema) { - return cached_grammar_compiler_.GetCompiledGrammarForJSON(); + return grammar_compiler_.CompileBuiltinJSONGrammar(); } else { - return cached_grammar_compiler_.GetCompiledGrammarForJSONSchema( - response_format.schema.value()); + return grammar_compiler_.CompileJSONSchema(response_format.schema.value()); } } @@ -992,8 +994,8 @@ class EngineImpl : public Engine { // internal tokenizer Tokenizer tokenizer_; std::vector token_table_; - // Cached grammar compiler for grammar matching. - xgrammar::CachedGrammarCompiler cached_grammar_compiler_; + // Grammar compiler for grammar matching. + xgrammar::GrammarCompiler grammar_compiler_; // Models Array models_; // Device that the models run on. diff --git a/cpp/serve/request_state.cc b/cpp/serve/request_state.cc index 17e02ee85b..4771f14d3b 100644 --- a/cpp/serve/request_state.cc +++ b/cpp/serve/request_state.cc @@ -24,7 +24,7 @@ RequestModelState::RequestModelState( if (compiled_grammar.has_value()) { // TODO(yixin): set rollback limit to a configurable value. n->grammar_matcher = - xgrammar::GrammarMatcher(compiled_grammar.value(), std::nullopt, false, std::nullopt, 10); + xgrammar::GrammarMatcher(compiled_grammar.value(), std::nullopt, false, 10); } n->request = std::move(request); @@ -44,7 +44,7 @@ bool RequestModelStateNode::RequireNextTokenBitmask() { return grammar_matcher.h void RequestModelStateNode::GetNextTokenBitmask(DLTensor* bitmask) { ICHECK(grammar_matcher.has_value()); - grammar_matcher->GetNextTokenBitmask(bitmask); + grammar_matcher->FillNextTokenBitmask(bitmask); } void RequestModelStateNode::CommitToken(SampleResult sampled_token) { From 6e37a0b6e2bbc346b3eae68adbc9a5bd48db053d Mon Sep 17 00:00:00 2001 From: irfnfnkemed Date: Thu, 20 Mar 2025 10:37:07 +0800 Subject: [PATCH 2/6] Updated submodule-xgrammar references --- 3rdparty/xgrammar | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/xgrammar b/3rdparty/xgrammar index d4f57c440f..f4badadfe7 160000 --- a/3rdparty/xgrammar +++ b/3rdparty/xgrammar @@ -1 +1 @@ -Subproject commit d4f57c440f3da8e7330a1e5d50bba9c31f9433ea +Subproject commit f4badadfe7363e4e09fafcde3c253a46dd5d6e97 From f70a37a4870c49d16ec328a85dfad01c83396d50 Mon Sep 17 00:00:00 2001 From: irfnfnkemed Date: Thu, 20 Mar 2025 11:05:46 +0800 Subject: [PATCH 3/6] [fix] modify cmake since xgrammar use nanobind to replace pybind --- 3rdparty/tvm | 2 +- CMakeLists.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 9c894f78fd..7752c9221c 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 9c894f78fdef156263ced19eed67e79203ca4a11 +Subproject commit 7752c9221c768617af01711f8ad155e0a1cd409e diff --git a/CMakeLists.txt b/CMakeLists.txt index a010a05192..99926be832 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -71,7 +71,7 @@ add_subdirectory(${TOKENZIER_CPP_PATH} tokenizers EXCLUDE_FROM_ALL) set(XGRAMMAR_PATH 3rdparty/xgrammar) tvm_file_glob(GLOB_RECURSE MLC_LLM_SRCS cpp/*.cc) tvm_file_glob(GLOB_RECURSE XGRAMMAR_SRCS ${XGRAMMAR_PATH}/cpp/*.cc) -list(FILTER XGRAMMAR_SRCS EXCLUDE REGEX "${XGRAMMAR_PATH}/cpp/pybind/.*\\.cc") +list(FILTER XGRAMMAR_SRCS EXCLUDE REGEX "${XGRAMMAR_PATH}/cpp/nanobind/.*\\.cc") list(APPEND MLC_LLM_SRCS ${XGRAMMAR_SRCS}) add_library(mlc_llm_objs OBJECT ${MLC_LLM_SRCS}) From 6e7426e8275270c30f138750e88dc0cdcdd3151f Mon Sep 17 00:00:00 2001 From: irfnfnkemed Date: Thu, 20 Mar 2025 11:43:25 +0800 Subject: [PATCH 4/6] [feature] Support tool function calling with Structural Tag with xgrammar - ensure the tool function will be called in expected format using xgrammar - modify RequestResponseFormat: add structural tag according to the tools when building response format - the tool function calling is now constrained by format: parameters - tools call list will be parsed according to the calling format when processing the response - also expose the Structural Tag api of xgrammar to RequestResponseFormat --- cpp/serve/config.cc | 55 ++++++++++++++++- cpp/serve/config.h | 3 + cpp/serve/engine.cc | 23 +++++-- cpp/serve/logit_processor.cc | 2 + .../mlc_llm/protocol/openai_api_protocol.py | 25 +++++++- python/mlc_llm/serve/engine.py | 14 +++++ python/mlc_llm/serve/engine_base.py | 60 +++++++++++++------ 7 files changed, 156 insertions(+), 26 deletions(-) diff --git a/cpp/serve/config.cc b/cpp/serve/config.cc index f7e71e72c9..22c431ba3e 100644 --- a/cpp/serve/config.cc +++ b/cpp/serve/config.cc @@ -10,11 +10,14 @@ #include #include +#include +#include #include "../json_ffi/openai_api_protocol.h" #include "../support/json_parser.h" #include "../support/utils.h" #include "data.h" +#include "tvm/runtime/container/array.h" namespace mlc { namespace llm { @@ -42,13 +45,43 @@ Result ResponseFormat::FromJSON(const picojson::object& config) ResponseFormat res; res.type = json::LookupOrDefault(config, "type", "text"); + if (res.type != "text" && res.type != "function" && res.type != "json_object" && + res.type != "structural_tag") { + return TResult::Error("Uknonwn response_format type " + res.type); + } + std::optional schema = json::LookupOptional(config, "schema"); if (schema.has_value()) { res.schema = schema.value(); } - if (res.type != "text" && res.type != "function" && res.type != "json_object") { - return TResult::Error("Uknonwn response_format type " + res.type); + if (auto tags_obj = json::LookupOptional(config, "tags")) { + auto tags = Array>(); + for (auto tag_obj : tags_obj.value()) { + Array tag = Array(); + std::optional begin = + json::LookupOptional(tag_obj.get(), "begin"); + std::optional schema = + json::LookupOptional(tag_obj.get(), "schema"); + std::optional end = + json::LookupOptional(tag_obj.get(), "end"); + if (!(begin.has_value() && schema.has_value() && end.has_value())) { + return TResult::Error("Miss tag attribute."); + } + tag.push_back(begin.value()); + tag.push_back(schema.value()); + tag.push_back(end.value()); + tags.push_back(std::move(tag)); + } + res.tags = tags; + } + + if (auto triggers_obj = json::LookupOptional(config, "triggers")) { + auto triggers = Array(); + for (auto trigger : triggers_obj.value()) { + triggers.push_back(trigger.get()); + } + res.triggers = triggers; } return TResult::Ok(res); @@ -60,6 +93,24 @@ picojson::object ResponseFormat::AsJSON() const { if (schema.defined()) { config["schema"] = picojson::value(schema.value().operator std::string()); } + if (tags.defined()) { + picojson::array tags_obj = picojson::array(); + for (auto tag : tags.value()) { + picojson::array tag_obj = picojson::array(); + tag_obj.emplace_back(tag[0]); + tag_obj.emplace_back(tag[1]); + tag_obj.emplace_back(tag[2]); + tags_obj.emplace_back(tag_obj); + } + config["tags"] = picojson::value(tags_obj); + } + if (triggers.defined()) { + picojson::array trigger_obj = picojson::array(); + for (std::string trigger : triggers.value()) { + trigger_obj.emplace_back(trigger); + } + config["triggers"] = picojson::value(trigger_obj); + } return config; } diff --git a/cpp/serve/config.h b/cpp/serve/config.h index 9da3ba2517..b48d981cd7 100644 --- a/cpp/serve/config.h +++ b/cpp/serve/config.h @@ -14,6 +14,7 @@ #include "../metadata/model.h" #include "../support/result.h" +#include "tvm/runtime/container/optional.h" namespace mlc { namespace llm { @@ -28,6 +29,8 @@ using namespace tvm::runtime; struct ResponseFormat { String type = "text"; Optional schema = NullOpt; + Optional>> tags = NullOpt; + Optional> triggers = NullOpt; /*! * \brief Create debug config from JSON. * \param config_json The json string for generation config diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 3f54878b6b..08a2ff2963 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -19,8 +19,10 @@ #include #include #include +#include #include #include +#include #include "../support/json_parser.h" #include "../support/result.h" @@ -35,6 +37,7 @@ #include "request.h" #include "request_state.h" #include "sampler/sampler.h" +#include "xgrammar/grammar.h" namespace mlc { namespace llm { @@ -978,12 +981,24 @@ class EngineImpl : public Engine { std::optional GetGrammarFromResponseFormat( const ResponseFormat& response_format) { // TODO: add other grammar type - if (response_format.type != "json_object") { + if (response_format.type == "text") { return std::nullopt; - } else if (!response_format.schema) { - return grammar_compiler_.CompileBuiltinJSONGrammar(); + } else if (response_format.type == "json_object") { + if (!response_format.schema) { + return grammar_compiler_.CompileBuiltinJSONGrammar(); + } else { + return grammar_compiler_.CompileJSONSchema(response_format.schema.value()); + } } else { - return grammar_compiler_.CompileJSONSchema(response_format.schema.value()); + std::vector tags; + std::vector triggers; + for (auto tag : response_format.tags.value()) { + tags.emplace_back(xgrammar::StructuralTagItem{tag[0], tag[1], tag[2]}); + } + for (auto trigger : response_format.triggers.value()) { + triggers.emplace_back(trigger); + } + return grammar_compiler_.CompileStructuralTag(std::move(tags), std::move(triggers)); } } diff --git a/cpp/serve/logit_processor.cc b/cpp/serve/logit_processor.cc index 3c9bf88717..aa3d41751a 100644 --- a/cpp/serve/logit_processor.cc +++ b/cpp/serve/logit_processor.cc @@ -12,6 +12,8 @@ #include #include +#include + namespace mlc { namespace llm { namespace serve { diff --git a/python/mlc_llm/protocol/openai_api_protocol.py b/python/mlc_llm/protocol/openai_api_protocol.py index cb2e1f2852..88147f12d3 100644 --- a/python/mlc_llm/protocol/openai_api_protocol.py +++ b/python/mlc_llm/protocol/openai_api_protocol.py @@ -86,12 +86,33 @@ class ModelResponse(BaseModel): class RequestResponseFormat(BaseModel): - type: Literal["text", "json_object"] = "text" - json_schema: Optional[str] = Field(default=None, alias="schema") + type: Literal["text", "json_object", "structural_tag"] = "text" """This field is named json_schema instead of schema because BaseModel defines a method called schema. During construction of RequestResponseFormat, key "schema" still should be used: `RequestResponseFormat(type="json_object", schema="{}")` """ + json_schema: Optional[str] = Field(default=None, alias="schema") + + """These field are only used for type="structural_tag".""" + tags: Optional[List[Dict[str, str]]] = Field(default=None, alias="tags") + triggers: Optional[List[str]] = Field(default=None, alias="triggers") + + @model_validator(mode="after") + def check_request_response_format(self) -> "RequestResponseFormat": + """Check if the RequestResponseFormat is valid.""" + if self.type == "structural_tag": + if self.tags is None or self.triggers is None: + raise ValueError("structural_tag type must contain keys 'tags' and 'triggers'.") + for tag in self.tags: + if set(tag.keys()) != {"begin", "schema", "end"}: + raise ValueError( + f"Each tag must contain exactly 'begin', 'schema' and 'end' keys. Got keys: {list(tag.keys())}." + ) + elif self.tags is not None or self.triggers is not None: + raise Warning( + "'tags' and 'triggers' attributes should be used when type='structural_tag'" + ) + return self class CompletionRequest(BaseModel): diff --git a/python/mlc_llm/serve/engine.py b/python/mlc_llm/serve/engine.py index 3d9d181b1f..0250db3211 100644 --- a/python/mlc_llm/serve/engine.py +++ b/python/mlc_llm/serve/engine.py @@ -976,6 +976,12 @@ async def _chat_completion( # pylint: disable=too-many-arguments,too-many-local if request_id is None: request_id = f"chatcmpl-{engine_utils.random_uuid()}" + tools = ( + [openai_api_protocol.ChatTool.model_validate(tool) for tool in tools] + if tools is not None + else None + ) + chatcmpl_generator = self._handle_chat_completion( openai_api_protocol.ChatCompletionRequest( messages=[ @@ -1207,6 +1213,10 @@ async def _handle_chat_completion( e : BadRequestError BadRequestError is raised when the request is invalid. """ + request.response_format = engine_base.set_structural_tag_from_tools( + request.tools, request.response_format + ) + ( prompts, generation_cfg, @@ -1764,6 +1774,10 @@ def _handle_chat_completion( e : BadRequestError BadRequestError is raised when the request is invalid. """ + request.response_format = engine_base.set_structural_tag_from_tools( + request.tools, request.response_format + ) + ( prompts, generation_cfg, diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index 1d5303e412..2b7178eb4c 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -7,6 +7,7 @@ import json import numbers import queue +import re import sys import threading from dataclasses import dataclass @@ -1146,29 +1147,52 @@ def create_completion_suffix_response( return response -def convert_function_str_to_json(stringified_calls: str) -> List[Union[Dict, None]]: +def convert_function_str_to_json(stringified_calls: str): """Convert a (possibly list) of function call string to a list of json objects. Return None for invalid function call string.""" + function_calls_json = [] + for call in re.finditer(r"(.*?)", stringified_calls, re.DOTALL): + function_name = call.group(1) + params_str = call.group(2).strip() + params = ast.literal_eval(params_str) + function_calls_json.append({"name": function_name, "arguments": params}) - def parse_function_call(call_str: str): - node = ast.parse(call_str, mode="eval") - call_node = node.body - if isinstance(call_node, ast.Call) and isinstance(call_node.func, ast.Name): - name = call_node.func.id - arguments = {} - for keyword in call_node.keywords: - arguments[keyword.arg] = ast.literal_eval(keyword.value) - return {"name": name, "arguments": arguments} - return None + return function_calls_json - if ( - stringified_calls[0] == "[" and stringified_calls[-1] == "]" - ): # hacky way to check if string list - calls = ast.literal_eval(stringified_calls) + +def set_structural_tag_from_tools( + tools: Optional[List[openai_api_protocol.ChatTool]], + response_format: Optional[openai_api_protocol.RequestResponseFormat], +): + """Add the corresponding structural tag to the response format according to the tools to ensure valid function calling. + Return the updated response format. + """ + if tools is None: + return response_format else: - calls = [stringified_calls] - function_calls_json = [parse_function_call(call_str) for call_str in calls] - return function_calls_json + if response_format is None or response_format.type == "text": + response_format = openai_api_protocol.RequestResponseFormat.model_validate( + {"type": "structural_tag", "tags": [], "triggers": []} + ) + elif response_format.type == "json_object": + response_format.tags = [] + response_format.triggers = [] + + response_format.triggers.append("", + "schema": json.dumps(schema), + "end": "", + } + ) + return response_format def process_function_call_output( From 4f980f3be8a5839e8501aa77feb98aaee36cfcaf Mon Sep 17 00:00:00 2001 From: irfnfnkemed Date: Mon, 24 Mar 2025 14:11:58 +0800 Subject: [PATCH 5/6] [feat] Add Structural-Tag api to RequestResponseFormat - Expose Structural-Tag api, which can be used to standarlize function calling format - Add test script for Structural-Tag (passed on Llama-2-7b-chat-hf-q0f16-MLC and Llama-3-8B-Instruct-q4f16_1-MLC) --- 3rdparty/xgrammar | 2 +- cpp/serve/config.cc | 5 +- cpp/serve/config.h | 3 +- cpp/serve/engine.cc | 3 - cpp/serve/logit_processor.cc | 2 - .../mlc_llm/protocol/openai_api_protocol.py | 1 + python/mlc_llm/serve/engine.py | 14 - python/mlc_llm/serve/engine_base.py | 60 +-- .../server/test_server_structural_tag.py | 429 ++++++++++++++++++ 9 files changed, 451 insertions(+), 68 deletions(-) create mode 100644 tests/python/serve/server/test_server_structural_tag.py diff --git a/3rdparty/xgrammar b/3rdparty/xgrammar index f4badadfe7..dbf200ecde 160000 --- a/3rdparty/xgrammar +++ b/3rdparty/xgrammar @@ -1 +1 @@ -Subproject commit f4badadfe7363e4e09fafcde3c253a46dd5d6e97 +Subproject commit dbf200ecde5dd5467c8320076ee60b1e248b23e0 diff --git a/cpp/serve/config.cc b/cpp/serve/config.cc index 22c431ba3e..aa14a6611d 100644 --- a/cpp/serve/config.cc +++ b/cpp/serve/config.cc @@ -10,14 +10,11 @@ #include #include -#include -#include #include "../json_ffi/openai_api_protocol.h" #include "../support/json_parser.h" #include "../support/utils.h" #include "data.h" -#include "tvm/runtime/container/array.h" namespace mlc { namespace llm { @@ -1124,4 +1121,4 @@ Result ModelsUseKVCache(const std::vector& model_configs } // namespace serve } // namespace llm -} // namespace mlc +} // namespace mlc \ No newline at end of file diff --git a/cpp/serve/config.h b/cpp/serve/config.h index b48d981cd7..f39e1911ba 100644 --- a/cpp/serve/config.h +++ b/cpp/serve/config.h @@ -14,7 +14,6 @@ #include "../metadata/model.h" #include "../support/result.h" -#include "tvm/runtime/container/optional.h" namespace mlc { namespace llm { @@ -451,4 +450,4 @@ inline PrefillMode PrefillModeFromString(const std::string& prefill_mode) { } // namespace llm } // namespace mlc -#endif // MLC_LLM_SERVE_CONFIG_H_ +#endif // MLC_LLM_SERVE_CONFIG_H_ \ No newline at end of file diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 08a2ff2963..0db12299a1 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -19,10 +19,8 @@ #include #include #include -#include #include #include -#include #include "../support/json_parser.h" #include "../support/result.h" @@ -37,7 +35,6 @@ #include "request.h" #include "request_state.h" #include "sampler/sampler.h" -#include "xgrammar/grammar.h" namespace mlc { namespace llm { diff --git a/cpp/serve/logit_processor.cc b/cpp/serve/logit_processor.cc index aa3d41751a..3c9bf88717 100644 --- a/cpp/serve/logit_processor.cc +++ b/cpp/serve/logit_processor.cc @@ -12,8 +12,6 @@ #include #include -#include - namespace mlc { namespace llm { namespace serve { diff --git a/python/mlc_llm/protocol/openai_api_protocol.py b/python/mlc_llm/protocol/openai_api_protocol.py index 88147f12d3..5b617810df 100644 --- a/python/mlc_llm/protocol/openai_api_protocol.py +++ b/python/mlc_llm/protocol/openai_api_protocol.py @@ -112,6 +112,7 @@ def check_request_response_format(self) -> "RequestResponseFormat": raise Warning( "'tags' and 'triggers' attributes should be used when type='structural_tag'" ) + return self diff --git a/python/mlc_llm/serve/engine.py b/python/mlc_llm/serve/engine.py index 0250db3211..3d9d181b1f 100644 --- a/python/mlc_llm/serve/engine.py +++ b/python/mlc_llm/serve/engine.py @@ -976,12 +976,6 @@ async def _chat_completion( # pylint: disable=too-many-arguments,too-many-local if request_id is None: request_id = f"chatcmpl-{engine_utils.random_uuid()}" - tools = ( - [openai_api_protocol.ChatTool.model_validate(tool) for tool in tools] - if tools is not None - else None - ) - chatcmpl_generator = self._handle_chat_completion( openai_api_protocol.ChatCompletionRequest( messages=[ @@ -1213,10 +1207,6 @@ async def _handle_chat_completion( e : BadRequestError BadRequestError is raised when the request is invalid. """ - request.response_format = engine_base.set_structural_tag_from_tools( - request.tools, request.response_format - ) - ( prompts, generation_cfg, @@ -1774,10 +1764,6 @@ def _handle_chat_completion( e : BadRequestError BadRequestError is raised when the request is invalid. """ - request.response_format = engine_base.set_structural_tag_from_tools( - request.tools, request.response_format - ) - ( prompts, generation_cfg, diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index 2b7178eb4c..1d5303e412 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -7,7 +7,6 @@ import json import numbers import queue -import re import sys import threading from dataclasses import dataclass @@ -1147,52 +1146,29 @@ def create_completion_suffix_response( return response -def convert_function_str_to_json(stringified_calls: str): +def convert_function_str_to_json(stringified_calls: str) -> List[Union[Dict, None]]: """Convert a (possibly list) of function call string to a list of json objects. Return None for invalid function call string.""" - function_calls_json = [] - for call in re.finditer(r"(.*?)", stringified_calls, re.DOTALL): - function_name = call.group(1) - params_str = call.group(2).strip() - params = ast.literal_eval(params_str) - function_calls_json.append({"name": function_name, "arguments": params}) - - return function_calls_json + def parse_function_call(call_str: str): + node = ast.parse(call_str, mode="eval") + call_node = node.body + if isinstance(call_node, ast.Call) and isinstance(call_node.func, ast.Name): + name = call_node.func.id + arguments = {} + for keyword in call_node.keywords: + arguments[keyword.arg] = ast.literal_eval(keyword.value) + return {"name": name, "arguments": arguments} + return None -def set_structural_tag_from_tools( - tools: Optional[List[openai_api_protocol.ChatTool]], - response_format: Optional[openai_api_protocol.RequestResponseFormat], -): - """Add the corresponding structural tag to the response format according to the tools to ensure valid function calling. - Return the updated response format. - """ - if tools is None: - return response_format + if ( + stringified_calls[0] == "[" and stringified_calls[-1] == "]" + ): # hacky way to check if string list + calls = ast.literal_eval(stringified_calls) else: - if response_format is None or response_format.type == "text": - response_format = openai_api_protocol.RequestResponseFormat.model_validate( - {"type": "structural_tag", "tags": [], "triggers": []} - ) - elif response_format.type == "json_object": - response_format.tags = [] - response_format.triggers = [] - - response_format.triggers.append("", - "schema": json.dumps(schema), - "end": "", - } - ) - return response_format + calls = [stringified_calls] + function_calls_json = [parse_function_call(call_str) for call_str in calls] + return function_calls_json def process_function_call_output( diff --git a/tests/python/serve/server/test_server_structural_tag.py b/tests/python/serve/server/test_server_structural_tag.py new file mode 100644 index 0000000000..5a7c93e4f5 --- /dev/null +++ b/tests/python/serve/server/test_server_structural_tag.py @@ -0,0 +1,429 @@ +# pylint: disable=line-too-long +""" +Test script for structural tag in chat completion. To run this script, use the following command: +- start a new shell session, run + mlc_llm serve --model "YOUR_MODEL" (e.g. ./dist/Llama-2-7b-chat-hf-q0f16-MLC) +- start another shell session, run this file + MLC_SERVE_MODEL="YOUR_MODEL" python tests/python/serve/server/test_server_structural_tag.py +""" + +# pylint: disable=missing-function-docstring,too-many-arguments,too-many-locals,too-many-branches +import json +import os +import re +from typing import Dict, List, Optional, Tuple + +import pytest +import requests + +OPENAI_V1_CHAT_COMPLETION_URL = "http://127.0.0.1:8000/v1/chat/completions" + + +def check_openai_nonstream_response( + response: Dict, + *, + model: str, + object_str: str, + num_choices: int, + finish_reason: List[str], + completion_tokens: Optional[int] = None, +): + assert response["model"] == model + assert response["object"] == object_str + + choices = response["choices"] + assert isinstance(choices, list) + assert len(choices) == num_choices + for idx, choice in enumerate(choices): + assert choice["index"] == idx + assert choice["finish_reason"] in finish_reason + + find_format_start = set() + beg_tag_start = set() + message = choice["message"]["content"] + print("Outputs:\n-----------") + print(message, flush=True) + pattern1 = r"(.*?)\|(.*?)\|End<---(.*?)>" + pattern2 = r"(.*?)\|(.*?)\|End<---(.*?)>" + # check format + for match in re.finditer(pattern1, message): + find_format_start.add(match.start()) + check_format(match.group(1), match.group(3), "CALL", match.group(2)) + for match in re.finditer(pattern2, message): + find_format_start.add(match.start()) + check_format(match.group(1), match.group(3), "call", match.group(2)) + for match in re.finditer(r"", message): + beg_tag_start.add(match.start()) + for match in re.finditer(r"", message): + beg_tag_start.add(match.start()) + assert find_format_start == beg_tag_start + + usage = response["usage"] + assert isinstance(usage, dict) + assert usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"] + assert usage["prompt_tokens"] > 0 + + if completion_tokens is not None: + assert usage["completion_tokens"] == completion_tokens + + +def check_openai_stream_response( + responses: List[Dict], + *, + model: str, + object_str: str, + num_choices: int, + finish_reason: str, + echo_prompt: Optional[str] = None, + suffix: Optional[str] = None, + stop: Optional[List[str]] = None, + require_substr: Optional[List[str]] = None, +): + assert len(responses) > 0 + + finished = [False for _ in range(num_choices)] + outputs = ["" for _ in range(num_choices)] + for response in responses: + assert response["model"] == model + assert response["object"] == object_str + + choices = response["choices"] + assert isinstance(choices, list) + assert len(choices) == num_choices + for idx, choice in enumerate(choices): + assert choice["index"] == idx + + delta = choice["delta"] + assert delta["role"] == "assistant" + assert isinstance(delta["content"], str) + outputs[idx] += delta["content"] + + if finished[idx]: + assert choice["finish_reason"] == finish_reason + elif choice["finish_reason"] is not None: + assert choice["finish_reason"] == finish_reason + finished[idx] = True + + for output in outputs: + if echo_prompt is not None: + assert output.startswith(echo_prompt) + if suffix is not None: + assert output.endswith(suffix) + if stop is not None: + for stop_str in stop: + assert stop_str not in output + if require_substr is not None: + for substr in require_substr: + assert substr in output + find_format_start = set() + beg_tag_start = set() + print("Outputs:\n-----------") + print(output, flush=True) + pattern1 = r"(.*?)\|(.*?)\|End<---(.*?)>" + pattern2 = r"(.*?)\|(.*?)\|End<---(.*?)>" + # check format + for match in re.finditer(pattern1, output): + find_format_start.add(match.start()) + check_format(match.group(1), match.group(3), "CALL", match.group(2)) + for match in re.finditer(pattern2, output): + find_format_start.add(match.start()) + check_format(match.group(1), match.group(3), "call", match.group(2)) + for match in re.finditer(r"", output): + beg_tag_start.add(match.start()) + for match in re.finditer(r"", output): + beg_tag_start.add(match.start()) + assert find_format_start == beg_tag_start + + +def check_format(name_beg: str, name_end: str, beg_tag: str, schema: str): + schema = json.loads(schema) + assert "hash_code" in schema + hash_code = schema["hash_code"] + assert hash_code in CHECK_INFO + info = CHECK_INFO[hash_code] + assert name_beg == info["name"] + assert name_end == info["name"] + assert beg_tag == info["beg_tag"] + for key in info["required"]: + assert key in schema + + +# NOTE: the end-tag format and the hash_code number is been hidden in the SYSTEM_PROMPT. +# By checking whether the end tag and hash code can be generated correctly without any prompts, the correctness of the structural tag can be verified. + +SYSTEM_PROMPT = { + "role": "system", + "content": """ +# Tool Instructions +- Always execute python code in messages that you share. +- When looking for real time information use relevant functions if available else fallback to brave_search +You have access to the following functions: +Use the function 'get_current_weather' to: Get the current weather in a given location +{ + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city to find the weather for, e.g. 'San Francisco'", + }, + "state": { + "type": "string", + "description": "the two-letter abbreviation for the state that the city is" + " in, e.g. 'CA' which would mean 'California'", + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"], + }, + "hash_code": { + "type": "string", + }, + }, + "required": ["city", "state", "unit", "hash_code"], + }, +} +Use the function 'get_current_date' to: Get the current date and time for a given timezone +{ + "name": "get_current_date", + "description": "Get the current date and time for a given timezone", + "parameters": { + "type": "object", + "properties": { + "timezone": { + "type": "string", + "description": "The timezone to fetch the current date and time for, e.g. 'America/New_York'", + }, + "hash_code": { + "type": "string", + }, + }, + "required": ["timezone", "hash_code"], + }, +} +If a you choose to call a function ONLY reply in the following format: +<{start_tag}--->{function_name}|{parameters}|{end_tag}<---{function_name}> +where +start_tag => ` a JSON dict with the function argument name as key and function argument value as value. +Here is an example, +example_function_name|{"example_name": "example_value"}... +or +example_function_name|{"example_name": "example_value"}... +Reminder: +- Function calls MUST follow the specified format +- Required parameters MUST be specified +You are a helpful assistant.""", +} + + +STRUCTURAL_TAGS = { + "triggers": ["", ""], + "tags": [ + { + "begin": "get_current_weather|", + "schema": json.dumps( + { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city to find the weather for, e.g. 'San Francisco'", + }, + "state": { + "type": "string", + "description": "the two-letter abbreviation for the state that the city is" + " in, e.g. 'CA' which would mean 'California'", + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"], + }, + "hash_code": {"const": 1234}, + }, + "required": ["city", "state", "unit", "hash_code"], + } + ), + "end": "|End<---get_current_weather>", + }, + { + "begin": "get_current_date|", + "schema": json.dumps( + { + "type": "object", + "properties": { + "timezone": { + "type": "string", + "description": "The timezone to fetch the current date and time for, e.g. 'America/New_York'", + }, + "hash_code": {"const": 2345}, + }, + "required": ["timezone", "hash_code"], + } + ), + "end": "|End<---get_current_date>", + }, + { + "begin": "get_current_weather|", + "schema": json.dumps( + { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city to find the weather for, e.g. 'San Francisco'", + }, + "state": { + "type": "string", + "description": "the two-letter abbreviation for the state that the city is" + " in, e.g. 'CA' which would mean 'California'", + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"], + }, + "hash_code": {"const": 3456}, + }, + "required": ["city", "state", "unit", "hash_code"], + } + ), + "end": "|End<---get_current_weather>", + }, + { + "begin": "get_current_date|", + "schema": json.dumps( + { + "type": "object", + "properties": { + "timezone": { + "type": "string", + "description": "The timezone to fetch the current date and time for, e.g. 'America/New_York'", + }, + "hash_code": {"const": 4567}, + }, + "required": ["timezone", "hash_code"], + } + ), + "end": "|End<---get_current_date>", + }, + ], +} + +CHECK_INFO = { + 1234: { + "name": "get_current_weather", + "beg_tag": "CALL", + "required": ["city", "state", "unit", "hash_code"], + }, + 2345: { + "name": "get_current_date", + "beg_tag": "CALL", + "required": ["timezone", "hash_code"], + }, + 3456: { + "name": "get_current_weather", + "beg_tag": "call", + "required": ["city", "state", "unit", "hash_code"], + }, + 4567: { + "name": "get_current_date", + "beg_tag": "call", + "required": ["timezone", "hash_code"], + }, +} + + +CHAT_COMPLETION_MESSAGES = [ + # messages #0 + [ + SYSTEM_PROMPT, + { + "role": "user", + "content": "You are in New York. Please get the current date and time.", + }, + ], + # messages #1 + [ + SYSTEM_PROMPT, + { + "role": "user", + "content": "You are in New York. Please get the current weather.", + }, + ], + # messages #2 + [ + SYSTEM_PROMPT, + { + "role": "user", + "content": "You are in New York. Please get the current date and time, and the weather.", + }, + ], +] + + +@pytest.mark.parametrize("stream", [False, True]) +@pytest.mark.parametrize("messages", CHAT_COMPLETION_MESSAGES) +def test_openai_v1_chat_completion_structural_tag( + served_model: str, + launch_server, # pylint: disable=unused-argument + stream: bool, + messages: List[Dict[str, str]], +): + # `served_model` and `launch_server` are pytest fixtures + # defined in conftest.py. + + payload = { + "model": served_model, + "messages": messages, + "stream": stream, + "response_format": { + "type": "structural_tag", + "tags": STRUCTURAL_TAGS["tags"], + "triggers": STRUCTURAL_TAGS["triggers"], + }, + "max_tokens": 1024, + } + + response = requests.post(OPENAI_V1_CHAT_COMPLETION_URL, json=payload, timeout=60) + if not stream: + check_openai_nonstream_response( + response.json(), + model=served_model, + object_str="chat.completion", + num_choices=1, + finish_reason=["stop"], + ) + else: + responses = [] + for chunk in response.iter_lines(chunk_size=512): + if not chunk or chunk == b"data: [DONE]": + continue + responses.append(json.loads(chunk.decode("utf-8")[6:])) + check_openai_stream_response( + responses, + model=served_model, + object_str="chat.completion.chunk", + num_choices=1, + finish_reason="stop", + ) + + print(f"-----------\nCheck for stream={stream} is passed!\n") + + +if __name__ == "__main__": + MODEL = os.environ.get("MLC_SERVE_MODEL") + if MODEL is None: + raise ValueError( + 'Environment variable "MLC_SERVE_MODEL" not found. ' + "Please set it to model compiled by MLC LLM " + "(e.g., `./dist/Llama-2-7b-chat-hf-q0f16-MLC`) " + ) + + for msg in CHAT_COMPLETION_MESSAGES: + test_openai_v1_chat_completion_structural_tag(MODEL, None, stream=False, messages=msg) + test_openai_v1_chat_completion_structural_tag(MODEL, None, stream=True, messages=msg) From 8b70dd7565c7ae0bdc8f05e090bd5244acf2bbfc Mon Sep 17 00:00:00 2001 From: irfnfnkemed Date: Mon, 24 Mar 2025 15:56:33 +0800 Subject: [PATCH 6/6] [fix] type annotation in test scripts --- python/mlc_llm/protocol/openai_api_protocol.py | 3 ++- .../serve/server/test_server_structural_tag.py | 15 +++++++++------ 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/python/mlc_llm/protocol/openai_api_protocol.py b/python/mlc_llm/protocol/openai_api_protocol.py index 5b617810df..eb6d6c7e7e 100644 --- a/python/mlc_llm/protocol/openai_api_protocol.py +++ b/python/mlc_llm/protocol/openai_api_protocol.py @@ -106,7 +106,8 @@ def check_request_response_format(self) -> "RequestResponseFormat": for tag in self.tags: if set(tag.keys()) != {"begin", "schema", "end"}: raise ValueError( - f"Each tag must contain exactly 'begin', 'schema' and 'end' keys. Got keys: {list(tag.keys())}." + "Each tag must contain exactly 'begin', 'schema' and 'end' keys." + f"Got keys: {list(tag.keys())}." ) elif self.tags is not None or self.triggers is not None: raise Warning( diff --git a/tests/python/serve/server/test_server_structural_tag.py b/tests/python/serve/server/test_server_structural_tag.py index 5a7c93e4f5..f34df96d31 100644 --- a/tests/python/serve/server/test_server_structural_tag.py +++ b/tests/python/serve/server/test_server_structural_tag.py @@ -11,7 +11,7 @@ import json import os import re -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional import pytest import requests @@ -136,16 +136,21 @@ def check_openai_stream_response( def check_format(name_beg: str, name_end: str, beg_tag: str, schema: str): - schema = json.loads(schema) + try: + paras: Dict[str, Any] = json.loads(schema) + except json.JSONDecodeError as e: + print(f"Invalid JSON format: {e}") + assert False + assert "hash_code" in paras assert "hash_code" in schema - hash_code = schema["hash_code"] + hash_code = paras["hash_code"] assert hash_code in CHECK_INFO info = CHECK_INFO[hash_code] assert name_beg == info["name"] assert name_end == info["name"] assert beg_tag == info["beg_tag"] for key in info["required"]: - assert key in schema + assert key in paras # NOTE: the end-tag format and the hash_code number is been hidden in the SYSTEM_PROMPT. @@ -219,7 +224,6 @@ def check_format(name_beg: str, name_end: str, beg_tag: str, schema: str): You are a helpful assistant.""", } - STRUCTURAL_TAGS = { "triggers": ["", ""], "tags": [ @@ -337,7 +341,6 @@ def check_format(name_beg: str, name_end: str, beg_tag: str, schema: str): }, } - CHAT_COMPLETION_MESSAGES = [ # messages #0 [