diff --git a/3rdparty/tvm b/3rdparty/tvm index 9c894f78fd..7752c9221c 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 9c894f78fdef156263ced19eed67e79203ca4a11 +Subproject commit 7752c9221c768617af01711f8ad155e0a1cd409e diff --git a/3rdparty/xgrammar b/3rdparty/xgrammar index d4f57c440f..dbf200ecde 160000 --- a/3rdparty/xgrammar +++ b/3rdparty/xgrammar @@ -1 +1 @@ -Subproject commit d4f57c440f3da8e7330a1e5d50bba9c31f9433ea +Subproject commit dbf200ecde5dd5467c8320076ee60b1e248b23e0 diff --git a/CMakeLists.txt b/CMakeLists.txt index a010a05192..99926be832 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -71,7 +71,7 @@ add_subdirectory(${TOKENZIER_CPP_PATH} tokenizers EXCLUDE_FROM_ALL) set(XGRAMMAR_PATH 3rdparty/xgrammar) tvm_file_glob(GLOB_RECURSE MLC_LLM_SRCS cpp/*.cc) tvm_file_glob(GLOB_RECURSE XGRAMMAR_SRCS ${XGRAMMAR_PATH}/cpp/*.cc) -list(FILTER XGRAMMAR_SRCS EXCLUDE REGEX "${XGRAMMAR_PATH}/cpp/pybind/.*\\.cc") +list(FILTER XGRAMMAR_SRCS EXCLUDE REGEX "${XGRAMMAR_PATH}/cpp/nanobind/.*\\.cc") list(APPEND MLC_LLM_SRCS ${XGRAMMAR_SRCS}) add_library(mlc_llm_objs OBJECT ${MLC_LLM_SRCS}) diff --git a/cpp/serve/config.cc b/cpp/serve/config.cc index f7e71e72c9..aa14a6611d 100644 --- a/cpp/serve/config.cc +++ b/cpp/serve/config.cc @@ -42,13 +42,43 @@ Result ResponseFormat::FromJSON(const picojson::object& config) ResponseFormat res; res.type = json::LookupOrDefault(config, "type", "text"); + if (res.type != "text" && res.type != "function" && res.type != "json_object" && + res.type != "structural_tag") { + return TResult::Error("Uknonwn response_format type " + res.type); + } + std::optional schema = json::LookupOptional(config, "schema"); if (schema.has_value()) { res.schema = schema.value(); } - if (res.type != "text" && res.type != "function" && res.type != "json_object") { - return TResult::Error("Uknonwn response_format type " + res.type); + if (auto tags_obj = json::LookupOptional(config, "tags")) { + auto tags = Array>(); + for (auto tag_obj : tags_obj.value()) { + Array tag = Array(); + std::optional begin = + json::LookupOptional(tag_obj.get(), "begin"); + std::optional schema = + json::LookupOptional(tag_obj.get(), "schema"); + std::optional end = + json::LookupOptional(tag_obj.get(), "end"); + if (!(begin.has_value() && schema.has_value() && end.has_value())) { + return TResult::Error("Miss tag attribute."); + } + tag.push_back(begin.value()); + tag.push_back(schema.value()); + tag.push_back(end.value()); + tags.push_back(std::move(tag)); + } + res.tags = tags; + } + + if (auto triggers_obj = json::LookupOptional(config, "triggers")) { + auto triggers = Array(); + for (auto trigger : triggers_obj.value()) { + triggers.push_back(trigger.get()); + } + res.triggers = triggers; } return TResult::Ok(res); @@ -60,6 +90,24 @@ picojson::object ResponseFormat::AsJSON() const { if (schema.defined()) { config["schema"] = picojson::value(schema.value().operator std::string()); } + if (tags.defined()) { + picojson::array tags_obj = picojson::array(); + for (auto tag : tags.value()) { + picojson::array tag_obj = picojson::array(); + tag_obj.emplace_back(tag[0]); + tag_obj.emplace_back(tag[1]); + tag_obj.emplace_back(tag[2]); + tags_obj.emplace_back(tag_obj); + } + config["tags"] = picojson::value(tags_obj); + } + if (triggers.defined()) { + picojson::array trigger_obj = picojson::array(); + for (std::string trigger : triggers.value()) { + trigger_obj.emplace_back(trigger); + } + config["triggers"] = picojson::value(trigger_obj); + } return config; } @@ -1073,4 +1121,4 @@ Result ModelsUseKVCache(const std::vector& model_configs } // namespace serve } // namespace llm -} // namespace mlc +} // namespace mlc \ No newline at end of file diff --git a/cpp/serve/config.h b/cpp/serve/config.h index 9da3ba2517..f39e1911ba 100644 --- a/cpp/serve/config.h +++ b/cpp/serve/config.h @@ -28,6 +28,8 @@ using namespace tvm::runtime; struct ResponseFormat { String type = "text"; Optional schema = NullOpt; + Optional>> tags = NullOpt; + Optional> triggers = NullOpt; /*! * \brief Create debug config from JSON. * \param config_json The json string for generation config @@ -448,4 +450,4 @@ inline PrefillMode PrefillModeFromString(const std::string& prefill_mode) { } // namespace llm } // namespace mlc -#endif // MLC_LLM_SERVE_CONFIG_H_ +#endif // MLC_LLM_SERVE_CONFIG_H_ \ No newline at end of file diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 2f09219392..0db12299a1 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -463,9 +463,11 @@ class EngineImpl : public Engine { ModelWorkspace{model->AllocEmbeddingTensor(), model->AllocHiddenStatesTensor()}); } // - Initialize tokenizer and grammar + n->tokenizer_ = Tokenizer::FromPath(engine_config->model, GetTokenizerInfo(model_configs[0])); n->token_table_ = n->tokenizer_->PostProcessedTokenTable(); - n->cached_grammar_compiler_ = xgrammar::CachedGrammarCompiler(n->token_table_); + // TODO: check 'vocab_size' of TokenizerInfo + n->grammar_compiler_ = xgrammar::GrammarCompiler(xgrammar::TokenizerInfo(n->token_table_)); // - Create the logit processor and sampler, and // the DraftTokenWorkspaceManager for speculative decoding. int max_num_tokens = engine_config->max_num_sequence; @@ -975,13 +977,25 @@ class EngineImpl : public Engine { * is not JSON, return std::nullopt. */ std::optional GetGrammarFromResponseFormat( const ResponseFormat& response_format) { - if (response_format.type != "json_object") { + // TODO: add other grammar type + if (response_format.type == "text") { return std::nullopt; - } else if (!response_format.schema) { - return cached_grammar_compiler_.GetCompiledGrammarForJSON(); + } else if (response_format.type == "json_object") { + if (!response_format.schema) { + return grammar_compiler_.CompileBuiltinJSONGrammar(); + } else { + return grammar_compiler_.CompileJSONSchema(response_format.schema.value()); + } } else { - return cached_grammar_compiler_.GetCompiledGrammarForJSONSchema( - response_format.schema.value()); + std::vector tags; + std::vector triggers; + for (auto tag : response_format.tags.value()) { + tags.emplace_back(xgrammar::StructuralTagItem{tag[0], tag[1], tag[2]}); + } + for (auto trigger : response_format.triggers.value()) { + triggers.emplace_back(trigger); + } + return grammar_compiler_.CompileStructuralTag(std::move(tags), std::move(triggers)); } } @@ -992,8 +1006,8 @@ class EngineImpl : public Engine { // internal tokenizer Tokenizer tokenizer_; std::vector token_table_; - // Cached grammar compiler for grammar matching. - xgrammar::CachedGrammarCompiler cached_grammar_compiler_; + // Grammar compiler for grammar matching. + xgrammar::GrammarCompiler grammar_compiler_; // Models Array models_; // Device that the models run on. diff --git a/cpp/serve/request_state.cc b/cpp/serve/request_state.cc index 17e02ee85b..4771f14d3b 100644 --- a/cpp/serve/request_state.cc +++ b/cpp/serve/request_state.cc @@ -24,7 +24,7 @@ RequestModelState::RequestModelState( if (compiled_grammar.has_value()) { // TODO(yixin): set rollback limit to a configurable value. n->grammar_matcher = - xgrammar::GrammarMatcher(compiled_grammar.value(), std::nullopt, false, std::nullopt, 10); + xgrammar::GrammarMatcher(compiled_grammar.value(), std::nullopt, false, 10); } n->request = std::move(request); @@ -44,7 +44,7 @@ bool RequestModelStateNode::RequireNextTokenBitmask() { return grammar_matcher.h void RequestModelStateNode::GetNextTokenBitmask(DLTensor* bitmask) { ICHECK(grammar_matcher.has_value()); - grammar_matcher->GetNextTokenBitmask(bitmask); + grammar_matcher->FillNextTokenBitmask(bitmask); } void RequestModelStateNode::CommitToken(SampleResult sampled_token) { diff --git a/python/mlc_llm/protocol/openai_api_protocol.py b/python/mlc_llm/protocol/openai_api_protocol.py index cb2e1f2852..eb6d6c7e7e 100644 --- a/python/mlc_llm/protocol/openai_api_protocol.py +++ b/python/mlc_llm/protocol/openai_api_protocol.py @@ -86,12 +86,35 @@ class ModelResponse(BaseModel): class RequestResponseFormat(BaseModel): - type: Literal["text", "json_object"] = "text" - json_schema: Optional[str] = Field(default=None, alias="schema") + type: Literal["text", "json_object", "structural_tag"] = "text" """This field is named json_schema instead of schema because BaseModel defines a method called schema. During construction of RequestResponseFormat, key "schema" still should be used: `RequestResponseFormat(type="json_object", schema="{}")` """ + json_schema: Optional[str] = Field(default=None, alias="schema") + + """These field are only used for type="structural_tag".""" + tags: Optional[List[Dict[str, str]]] = Field(default=None, alias="tags") + triggers: Optional[List[str]] = Field(default=None, alias="triggers") + + @model_validator(mode="after") + def check_request_response_format(self) -> "RequestResponseFormat": + """Check if the RequestResponseFormat is valid.""" + if self.type == "structural_tag": + if self.tags is None or self.triggers is None: + raise ValueError("structural_tag type must contain keys 'tags' and 'triggers'.") + for tag in self.tags: + if set(tag.keys()) != {"begin", "schema", "end"}: + raise ValueError( + "Each tag must contain exactly 'begin', 'schema' and 'end' keys." + f"Got keys: {list(tag.keys())}." + ) + elif self.tags is not None or self.triggers is not None: + raise Warning( + "'tags' and 'triggers' attributes should be used when type='structural_tag'" + ) + + return self class CompletionRequest(BaseModel): diff --git a/tests/python/serve/server/test_server_structural_tag.py b/tests/python/serve/server/test_server_structural_tag.py new file mode 100644 index 0000000000..f34df96d31 --- /dev/null +++ b/tests/python/serve/server/test_server_structural_tag.py @@ -0,0 +1,432 @@ +# pylint: disable=line-too-long +""" +Test script for structural tag in chat completion. To run this script, use the following command: +- start a new shell session, run + mlc_llm serve --model "YOUR_MODEL" (e.g. ./dist/Llama-2-7b-chat-hf-q0f16-MLC) +- start another shell session, run this file + MLC_SERVE_MODEL="YOUR_MODEL" python tests/python/serve/server/test_server_structural_tag.py +""" + +# pylint: disable=missing-function-docstring,too-many-arguments,too-many-locals,too-many-branches +import json +import os +import re +from typing import Any, Dict, List, Optional + +import pytest +import requests + +OPENAI_V1_CHAT_COMPLETION_URL = "http://127.0.0.1:8000/v1/chat/completions" + + +def check_openai_nonstream_response( + response: Dict, + *, + model: str, + object_str: str, + num_choices: int, + finish_reason: List[str], + completion_tokens: Optional[int] = None, +): + assert response["model"] == model + assert response["object"] == object_str + + choices = response["choices"] + assert isinstance(choices, list) + assert len(choices) == num_choices + for idx, choice in enumerate(choices): + assert choice["index"] == idx + assert choice["finish_reason"] in finish_reason + + find_format_start = set() + beg_tag_start = set() + message = choice["message"]["content"] + print("Outputs:\n-----------") + print(message, flush=True) + pattern1 = r"(.*?)\|(.*?)\|End<---(.*?)>" + pattern2 = r"(.*?)\|(.*?)\|End<---(.*?)>" + # check format + for match in re.finditer(pattern1, message): + find_format_start.add(match.start()) + check_format(match.group(1), match.group(3), "CALL", match.group(2)) + for match in re.finditer(pattern2, message): + find_format_start.add(match.start()) + check_format(match.group(1), match.group(3), "call", match.group(2)) + for match in re.finditer(r"", message): + beg_tag_start.add(match.start()) + for match in re.finditer(r"", message): + beg_tag_start.add(match.start()) + assert find_format_start == beg_tag_start + + usage = response["usage"] + assert isinstance(usage, dict) + assert usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"] + assert usage["prompt_tokens"] > 0 + + if completion_tokens is not None: + assert usage["completion_tokens"] == completion_tokens + + +def check_openai_stream_response( + responses: List[Dict], + *, + model: str, + object_str: str, + num_choices: int, + finish_reason: str, + echo_prompt: Optional[str] = None, + suffix: Optional[str] = None, + stop: Optional[List[str]] = None, + require_substr: Optional[List[str]] = None, +): + assert len(responses) > 0 + + finished = [False for _ in range(num_choices)] + outputs = ["" for _ in range(num_choices)] + for response in responses: + assert response["model"] == model + assert response["object"] == object_str + + choices = response["choices"] + assert isinstance(choices, list) + assert len(choices) == num_choices + for idx, choice in enumerate(choices): + assert choice["index"] == idx + + delta = choice["delta"] + assert delta["role"] == "assistant" + assert isinstance(delta["content"], str) + outputs[idx] += delta["content"] + + if finished[idx]: + assert choice["finish_reason"] == finish_reason + elif choice["finish_reason"] is not None: + assert choice["finish_reason"] == finish_reason + finished[idx] = True + + for output in outputs: + if echo_prompt is not None: + assert output.startswith(echo_prompt) + if suffix is not None: + assert output.endswith(suffix) + if stop is not None: + for stop_str in stop: + assert stop_str not in output + if require_substr is not None: + for substr in require_substr: + assert substr in output + find_format_start = set() + beg_tag_start = set() + print("Outputs:\n-----------") + print(output, flush=True) + pattern1 = r"(.*?)\|(.*?)\|End<---(.*?)>" + pattern2 = r"(.*?)\|(.*?)\|End<---(.*?)>" + # check format + for match in re.finditer(pattern1, output): + find_format_start.add(match.start()) + check_format(match.group(1), match.group(3), "CALL", match.group(2)) + for match in re.finditer(pattern2, output): + find_format_start.add(match.start()) + check_format(match.group(1), match.group(3), "call", match.group(2)) + for match in re.finditer(r"", output): + beg_tag_start.add(match.start()) + for match in re.finditer(r"", output): + beg_tag_start.add(match.start()) + assert find_format_start == beg_tag_start + + +def check_format(name_beg: str, name_end: str, beg_tag: str, schema: str): + try: + paras: Dict[str, Any] = json.loads(schema) + except json.JSONDecodeError as e: + print(f"Invalid JSON format: {e}") + assert False + assert "hash_code" in paras + assert "hash_code" in schema + hash_code = paras["hash_code"] + assert hash_code in CHECK_INFO + info = CHECK_INFO[hash_code] + assert name_beg == info["name"] + assert name_end == info["name"] + assert beg_tag == info["beg_tag"] + for key in info["required"]: + assert key in paras + + +# NOTE: the end-tag format and the hash_code number is been hidden in the SYSTEM_PROMPT. +# By checking whether the end tag and hash code can be generated correctly without any prompts, the correctness of the structural tag can be verified. + +SYSTEM_PROMPT = { + "role": "system", + "content": """ +# Tool Instructions +- Always execute python code in messages that you share. +- When looking for real time information use relevant functions if available else fallback to brave_search +You have access to the following functions: +Use the function 'get_current_weather' to: Get the current weather in a given location +{ + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city to find the weather for, e.g. 'San Francisco'", + }, + "state": { + "type": "string", + "description": "the two-letter abbreviation for the state that the city is" + " in, e.g. 'CA' which would mean 'California'", + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"], + }, + "hash_code": { + "type": "string", + }, + }, + "required": ["city", "state", "unit", "hash_code"], + }, +} +Use the function 'get_current_date' to: Get the current date and time for a given timezone +{ + "name": "get_current_date", + "description": "Get the current date and time for a given timezone", + "parameters": { + "type": "object", + "properties": { + "timezone": { + "type": "string", + "description": "The timezone to fetch the current date and time for, e.g. 'America/New_York'", + }, + "hash_code": { + "type": "string", + }, + }, + "required": ["timezone", "hash_code"], + }, +} +If a you choose to call a function ONLY reply in the following format: +<{start_tag}--->{function_name}|{parameters}|{end_tag}<---{function_name}> +where +start_tag => ` a JSON dict with the function argument name as key and function argument value as value. +Here is an example, +example_function_name|{"example_name": "example_value"}... +or +example_function_name|{"example_name": "example_value"}... +Reminder: +- Function calls MUST follow the specified format +- Required parameters MUST be specified +You are a helpful assistant.""", +} + +STRUCTURAL_TAGS = { + "triggers": ["", ""], + "tags": [ + { + "begin": "get_current_weather|", + "schema": json.dumps( + { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city to find the weather for, e.g. 'San Francisco'", + }, + "state": { + "type": "string", + "description": "the two-letter abbreviation for the state that the city is" + " in, e.g. 'CA' which would mean 'California'", + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"], + }, + "hash_code": {"const": 1234}, + }, + "required": ["city", "state", "unit", "hash_code"], + } + ), + "end": "|End<---get_current_weather>", + }, + { + "begin": "get_current_date|", + "schema": json.dumps( + { + "type": "object", + "properties": { + "timezone": { + "type": "string", + "description": "The timezone to fetch the current date and time for, e.g. 'America/New_York'", + }, + "hash_code": {"const": 2345}, + }, + "required": ["timezone", "hash_code"], + } + ), + "end": "|End<---get_current_date>", + }, + { + "begin": "get_current_weather|", + "schema": json.dumps( + { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city to find the weather for, e.g. 'San Francisco'", + }, + "state": { + "type": "string", + "description": "the two-letter abbreviation for the state that the city is" + " in, e.g. 'CA' which would mean 'California'", + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"], + }, + "hash_code": {"const": 3456}, + }, + "required": ["city", "state", "unit", "hash_code"], + } + ), + "end": "|End<---get_current_weather>", + }, + { + "begin": "get_current_date|", + "schema": json.dumps( + { + "type": "object", + "properties": { + "timezone": { + "type": "string", + "description": "The timezone to fetch the current date and time for, e.g. 'America/New_York'", + }, + "hash_code": {"const": 4567}, + }, + "required": ["timezone", "hash_code"], + } + ), + "end": "|End<---get_current_date>", + }, + ], +} + +CHECK_INFO = { + 1234: { + "name": "get_current_weather", + "beg_tag": "CALL", + "required": ["city", "state", "unit", "hash_code"], + }, + 2345: { + "name": "get_current_date", + "beg_tag": "CALL", + "required": ["timezone", "hash_code"], + }, + 3456: { + "name": "get_current_weather", + "beg_tag": "call", + "required": ["city", "state", "unit", "hash_code"], + }, + 4567: { + "name": "get_current_date", + "beg_tag": "call", + "required": ["timezone", "hash_code"], + }, +} + +CHAT_COMPLETION_MESSAGES = [ + # messages #0 + [ + SYSTEM_PROMPT, + { + "role": "user", + "content": "You are in New York. Please get the current date and time.", + }, + ], + # messages #1 + [ + SYSTEM_PROMPT, + { + "role": "user", + "content": "You are in New York. Please get the current weather.", + }, + ], + # messages #2 + [ + SYSTEM_PROMPT, + { + "role": "user", + "content": "You are in New York. Please get the current date and time, and the weather.", + }, + ], +] + + +@pytest.mark.parametrize("stream", [False, True]) +@pytest.mark.parametrize("messages", CHAT_COMPLETION_MESSAGES) +def test_openai_v1_chat_completion_structural_tag( + served_model: str, + launch_server, # pylint: disable=unused-argument + stream: bool, + messages: List[Dict[str, str]], +): + # `served_model` and `launch_server` are pytest fixtures + # defined in conftest.py. + + payload = { + "model": served_model, + "messages": messages, + "stream": stream, + "response_format": { + "type": "structural_tag", + "tags": STRUCTURAL_TAGS["tags"], + "triggers": STRUCTURAL_TAGS["triggers"], + }, + "max_tokens": 1024, + } + + response = requests.post(OPENAI_V1_CHAT_COMPLETION_URL, json=payload, timeout=60) + if not stream: + check_openai_nonstream_response( + response.json(), + model=served_model, + object_str="chat.completion", + num_choices=1, + finish_reason=["stop"], + ) + else: + responses = [] + for chunk in response.iter_lines(chunk_size=512): + if not chunk or chunk == b"data: [DONE]": + continue + responses.append(json.loads(chunk.decode("utf-8")[6:])) + check_openai_stream_response( + responses, + model=served_model, + object_str="chat.completion.chunk", + num_choices=1, + finish_reason="stop", + ) + + print(f"-----------\nCheck for stream={stream} is passed!\n") + + +if __name__ == "__main__": + MODEL = os.environ.get("MLC_SERVE_MODEL") + if MODEL is None: + raise ValueError( + 'Environment variable "MLC_SERVE_MODEL" not found. ' + "Please set it to model compiled by MLC LLM " + "(e.g., `./dist/Llama-2-7b-chat-hf-q0f16-MLC`) " + ) + + for msg in CHAT_COMPLETION_MESSAGES: + test_openai_v1_chat_completion_structural_tag(MODEL, None, stream=False, messages=msg) + test_openai_v1_chat_completion_structural_tag(MODEL, None, stream=True, messages=msg)