From 554ff47dec44c6229f30cc9c486ac6690d2e55fd Mon Sep 17 00:00:00 2001 From: avigny <47987522+avigny@users.noreply.github.com> Date: Tue, 10 Jun 2025 14:38:18 +0200 Subject: [PATCH 01/12] Testing mistral tool parser Tests are similar as the ones added for Jamba models in https://github.com/vllm-project/vllm/pull/9154 Signed-off-by: avigny <47987522+avigny@users.noreply.github.com> --- tests/tool_use/test_mistral_tool_parser.py | 315 +++++++++++++++++++++ 1 file changed, 315 insertions(+) create mode 100644 tests/tool_use/test_mistral_tool_parser.py diff --git a/tests/tool_use/test_mistral_tool_parser.py b/tests/tool_use/test_mistral_tool_parser.py new file mode 100644 index 000000000000..1405bc9a8d5f --- /dev/null +++ b/tests/tool_use/test_mistral_tool_parser.py @@ -0,0 +1,315 @@ +# SPDX-License-Identifier: Apache-2.0 + +import json +from collections.abc import Generator +from typing import Optional + +import partial_json_parser +import pytest +from partial_json_parser.core.options import Allow + +from vllm.entrypoints.openai.protocol import (DeltaMessage, FunctionCall, + ToolCall) +from vllm.entrypoints.openai.tool_parsers import MistralToolParser +from vllm.transformers_utils.detokenizer import detokenize_incrementally +from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer + +MODEL = "mistralai/Mistral-7B-Instruct-v0.3" + + +@pytest.fixture(scope="module") +def mistral_tokenizer(): + return get_tokenizer(tokenizer_name=MODEL) + + +@pytest.fixture +def mistral_tool_parser(mistral_tokenizer): + return MistralToolParser(mistral_tokenizer) + + +def assert_tool_calls(actual_tool_calls: list[ToolCall], + expected_tool_calls: list[ToolCall]): + assert len(actual_tool_calls) == len(expected_tool_calls) + + for actual_tool_call, expected_tool_call in zip(actual_tool_calls, + expected_tool_calls): + assert isinstance(actual_tool_call.id, str) + assert len(actual_tool_call.id) == 9 + + assert actual_tool_call.type == "function" + assert actual_tool_call.function == expected_tool_call.function, f'got ${actual_tool_call.function}' + + +def stream_delta_message_generator( + mistral_tool_parser: MistralToolParser, mistral_tokenizer: AnyTokenizer, + model_output: str) -> Generator[DeltaMessage, None, None]: + all_token_ids = mistral_tokenizer.encode(model_output, + add_special_tokens=False) + + previous_text = "" + previous_tokens = None + prefix_offset = 0 + read_offset = 0 + for i, delta_token in enumerate(all_token_ids): + delta_token_ids = [delta_token] + previous_token_ids = all_token_ids[:i] + current_token_ids = all_token_ids[:i + 1] + + (new_tokens, delta_text, new_prefix_offset, + new_read_offset) = detokenize_incrementally( + tokenizer=mistral_tokenizer, + all_input_ids=current_token_ids, + prev_tokens=previous_tokens, + prefix_offset=prefix_offset, + read_offset=read_offset, + skip_special_tokens=False, + spaces_between_special_tokens=True, + ) + + current_text = previous_text + delta_text + + delta_message = mistral_tool_parser.extract_tool_calls_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + request=None, # type: ignore[arg-type] + ) + if delta_message: + yield delta_message + + previous_text = current_text + previous_tokens = previous_tokens + new_tokens if previous_tokens\ + else new_tokens + prefix_offset = new_prefix_offset + read_offset = new_read_offset + + +def test_extract_tool_calls_no_tools(mistral_tool_parser): + model_output = "This is a test" + extracted_tool_calls = mistral_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + assert not extracted_tool_calls.tools_called + assert extracted_tool_calls.tool_calls == [] + assert extracted_tool_calls.content == model_output + + +@pytest.mark.parametrize( + ids=[ + "single_tool_add", + "single_tool_weather", + "argument_before_name", + "argument_before_name_and_name_in_argument" + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ( + '''[TOOL_CALLS][{"name": "add", "arguments":{"a": 3.5, "b": 4}}]''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="add", + arguments=json.dumps( + { + "a": 3.5, + "b": 4 + }))) + ], + None), + ( + '''[TOOL_CALLS] [{"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}}]''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="get_current_weather", + arguments=json.dumps( + { + "city": "San Francisco", + "state": "CA", + "unit": "celsius" + }))) + ], + None), + ( + '''[TOOL_CALLS] [{"arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="get_current_weather", + arguments=json.dumps( + { + "city": "San Francisco", + "state": "CA", + "unit": "celsius" + }))) + ], + None), + ( + '''[TOOL_CALLS] [{"arguments":{"name": "John Doe"}, "name": "get_age"}]''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="get_age", + arguments=json.dumps( + { + "name": "John Doe", + }))) + ], + None), + ], +) +def test_extract_tool_calls(mistral_tool_parser, model_output, + expected_tool_calls, expected_content): + extracted_tool_calls = mistral_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + assert extracted_tool_calls.tools_called + + assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) + + assert extracted_tool_calls.content == expected_content + + +@pytest.mark.parametrize( + ids=[ + "no_tools", + "single_tool_add", + "single_tool_add_strings", + "single_tool_weather", + "argument_before_name", + "argument_before_name_and_name_in_argument", + "multiple_tools", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ('''This is a test''', [], '''This is a test'''), + ( + '''[TOOL_CALLS] [ {"name":"add" , "arguments" : {"a": 3, "b": 4} } ]''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="add", + arguments=json.dumps( + { + "a": 3, + "b": 4 + }))) + ], + ""), + ( + '''[TOOL_CALLS] [{"name": "add", "arguments":{"a": "3", "b": "4"}}]''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="add", + arguments=json.dumps( + { + "a": "3", + "b": "4" + }))) + ], + ""), + ( + '''[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}}]''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="get_current_weather", + arguments=json.dumps( + { + "city": "San Francisco", + "state": "CA", + "unit": "celsius" + }))) + ], + ""), + ( + '''[TOOL_CALLS] [{"arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="get_current_weather", + arguments=json.dumps( + { + "city": "San Francisco", + "state": "CA", + "unit": "celsius" + }))) + ], + ''), + ( + '''[TOOL_CALLS] [{"arguments": {"name": "John Doe"}, "name": "get_age"}]''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="get_age", + arguments=json.dumps( + { + "name": "John Doe", + }))) + ], + ''), + ( + '''[TOOL_CALLS][{"name": "add", "arguments": {"a": 3.5, "b": 4}}, {"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}]''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="add", + arguments=json.dumps( + { + "a": 3.5, + "b": 4 + }))), + ToolCall(function=FunctionCall(name="get_current_weather", + arguments=json.dumps( + { + "city": "San Francisco", + "state": "CA", + "unit": "celsius" + }))) + ], + ''), + ], +) +def test_extract_tool_calls_streaming(mistral_tool_parser, mistral_tokenizer, + model_output, expected_tool_calls, + expected_content): + other_content: str = '' + function_names: list[str] = [] + function_args_strs: list[str] = [] + tool_call_idx: int = -1 + tool_call_ids: list[Optional[str]] = [] + + for delta_message in stream_delta_message_generator( + mistral_tool_parser, mistral_tokenizer, model_output): + # role should never be streamed from tool parser + assert not delta_message.role + + if delta_message.content: + other_content += delta_message.content + + streamed_tool_calls = delta_message.tool_calls + + if streamed_tool_calls and len(streamed_tool_calls) > 0: + # make sure only one diff is present - correct even for parallel + assert len(streamed_tool_calls) == 1 + tool_call = streamed_tool_calls[0] + + # if a new tool is being called, set up empty arguments + if tool_call.index != tool_call_idx: + tool_call_idx = tool_call.index + function_args_strs.append("") + tool_call_ids.append(None) + + # if a tool call ID is streamed, make sure one hasn't been already + if tool_call.id and not tool_call_ids[tool_call.index]: + tool_call_ids[tool_call.index] = tool_call.id + + # if parts of the function start being streamed + if tool_call.function: + # if the function name is defined, set it. it should be streamed + # IN ENTIRETY, exactly one time. + if tool_call.function.name: + assert isinstance(tool_call.function.name, str) + function_names.append(tool_call.function.name) + + if tool_call.function.arguments: + # make sure they're a string and then add them to the list + assert isinstance(tool_call.function.arguments, str) + + function_args_strs[ + tool_call.index] += tool_call.function.arguments + + assert other_content == expected_content + + actual_tool_calls = [ + ToolCall(id=tool_call_id, + function=FunctionCall( + name=function_name, + arguments=partial_json_parser.ensure_json( + function_args_str, Allow.OBJ | Allow.STR))) + for tool_call_id, function_name, function_args_str in zip( + tool_call_ids, function_names, function_args_strs) + ] + assert_tool_calls(actual_tool_calls, expected_tool_calls) From f1e1e3800aef0b64a6dd3a1ea5a36b937a61d76a Mon Sep 17 00:00:00 2001 From: avigny <47987522+avigny@users.noreply.github.com> Date: Tue, 10 Jun 2025 16:15:39 +0200 Subject: [PATCH 02/12] Update streaming tool parser for mistral Signed-off-by: avigny <47987522+avigny@users.noreply.github.com> --- .../tool_parsers/mistral_tool_parser.py | 361 ++++++++++-------- 1 file changed, 193 insertions(+), 168 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index fecad7e653ab..c2aec0f17dcb 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -4,11 +4,9 @@ from collections.abc import Sequence from random import choices from string import ascii_letters, digits -from typing import Union +from typing import Literal, Union -import partial_json_parser import regex as re -from partial_json_parser.core.options import Allow from pydantic import Field from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, @@ -61,11 +59,24 @@ def __init__(self, tokenizer: AnyTokenizer): # initialize properties used for state when parsing tool calls in # streaming mode - self.prev_tool_call_arr: list[dict] = [] + self.json_decoder: json.JSONDecoder = json.JSONDecoder() + self.tool_call_first_attribute_name: re.Pattern[str] = re.compile(r'.*\s*"name"\s*:\s*') + self.string_value_pattern: re.Pattern[str] = re.compile(r'\s*"(.*?)(? end_index: + additional_content = self.raw_tool_calls[end_index:] + except json.decoder.JSONDecodeError: + # we are in tool calls + pass + + if self.current_tool_name_finished and self.current_tool_arguments_finished: + if self.tools_parsing_finished: + return self._none_or_additional_content(additional_content) + # let's find the next tool starting position + next_tool_start_index = self._next_tool_starting_position() + if next_tool_start_index > 0: + self.current_tool_id += 1 + self.current_tool_start_index = next_tool_start_index + self.current_attribute_start_index = -1 + self.previous_attribute_end_index = 0 + self.current_tool_name_finished = False + self.current_tool_arguments_finished = False + + if self.current_tool_start_index >= len(self.raw_tool_calls): + # tool call has not started + return self._none_or_additional_content(additional_content) + raw_current_tool_call = self.raw_tool_calls[self.current_tool_start_index:] + + if self.current_element_streaming is None: + # we are waiting for the next argument to be parsed + match_name = self.tool_call_first_attribute_name.match(raw_current_tool_call) + match_arguments = self.tool_call_first_attribute_arguments.match(raw_current_tool_call) + if not self.current_tool_name_finished and match_name: + if self.previous_attribute_end_index is not None and match_name.end() <= self.previous_attribute_end_index: + return self._none_or_additional_content(additional_content) + self.current_element_streaming = "name" + self.current_attribute_start_index = match_name.end() + elif not self.current_tool_arguments_finished and match_arguments: + if self.previous_attribute_end_index is not None and match_arguments.end() <= self.previous_attribute_end_index: + return self._none_or_additional_content(additional_content) + self.current_element_streaming = "arguments" + self.current_attribute_start_index = match_arguments.end() - 1 # the `{` is the last IN the match part. We want it as the start index element + else: + # let's wait for more deltas + return self._none_or_additional_content(additional_content) - # replace BOT token with empty string, and convert single quotes - # to double to allow parsing as JSON since mistral uses single - # quotes instead of double for tool calls - parsable_arr = current_text.split(self.bot_token)[-1] - - # tool calls are generated in an array, so do partial JSON - # parsing on the entire array + if self.current_element_streaming == "name": try: - tool_call_arr: list[dict] = partial_json_parser.loads( - parsable_arr, flags) - except partial_json_parser.core.exceptions.MalformedJSON: - logger.debug('not enough tokens to parse into JSON yet') - return None - - # select as the current tool call the one we're on the state at - - current_tool_call: dict = tool_call_arr[self.current_tool_id] \ - if len(tool_call_arr) > 0 else {} - - # case -- if no tokens have been streamed for the tool, e.g. - # only the array brackets, stream nothing - if len(tool_call_arr) == 0: - return None - - # case: we are starting a new tool in the array - # -> array has > 0 length AND length has moved past cursor - elif (len(tool_call_arr) > 0 - and len(tool_call_arr) > self.current_tool_id + 1): - - # if we're moving on to a new call, first make sure we - # haven't missed anything in the previous one that was - # auto-generated due to JSON completions, but wasn't - # streamed to the client yet. - if self.current_tool_id >= 0: - diff: Union[str, None] = current_tool_call.get("arguments") - - if diff: - diff = json.dumps(diff, ensure_ascii=False).replace( - self.streamed_args_for_tool[self.current_tool_id], - "") - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=diff).model_dump( - exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += diff - else: - delta = None - else: - delta = None - # re-set stuff pertaining to progress in the current tool - self.current_tool_id = len(tool_call_arr) - 1 - self.current_tool_name_sent = False - self.streamed_args_for_tool.append("") - logger.debug("starting on new tool %d", self.current_tool_id) - return delta - - # case: update an existing tool - this is handled below - - # if the current tool name hasn't been sent, send if available - # - otherwise send nothing - if not self.current_tool_name_sent: - function_name = current_tool_call.get("name") - if function_name: - - delta = DeltaMessage(tool_calls=[ + function_name, name_end_index = self._extracted_complete_name(raw_current_tool_call, self.current_attribute_start_index) + except IndexError: + # name value has not started being generated + return self._none_or_additional_content(additional_content) + if function_name == "": + return self._none_or_additional_content(additional_content) + else: + assert name_end_index is not None # because the function name was successfully retrieved + self.current_tool_name_finished = True + self.current_element_streaming = None + self.current_attribute_start_index = -1 + self.previous_attribute_end_index = name_end_index + delta = DeltaMessage( + content=additional_content, + tool_calls=[ DeltaToolCall(index=self.current_tool_id, type="function", id=MistralToolCall.generate_random_id(), function=DeltaFunctionCall( name=function_name).model_dump( exclude_none=True)) + ]) + return delta + if self.current_element_streaming == "arguments": + try: + diff, arguments_end_index = self._extract_argument_fragment(raw_current_tool_call, self.current_attribute_start_index, delta_text) + self.current_tool_arguments_finished = arguments_end_index != -1 + if self.current_tool_arguments_finished: + self.current_element_streaming = None + self.current_attribute_start_index = -1 + self.previous_attribute_end_index = arguments_end_index + delta = DeltaMessage( + content=additional_content, + tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff). + model_dump(exclude_none=True)) ]) - self.current_tool_name_sent = True - else: - delta = None + return delta + except IndexError: + # arguments value has not started being generated + return self._none_or_additional_content(additional_content) - # now we know we're on the same tool call and we're streaming - # arguments - else: - prev_arguments = self.prev_tool_call_arr[ - self.current_tool_id].get("arguments") - cur_arguments = current_tool_call.get("arguments") - - new_text = delta_text.replace("\'", "\"") - if ('"}' in new_text): - new_text = new_text[:new_text.rindex('"}')] - - if not cur_arguments and not prev_arguments: - - delta = None - elif not cur_arguments and prev_arguments: - logger.error( - "INVARIANT - impossible to have arguments reset " - "mid-arguments") - delta = None - elif cur_arguments and not prev_arguments: - cur_arguments_json = json.dumps(cur_arguments, - ensure_ascii=False)[:-2] - logger.debug("finding %s in %s", new_text, - cur_arguments_json) - - if (new_text not in cur_arguments_json): - return None - arguments_delta = cur_arguments_json[:cur_arguments_json. - rindex(new_text) + - len(new_text)] - logger.debug("First tokens in arguments received: %s", - arguments_delta) - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=arguments_delta). - model_dump(exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += arguments_delta - - elif cur_arguments and prev_arguments: - cur_args_json = json.dumps(cur_arguments, - ensure_ascii=False) - prev_args_json = json.dumps(prev_arguments, - ensure_ascii=False) - logger.debug("Searching for diff between \n%s\n%s", - cur_args_json, prev_args_json) - - argument_diff = extract_intermediate_diff( - cur_args_json, prev_args_json) - logger.debug("got arguments diff: %s", argument_diff) - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=argument_diff).model_dump( - exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += argument_diff - else: - # try parsing it with regular JSON - if it works we're - # at the end, and we need to send the difference between - # tokens streamed so far and the valid JSON - delta = None - - # check to see if the name is defined and has been sent. if so, - # stream the name - otherwise keep waiting - # finish by setting old and returning None as base case - self.prev_tool_call_arr = tool_call_arr - return delta + def _extracted_complete_name(self, raw_current_tool_call: str, current_attribute_start_index: int) -> tuple[str, Union[int, None]]: + """ + Extract the complete function name from the current tool call. - except Exception: - logger.exception("Error trying to handle streaming tool call.") - logger.debug( - "Skipping chunk as a result of tool streaming extraction " - "error") - return None + Args: + raw_current_tool_call: The raw JSON string of the current tool call + current_attribute_start_index: The starting index of the name attribute in the raw_current_tool_call string + + Returns: + tuple: + - The function name, or "" if extraction failed + - The end index of the name in raw_current_tool_call, or None if extraction failed + """ + partial_name_value = raw_current_tool_call[current_attribute_start_index:] + if match := self.string_value_pattern.match(partial_name_value): + return match.group(1), match.end() + current_attribute_start_index + return "", None + + def _extract_argument_fragment(self, raw_current_tool_call: str, current_attribute_start_index: int, delta: str) -> tuple[str, int]: + """ + Extract the relevant argument fragment from the current streaming delta. + + Args: + raw_current_tool_call: The raw JSON string of the current tool call + current_attribute_start_index: The starting index of the arguments attribute in the raw string + delta: The new text added in this streaming step + + Returns: + tuple: + - The extracted argument diff text to be sent in the streaming response + - The end index of the arguments in the raw string, or -1 if not yet complete + """ + partial_arguments_value = raw_current_tool_call[current_attribute_start_index:] + try: + _, end_index = self.json_decoder.raw_decode(partial_arguments_value) + return delta[:len(delta) + end_index - len(partial_arguments_value)], current_attribute_start_index + end_index + except json.decoder.JSONDecodeError: + # The arguments object is not complete + + # delta contains data from before the argument start + if len(delta) > len(partial_arguments_value): + return delta[-len(partial_arguments_value):], -1 + + # We can send the whole delta + return delta, -1 + + def _next_tool_starting_position(self) -> int: + """ + Find the starting position of the next tool in the raw tool calls string. + + Returns: + The index position where the next tool starts, or -1 if no next tool is found yet + """ + assert self.current_tool_start_index >= 0 + current_tool_call = self.raw_tool_calls[self.current_tool_start_index:] + try: + _, end_index = self.json_decoder.raw_decode(current_tool_call) + return self.current_tool_start_index + end_index + current_tool_call[end_index:].find("{") + except json.decoder.JSONDecodeError: + # The current tool object is not yet closed + return -1 + except IndexError: + # The next tool has not started yet + # and the delta just closes the current tool call + return -1 + + def _none_or_additional_content(self, additional_content: str) -> Union[DeltaMessage, None]: + """ + Create a DeltaMessage with additional content if present, otherwise return None. + + Args: + additional_content: The text content to include in the message + + Returns: + A DeltaMessage with the additional content, or None if no content is provided + """ + if additional_content: + return DeltaMessage(content=additional_content) + return None \ No newline at end of file From 92601d984cff280248f0593fa4b76347aa54d7fb Mon Sep 17 00:00:00 2001 From: avigny <47987522+avigny@users.noreply.github.com> Date: Tue, 10 Jun 2025 17:12:02 +0200 Subject: [PATCH 03/12] Removing unneeded check Signed-off-by: avigny <47987522+avigny@users.noreply.github.com> --- vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index c2aec0f17dcb..0a2c4bdccbfe 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -226,12 +226,12 @@ def extract_tool_calls_streaming( match_name = self.tool_call_first_attribute_name.match(raw_current_tool_call) match_arguments = self.tool_call_first_attribute_arguments.match(raw_current_tool_call) if not self.current_tool_name_finished and match_name: - if self.previous_attribute_end_index is not None and match_name.end() <= self.previous_attribute_end_index: + if match_name.end() <= self.previous_attribute_end_index: return self._none_or_additional_content(additional_content) self.current_element_streaming = "name" self.current_attribute_start_index = match_name.end() elif not self.current_tool_arguments_finished and match_arguments: - if self.previous_attribute_end_index is not None and match_arguments.end() <= self.previous_attribute_end_index: + if match_arguments.end() <= self.previous_attribute_end_index: return self._none_or_additional_content(additional_content) self.current_element_streaming = "arguments" self.current_attribute_start_index = match_arguments.end() - 1 # the `{` is the last IN the match part. We want it as the start index element From d6d17c15a639dc6a4d741c9284b40981acaf631a Mon Sep 17 00:00:00 2001 From: avigny <47987522+avigny@users.noreply.github.com> Date: Tue, 10 Jun 2025 17:40:24 +0200 Subject: [PATCH 04/12] repair ruff pre-commit Signed-off-by: avigny <47987522+avigny@users.noreply.github.com> --- tests/tool_use/test_mistral_tool_parser.py | 74 ++++--- .../tool_parsers/mistral_tool_parser.py | 184 +++++++++++------- 2 files changed, 151 insertions(+), 107 deletions(-) diff --git a/tests/tool_use/test_mistral_tool_parser.py b/tests/tool_use/test_mistral_tool_parser.py index 1405bc9a8d5f..b9bbdef57d97 100644 --- a/tests/tool_use/test_mistral_tool_parser.py +++ b/tests/tool_use/test_mistral_tool_parser.py @@ -37,14 +37,16 @@ def assert_tool_calls(actual_tool_calls: list[ToolCall], assert len(actual_tool_call.id) == 9 assert actual_tool_call.type == "function" - assert actual_tool_call.function == expected_tool_call.function, f'got ${actual_tool_call.function}' + assert actual_tool_call.function == expected_tool_call.function, ( + f'got ${actual_tool_call.function}') def stream_delta_message_generator( - mistral_tool_parser: MistralToolParser, mistral_tokenizer: AnyTokenizer, + mistral_tool_parser: MistralToolParser, + mistral_tokenizer: AnyTokenizer, model_output: str) -> Generator[DeltaMessage, None, None]: all_token_ids = mistral_tokenizer.encode(model_output, - add_special_tokens=False) + add_special_tokens=False) previous_text = "" previous_tokens = None @@ -98,9 +100,7 @@ def test_extract_tool_calls_no_tools(mistral_tool_parser): @pytest.mark.parametrize( ids=[ - "single_tool_add", - "single_tool_weather", - "argument_before_name", + "single_tool_add", "single_tool_weather", "argument_before_name", "argument_before_name_and_name_in_argument" ], argnames=["model_output", "expected_tool_calls", "expected_content"], @@ -109,11 +109,10 @@ def test_extract_tool_calls_no_tools(mistral_tool_parser): '''[TOOL_CALLS][{"name": "add", "arguments":{"a": 3.5, "b": 4}}]''', # noqa: E501 [ ToolCall(function=FunctionCall(name="add", - arguments=json.dumps( - { - "a": 3.5, - "b": 4 - }))) + arguments=json.dumps({ + "a": 3.5, + "b": 4 + }))) ], None), ( @@ -125,7 +124,7 @@ def test_extract_tool_calls_no_tools(mistral_tool_parser): "city": "San Francisco", "state": "CA", "unit": "celsius" - }))) + }))) ], None), ( @@ -137,17 +136,17 @@ def test_extract_tool_calls_no_tools(mistral_tool_parser): "city": "San Francisco", "state": "CA", "unit": "celsius" - }))) + }))) ], None), ( '''[TOOL_CALLS] [{"arguments":{"name": "John Doe"}, "name": "get_age"}]''', # noqa: E501 [ ToolCall(function=FunctionCall(name="get_age", - arguments=json.dumps( - { - "name": "John Doe", - }))) + arguments=json.dumps({ + "name": + "John Doe", + }))) ], None), ], @@ -180,22 +179,20 @@ def test_extract_tool_calls(mistral_tool_parser, model_output, '''[TOOL_CALLS] [ {"name":"add" , "arguments" : {"a": 3, "b": 4} } ]''', # noqa: E501 [ ToolCall(function=FunctionCall(name="add", - arguments=json.dumps( - { - "a": 3, - "b": 4 - }))) + arguments=json.dumps({ + "a": 3, + "b": 4 + }))) ], ""), ( '''[TOOL_CALLS] [{"name": "add", "arguments":{"a": "3", "b": "4"}}]''', # noqa: E501 [ ToolCall(function=FunctionCall(name="add", - arguments=json.dumps( - { - "a": "3", - "b": "4" - }))) + arguments=json.dumps({ + "a": "3", + "b": "4" + }))) ], ""), ( @@ -207,7 +204,7 @@ def test_extract_tool_calls(mistral_tool_parser, model_output, "city": "San Francisco", "state": "CA", "unit": "celsius" - }))) + }))) ], ""), ( @@ -219,35 +216,34 @@ def test_extract_tool_calls(mistral_tool_parser, model_output, "city": "San Francisco", "state": "CA", "unit": "celsius" - }))) + }))) ], ''), ( '''[TOOL_CALLS] [{"arguments": {"name": "John Doe"}, "name": "get_age"}]''', # noqa: E501 [ ToolCall(function=FunctionCall(name="get_age", - arguments=json.dumps( - { - "name": "John Doe", - }))) + arguments=json.dumps({ + "name": + "John Doe", + }))) ], ''), ( '''[TOOL_CALLS][{"name": "add", "arguments": {"a": 3.5, "b": 4}}, {"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}]''', # noqa: E501 [ ToolCall(function=FunctionCall(name="add", - arguments=json.dumps( - { - "a": 3.5, - "b": 4 - }))), + arguments=json.dumps({ + "a": 3.5, + "b": 4 + }))), ToolCall(function=FunctionCall(name="get_current_weather", arguments=json.dumps( { "city": "San Francisco", "state": "CA", "unit": "celsius" - }))) + }))) ], ''), ], diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index 0a2c4bdccbfe..6bf96fdbe857 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -16,8 +16,6 @@ FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser, ToolParserManager) -from vllm.entrypoints.openai.tool_parsers.utils import ( - extract_intermediate_diff) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer @@ -60,20 +58,27 @@ def __init__(self, tokenizer: AnyTokenizer): # initialize properties used for state when parsing tool calls in # streaming mode self.json_decoder: json.JSONDecoder = json.JSONDecoder() - self.tool_call_first_attribute_name: re.Pattern[str] = re.compile(r'.*\s*"name"\s*:\s*') - self.string_value_pattern: re.Pattern[str] = re.compile(r'\s*"(.*?)(? end_index: additional_content = self.raw_tool_calls[end_index:] @@ -203,7 +212,8 @@ def extract_tool_calls_streaming( # we are in tool calls pass - if self.current_tool_name_finished and self.current_tool_arguments_finished: + if (self.current_tool_name_finished + and self.current_tool_arguments_finished): if self.tools_parsing_finished: return self._none_or_additional_content(additional_content) # let's find the next tool starting position @@ -219,12 +229,15 @@ def extract_tool_calls_streaming( if self.current_tool_start_index >= len(self.raw_tool_calls): # tool call has not started return self._none_or_additional_content(additional_content) - raw_current_tool_call = self.raw_tool_calls[self.current_tool_start_index:] + raw_current_tool_call = self.raw_tool_calls[self. + current_tool_start_index:] if self.current_element_streaming is None: # we are waiting for the next argument to be parsed - match_name = self.tool_call_first_attribute_name.match(raw_current_tool_call) - match_arguments = self.tool_call_first_attribute_arguments.match(raw_current_tool_call) + match_name = self.tool_call_first_attribute_name.match( + raw_current_tool_call) + match_arguments = self.tool_call_first_attribute_arguments.match( + raw_current_tool_call) if not self.current_tool_name_finished and match_name: if match_name.end() <= self.previous_attribute_end_index: return self._none_or_additional_content(additional_content) @@ -234,21 +247,26 @@ def extract_tool_calls_streaming( if match_arguments.end() <= self.previous_attribute_end_index: return self._none_or_additional_content(additional_content) self.current_element_streaming = "arguments" - self.current_attribute_start_index = match_arguments.end() - 1 # the `{` is the last IN the match part. We want it as the start index element + self.current_attribute_start_index = match_arguments.end() - 1 + # the `{` is the last IN the match part. + # We want it as the start index element else: # let's wait for more deltas return self._none_or_additional_content(additional_content) if self.current_element_streaming == "name": try: - function_name, name_end_index = self._extracted_complete_name(raw_current_tool_call, self.current_attribute_start_index) + function_name, name_end_index = self._extracted_complete_name( + raw_current_tool_call, self.current_attribute_start_index) except IndexError: # name value has not started being generated return self._none_or_additional_content(additional_content) if function_name == "": return self._none_or_additional_content(additional_content) else: - assert name_end_index is not None # because the function name was successfully retrieved + assert name_end_index is not None + # because the function name was successfully retrieved + self.current_tool_name_finished = True self.current_element_streaming = None self.current_attribute_start_index = -1 @@ -256,75 +274,99 @@ def extract_tool_calls_streaming( delta = DeltaMessage( content=additional_content, tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - type="function", - id=MistralToolCall.generate_random_id(), - function=DeltaFunctionCall( - name=function_name).model_dump( - exclude_none=True)) - ]) + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=MistralToolCall.generate_random_id(), + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True), + ) + ], + ) return delta if self.current_element_streaming == "arguments": - try: - diff, arguments_end_index = self._extract_argument_fragment(raw_current_tool_call, self.current_attribute_start_index, delta_text) - self.current_tool_arguments_finished = arguments_end_index != -1 - if self.current_tool_arguments_finished: - self.current_element_streaming = None - self.current_attribute_start_index = -1 - self.previous_attribute_end_index = arguments_end_index - delta = DeltaMessage( - content=additional_content, - tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=diff). - model_dump(exclude_none=True)) - ]) - return delta - except IndexError: - # arguments value has not started being generated - return self._none_or_additional_content(additional_content) - + try: + diff, arguments_end_index = self._extract_argument_fragment( + raw_current_tool_call, + self.current_attribute_start_index, + delta_text, + ) + self.current_tool_arguments_finished = arguments_end_index != -1 + if self.current_tool_arguments_finished: + self.current_element_streaming = None + self.current_attribute_start_index = -1 + self.previous_attribute_end_index = arguments_end_index + delta = DeltaMessage( + content=additional_content, + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff).model_dump(exclude_none=True), + ) + ], + ) + return delta + except IndexError: + # arguments value has not started being generated + return self._none_or_additional_content(additional_content) - def _extracted_complete_name(self, raw_current_tool_call: str, current_attribute_start_index: int) -> tuple[str, Union[int, None]]: + def _extracted_complete_name( + self, raw_current_tool_call: str, + current_attribute_start_index: int + ) -> tuple[str, Union[int, None]]: """ Extract the complete function name from the current tool call. Args: raw_current_tool_call: The raw JSON string of the current tool call - current_attribute_start_index: The starting index of the name attribute in the raw_current_tool_call string + current_attribute_start_index: The starting index of the + name attribute in the raw_current_tool_call string Returns: tuple: - - The function name, or "" if extraction failed - - The end index of the name in raw_current_tool_call, or None if extraction failed + - The function name, or "" if extraction failed + - The end index of the name in raw_current_tool_call, + or None if extraction failed """ - partial_name_value = raw_current_tool_call[current_attribute_start_index:] + partial_name_value = raw_current_tool_call[ + current_attribute_start_index:] if match := self.string_value_pattern.match(partial_name_value): return match.group(1), match.end() + current_attribute_start_index return "", None - def _extract_argument_fragment(self, raw_current_tool_call: str, current_attribute_start_index: int, delta: str) -> tuple[str, int]: + def _extract_argument_fragment(self, raw_current_tool_call: str, + current_attribute_start_index: int, + delta: str) -> tuple[str, int]: """ Extract the relevant argument fragment from the current streaming delta. Args: raw_current_tool_call: The raw JSON string of the current tool call - current_attribute_start_index: The starting index of the arguments attribute in the raw string + current_attribute_start_index: The starting index + of the arguments attribute in the raw string delta: The new text added in this streaming step Returns: tuple: - - The extracted argument diff text to be sent in the streaming response - - The end index of the arguments in the raw string, or -1 if not yet complete + - The extracted argument diff text + to be sent in the streaming response + - The end index of the arguments in the raw string, + or -1 if not yet complete """ - partial_arguments_value = raw_current_tool_call[current_attribute_start_index:] + partial_arguments_value = raw_current_tool_call[ + current_attribute_start_index:] try: - _, end_index = self.json_decoder.raw_decode(partial_arguments_value) - return delta[:len(delta) + end_index - len(partial_arguments_value)], current_attribute_start_index + end_index + _, end_index = self.json_decoder.raw_decode( + partial_arguments_value) + return ( + delta[:len(delta) + end_index - len(partial_arguments_value)], + current_attribute_start_index + end_index, + ) except json.decoder.JSONDecodeError: # The arguments object is not complete - + # delta contains data from before the argument start if len(delta) > len(partial_arguments_value): return delta[-len(partial_arguments_value):], -1 @@ -334,16 +376,19 @@ def _extract_argument_fragment(self, raw_current_tool_call: str, current_attribu def _next_tool_starting_position(self) -> int: """ - Find the starting position of the next tool in the raw tool calls string. + Find the starting position of the next tool + in the raw tool calls string. Returns: - The index position where the next tool starts, or -1 if no next tool is found yet + The index position where the next tool starts, + or -1 if no next tool is found yet """ assert self.current_tool_start_index >= 0 current_tool_call = self.raw_tool_calls[self.current_tool_start_index:] try: _, end_index = self.json_decoder.raw_decode(current_tool_call) - return self.current_tool_start_index + end_index + current_tool_call[end_index:].find("{") + return (self.current_tool_start_index + end_index + + current_tool_call[end_index:].find("{")) except json.decoder.JSONDecodeError: # The current tool object is not yet closed return -1 @@ -351,17 +396,20 @@ def _next_tool_starting_position(self) -> int: # The next tool has not started yet # and the delta just closes the current tool call return -1 - - def _none_or_additional_content(self, additional_content: str) -> Union[DeltaMessage, None]: + + def _none_or_additional_content( + self, additional_content: str) -> Union[DeltaMessage, None]: """ - Create a DeltaMessage with additional content if present, otherwise return None. + Create a DeltaMessage with additional content if present, + otherwise return None. Args: additional_content: The text content to include in the message Returns: - A DeltaMessage with the additional content, or None if no content is provided + A DeltaMessage with the additional content, + or None if no content is provided """ if additional_content: return DeltaMessage(content=additional_content) - return None \ No newline at end of file + return None From dd788b6e4c97142d1043a8ba7b495c951caba7fc Mon Sep 17 00:00:00 2001 From: avigny <47987522+avigny@users.noreply.github.com> Date: Mon, 7 Jul 2025 01:00:12 +0200 Subject: [PATCH 05/12] Adding SPDX header Signed-off-by: avigny <47987522+avigny@users.noreply.github.com> --- tests/tool_use/test_mistral_tool_parser.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/tool_use/test_mistral_tool_parser.py b/tests/tool_use/test_mistral_tool_parser.py index b9bbdef57d97..f2884255032b 100644 --- a/tests/tool_use/test_mistral_tool_parser.py +++ b/tests/tool_use/test_mistral_tool_parser.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import json from collections.abc import Generator From 170a04e47a3b7b3286e2e8ff36014a230b61b07e Mon Sep 17 00:00:00 2001 From: avigny <47987522+avigny@users.noreply.github.com> Date: Mon, 7 Jul 2025 01:02:46 +0200 Subject: [PATCH 06/12] Update non streaming tests with v11 tokenizer and tool call format Signed-off-by: avigny <47987522+avigny@users.noreply.github.com> --- tests/tool_use/test_mistral_tool_parser.py | 113 ++++++++++++++++++++- 1 file changed, 111 insertions(+), 2 deletions(-) diff --git a/tests/tool_use/test_mistral_tool_parser.py b/tests/tool_use/test_mistral_tool_parser.py index f2884255032b..47cc2905aa8a 100644 --- a/tests/tool_use/test_mistral_tool_parser.py +++ b/tests/tool_use/test_mistral_tool_parser.py @@ -15,18 +15,25 @@ from vllm.transformers_utils.detokenizer import detokenize_incrementally from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer -MODEL = "mistralai/Mistral-7B-Instruct-v0.3" - @pytest.fixture(scope="module") def mistral_tokenizer(): + MODEL = "mistralai/Mistral-7B-Instruct-v0.3" return get_tokenizer(tokenizer_name=MODEL) +@pytest.fixture(scope="module") +def mistral_v11_tokenizer(): + MODEL = "mistralai/Mistral-Small-3.2-24B-Instruct-2506" + return get_tokenizer(tokenizer_name=MODEL, tokenizer_mode="mistral") @pytest.fixture def mistral_tool_parser(mistral_tokenizer): return MistralToolParser(mistral_tokenizer) +@pytest.fixture +def mistral_v11_tool_parser(mistral_v11_tokenizer): + return MistralToolParser(mistral_v11_tokenizer) + def assert_tool_calls(actual_tool_calls: list[ToolCall], expected_tool_calls: list[ToolCall]): @@ -162,6 +169,108 @@ def test_extract_tool_calls(mistral_tool_parser, model_output, assert extracted_tool_calls.content == expected_content +@pytest.mark.parametrize( + ids=[ + "single_tool_add", "single_tool_weather", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ( + '''[TOOL_CALLS]add{"a": 3.5, "b": 4}''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="add", + arguments=json.dumps({ + "a": 3.5, + "b": 4 + }))) + ], + None), + ( + '''[TOOL_CALLS]get_current_weather{"city": "San Francisco", "state": "CA", "unit": "celsius"}''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="get_current_weather", + arguments=json.dumps( + { + "city": "San Francisco", + "state": "CA", + "unit": "celsius" + }))) + ], + None) + ] +) +def test_extract_tool_calls_v11_tokenizer(mistral_v11_tool_parser, model_output, + expected_tool_calls, expected_content): + extracted_tool_calls = mistral_v11_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + assert extracted_tool_calls.tools_called + + assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) + + assert extracted_tool_calls.content == expected_content + + +def _test_extract_tool_calls_streaming(tool_parser, tokenizer, + model_output, expected_tool_calls, + expected_content): + other_content: str = '' + function_names: list[str] = [] + function_args_strs: list[str] = [] + tool_call_idx: int = -1 + tool_call_ids: list[Optional[str]] = [] + + for delta_message in stream_delta_message_generator( + tool_parser, tokenizer, model_output): + # role should never be streamed from tool parser + assert not delta_message.role + + if delta_message.content: + other_content += delta_message.content + + streamed_tool_calls = delta_message.tool_calls + + if streamed_tool_calls and len(streamed_tool_calls) > 0: + # make sure only one diff is present - correct even for parallel + assert len(streamed_tool_calls) == 1 + tool_call = streamed_tool_calls[0] + + # if a new tool is being called, set up empty arguments + if tool_call.index != tool_call_idx: + tool_call_idx = tool_call.index + function_args_strs.append("") + tool_call_ids.append(None) + + # if a tool call ID is streamed, make sure one hasn't been already + if tool_call.id and not tool_call_ids[tool_call.index]: + tool_call_ids[tool_call.index] = tool_call.id + + # if parts of the function start being streamed + if tool_call.function: + # if the function name is defined, set it. it should be streamed + # IN ENTIRETY, exactly one time. + if tool_call.function.name: + assert isinstance(tool_call.function.name, str) + function_names.append(tool_call.function.name) + + if tool_call.function.arguments: + # make sure they're a string and then add them to the list + assert isinstance(tool_call.function.arguments, str) + + function_args_strs[ + tool_call.index] += tool_call.function.arguments + + assert other_content == expected_content + + actual_tool_calls = [ + ToolCall(id=tool_call_id, + function=FunctionCall( + name=function_name, + arguments=partial_json_parser.ensure_json( + function_args_str, Allow.OBJ | Allow.STR))) + for tool_call_id, function_name, function_args_str in zip( + tool_call_ids, function_names, function_args_strs) + ] + assert_tool_calls(actual_tool_calls, expected_tool_calls) @pytest.mark.parametrize( ids=[ From aca50e6002d051ee6b3491e5fa5c8869a4801f2b Mon Sep 17 00:00:00 2001 From: avigny <47987522+avigny@users.noreply.github.com> Date: Tue, 8 Jul 2025 18:53:11 +0200 Subject: [PATCH 07/12] Updating tests for new v11 mistral tokenizer and new tool call format in the model output Signed-off-by: avigny <47987522+avigny@users.noreply.github.com> --- tests/tool_use/test_mistral_tool_parser.py | 457 +++++++++++++-------- 1 file changed, 282 insertions(+), 175 deletions(-) diff --git a/tests/tool_use/test_mistral_tool_parser.py b/tests/tool_use/test_mistral_tool_parser.py index 47cc2905aa8a..37d146fe5585 100644 --- a/tests/tool_use/test_mistral_tool_parser.py +++ b/tests/tool_use/test_mistral_tool_parser.py @@ -7,13 +7,16 @@ import partial_json_parser import pytest +from mistral_common.protocol.instruct.messages import AssistantMessage +from mistral_common.protocol.instruct.tool_calls import FunctionCall, ToolCall +from mistral_common.tokens.instruct.request import InstructRequest from partial_json_parser.core.options import Allow -from vllm.entrypoints.openai.protocol import (DeltaMessage, FunctionCall, - ToolCall) +from vllm.entrypoints.openai.protocol import DeltaMessage from vllm.entrypoints.openai.tool_parsers import MistralToolParser from vllm.transformers_utils.detokenizer import detokenize_incrementally -from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer +from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, + get_tokenizer) @pytest.fixture(scope="module") @@ -21,15 +24,18 @@ def mistral_tokenizer(): MODEL = "mistralai/Mistral-7B-Instruct-v0.3" return get_tokenizer(tokenizer_name=MODEL) + @pytest.fixture(scope="module") def mistral_v11_tokenizer(): MODEL = "mistralai/Mistral-Small-3.2-24B-Instruct-2506" return get_tokenizer(tokenizer_name=MODEL, tokenizer_mode="mistral") + @pytest.fixture def mistral_tool_parser(mistral_tokenizer): return MistralToolParser(mistral_tokenizer) + @pytest.fixture def mistral_v11_tool_parser(mistral_v11_tokenizer): return MistralToolParser(mistral_v11_tokenizer) @@ -45,16 +51,36 @@ def assert_tool_calls(actual_tool_calls: list[ToolCall], assert len(actual_tool_call.id) == 9 assert actual_tool_call.type == "function" - assert actual_tool_call.function == expected_tool_call.function, ( - f'got ${actual_tool_call.function}') + assert ( + actual_tool_call.function.name == expected_tool_call.function.name + ), f"got wrong function name:${actual_tool_call.function.name}" + assert ( + actual_tool_call.function.arguments == + expected_tool_call.function.arguments + ), f"got wrong function argument:${actual_tool_call.function.arguments}" def stream_delta_message_generator( - mistral_tool_parser: MistralToolParser, - mistral_tokenizer: AnyTokenizer, - model_output: str) -> Generator[DeltaMessage, None, None]: - all_token_ids = mistral_tokenizer.encode(model_output, - add_special_tokens=False) + mistral_tool_parser: MistralToolParser, + mistral_tokenizer: AnyTokenizer, + model_output: Optional[str], + tools: Optional[list[tuple[str, str]]], +) -> Generator[DeltaMessage, None, None]: + if isinstance(mistral_tokenizer, MistralTokenizer): + assert tools is not None + assistant_msg = AssistantMessage(tool_calls=[ + ToolCall(function=FunctionCall( + name=name, + arguments=arg, + )) for (name, arg) in tools + ], ) + request = InstructRequest(messages=[assistant_msg], ) + all_token_ids = mistral_tokenizer.instruct.encode_instruct( + request).tokens + else: + assert model_output is not None + all_token_ids = mistral_tokenizer.encode(model_output, + add_special_tokens=False) previous_text = "" previous_tokens = None @@ -66,15 +92,16 @@ def stream_delta_message_generator( current_token_ids = all_token_ids[:i + 1] (new_tokens, delta_text, new_prefix_offset, - new_read_offset) = detokenize_incrementally( + new_read_offset) = (detokenize_incrementally( tokenizer=mistral_tokenizer, all_input_ids=current_token_ids, prev_tokens=previous_tokens, prefix_offset=prefix_offset, read_offset=read_offset, - skip_special_tokens=False, + skip_special_tokens=isinstance(mistral_tokenizer, + MistralTokenizer), spaces_between_special_tokens=True, - ) + )) current_text = previous_text + delta_text @@ -91,8 +118,8 @@ def stream_delta_message_generator( yield delta_message previous_text = current_text - previous_tokens = previous_tokens + new_tokens if previous_tokens\ - else new_tokens + previous_tokens = (previous_tokens + + new_tokens if previous_tokens else new_tokens) prefix_offset = new_prefix_offset read_offset = new_read_offset @@ -108,13 +135,15 @@ def test_extract_tool_calls_no_tools(mistral_tool_parser): @pytest.mark.parametrize( ids=[ - "single_tool_add", "single_tool_weather", "argument_before_name", - "argument_before_name_and_name_in_argument" + "single_tool_add", + "single_tool_weather", + "argument_before_name", + "argument_before_name_and_name_in_argument", ], argnames=["model_output", "expected_tool_calls", "expected_content"], argvalues=[ ( - '''[TOOL_CALLS][{"name": "add", "arguments":{"a": 3.5, "b": 4}}]''', # noqa: E501 + """[TOOL_CALLS][{"name": "add", "arguments":{"a": 3.5, "b": 4}}]""", # noqa: E501 [ ToolCall(function=FunctionCall(name="add", arguments=json.dumps({ @@ -122,45 +151,53 @@ def test_extract_tool_calls_no_tools(mistral_tool_parser): "b": 4 }))) ], - None), + None, + ), ( - '''[TOOL_CALLS] [{"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}}]''', # noqa: E501 + """[TOOL_CALLS] [{"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501 [ - ToolCall(function=FunctionCall(name="get_current_weather", - arguments=json.dumps( - { - "city": "San Francisco", - "state": "CA", - "unit": "celsius" - }))) + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "San Francisco", + "state": "CA", + "unit": "celsius" + }), + )) ], - None), + None, + ), ( - '''[TOOL_CALLS] [{"arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]''', # noqa: E501 + """[TOOL_CALLS] [{"arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501 [ - ToolCall(function=FunctionCall(name="get_current_weather", - arguments=json.dumps( - { - "city": "San Francisco", - "state": "CA", - "unit": "celsius" - }))) + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "San Francisco", + "state": "CA", + "unit": "celsius" + }), + )) ], - None), + None, + ), ( - '''[TOOL_CALLS] [{"arguments":{"name": "John Doe"}, "name": "get_age"}]''', # noqa: E501 + """[TOOL_CALLS] [{"arguments":{"name": "John Doe"}, "name": "get_age"}]""", # noqa: E501 [ - ToolCall(function=FunctionCall(name="get_age", - arguments=json.dumps({ - "name": - "John Doe", - }))) + ToolCall(function=FunctionCall( + name="get_age", + arguments=json.dumps({ + "name": "John Doe", + }), + )) ], - None), + None, + ), ], ) -def test_extract_tool_calls(mistral_tool_parser, model_output, - expected_tool_calls, expected_content): +def test_extract_tool_callspre_v11_tokenizer(mistral_tool_parser, model_output, + expected_tool_calls, + expected_content): extracted_tool_calls = mistral_tool_parser.extract_tool_calls( model_output, request=None) # type: ignore[arg-type] assert extracted_tool_calls.tools_called @@ -169,37 +206,60 @@ def test_extract_tool_calls(mistral_tool_parser, model_output, assert extracted_tool_calls.content == expected_content + @pytest.mark.parametrize( ids=[ - "single_tool_add", "single_tool_weather", + "single_tool_add", + "single_tool_weather", + # "multiple_tool_calls", # Was already broken ], argnames=["model_output", "expected_tool_calls", "expected_content"], argvalues=[ ( - '''[TOOL_CALLS]add{"a": 3.5, "b": 4}''', # noqa: E501 + """[TOOL_CALLS]add_this_and_that{"a": 3.5, "b": 4}""", # noqa: E501 [ - ToolCall(function=FunctionCall(name="add", - arguments=json.dumps({ - "a": 3.5, - "b": 4 - }))) + ToolCall(function=FunctionCall( + name="add_this_and_that", + arguments=json.dumps({ + "a": 3.5, + "b": 4 + }), + )) ], - None), + None, + ), ( - '''[TOOL_CALLS]get_current_weather{"city": "San Francisco", "state": "CA", "unit": "celsius"}''', # noqa: E501 + """[TOOL_CALLS]get_current_weather{"city": "San Francisco", "state": "CA", "unit": "celsius"}""", # noqa: E501 [ - ToolCall(function=FunctionCall(name="get_current_weather", - arguments=json.dumps( - { - "city": "San Francisco", - "state": "CA", - "unit": "celsius" - }))) + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "San Francisco", + "state": "CA", + "unit": "celsius" + }), + )) ], - None) - ] + None, + ), + # ( + # '''[TOOL_CALLS]add{"a": 3.5, "b": 4}[TOOL_CALLS]multiply{"a": 3, "b": 6}''', # noqa: E501 + # [ + # ToolCall(function=FunctionCall(name="add", + # arguments=json.dumps({ + # "a": 3.5, + # "b": 4 + # }))), + # ToolCall(function=FunctionCall(name="multiply", + # arguments=json.dumps({ + # "a": 3, + # "b": 6 + # }))) + # ], + # None) # Was already broken + ], ) -def test_extract_tool_calls_v11_tokenizer(mistral_v11_tool_parser, model_output, +def test_extract_tool_calls(mistral_v11_tool_parser, model_output, expected_tool_calls, expected_content): extracted_tool_calls = mistral_v11_tool_parser.extract_tool_calls( model_output, request=None) # type: ignore[arg-type] @@ -210,17 +270,17 @@ def test_extract_tool_calls_v11_tokenizer(mistral_v11_tool_parser, model_output, assert extracted_tool_calls.content == expected_content -def _test_extract_tool_calls_streaming(tool_parser, tokenizer, - model_output, expected_tool_calls, - expected_content): - other_content: str = '' +def _test_extract_tool_calls_streaming(tool_parser, tokenizer, model_output, + tools, expected_tool_calls, + expected_content): + other_content: str = "" function_names: list[str] = [] function_args_strs: list[str] = [] tool_call_idx: int = -1 tool_call_ids: list[Optional[str]] = [] for delta_message in stream_delta_message_generator( - tool_parser, tokenizer, model_output): + tool_parser, tokenizer, model_output, tools): # role should never be streamed from tool parser assert not delta_message.role @@ -238,6 +298,7 @@ def _test_extract_tool_calls_streaming(tool_parser, tokenizer, if tool_call.index != tool_call_idx: tool_call_idx = tool_call.index function_args_strs.append("") + function_names.append("") tool_call_ids.append(None) # if a tool call ID is streamed, make sure one hasn't been already @@ -250,7 +311,7 @@ def _test_extract_tool_calls_streaming(tool_parser, tokenizer, # IN ENTIRETY, exactly one time. if tool_call.function.name: assert isinstance(tool_call.function.name, str) - function_names.append(tool_call.function.name) + function_names[tool_call.index] += tool_call.function.name if tool_call.function.arguments: # make sure they're a string and then add them to the list @@ -262,16 +323,19 @@ def _test_extract_tool_calls_streaming(tool_parser, tokenizer, assert other_content == expected_content actual_tool_calls = [ - ToolCall(id=tool_call_id, - function=FunctionCall( - name=function_name, - arguments=partial_json_parser.ensure_json( - function_args_str, Allow.OBJ | Allow.STR))) - for tool_call_id, function_name, function_args_str in zip( + ToolCall( + id=tool_call_id, + function=FunctionCall( + name=function_name, + arguments=partial_json_parser.ensure_json( + function_args_str, Allow.OBJ | Allow.STR), + ), + ) for tool_call_id, function_name, function_args_str in zip( tool_call_ids, function_names, function_args_strs) ] assert_tool_calls(actual_tool_calls, expected_tool_calls) + @pytest.mark.parametrize( ids=[ "no_tools", @@ -284,9 +348,9 @@ def _test_extract_tool_calls_streaming(tool_parser, tokenizer, ], argnames=["model_output", "expected_tool_calls", "expected_content"], argvalues=[ - ('''This is a test''', [], '''This is a test'''), + ("""This is a test""", [], """This is a test"""), ( - '''[TOOL_CALLS] [ {"name":"add" , "arguments" : {"a": 3, "b": 4} } ]''', # noqa: E501 + """[TOOL_CALLS] [ {"name":"add" , "arguments" : {"a": 3, "b": 4} } ]""", # noqa: E501 [ ToolCall(function=FunctionCall(name="add", arguments=json.dumps({ @@ -294,9 +358,10 @@ def _test_extract_tool_calls_streaming(tool_parser, tokenizer, "b": 4 }))) ], - ""), + "", + ), ( - '''[TOOL_CALLS] [{"name": "add", "arguments":{"a": "3", "b": "4"}}]''', # noqa: E501 + """[TOOL_CALLS] [{"name": "add", "arguments":{"a": "3", "b": "4"}}]""", # noqa: E501 [ ToolCall(function=FunctionCall(name="add", arguments=json.dumps({ @@ -304,118 +369,160 @@ def _test_extract_tool_calls_streaming(tool_parser, tokenizer, "b": "4" }))) ], - ""), + "", + ), ( - '''[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}}]''', # noqa: E501 + """[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501 [ - ToolCall(function=FunctionCall(name="get_current_weather", - arguments=json.dumps( - { - "city": "San Francisco", - "state": "CA", - "unit": "celsius" - }))) + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "San Francisco", + "state": "CA", + "unit": "celsius" + }), + )) ], - ""), + "", + ), ( - '''[TOOL_CALLS] [{"arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]''', # noqa: E501 + """[TOOL_CALLS] [{"arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501 [ - ToolCall(function=FunctionCall(name="get_current_weather", - arguments=json.dumps( - { - "city": "San Francisco", - "state": "CA", - "unit": "celsius" - }))) + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "San Francisco", + "state": "CA", + "unit": "celsius" + }), + )) ], - ''), + "", + ), ( - '''[TOOL_CALLS] [{"arguments": {"name": "John Doe"}, "name": "get_age"}]''', # noqa: E501 + """[TOOL_CALLS] [{"arguments": {"name": "John Doe"}, "name": "get_age"}]""", # noqa: E501 [ - ToolCall(function=FunctionCall(name="get_age", - arguments=json.dumps({ - "name": - "John Doe", - }))) + ToolCall(function=FunctionCall( + name="get_age", + arguments=json.dumps({ + "name": "John Doe", + }), + )) ], - ''), + "", + ), ( - '''[TOOL_CALLS][{"name": "add", "arguments": {"a": 3.5, "b": 4}}, {"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}]''', # noqa: E501 + """[TOOL_CALLS][{"name": "add", "arguments": {"a": 3.5, "b": 4}}, {"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}]""", # noqa: E501 [ ToolCall(function=FunctionCall(name="add", arguments=json.dumps({ "a": 3.5, "b": 4 }))), - ToolCall(function=FunctionCall(name="get_current_weather", - arguments=json.dumps( - { - "city": "San Francisco", - "state": "CA", - "unit": "celsius" - }))) + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "San Francisco", + "state": "CA", + "unit": "celsius" + }), + )), ], - ''), + "", + ), ], ) -def test_extract_tool_calls_streaming(mistral_tool_parser, mistral_tokenizer, - model_output, expected_tool_calls, - expected_content): - other_content: str = '' - function_names: list[str] = [] - function_args_strs: list[str] = [] - tool_call_idx: int = -1 - tool_call_ids: list[Optional[str]] = [] - - for delta_message in stream_delta_message_generator( - mistral_tool_parser, mistral_tokenizer, model_output): - # role should never be streamed from tool parser - assert not delta_message.role - - if delta_message.content: - other_content += delta_message.content - - streamed_tool_calls = delta_message.tool_calls - - if streamed_tool_calls and len(streamed_tool_calls) > 0: - # make sure only one diff is present - correct even for parallel - assert len(streamed_tool_calls) == 1 - tool_call = streamed_tool_calls[0] - - # if a new tool is being called, set up empty arguments - if tool_call.index != tool_call_idx: - tool_call_idx = tool_call.index - function_args_strs.append("") - tool_call_ids.append(None) - - # if a tool call ID is streamed, make sure one hasn't been already - if tool_call.id and not tool_call_ids[tool_call.index]: - tool_call_ids[tool_call.index] = tool_call.id +def test_extract_tool_calls_streaming_pre_v11_tokenizer( + mistral_tool_parser, + mistral_tokenizer, + model_output, + expected_tool_calls, + expected_content, +): + _test_extract_tool_calls_streaming( + mistral_tool_parser, + mistral_tokenizer, + model_output, + None, + expected_tool_calls, + expected_content, + ) - # if parts of the function start being streamed - if tool_call.function: - # if the function name is defined, set it. it should be streamed - # IN ENTIRETY, exactly one time. - if tool_call.function.name: - assert isinstance(tool_call.function.name, str) - function_names.append(tool_call.function.name) - if tool_call.function.arguments: - # make sure they're a string and then add them to the list - assert isinstance(tool_call.function.arguments, str) - - function_args_strs[ - tool_call.index] += tool_call.function.arguments - - assert other_content == expected_content - - actual_tool_calls = [ - ToolCall(id=tool_call_id, - function=FunctionCall( - name=function_name, - arguments=partial_json_parser.ensure_json( - function_args_str, Allow.OBJ | Allow.STR))) - for tool_call_id, function_name, function_args_str in zip( - tool_call_ids, function_names, function_args_strs) - ] - assert_tool_calls(actual_tool_calls, expected_tool_calls) +@pytest.mark.parametrize( + ids=[ + "single_tool_add", + "single_tool_add_strings", + "multiple_tools", + ], + argnames=["tools", "expected_tool_calls", "expected_content"], + argvalues=[ + ( + [("add", '{"a": 3, "b": 4}')], + # [TOOL_CALLS]add{"a": 3, "b": 4} + [ + ToolCall(function=FunctionCall(name="add", + arguments=json.dumps({ + "a": 3, + "b": 4 + }))) + ], + "", + ), + ( + [("add_two_strings", '{"a": "3", "b": "4"}')], + # [TOOL_CALLS]add{"a": "3", "b": "4"} + [ + ToolCall(function=FunctionCall( + name="add_two_strings", + arguments=json.dumps({ + "a": "3", + "b": "4" + }), + )) + ], + "", + ), + ( + [ + ("add", '{"a": 3.5, "b": 4}'), + ( + "get_current_weather", + '{"city": "San Francisco", "state": "CA", "unit": "celsius"}', # noqa: E501 + ), + ], + # [TOOL_CALLS]add{"a": 3.5, "b": 4}[TOOL_CALLS]get_current_weather{"city": "San Francisco", "state": "CA", "unit": "celsius"} # noqa: E501 + [ + ToolCall(function=FunctionCall(name="add", + arguments=json.dumps({ + "a": 3.5, + "b": 4 + }))), + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "San Francisco", + "state": "CA", + "unit": "celsius" + }), + )), + ], + "", + ), + ], +) +def test_extract_tool_calls_streaming( + mistral_v11_tool_parser, + mistral_v11_tokenizer, + tools, + expected_tool_calls, + expected_content, +): + _test_extract_tool_calls_streaming( + mistral_v11_tool_parser, + mistral_v11_tokenizer, + None, + tools, + expected_tool_calls, + expected_content, + ) From d79a234130e4fd0bec26558dcfac7692d5c4957c Mon Sep 17 00:00:00 2001 From: avigny <47987522+avigny@users.noreply.github.com> Date: Tue, 8 Jul 2025 18:54:07 +0200 Subject: [PATCH 08/12] Adding support for the new tool call format in mistral models Signed-off-by: avigny <47987522+avigny@users.noreply.github.com> --- .../tool_parsers/mistral_tool_parser.py | 125 +++++++++++++++++- 1 file changed, 121 insertions(+), 4 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index 9f94c6efca98..8743e5feedbb 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -78,14 +78,16 @@ def __init__(self, tokenizer: AnyTokenizer): self.tools_parsing_finished: bool = False self.current_tool_id: int = -1 + self.current_element_streaming: Union[Literal["name", "arguments"], + None] = None + + # For pre v11 tokenizer tool calls self.current_tool_start_index: int = -1 # index in the `self.raw_tool_calls` string self.current_attribute_start_index: int = -1 # index in the `self.raw_current_tool_call` string self.previous_attribute_end_index: int = 0 # index in the `self.raw_current_tool_call` string - self.current_element_streaming: Union[Literal["name", "arguments"], - None] = None self.current_tool_name_finished: bool = False self.current_tool_arguments_finished: bool = False @@ -201,13 +203,128 @@ def extract_tool_calls_streaming( request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: - # if the tool call token is not in the tokens generated so far, append - # output to contents since it's not a tool if self.bot_token not in current_text: + # if the tool call token is not in the tokens generated so far, + # append output to contents since it's not a tool return DeltaMessage(content=delta_text) # if the tool call token ID IS in the tokens generated so far, that # means we're parsing as tool calls now + if _is_fn_name_regex_support(self.model_tokenizer): + return self._extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=delta_token_ids, + request=request, + ) + else: + return self._extract_tool_calls_streaming_pre_v11_tokenizer( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=delta_token_ids, + request=request, + ) + + def _extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + additional_content: str = "" + if self.current_tool_id == -1: + # this is the first tool call + assert self.bot_token in delta_text + if not delta_text.startswith(self.bot_token): + additional_content += delta_text.split(self.bot_token)[0] + delta_text = self.bot_token + "".join( + delta_text.split(self.bot_token)[1:]) + + delta_tool_calls = self._generate_delta_tool_call(delta_text) + delta = DeltaMessage( + content=additional_content, + tool_calls=delta_tool_calls, + ) + return delta + + def _generate_delta_tool_call(self, + delta_text: str) -> list[DeltaToolCall]: + if delta_text == "" or delta_text is None: + return [] + delta_function_name = None + tool_id = None + if self.current_element_streaming is None and delta_text.startswith( + self.bot_token): + self.current_tool_id += 1 + tool_id = MistralToolCall.generate_random_id() + self.current_element_streaming = 'name' + delta_text = delta_text.replace(self.bot_token, "", 1) + if self.current_element_streaming == 'name': + if "{" in delta_text: + delta_function_name = delta_text.split("{")[0] + delta_text = delta_text[len(delta_function_name):] + self.current_element_streaming = 'arguments' + else: + delta_function_name = delta_text + return [ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=tool_id, + function=DeltaFunctionCall( + name=delta_function_name).model_dump( + exclude_none=True), + ) + ] + if self.current_element_streaming == 'arguments': + next_function_text = None + if self.bot_token in delta_text: + # current tool call is over + if delta_text.startswith(self.bot_token): + delta_arguments = "" + else: + delta_arguments = delta_text.split(self.bot_token)[0] + next_function_text = delta_text[len(delta_arguments):] + self.current_element_streaming = None + else: + delta_arguments = delta_text + ret = [] + if delta_function_name or delta_arguments: + ret += [ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=tool_id, + function=DeltaFunctionCall( + name=delta_function_name, + arguments=delta_arguments).model_dump( + exclude_none=True), + ) + ] + if next_function_text: + ret += self._generate_delta_tool_call(next_function_text) + return ret + + def _extract_tool_calls_streaming_pre_v11_tokenizer( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: additional_content: str = "" if self.bot_token in delta_text: self.raw_tool_calls += (delta_text.split( From 7cf17c075f7b6ae421da340a8bfa0fbba9334496 Mon Sep 17 00:00:00 2001 From: avigny <47987522+avigny@users.noreply.github.com> Date: Tue, 8 Jul 2025 19:20:35 +0200 Subject: [PATCH 09/12] Comment update Signed-off-by: avigny <47987522+avigny@users.noreply.github.com> --- tests/tool_use/test_mistral_tool_parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tool_use/test_mistral_tool_parser.py b/tests/tool_use/test_mistral_tool_parser.py index 37d146fe5585..b30f5675012a 100644 --- a/tests/tool_use/test_mistral_tool_parser.py +++ b/tests/tool_use/test_mistral_tool_parser.py @@ -471,7 +471,7 @@ def test_extract_tool_calls_streaming_pre_v11_tokenizer( ), ( [("add_two_strings", '{"a": "3", "b": "4"}')], - # [TOOL_CALLS]add{"a": "3", "b": "4"} + # [TOOL_CALLS]add_two_strings{"a": "3", "b": "4"} [ ToolCall(function=FunctionCall( name="add_two_strings", From 138cef34528c9f67c0a06724673b6c1a358d260c Mon Sep 17 00:00:00 2001 From: avigny <47987522+avigny@users.noreply.github.com> Date: Tue, 8 Jul 2025 20:09:21 +0200 Subject: [PATCH 10/12] Repairing multi tool calls in non streaming mode Signed-off-by: avigny <47987522+avigny@users.noreply.github.com> --- tests/tool_use/test_mistral_tool_parser.py | 32 +++++++++---------- .../tool_parsers/mistral_tool_parser.py | 25 ++++++++------- 2 files changed, 29 insertions(+), 28 deletions(-) diff --git a/tests/tool_use/test_mistral_tool_parser.py b/tests/tool_use/test_mistral_tool_parser.py index b30f5675012a..024c29efd2a1 100644 --- a/tests/tool_use/test_mistral_tool_parser.py +++ b/tests/tool_use/test_mistral_tool_parser.py @@ -211,7 +211,7 @@ def test_extract_tool_callspre_v11_tokenizer(mistral_tool_parser, model_output, ids=[ "single_tool_add", "single_tool_weather", - # "multiple_tool_calls", # Was already broken + "multiple_tool_calls", ], argnames=["model_output", "expected_tool_calls", "expected_content"], argvalues=[ @@ -242,21 +242,21 @@ def test_extract_tool_callspre_v11_tokenizer(mistral_tool_parser, model_output, ], None, ), - # ( - # '''[TOOL_CALLS]add{"a": 3.5, "b": 4}[TOOL_CALLS]multiply{"a": 3, "b": 6}''', # noqa: E501 - # [ - # ToolCall(function=FunctionCall(name="add", - # arguments=json.dumps({ - # "a": 3.5, - # "b": 4 - # }))), - # ToolCall(function=FunctionCall(name="multiply", - # arguments=json.dumps({ - # "a": 3, - # "b": 6 - # }))) - # ], - # None) # Was already broken + ( + '''[TOOL_CALLS]add{"a": 3.5, "b": 4}[TOOL_CALLS]multiply{"a": 3, "b": 6}''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="add", + arguments=json.dumps({ + "a": 3.5, + "b": 4 + }))), + ToolCall(function=FunctionCall(name="multiply", + arguments=json.dumps({ + "a": 3, + "b": 6 + }))) + ], + None) ], ) def test_extract_tool_calls(mistral_v11_tool_parser, model_output, diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index 8743e5feedbb..7b2420619600 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -143,19 +143,20 @@ def extract_tool_calls( # jsons is difficult try: if self.fn_name_regex: - matches = self.fn_name_regex.findall(tool_content) - function_call_arr = [] - for match in matches: - fn_name = match[0] - args = match[1] - - # fn_name is encoded outside serialized json dump - # only arguments are serialized - function_call_arr.append({ - "name": fn_name, - "arguments": json.loads(args) - }) + for single_tool_content in model_output.split(self.bot_token): + matches = self.fn_name_regex.findall(single_tool_content) + + for match in matches: + fn_name = match[0] + args = match[1] + + # fn_name is encoded outside serialized json dump + # only arguments are serialized + function_call_arr.append({ + "name": fn_name, + "arguments": json.loads(args) + }) else: function_call_arr = json.loads(tool_content) except json.JSONDecodeError: From e3e841462b31b3008114a073d60b8b205b7de134 Mon Sep 17 00:00:00 2001 From: avigny <47987522+avigny@users.noreply.github.com> Date: Tue, 8 Jul 2025 21:10:13 +0200 Subject: [PATCH 11/12] CI repair Signed-off-by: avigny <47987522+avigny@users.noreply.github.com> --- .../openai/tool_parsers/mistral_tool_parser.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index 7b2420619600..626ba7b20973 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -144,8 +144,10 @@ def extract_tool_calls( try: if self.fn_name_regex: function_call_arr = [] - for single_tool_content in model_output.split(self.bot_token): - matches = self.fn_name_regex.findall(single_tool_content) + for single_tool_content in model_output.split( + self.bot_token): + matches = self.fn_name_regex.findall( + single_tool_content) for match in matches: fn_name = match[0] @@ -154,8 +156,10 @@ def extract_tool_calls( # fn_name is encoded outside serialized json dump # only arguments are serialized function_call_arr.append({ - "name": fn_name, - "arguments": json.loads(args) + "name": + fn_name, + "arguments": + json.loads(args) }) else: function_call_arr = json.loads(tool_content) @@ -315,6 +319,8 @@ def _generate_delta_tool_call(self, if next_function_text: ret += self._generate_delta_tool_call(next_function_text) return ret + # Should not happen + return [] def _extract_tool_calls_streaming_pre_v11_tokenizer( self, From 8433789b9f7a7384bb8bdd39a604a52d366ab984 Mon Sep 17 00:00:00 2001 From: avigny <47987522+avigny@users.noreply.github.com> Date: Wed, 16 Jul 2025 17:02:51 +0200 Subject: [PATCH 12/12] Repair Test Signed-off-by: avigny <47987522+avigny@users.noreply.github.com> --- tests/tool_use/test_mistral_tool_parser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tool_use/test_mistral_tool_parser.py b/tests/tool_use/test_mistral_tool_parser.py index 024c29efd2a1..8628e577ffd0 100644 --- a/tests/tool_use/test_mistral_tool_parser.py +++ b/tests/tool_use/test_mistral_tool_parser.py @@ -412,7 +412,7 @@ def _test_extract_tool_calls_streaming(tool_parser, tokenizer, model_output, "", ), ( - """[TOOL_CALLS][{"name": "add", "arguments": {"a": 3.5, "b": 4}}, {"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}]""", # noqa: E501 + """[TOOL_CALLS][{"name": "add", "arguments": {"a": 3.5, "b": 4}}, {"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501 [ ToolCall(function=FunctionCall(name="add", arguments=json.dumps({