Skip to content

Commit 2b2b014

Browse files
committed
[Model] Hunyuan A13B tool parser refine and tests.
- add test for hunyuan a13b tool parser. - fix mypy error on tool parser - refine reason parser test. - refactory tool parser stream function. Signed-off-by: Asher Zhang <asherszhang@tencent.com>
1 parent c45f97d commit 2b2b014

File tree

3 files changed

+333
-275
lines changed

3 files changed

+333
-275
lines changed
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import json
5+
from unittest.mock import MagicMock
6+
7+
import pytest
8+
9+
from tests.entrypoints.openai.tool_parsers.utils import (
10+
run_tool_extraction, run_tool_extraction_streaming)
11+
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
12+
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
13+
14+
15+
def make_tool_call(name, arguments):
16+
return ToolCall(type="function",
17+
function=FunctionCall(name=name,
18+
arguments=json.dumps(arguments)))
19+
20+
21+
# TODO: add reason prefix and suffix.
22+
23+
24+
@pytest.mark.parametrize(
25+
"model_output,expected_tool_calls,expected_content",
26+
[
27+
# No tool call
28+
("How can I help you today?", [], "How can I help you today?"),
29+
# Single tool call, no content
30+
(
31+
"<tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"San Francisco\", \"metric\": \"celsius\"}}]</tool_calls>", #noqa: E501
32+
[
33+
make_tool_call("get_weather", {
34+
"city": "San Francisco",
35+
"metric": "celsius"
36+
})
37+
],
38+
None),
39+
# Multiple tool calls
40+
(
41+
"<tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"San Francisco\", \"metric\": \"celsius\"}}, {\"name\": \"register_user\", \"arguments\": {\"name\": \"John Doe\", \"age\": 37, \"address\": {\"city\": \"San Francisco\", \"state\": \"CA\"}, \"role\": null, \"passed_test\": true, \"aliases\": [\"John\", \"Johnny\"]}}]</tool_calls>", #noqa: E501
42+
[
43+
make_tool_call("get_weather", {
44+
"city": "San Francisco",
45+
"metric": "celsius"
46+
}),
47+
make_tool_call(
48+
"register_user", {
49+
"name": "John Doe",
50+
"age": 37,
51+
"address": {
52+
"city": "San Francisco",
53+
"state": "CA"
54+
},
55+
"role": None,
56+
"passed_test": True,
57+
"aliases": ["John", "Johnny"]
58+
})
59+
],
60+
None),
61+
# Content before tool call
62+
(
63+
"I will call the tool now. <tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"Boston\"}}]</tool_calls>", #noqa: E501
64+
[make_tool_call("get_weather", {"city": "Boston"})],
65+
"I will call the tool now. "),
66+
# Content after tool call (should be stripped)
67+
(
68+
"<tool_calls>[{\"name\": \"get_weather\", \"arguments\": {\"city\": \"Seattle\"}}]</tool_calls>\nThank you!", #noqa: E501
69+
[make_tool_call("get_weather", {"city": "Seattle"})],
70+
None),
71+
])
72+
def test_hunyuan_a13b_tool_parser_extract(model_output, expected_tool_calls,
73+
expected_content):
74+
mock_tokenizer = MagicMock()
75+
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
76+
"hunyuan_a13b")(mock_tokenizer)
77+
content, tool_calls = run_tool_extraction(tool_parser,
78+
model_output,
79+
streaming=False)
80+
81+
# align the random id.
82+
for idx in range(len(tool_calls)):
83+
tool_calls[idx].id = expected_tool_calls[idx].id
84+
assert tool_calls == expected_tool_calls
85+
assert content == expected_content
86+
87+
88+
# Streaming test: simulate incremental output
89+
@pytest.mark.parametrize("model_deltas,expected_tool_calls", [
90+
([
91+
"<tool_calls>[{\"name\": \"get_weather\", ",
92+
"\"arguments\": {\"city\": \"San Francisco\", ",
93+
"\"metric\": \"celsius\"}}]", "</tool_calls>"
94+
], [
95+
make_tool_call("get_weather", {
96+
"city": "San Francisco",
97+
"metric": "celsius"
98+
})
99+
]),
100+
([
101+
"<tool_calls>[{\"name\":", " \"get_weather\",", " \"arguments\":",
102+
" {\"city\": \"Boston\"}", "}]", "</tool_calls>"
103+
], [make_tool_call("get_weather", {"city": "Boston"})]),
104+
([
105+
"", "<tool_calls>[{\"name\":", " \"get_weather\",", " \"arguments\":",
106+
" {\"city\": \"Boston\"}", "}]", "</tool_calls>", "\n</answer>"
107+
], [make_tool_call("get_weather", {"city": "Boston"})]),
108+
])
109+
def test_hunyuan_a13b_tool_parser_streaming(model_deltas, expected_tool_calls):
110+
mock_tokenizer = MagicMock()
111+
112+
tool_parser: ToolParser = ToolParserManager.get_tool_parser(
113+
"hunyuan_a13b")(mock_tokenizer)
114+
reconstructor = run_tool_extraction_streaming(
115+
tool_parser, model_deltas, assert_one_tool_per_delta=False)
116+
117+
# align the random id.
118+
for idx in range(len(reconstructor.tool_calls)):
119+
reconstructor.tool_calls[idx].id = expected_tool_calls[idx].id
120+
121+
assert reconstructor.tool_calls == expected_tool_calls

tests/reasoning/test_hunyuan_reasoning_parser.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@
3030
"reasoning_content": "This is a reasoning section",
3131
"content": None,
3232
}
33+
34+
COMPLETE_REASONING_WITH_SYMBOL = {
35+
"output": f"{START_REASONING}This is a reasoning section!{START_RESPONSE}",
36+
"reasoning_content": "This is a reasoning section!",
37+
"content": None,
38+
}
3339
NO_REASONING = {
3440
"output": "This is content",
3541
"reasoning_content": None,
@@ -70,6 +76,11 @@
7076
COMPLETE_REASONING,
7177
id="complete_reasoning",
7278
),
79+
pytest.param(
80+
False,
81+
COMPLETE_REASONING_WITH_SYMBOL,
82+
id="complete_reasoning_with_symbol",
83+
),
7384
pytest.param(
7485
False,
7586
NO_REASONING,

0 commit comments

Comments
 (0)