Skip to content

Commit 066434f

Browse files
committed
[Model] Hunyuan A13B tool parser refine and tests.
- add test for hunyuan a13b tool parser. - fix mypy error on tool parser - refine reason parser test. - refactory tool parser stream function. Signed-off-by: Asher Zhang <asherszhang@tencent.com>
1 parent c0fdba0 commit 066434f

File tree

3 files changed

+364
-276
lines changed

3 files changed

+364
-276
lines changed
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
# ruff: noqa: E501
4+
5+
import json
6+
from unittest.mock import MagicMock
7+
8+
import pytest
9+
10+
from tests.entrypoints.openai.tool_parsers.utils import (
11+
run_tool_extraction, run_tool_extraction_streaming)
12+
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
13+
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
14+
15+
16+
def make_tool_call(name, arguments):
17+
return ToolCall(type="function",
18+
function=FunctionCall(name=name,
19+
arguments=json.dumps(arguments)))
20+
21+
22+
# TODO: add reason prefix and suffix.
23+
24+
25+
@pytest.mark.parametrize(
26+
"model_output,expected_tool_calls,expected_content",
27+
[
28+
# No tool call
29+
("How can I help you today?", [], "How can I help you today?"),
30+
# Single tool call, no content
31+
(
32+
"<tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"San Francisco\", \"metric\": \"celsius\"}}]</tool_calls>", #noqa: E501
33+
[
34+
make_tool_call("get_weather", {
35+
"city": "San Francisco",
36+
"metric": "celsius"
37+
})
38+
],
39+
None),
40+
# Multiple tool calls
41+
(
42+
"<tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"San Francisco\", \"metric\": \"celsius\"}}, {\"name\": \"register_user\", \"arguments\": {\"name\": \"John Doe\", \"age\": 37, \"address\": {\"city\": \"San Francisco\", \"state\": \"CA\"}, \"role\": null, \"passed_test\": true, \"aliases\": [\"John\", \"Johnny\"]}}]</tool_calls>", #noqa: E501
43+
[
44+
make_tool_call("get_weather", {
45+
"city": "San Francisco",
46+
"metric": "celsius"
47+
}),
48+
make_tool_call(
49+
"register_user", {
50+
"name": "John Doe",
51+
"age": 37,
52+
"address": {
53+
"city": "San Francisco",
54+
"state": "CA"
55+
},
56+
"role": None,
57+
"passed_test": True,
58+
"aliases": ["John", "Johnny"]
59+
})
60+
],
61+
None),
62+
# Content before tool call
63+
(
64+
"I will call the tool now. <tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"Boston\"}}]</tool_calls>", #noqa: E501
65+
[make_tool_call("get_weather", {"city": "Boston"})],
66+
"I will call the tool now. "),
67+
# Content after tool call (should be stripped)
68+
(
69+
"<tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"Seattle\"}}]</tool_calls>\nThank you!", #noqa: E501
70+
[make_tool_call("get_weather", {"city": "Seattle"})],
71+
None),
72+
(
73+
"<tool_calls>[{\"name\": \"complex_tool\", \"arguments\": {\"level1\": {\"level2\": {\"level3\": {\"value\": 123}}}}}]</tool_calls>",
74+
[
75+
make_tool_call(
76+
"complex_tool",
77+
{"level1": {
78+
"level2": {
79+
"level3": {
80+
"value": 123
81+
}
82+
}
83+
}})
84+
],
85+
None,
86+
),
87+
])
88+
def test_hunyuan_a13b_tool_parser_extract(model_output, expected_tool_calls,
89+
expected_content):
90+
mock_tokenizer = MagicMock()
91+
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
92+
"hunyuan_a13b")(mock_tokenizer)
93+
content, tool_calls = run_tool_extraction(tool_parser,
94+
model_output,
95+
streaming=False)
96+
97+
# align the random id.
98+
for idx in range(len(tool_calls)):
99+
tool_calls[idx].id = expected_tool_calls[idx].id
100+
assert tool_calls == expected_tool_calls
101+
assert content == expected_content
102+
103+
104+
# Streaming test: simulate incremental output
105+
@pytest.mark.parametrize("model_deltas,expected_tool_calls", [
106+
([
107+
"<tool_calls>[{\"name\": \"get_weather\", ",
108+
"\"arguments\": {\"city\": \"San Francisco\", ",
109+
"\"metric\": \"celsius\"}}]", "</tool_calls>"
110+
], [
111+
make_tool_call("get_weather", {
112+
"city": "San Francisco",
113+
"metric": "celsius"
114+
})
115+
]),
116+
([
117+
"<tool_calls>[{\"name\":", " \"get_weather\",", " \"arguments\":",
118+
" {\"city\": \"Boston\"}", "}]", "</tool_calls>"
119+
], [make_tool_call("get_weather", {"city": "Boston"})]),
120+
([
121+
"", "<tool_calls>[{\"name\":", " \"get_weather\",", " \"arguments\":",
122+
" {\"city\": \"Boston\"}", "}]", "</tool_calls>", "\n</answer>"
123+
], [make_tool_call("get_weather", {"city": "Boston"})]),
124+
pytest.param([
125+
"<tool_calls>[{\"name\": \"complex_tool\",", " \"arguments\": ",
126+
" {\"level1\": {\"level2\": ", "{\"level3\": {\"value\": 123}}}}}",
127+
"]</tool_calls>"
128+
], [
129+
make_tool_call("complex_tool",
130+
{"level1": {
131+
"level2": {
132+
"level3": {
133+
"value": 123
134+
}
135+
}
136+
}})
137+
],
138+
marks=pytest.mark.xfail(
139+
reason="stream parsing not support nested json yet.")),
140+
])
141+
def test_hunyuan_a13b_tool_parser_streaming(model_deltas, expected_tool_calls):
142+
mock_tokenizer = MagicMock()
143+
144+
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
145+
"hunyuan_a13b")(mock_tokenizer)
146+
reconstructor = run_tool_extraction_streaming(
147+
tool_parser, model_deltas, assert_one_tool_per_delta=False)
148+
149+
# align the random id.
150+
for idx in range(len(reconstructor.tool_calls)):
151+
reconstructor.tool_calls[idx].id = expected_tool_calls[idx].id
152+
153+
assert reconstructor.tool_calls == expected_tool_calls

tests/reasoning/test_hunyuan_reasoning_parser.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@
3030
"reasoning_content": "This is a reasoning section",
3131
"content": None,
3232
}
33+
34+
COMPLETE_REASONING_WITH_SYMBOL = {
35+
"output": f"{START_REASONING}This is a reasoning section!{START_RESPONSE}",
36+
"reasoning_content": "This is a reasoning section!",
37+
"content": None,
38+
}
3339
NO_REASONING = {
3440
"output": "This is content",
3541
"reasoning_content": None,
@@ -70,6 +76,11 @@
7076
COMPLETE_REASONING,
7177
id="complete_reasoning",
7278
),
79+
pytest.param(
80+
False,
81+
COMPLETE_REASONING_WITH_SYMBOL,
82+
id="complete_reasoning_with_symbol",
83+
),
7384
pytest.param(
7485
False,
7586
NO_REASONING,

0 commit comments

Comments
 (0)