Skip to content

Commit 4f980f3

Browse files
committed
[feat] Add Structural-Tag api to RequestResponseFormat
- Expose Structural-Tag api, which can be used to standarlize function calling format - Add test script for Structural-Tag (passed on Llama-2-7b-chat-hf-q0f16-MLC and Llama-3-8B-Instruct-q4f16_1-MLC)
1 parent 6e7426e commit 4f980f3

File tree

9 files changed

+451
-68
lines changed

9 files changed

+451
-68
lines changed

cpp/serve/config.cc

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,11 @@
1010

1111
#include <limits>
1212
#include <random>
13-
#include <string>
14-
#include <vector>
1513

1614
#include "../json_ffi/openai_api_protocol.h"
1715
#include "../support/json_parser.h"
1816
#include "../support/utils.h"
1917
#include "data.h"
20-
#include "tvm/runtime/container/array.h"
2118

2219
namespace mlc {
2320
namespace llm {
@@ -1124,4 +1121,4 @@ Result<bool> ModelsUseKVCache(const std::vector<picojson::object>& model_configs
11241121

11251122
} // namespace serve
11261123
} // namespace llm
1127-
} // namespace mlc
1124+
} // namespace mlc

cpp/serve/config.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
#include "../metadata/model.h"
1616
#include "../support/result.h"
17-
#include "tvm/runtime/container/optional.h"
1817

1918
namespace mlc {
2019
namespace llm {
@@ -451,4 +450,4 @@ inline PrefillMode PrefillModeFromString(const std::string& prefill_mode) {
451450
} // namespace llm
452451
} // namespace mlc
453452

454-
#endif // MLC_LLM_SERVE_CONFIG_H_
453+
#endif // MLC_LLM_SERVE_CONFIG_H_

cpp/serve/engine.cc

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,8 @@
1919
#include <functional>
2020
#include <numeric>
2121
#include <optional>
22-
#include <string>
2322
#include <tuple>
2423
#include <unordered_set>
25-
#include <utility>
2624

2725
#include "../support/json_parser.h"
2826
#include "../support/result.h"
@@ -37,7 +35,6 @@
3735
#include "request.h"
3836
#include "request_state.h"
3937
#include "sampler/sampler.h"
40-
#include "xgrammar/grammar.h"
4138

4239
namespace mlc {
4340
namespace llm {

cpp/serve/logit_processor.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
#include <tvm/runtime/registry.h>
1313
#include <tvm/runtime/threading_backend.h>
1414

15-
#include <cstdio>
16-
1715
namespace mlc {
1816
namespace llm {
1917
namespace serve {

python/mlc_llm/protocol/openai_api_protocol.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def check_request_response_format(self) -> "RequestResponseFormat":
112112
raise Warning(
113113
"'tags' and 'triggers' attributes should be used when type='structural_tag'"
114114
)
115+
115116
return self
116117

117118

python/mlc_llm/serve/engine.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -976,12 +976,6 @@ async def _chat_completion( # pylint: disable=too-many-arguments,too-many-local
976976
if request_id is None:
977977
request_id = f"chatcmpl-{engine_utils.random_uuid()}"
978978

979-
tools = (
980-
[openai_api_protocol.ChatTool.model_validate(tool) for tool in tools]
981-
if tools is not None
982-
else None
983-
)
984-
985979
chatcmpl_generator = self._handle_chat_completion(
986980
openai_api_protocol.ChatCompletionRequest(
987981
messages=[
@@ -1213,10 +1207,6 @@ async def _handle_chat_completion(
12131207
e : BadRequestError
12141208
BadRequestError is raised when the request is invalid.
12151209
"""
1216-
request.response_format = engine_base.set_structural_tag_from_tools(
1217-
request.tools, request.response_format
1218-
)
1219-
12201210
(
12211211
prompts,
12221212
generation_cfg,
@@ -1774,10 +1764,6 @@ def _handle_chat_completion(
17741764
e : BadRequestError
17751765
BadRequestError is raised when the request is invalid.
17761766
"""
1777-
request.response_format = engine_base.set_structural_tag_from_tools(
1778-
request.tools, request.response_format
1779-
)
1780-
17811767
(
17821768
prompts,
17831769
generation_cfg,

python/mlc_llm/serve/engine_base.py

Lines changed: 18 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import json
88
import numbers
99
import queue
10-
import re
1110
import sys
1211
import threading
1312
from dataclasses import dataclass
@@ -1147,52 +1146,29 @@ def create_completion_suffix_response(
11471146
return response
11481147

11491148

1150-
def convert_function_str_to_json(stringified_calls: str):
1149+
def convert_function_str_to_json(stringified_calls: str) -> List[Union[Dict, None]]:
11511150
"""Convert a (possibly list) of function call string to a list of json objects.
11521151
Return None for invalid function call string."""
1153-
function_calls_json = []
1154-
for call in re.finditer(r"<function=(.*?)>(.*?)</function>", stringified_calls, re.DOTALL):
1155-
function_name = call.group(1)
1156-
params_str = call.group(2).strip()
1157-
params = ast.literal_eval(params_str)
1158-
function_calls_json.append({"name": function_name, "arguments": params})
1159-
1160-
return function_calls_json
11611152

1153+
def parse_function_call(call_str: str):
1154+
node = ast.parse(call_str, mode="eval")
1155+
call_node = node.body
1156+
if isinstance(call_node, ast.Call) and isinstance(call_node.func, ast.Name):
1157+
name = call_node.func.id
1158+
arguments = {}
1159+
for keyword in call_node.keywords:
1160+
arguments[keyword.arg] = ast.literal_eval(keyword.value)
1161+
return {"name": name, "arguments": arguments}
1162+
return None
11621163

1163-
def set_structural_tag_from_tools(
1164-
tools: Optional[List[openai_api_protocol.ChatTool]],
1165-
response_format: Optional[openai_api_protocol.RequestResponseFormat],
1166-
):
1167-
"""Add the corresponding structural tag to the response format according to the tools to ensure valid function calling.
1168-
Return the updated response format.
1169-
"""
1170-
if tools is None:
1171-
return response_format
1164+
if (
1165+
stringified_calls[0] == "[" and stringified_calls[-1] == "]"
1166+
): # hacky way to check if string list
1167+
calls = ast.literal_eval(stringified_calls)
11721168
else:
1173-
if response_format is None or response_format.type == "text":
1174-
response_format = openai_api_protocol.RequestResponseFormat.model_validate(
1175-
{"type": "structural_tag", "tags": [], "triggers": []}
1176-
)
1177-
elif response_format.type == "json_object":
1178-
response_format.tags = []
1179-
response_format.triggers = []
1180-
1181-
response_format.triggers.append("<function=")
1182-
for tool in tools:
1183-
schema = {
1184-
"properties": tool.function.parameters["properties"],
1185-
"required": tool.function.parameters["required"],
1186-
"type": tool.function.parameters["type"],
1187-
}
1188-
response_format.tags.append(
1189-
{
1190-
"begin": f"<function={tool.function.name}>",
1191-
"schema": json.dumps(schema),
1192-
"end": "</function>",
1193-
}
1194-
)
1195-
return response_format
1169+
calls = [stringified_calls]
1170+
function_calls_json = [parse_function_call(call_str) for call_str in calls]
1171+
return function_calls_json
11961172

11971173

11981174
def process_function_call_output(

0 commit comments

Comments
 (0)