Skip to content

Commit 6e7426e

Browse files
committed
[feature] Support tool function calling with Structural Tag with xgrammar
- ensure the tool function will be called in expected format using xgrammar - modify RequestResponseFormat: add structural tag according to the tools when building response format - the tool function calling is now constrained by format: <function=function_name>parameters</function> - tools call list will be parsed according to the calling format when processing the response - also expose the Structural Tag api of xgrammar to RequestResponseFormat
1 parent f70a37a commit 6e7426e

File tree

7 files changed

+156
-26
lines changed

7 files changed

+156
-26
lines changed

cpp/serve/config.cc

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@
1010

1111
#include <limits>
1212
#include <random>
13+
#include <string>
14+
#include <vector>
1315

1416
#include "../json_ffi/openai_api_protocol.h"
1517
#include "../support/json_parser.h"
1618
#include "../support/utils.h"
1719
#include "data.h"
20+
#include "tvm/runtime/container/array.h"
1821

1922
namespace mlc {
2023
namespace llm {
@@ -42,13 +45,43 @@ Result<ResponseFormat> ResponseFormat::FromJSON(const picojson::object& config)
4245
ResponseFormat res;
4346
res.type = json::LookupOrDefault<std::string>(config, "type", "text");
4447

48+
if (res.type != "text" && res.type != "function" && res.type != "json_object" &&
49+
res.type != "structural_tag") {
50+
return TResult::Error("Uknonwn response_format type " + res.type);
51+
}
52+
4553
std::optional<std::string> schema = json::LookupOptional<std::string>(config, "schema");
4654
if (schema.has_value()) {
4755
res.schema = schema.value();
4856
}
4957

50-
if (res.type != "text" && res.type != "function" && res.type != "json_object") {
51-
return TResult::Error("Uknonwn response_format type " + res.type);
58+
if (auto tags_obj = json::LookupOptional<picojson::array>(config, "tags")) {
59+
auto tags = Array<Array<String>>();
60+
for (auto tag_obj : tags_obj.value()) {
61+
Array<String> tag = Array<String>();
62+
std::optional<std::string> begin =
63+
json::LookupOptional<std::string>(tag_obj.get<picojson::object>(), "begin");
64+
std::optional<std::string> schema =
65+
json::LookupOptional<std::string>(tag_obj.get<picojson::object>(), "schema");
66+
std::optional<std::string> end =
67+
json::LookupOptional<std::string>(tag_obj.get<picojson::object>(), "end");
68+
if (!(begin.has_value() && schema.has_value() && end.has_value())) {
69+
return TResult::Error("Miss tag attribute.");
70+
}
71+
tag.push_back(begin.value());
72+
tag.push_back(schema.value());
73+
tag.push_back(end.value());
74+
tags.push_back(std::move(tag));
75+
}
76+
res.tags = tags;
77+
}
78+
79+
if (auto triggers_obj = json::LookupOptional<picojson::array>(config, "triggers")) {
80+
auto triggers = Array<String>();
81+
for (auto trigger : triggers_obj.value()) {
82+
triggers.push_back(trigger.get<std::string>());
83+
}
84+
res.triggers = triggers;
5285
}
5386

5487
return TResult::Ok(res);
@@ -60,6 +93,24 @@ picojson::object ResponseFormat::AsJSON() const {
6093
if (schema.defined()) {
6194
config["schema"] = picojson::value(schema.value().operator std::string());
6295
}
96+
if (tags.defined()) {
97+
picojson::array tags_obj = picojson::array();
98+
for (auto tag : tags.value()) {
99+
picojson::array tag_obj = picojson::array();
100+
tag_obj.emplace_back(tag[0]);
101+
tag_obj.emplace_back(tag[1]);
102+
tag_obj.emplace_back(tag[2]);
103+
tags_obj.emplace_back(tag_obj);
104+
}
105+
config["tags"] = picojson::value(tags_obj);
106+
}
107+
if (triggers.defined()) {
108+
picojson::array trigger_obj = picojson::array();
109+
for (std::string trigger : triggers.value()) {
110+
trigger_obj.emplace_back(trigger);
111+
}
112+
config["triggers"] = picojson::value(trigger_obj);
113+
}
63114
return config;
64115
}
65116

cpp/serve/config.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "../metadata/model.h"
1616
#include "../support/result.h"
17+
#include "tvm/runtime/container/optional.h"
1718

1819
namespace mlc {
1920
namespace llm {
@@ -28,6 +29,8 @@ using namespace tvm::runtime;
2829
struct ResponseFormat {
2930
String type = "text";
3031
Optional<String> schema = NullOpt;
32+
Optional<Array<Array<String>>> tags = NullOpt;
33+
Optional<Array<String>> triggers = NullOpt;
3134
/*!
3235
* \brief Create debug config from JSON.
3336
* \param config_json The json string for generation config

cpp/serve/engine.cc

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919
#include <functional>
2020
#include <numeric>
2121
#include <optional>
22+
#include <string>
2223
#include <tuple>
2324
#include <unordered_set>
25+
#include <utility>
2426

2527
#include "../support/json_parser.h"
2628
#include "../support/result.h"
@@ -35,6 +37,7 @@
3537
#include "request.h"
3638
#include "request_state.h"
3739
#include "sampler/sampler.h"
40+
#include "xgrammar/grammar.h"
3841

3942
namespace mlc {
4043
namespace llm {
@@ -978,12 +981,24 @@ class EngineImpl : public Engine {
978981
std::optional<xgrammar::CompiledGrammar> GetGrammarFromResponseFormat(
979982
const ResponseFormat& response_format) {
980983
// TODO: add other grammar type
981-
if (response_format.type != "json_object") {
984+
if (response_format.type == "text") {
982985
return std::nullopt;
983-
} else if (!response_format.schema) {
984-
return grammar_compiler_.CompileBuiltinJSONGrammar();
986+
} else if (response_format.type == "json_object") {
987+
if (!response_format.schema) {
988+
return grammar_compiler_.CompileBuiltinJSONGrammar();
989+
} else {
990+
return grammar_compiler_.CompileJSONSchema(response_format.schema.value());
991+
}
985992
} else {
986-
return grammar_compiler_.CompileJSONSchema(response_format.schema.value());
993+
std::vector<xgrammar::StructuralTagItem> tags;
994+
std::vector<std::string> triggers;
995+
for (auto tag : response_format.tags.value()) {
996+
tags.emplace_back(xgrammar::StructuralTagItem{tag[0], tag[1], tag[2]});
997+
}
998+
for (auto trigger : response_format.triggers.value()) {
999+
triggers.emplace_back(trigger);
1000+
}
1001+
return grammar_compiler_.CompileStructuralTag(std::move(tags), std::move(triggers));
9871002
}
9881003
}
9891004

cpp/serve/logit_processor.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
#include <tvm/runtime/registry.h>
1313
#include <tvm/runtime/threading_backend.h>
1414

15+
#include <cstdio>
16+
1517
namespace mlc {
1618
namespace llm {
1719
namespace serve {

python/mlc_llm/protocol/openai_api_protocol.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,33 @@ class ModelResponse(BaseModel):
8686

8787

8888
class RequestResponseFormat(BaseModel):
89-
type: Literal["text", "json_object"] = "text"
90-
json_schema: Optional[str] = Field(default=None, alias="schema")
89+
type: Literal["text", "json_object", "structural_tag"] = "text"
9190
"""This field is named json_schema instead of schema because BaseModel defines a method called
9291
schema. During construction of RequestResponseFormat, key "schema" still should be used:
9392
`RequestResponseFormat(type="json_object", schema="{}")`
9493
"""
94+
json_schema: Optional[str] = Field(default=None, alias="schema")
95+
96+
"""These field are only used for type="structural_tag"."""
97+
tags: Optional[List[Dict[str, str]]] = Field(default=None, alias="tags")
98+
triggers: Optional[List[str]] = Field(default=None, alias="triggers")
99+
100+
@model_validator(mode="after")
101+
def check_request_response_format(self) -> "RequestResponseFormat":
102+
"""Check if the RequestResponseFormat is valid."""
103+
if self.type == "structural_tag":
104+
if self.tags is None or self.triggers is None:
105+
raise ValueError("structural_tag type must contain keys 'tags' and 'triggers'.")
106+
for tag in self.tags:
107+
if set(tag.keys()) != {"begin", "schema", "end"}:
108+
raise ValueError(
109+
f"Each tag must contain exactly 'begin', 'schema' and 'end' keys. Got keys: {list(tag.keys())}."
110+
)
111+
elif self.tags is not None or self.triggers is not None:
112+
raise Warning(
113+
"'tags' and 'triggers' attributes should be used when type='structural_tag'"
114+
)
115+
return self
95116

96117

97118
class CompletionRequest(BaseModel):

python/mlc_llm/serve/engine.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -976,6 +976,12 @@ async def _chat_completion( # pylint: disable=too-many-arguments,too-many-local
976976
if request_id is None:
977977
request_id = f"chatcmpl-{engine_utils.random_uuid()}"
978978

979+
tools = (
980+
[openai_api_protocol.ChatTool.model_validate(tool) for tool in tools]
981+
if tools is not None
982+
else None
983+
)
984+
979985
chatcmpl_generator = self._handle_chat_completion(
980986
openai_api_protocol.ChatCompletionRequest(
981987
messages=[
@@ -1207,6 +1213,10 @@ async def _handle_chat_completion(
12071213
e : BadRequestError
12081214
BadRequestError is raised when the request is invalid.
12091215
"""
1216+
request.response_format = engine_base.set_structural_tag_from_tools(
1217+
request.tools, request.response_format
1218+
)
1219+
12101220
(
12111221
prompts,
12121222
generation_cfg,
@@ -1764,6 +1774,10 @@ def _handle_chat_completion(
17641774
e : BadRequestError
17651775
BadRequestError is raised when the request is invalid.
17661776
"""
1777+
request.response_format = engine_base.set_structural_tag_from_tools(
1778+
request.tools, request.response_format
1779+
)
1780+
17671781
(
17681782
prompts,
17691783
generation_cfg,

python/mlc_llm/serve/engine_base.py

Lines changed: 42 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import json
88
import numbers
99
import queue
10+
import re
1011
import sys
1112
import threading
1213
from dataclasses import dataclass
@@ -1146,29 +1147,52 @@ def create_completion_suffix_response(
11461147
return response
11471148

11481149

1149-
def convert_function_str_to_json(stringified_calls: str) -> List[Union[Dict, None]]:
1150+
def convert_function_str_to_json(stringified_calls: str):
11501151
"""Convert a (possibly list) of function call string to a list of json objects.
11511152
Return None for invalid function call string."""
1153+
function_calls_json = []
1154+
for call in re.finditer(r"<function=(.*?)>(.*?)</function>", stringified_calls, re.DOTALL):
1155+
function_name = call.group(1)
1156+
params_str = call.group(2).strip()
1157+
params = ast.literal_eval(params_str)
1158+
function_calls_json.append({"name": function_name, "arguments": params})
11521159

1153-
def parse_function_call(call_str: str):
1154-
node = ast.parse(call_str, mode="eval")
1155-
call_node = node.body
1156-
if isinstance(call_node, ast.Call) and isinstance(call_node.func, ast.Name):
1157-
name = call_node.func.id
1158-
arguments = {}
1159-
for keyword in call_node.keywords:
1160-
arguments[keyword.arg] = ast.literal_eval(keyword.value)
1161-
return {"name": name, "arguments": arguments}
1162-
return None
1160+
return function_calls_json
11631161

1164-
if (
1165-
stringified_calls[0] == "[" and stringified_calls[-1] == "]"
1166-
): # hacky way to check if string list
1167-
calls = ast.literal_eval(stringified_calls)
1162+
1163+
def set_structural_tag_from_tools(
1164+
tools: Optional[List[openai_api_protocol.ChatTool]],
1165+
response_format: Optional[openai_api_protocol.RequestResponseFormat],
1166+
):
1167+
"""Add the corresponding structural tag to the response format according to the tools to ensure valid function calling.
1168+
Return the updated response format.
1169+
"""
1170+
if tools is None:
1171+
return response_format
11681172
else:
1169-
calls = [stringified_calls]
1170-
function_calls_json = [parse_function_call(call_str) for call_str in calls]
1171-
return function_calls_json
1173+
if response_format is None or response_format.type == "text":
1174+
response_format = openai_api_protocol.RequestResponseFormat.model_validate(
1175+
{"type": "structural_tag", "tags": [], "triggers": []}
1176+
)
1177+
elif response_format.type == "json_object":
1178+
response_format.tags = []
1179+
response_format.triggers = []
1180+
1181+
response_format.triggers.append("<function=")
1182+
for tool in tools:
1183+
schema = {
1184+
"properties": tool.function.parameters["properties"],
1185+
"required": tool.function.parameters["required"],
1186+
"type": tool.function.parameters["type"],
1187+
}
1188+
response_format.tags.append(
1189+
{
1190+
"begin": f"<function={tool.function.name}>",
1191+
"schema": json.dumps(schema),
1192+
"end": "</function>",
1193+
}
1194+
)
1195+
return response_format
11721196

11731197

11741198
def process_function_call_output(

0 commit comments

Comments
 (0)