Skip to content

Commit c78d1fb

Browse files
avignysjuxax
authored andcommitted
Bring in tests from vllm-project#19425
1 parent 976dc82 commit c78d1fb

File tree

1 file changed

+311
-0
lines changed

1 file changed

+311
-0
lines changed
Lines changed: 311 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,311 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import json
4+
from collections.abc import Generator
5+
from typing import Optional
6+
7+
import partial_json_parser
8+
import pytest
9+
from partial_json_parser.core.options import Allow
10+
11+
from vllm.entrypoints.openai.protocol import (DeltaMessage, FunctionCall,
12+
ToolCall)
13+
from vllm.entrypoints.openai.tool_parsers import MistralToolParser
14+
from vllm.transformers_utils.detokenizer import detokenize_incrementally
15+
from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer
16+
17+
MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
18+
19+
20+
@pytest.fixture(scope="module")
21+
def mistral_tokenizer():
22+
return get_tokenizer(tokenizer_name=MODEL)
23+
24+
25+
@pytest.fixture
26+
def mistral_tool_parser(mistral_tokenizer):
27+
return MistralToolParser(mistral_tokenizer)
28+
29+
30+
def assert_tool_calls(actual_tool_calls: list[ToolCall],
31+
expected_tool_calls: list[ToolCall]):
32+
assert len(actual_tool_calls) == len(expected_tool_calls)
33+
34+
for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
35+
expected_tool_calls):
36+
assert isinstance(actual_tool_call.id, str)
37+
assert len(actual_tool_call.id) == 9
38+
39+
assert actual_tool_call.type == "function"
40+
assert actual_tool_call.function == expected_tool_call.function, (
41+
f'got ${actual_tool_call.function}')
42+
43+
44+
def stream_delta_message_generator(
45+
mistral_tool_parser: MistralToolParser,
46+
mistral_tokenizer: AnyTokenizer,
47+
model_output: str) -> Generator[DeltaMessage, None, None]:
48+
all_token_ids = mistral_tokenizer.encode(model_output,
49+
add_special_tokens=False)
50+
51+
previous_text = ""
52+
previous_tokens = None
53+
prefix_offset = 0
54+
read_offset = 0
55+
for i, delta_token in enumerate(all_token_ids):
56+
delta_token_ids = [delta_token]
57+
previous_token_ids = all_token_ids[:i]
58+
current_token_ids = all_token_ids[:i + 1]
59+
60+
(new_tokens, delta_text, new_prefix_offset,
61+
new_read_offset) = detokenize_incrementally(
62+
tokenizer=mistral_tokenizer,
63+
all_input_ids=current_token_ids,
64+
prev_tokens=previous_tokens,
65+
prefix_offset=prefix_offset,
66+
read_offset=read_offset,
67+
skip_special_tokens=False,
68+
spaces_between_special_tokens=True,
69+
)
70+
71+
current_text = previous_text + delta_text
72+
73+
delta_message = mistral_tool_parser.extract_tool_calls_streaming(
74+
previous_text,
75+
current_text,
76+
delta_text,
77+
previous_token_ids,
78+
current_token_ids,
79+
delta_token_ids,
80+
request=None, # type: ignore[arg-type]
81+
)
82+
if delta_message:
83+
yield delta_message
84+
85+
previous_text = current_text
86+
previous_tokens = previous_tokens + new_tokens if previous_tokens\
87+
else new_tokens
88+
prefix_offset = new_prefix_offset
89+
read_offset = new_read_offset
90+
91+
92+
def test_extract_tool_calls_no_tools(mistral_tool_parser):
93+
model_output = "This is a test"
94+
extracted_tool_calls = mistral_tool_parser.extract_tool_calls(
95+
model_output, request=None) # type: ignore[arg-type]
96+
assert not extracted_tool_calls.tools_called
97+
assert extracted_tool_calls.tool_calls == []
98+
assert extracted_tool_calls.content == model_output
99+
100+
101+
@pytest.mark.parametrize(
102+
ids=[
103+
"single_tool_add", "single_tool_weather", "argument_before_name",
104+
"argument_before_name_and_name_in_argument"
105+
],
106+
argnames=["model_output", "expected_tool_calls", "expected_content"],
107+
argvalues=[
108+
(
109+
'''[TOOL_CALLS][{"name": "add", "arguments":{"a": 3.5, "b": 4}}]''', # noqa: E501
110+
[
111+
ToolCall(function=FunctionCall(name="add",
112+
arguments=json.dumps({
113+
"a": 3.5,
114+
"b": 4
115+
})))
116+
],
117+
None),
118+
(
119+
'''[TOOL_CALLS] [{"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}}]''', # noqa: E501
120+
[
121+
ToolCall(function=FunctionCall(name="get_current_weather",
122+
arguments=json.dumps(
123+
{
124+
"city": "San Francisco",
125+
"state": "CA",
126+
"unit": "celsius"
127+
})))
128+
],
129+
None),
130+
(
131+
'''[TOOL_CALLS] [{"arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]''', # noqa: E501
132+
[
133+
ToolCall(function=FunctionCall(name="get_current_weather",
134+
arguments=json.dumps(
135+
{
136+
"city": "San Francisco",
137+
"state": "CA",
138+
"unit": "celsius"
139+
})))
140+
],
141+
None),
142+
(
143+
'''[TOOL_CALLS] [{"arguments":{"name": "John Doe"}, "name": "get_age"}]''', # noqa: E501
144+
[
145+
ToolCall(function=FunctionCall(name="get_age",
146+
arguments=json.dumps({
147+
"name":
148+
"John Doe",
149+
})))
150+
],
151+
None),
152+
],
153+
)
154+
def test_extract_tool_calls(mistral_tool_parser, model_output,
155+
expected_tool_calls, expected_content):
156+
extracted_tool_calls = mistral_tool_parser.extract_tool_calls(
157+
model_output, request=None) # type: ignore[arg-type]
158+
assert extracted_tool_calls.tools_called
159+
160+
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
161+
162+
assert extracted_tool_calls.content == expected_content
163+
164+
165+
@pytest.mark.parametrize(
166+
ids=[
167+
"no_tools",
168+
"single_tool_add",
169+
"single_tool_add_strings",
170+
"single_tool_weather",
171+
"argument_before_name",
172+
"argument_before_name_and_name_in_argument",
173+
"multiple_tools",
174+
],
175+
argnames=["model_output", "expected_tool_calls", "expected_content"],
176+
argvalues=[
177+
('''This is a test''', [], '''This is a test'''),
178+
(
179+
'''[TOOL_CALLS] [ {"name":"add" , "arguments" : {"a": 3, "b": 4} } ]''', # noqa: E501
180+
[
181+
ToolCall(function=FunctionCall(name="add",
182+
arguments=json.dumps({
183+
"a": 3,
184+
"b": 4
185+
})))
186+
],
187+
""),
188+
(
189+
'''[TOOL_CALLS] [{"name": "add", "arguments":{"a": "3", "b": "4"}}]''', # noqa: E501
190+
[
191+
ToolCall(function=FunctionCall(name="add",
192+
arguments=json.dumps({
193+
"a": "3",
194+
"b": "4"
195+
})))
196+
],
197+
""),
198+
(
199+
'''[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}}]''', # noqa: E501
200+
[
201+
ToolCall(function=FunctionCall(name="get_current_weather",
202+
arguments=json.dumps(
203+
{
204+
"city": "San Francisco",
205+
"state": "CA",
206+
"unit": "celsius"
207+
})))
208+
],
209+
""),
210+
(
211+
'''[TOOL_CALLS] [{"arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]''', # noqa: E501
212+
[
213+
ToolCall(function=FunctionCall(name="get_current_weather",
214+
arguments=json.dumps(
215+
{
216+
"city": "San Francisco",
217+
"state": "CA",
218+
"unit": "celsius"
219+
})))
220+
],
221+
''),
222+
(
223+
'''[TOOL_CALLS] [{"arguments": {"name": "John Doe"}, "name": "get_age"}]''', # noqa: E501
224+
[
225+
ToolCall(function=FunctionCall(name="get_age",
226+
arguments=json.dumps({
227+
"name":
228+
"John Doe",
229+
})))
230+
],
231+
''),
232+
(
233+
'''[TOOL_CALLS][{"name": "add", "arguments": {"a": 3.5, "b": 4}}, {"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}]''', # noqa: E501
234+
[
235+
ToolCall(function=FunctionCall(name="add",
236+
arguments=json.dumps({
237+
"a": 3.5,
238+
"b": 4
239+
}))),
240+
ToolCall(function=FunctionCall(name="get_current_weather",
241+
arguments=json.dumps(
242+
{
243+
"city": "San Francisco",
244+
"state": "CA",
245+
"unit": "celsius"
246+
})))
247+
],
248+
''),
249+
],
250+
)
251+
def test_extract_tool_calls_streaming(mistral_tool_parser, mistral_tokenizer,
252+
model_output, expected_tool_calls,
253+
expected_content):
254+
other_content: str = ''
255+
function_names: list[str] = []
256+
function_args_strs: list[str] = []
257+
tool_call_idx: int = -1
258+
tool_call_ids: list[Optional[str]] = []
259+
260+
for delta_message in stream_delta_message_generator(
261+
mistral_tool_parser, mistral_tokenizer, model_output):
262+
# role should never be streamed from tool parser
263+
assert not delta_message.role
264+
265+
if delta_message.content:
266+
other_content += delta_message.content
267+
268+
streamed_tool_calls = delta_message.tool_calls
269+
270+
if streamed_tool_calls and len(streamed_tool_calls) > 0:
271+
# make sure only one diff is present - correct even for parallel
272+
assert len(streamed_tool_calls) == 1
273+
tool_call = streamed_tool_calls[0]
274+
275+
# if a new tool is being called, set up empty arguments
276+
if tool_call.index != tool_call_idx:
277+
tool_call_idx = tool_call.index
278+
function_args_strs.append("")
279+
tool_call_ids.append(None)
280+
281+
# if a tool call ID is streamed, make sure one hasn't been already
282+
if tool_call.id and not tool_call_ids[tool_call.index]:
283+
tool_call_ids[tool_call.index] = tool_call.id
284+
285+
# if parts of the function start being streamed
286+
if tool_call.function:
287+
# if the function name is defined, set it. it should be streamed
288+
# IN ENTIRETY, exactly one time.
289+
if tool_call.function.name:
290+
assert isinstance(tool_call.function.name, str)
291+
function_names.append(tool_call.function.name)
292+
293+
if tool_call.function.arguments:
294+
# make sure they're a string and then add them to the list
295+
assert isinstance(tool_call.function.arguments, str)
296+
297+
function_args_strs[
298+
tool_call.index] += tool_call.function.arguments
299+
300+
assert other_content == expected_content
301+
302+
actual_tool_calls = [
303+
ToolCall(id=tool_call_id,
304+
function=FunctionCall(
305+
name=function_name,
306+
arguments=partial_json_parser.ensure_json(
307+
function_args_str, Allow.OBJ | Allow.STR)))
308+
for tool_call_id, function_name, function_args_str in zip(
309+
tool_call_ids, function_names, function_args_strs)
310+
]
311+
assert_tool_calls(actual_tool_calls, expected_tool_calls)

0 commit comments

Comments
 (0)