From 976dc828be4a5ee4175598f6f3da88f780cad8d5 Mon Sep 17 00:00:00 2001 From: Jeff Cook Date: Fri, 4 Jul 2025 04:42:31 -0600 Subject: [PATCH 01/12] feat: Add streaming support for v11 tool format in Mistral parser Co-authored-by: avigny <47987522+avigny@users.noreply.github.com> Co-authored-by: aider (anthropic/claude-sonnet-4-20250514) Co-authored-by: aider (gemini/gemini-2.5-pro) Signed-off-by: Jeff Cook --- .../tool_parsers/mistral_tool_parser.py | 687 +++++++++++++----- 1 file changed, 503 insertions(+), 184 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index c0691f12290..ab652b494c5 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -3,13 +3,12 @@ import json from collections.abc import Sequence +from enum import Enum from random import choices from string import ascii_letters, digits -from typing import Union +from typing import Literal -import partial_json_parser import regex as re -from partial_json_parser.core.options import Allow from pydantic import Field from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, @@ -19,8 +18,6 @@ FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser, ToolParserManager) -from vllm.entrypoints.openai.tool_parsers.utils import ( - extract_intermediate_diff) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer @@ -29,6 +26,15 @@ ALPHANUMERIC = ascii_letters + digits +class StreamingState(Enum): + """Enum for tracking the current streaming parsing state.""" + WAITING_FOR_TOOL_START = "waiting_for_tool_start" + PARSING_NAME = "parsing_name" + PARSING_ARGUMENTS = "parsing_arguments" + TOOL_COMPLETE = "tool_complete" + ALL_TOOLS_COMPLETE = "all_tools_complete" + + class MistralToolCall(ToolCall): id: str = Field( default_factory=lambda: MistralToolCall.generate_random_id()) @@ -68,11 +74,43 @@ def __init__(self, tokenizer: AnyTokenizer): # initialize properties used for state when parsing tool calls in # streaming mode - self.prev_tool_call_arr: list[dict] = [] + self.json_decoder: json.JSONDecoder = json.JSONDecoder() + + # Optimized regex patterns + self.tool_call_first_attribute_name: re.Pattern[str] = re.compile( + r'.*\s*"name"\s*:\s*') + self.string_value_pattern: re.Pattern[str] = re.compile( + r'\s*"(.*?)(? DeltaMessage | None: + """ + Extract tool calls from a streaming response, specifically for the + v11 MistralTokenizer format: ToolName{arguments}. This logic is a + streaming equivalent of the `self.fn_name_regex` used in + non-streaming extraction. + """ + logger.debug("v11 streaming: raw_tool_calls='%s'", self.raw_tool_calls) + logger.debug("v11 streaming: current_tool_name_sent='%s'", + self.current_tool_name_sent) + logger.debug("v11 streaming: prev_args_sent='%s'", self.prev_args_sent) + + # Handle multiple tools separated by commas/whitespace + if self.current_tool_name_finished and self.current_tool_arguments_finished: + if self._should_advance_to_next_v11_tool(): + self._reset_v11_tool_state() + logger.debug("v11 streaming: found next tool, resetting state") + + # Phase 1: Extract and send function name + if not self.current_tool_name_sent: + # Look for function name pattern: name followed by { + brace_index = self.raw_tool_calls.find("{") + if brace_index == -1: + logger.debug("v11 streaming: no opening brace found yet") + return self._none_or_additional_content(additional_content) + + # Extract function name + func_name = self.raw_tool_calls[:brace_index].strip() + # Remove any leading separators from previous tools + func_name = re.sub(r'^[\s,]*', '', func_name) + + if not func_name: + logger.debug("v11 streaming: function name is empty") + return self._none_or_additional_content(additional_content) + + logger.debug("v11 streaming: sending function name='%s'", + func_name) + self.current_tool_name_sent = True + self.current_tool_id += 1 + + return DeltaMessage( + content=additional_content, + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=MistralToolCall.generate_random_id(), + function=DeltaFunctionCall(name=func_name).model_dump( + exclude_none=True), + ) + ], + ) + + # Phase 2: Extract and send argument fragments + if self.current_tool_name_sent and not self.current_tool_arguments_finished: + # Find the arguments part (everything after the first {) + brace_index = self.raw_tool_calls.find("{") + if brace_index == -1: + logger.debug("v11 streaming: no opening brace found for args") + return self._none_or_additional_content(additional_content) + + current_args = self.raw_tool_calls[brace_index:] + logger.debug("v11 streaming: current_args='%s'", current_args) + + # Check if JSON is complete + try: + parsed_obj, end_idx = self.json_decoder.raw_decode( + current_args) + # JSON is complete + self.current_tool_arguments_finished = True + logger.debug("v11 streaming: JSON complete, parsed_obj=%s", + parsed_obj) + except json.decoder.JSONDecodeError: + # JSON still incomplete + logger.debug("v11 streaming: JSON still incomplete") + pass + + # Calculate what's new since last time + if current_args != self.prev_args_sent: + if self.prev_args_sent and current_args.startswith( + self.prev_args_sent): + # Incremental update + new_content = current_args[len(self.prev_args_sent):] + logger.debug("v11 streaming: incremental args='%s'", + new_content) + else: + # First time or reset + new_content = current_args + logger.debug("v11 streaming: first/reset args='%s'", + new_content) + + self.prev_args_sent = current_args + + if new_content: + return DeltaMessage( + content=additional_content, + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=new_content).model_dump( + exclude_none=True), + ) + ], + ) + + return self._none_or_additional_content(additional_content) + + def _should_advance_to_next_v11_tool(self) -> bool: + """Check if we should advance to the next tool in V11 format.""" + # Find pattern: completed_tool, next_tool or completed_tool next_tool + pattern = r'([a-zA-Z0-9_-]+\{[^}]*\})\s*[,\s]*([a-zA-Z0-9_-]+.*)' + match = re.search(pattern, self.raw_tool_calls) + return match is not None and len(match.group(2).strip()) > 0 + + def _reset_v11_tool_state(self) -> None: + """Reset V11 tool parsing state for the next tool.""" + self.current_tool_name_finished = False + self.current_tool_arguments_finished = False + self.current_tool_name_sent = False + self.prev_args_sent = "" + + def _determine_next_parsing_element( + self, + raw_current_tool_call: str) -> Literal["name", "arguments"] | None: + """ + Determine the next element to parse based on current state. + + Args: + raw_current_tool_call: The current tool call text + + Returns: + The next element to parse, or None if nothing is ready + """ + # Check for name attribute + if not self.current_tool_name_finished: + match_name = self.tool_call_first_attribute_name.match( + raw_current_tool_call) + if match_name and match_name.end( + ) > self.previous_attribute_end_index: + self.current_attribute_start_index = match_name.end() + return "name" + + # Check for arguments attribute + if not self.current_tool_arguments_finished: + match_arguments = self.tool_call_first_attribute_arguments.match( + raw_current_tool_call) + if match_arguments and match_arguments.end( + ) > self.previous_attribute_end_index: + # The `{` is the last character in the match - we want it as start index + self.current_attribute_start_index = match_arguments.end() - 1 + return "arguments" + + return None + + def _is_current_tool_complete(self) -> bool: + """Check if the current tool parsing is complete.""" + return (self.current_tool_name_finished + and self.current_tool_arguments_finished) + + def _advance_to_next_tool(self) -> bool: + """ + Advance to the next tool if available. + + Returns: + True if successfully advanced to next tool, False otherwise + """ + next_tool_start_index = self._next_tool_starting_position() + if next_tool_start_index > 0: + self.current_tool_id += 1 + self.current_tool_start_index = next_tool_start_index + self.current_attribute_start_index = -1 + self.previous_attribute_end_index = 0 + self.current_tool_name_finished = False + self.current_tool_arguments_finished = False + return True + return False + + def _process_delta_text(self, delta_text: str) -> str: + """ + Process delta text and update raw_tool_calls, returning any additional content. + + Args: + delta_text: The new text delta to process + + Returns: + Any additional content that appears before the bot token + """ + additional_content = "" + + if self.bot_token in delta_text: + # Split only once for efficiency + parts = delta_text.split(self.bot_token, 1) + if len(parts) > 1: + if parts[0]: # Content before bot token + additional_content = parts[0] + # Process content after bot token + tool_content = parts[1].replace("'", '"').lstrip() + self.raw_tool_calls += tool_content + else: + # No bot token in delta, just clean and append + cleaned_delta = delta_text.replace("'", '"') + self.raw_tool_calls += cleaned_delta + # Remove leading spaces only if we have content + if self.raw_tool_calls: + self.raw_tool_calls = self.raw_tool_calls.lstrip() + + return additional_content + + def _should_detect_v11_format(self) -> bool: + """Check if we should attempt V11 format detection.""" + return (self.fn_name_regex is not None and self.current_tool_id == -1 + and not self.v11_tool_format) + + def _detect_v11_format(self) -> None: + """Detect if we're using V11 tool format.""" + stripped_calls = self.raw_tool_calls.lstrip() + if stripped_calls and stripped_calls[0] != "[": + logger.debug("flipping v11 tool format to True ...") + self.v11_tool_format = True + + def _try_parse_json_cached(self, text: str) -> tuple[bool, int]: + """ + Attempt to parse JSON with caching for performance. + + Args: + text: The text to parse as JSON + + Returns: + Tuple of (success, end_index) + """ + if text == self._last_json_parse_input: + return self._last_json_parse_result + + try: + _, end_index = self.json_decoder.raw_decode(text) + result = (True, end_index) + except json.decoder.JSONDecodeError: + result = (False, -1) + + # Cache the result + self._last_json_parse_input = text + self._last_json_parse_result = result + return result + + def _extracted_complete_name( + self, raw_current_tool_call: str, + current_attribute_start_index: int) -> tuple[str, int | None]: + """ + Extract the complete function name from the current tool call. + + Args: + raw_current_tool_call: The raw JSON string of the current tool call + current_attribute_start_index: The starting index of the + name attribute in the raw_current_tool_call string + + Returns: + tuple: + - The function name, or "" if extraction failed + - The end index of the name in raw_current_tool_call, + or None if extraction failed + """ + partial_name_value = raw_current_tool_call[ + current_attribute_start_index:] + if match := self.string_value_pattern.match(partial_name_value): + return match.group(1), match.end() + current_attribute_start_index + return "", None + + def _extract_argument_fragment(self, raw_current_tool_call: str, + current_attribute_start_index: int, + delta: str) -> tuple[str, int]: + """ + Extract the relevant argument fragment from the current streaming delta. + + Args: + raw_current_tool_call: The raw JSON string of the current tool call + current_attribute_start_index: The starting index + of the arguments attribute in the raw string + delta: The new text added in this streaming step + + Returns: + tuple: + - The extracted argument diff text + to be sent in the streaming response + - The end index of the arguments in the raw string, + or -1 if not yet complete + """ + partial_arguments_value = raw_current_tool_call[ + current_attribute_start_index:] + try: + _, end_index = self.json_decoder.raw_decode( + partial_arguments_value) + return ( + delta[:len(delta) + end_index - len(partial_arguments_value)], + current_attribute_start_index + end_index, + ) + except json.decoder.JSONDecodeError: + # The arguments object is not complete + + # delta contains data from before the argument start + if len(delta) > len(partial_arguments_value): + return delta[-len(partial_arguments_value):], -1 + + # We can send the whole delta + return delta, -1 + + def _next_tool_starting_position(self) -> int: + """ + Find the starting position of the next tool + in the raw tool calls string. + + Returns: + The index position where the next tool starts, + or -1 if no next tool is found yet + """ + assert self.current_tool_start_index >= 0 + current_tool_call = self.raw_tool_calls[self.current_tool_start_index:] + try: + _, end_index = self.json_decoder.raw_decode(current_tool_call) + return (self.current_tool_start_index + end_index + + current_tool_call[end_index:].find("{")) + except json.decoder.JSONDecodeError: + # The current tool object is not yet closed + return -1 + except IndexError: + # The next tool has not started yet + # and the delta just closes the current tool call + return -1 + + def _none_or_additional_content( + self, additional_content: str) -> DeltaMessage | None: + """ + Create a DeltaMessage with additional content if present, + otherwise return None. + + Args: + additional_content: The text content to include in the message + + Returns: + A DeltaMessage with the additional content, + or None if no content is provided + """ + if additional_content: + return DeltaMessage(content=additional_content) + return None + def adjust_request( self, request: ChatCompletionRequest) -> ChatCompletionRequest: if not isinstance( @@ -183,187 +569,120 @@ def extract_tool_calls_streaming( current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, - ) -> Union[DeltaMessage, None]: + ) -> DeltaMessage | None: - # if the tool call token is not in the tokens generated so far, append - # output to contents since it's not a tool + # Early return if no tool call token present if self.bot_token not in current_text: return DeltaMessage(content=delta_text) - # if the tool call token ID IS in the tokens generated so far, that - # means we're parsing as tool calls now - - # handle if we detected the BOT token which means the start of tool - # calling - if (self.bot_token_id in delta_token_ids - and len(delta_token_ids) == 1): - # if it's the only token, return None, so we don't send a chat - # completion any don't send a control token - return None - - # bit mask flags for partial JSON parsing. If the name hasn't been - # sent yet, don't allow sending - # an incomplete string since OpenAI only ever (as far as I have - # seen) allows sending the entire tool/ function name at once. - flags = Allow.ALL if self.current_tool_name_sent \ - else Allow.ALL & ~Allow.STR - try: - - # replace BOT token with empty string, and convert single quotes - # to double to allow parsing as JSON since mistral uses single - # quotes instead of double for tool calls - parsable_arr = current_text.split(self.bot_token)[-1] + # Process delta text and extract additional content + additional_content = self._process_delta_text(delta_text) - # tool calls are generated in an array, so do partial JSON - # parsing on the entire array - try: - tool_call_arr: list[dict] = partial_json_parser.loads( - parsable_arr, flags) - except partial_json_parser.core.exceptions.MalformedJSON: - logger.debug('not enough tokens to parse into JSON yet') - return None - - # select as the current tool call the one we're on the state at - - current_tool_call: dict = tool_call_arr[self.current_tool_id] \ - if len(tool_call_arr) > 0 else {} - - # case -- if no tokens have been streamed for the tool, e.g. - # only the array brackets, stream nothing - if len(tool_call_arr) == 0: - return None - - # case: we are starting a new tool in the array - # -> array has > 0 length AND length has moved past cursor - elif (len(tool_call_arr) > 0 - and len(tool_call_arr) > self.current_tool_id + 1): - - # if we're moving on to a new call, first make sure we - # haven't missed anything in the previous one that was - # auto-generated due to JSON completions, but wasn't - # streamed to the client yet. - if self.current_tool_id >= 0: - diff: Union[str, None] = current_tool_call.get("arguments") - - if diff: - diff = json.dumps(diff, ensure_ascii=False).replace( - self.streamed_args_for_tool[self.current_tool_id], - "") - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=diff).model_dump( - exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += diff - else: - delta = None - else: - delta = None - # re-set stuff pertaining to progress in the current tool - self.current_tool_id = len(tool_call_arr) - 1 - self.current_tool_name_sent = False - self.streamed_args_for_tool.append("") - logger.debug("starting on new tool %d", self.current_tool_id) - return delta + # Detect and handle V11 format + if self._should_detect_v11_format(): + self._detect_v11_format() - # case: update an existing tool - this is handled below - - # if the current tool name hasn't been sent, send if available - # - otherwise send nothing - if not self.current_tool_name_sent: - function_name = current_tool_call.get("name") - if function_name: - - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - type="function", - id=MistralToolCall.generate_random_id(), - function=DeltaFunctionCall( - name=function_name).model_dump( - exclude_none=True)) - ]) - self.current_tool_name_sent = True - else: - delta = None + if self.v11_tool_format: + return self._extract_tool_calls_streaming_v11( + additional_content, delta_text) - # now we know we're on the same tool call and we're streaming - # arguments + # Check if tool calls have started + if self.current_tool_start_index < 0: + bracket_pos = self.raw_tool_calls.find("[") + if bracket_pos >= 0: + self.current_tool_start_index = bracket_pos + 1 + self.current_tool_id += 1 else: - - prev_arguments = self.prev_tool_call_arr[ - self.current_tool_id].get("arguments") - cur_arguments = current_tool_call.get("arguments") - - new_text = delta_text.replace("\'", "\"") - if ('"}' in new_text): - new_text = new_text[:new_text.rindex('"}')] - - if not cur_arguments and not prev_arguments: - - delta = None - elif not cur_arguments and prev_arguments: - logger.error( - "INVARIANT - impossible to have arguments reset " - "mid-arguments") - delta = None - elif cur_arguments and not prev_arguments: - cur_arguments_json = json.dumps(cur_arguments, - ensure_ascii=False)[:-2] - logger.debug("finding %s in %s", new_text, - cur_arguments_json) - - if (new_text not in cur_arguments_json): - return None - arguments_delta = cur_arguments_json[:cur_arguments_json. - rindex(new_text) + - len(new_text)] - logger.debug("First tokens in arguments received: %s", - arguments_delta) - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=arguments_delta). - model_dump(exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += arguments_delta - - elif cur_arguments and prev_arguments: - cur_args_json = json.dumps(cur_arguments, - ensure_ascii=False) - prev_args_json = json.dumps(prev_arguments, - ensure_ascii=False) - logger.debug("Searching for diff between \n%s\n%s", - cur_args_json, prev_args_json) - - argument_diff = extract_intermediate_diff( - cur_args_json, prev_args_json) - logger.debug("got arguments diff: %s", argument_diff) - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=argument_diff).model_dump( - exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += argument_diff - else: - # try parsing it with regular JSON - if it works we're - # at the end, and we need to send the difference between - # tokens streamed so far and the valid JSON - delta = None - - # check to see if the name is defined and has been sent. if so, - # stream the name - otherwise keep waiting - # finish by setting old and returning None as base case - self.prev_tool_call_arr = tool_call_arr - return delta - - except Exception: - logger.exception("Error trying to handle streaming tool call.") - logger.debug( - "Skipping chunk as a result of tool streaming extraction " - "error") - return None + return self._none_or_additional_content(additional_content) + + # Try to parse complete JSON with caching + parse_success, end_index = self._try_parse_json_cached( + self.raw_tool_calls) + if parse_success: + self.tools_parsing_finished = True + if len(self.raw_tool_calls) > end_index: + additional_content = self.raw_tool_calls[end_index:] + + # Handle tool completion and transition to next tool + if self._is_current_tool_complete(): + if self.tools_parsing_finished: + return self._none_or_additional_content(additional_content) + + if self._advance_to_next_tool(): + # Successfully moved to next tool, continue processing + pass + else: + # No next tool ready yet + return self._none_or_additional_content(additional_content) + + if self.current_tool_start_index >= len(self.raw_tool_calls): + # tool call has not started + return self._none_or_additional_content(additional_content) + raw_current_tool_call = self.raw_tool_calls[self. + current_tool_start_index:] + + # Determine what to parse next + if self.current_element_streaming is None: + next_element = self._determine_next_parsing_element( + raw_current_tool_call) + if next_element is None: + return self._none_or_additional_content(additional_content) + self.current_element_streaming = next_element + + if self.current_element_streaming == "name": + try: + function_name, name_end_index = self._extracted_complete_name( + raw_current_tool_call, self.current_attribute_start_index) + except IndexError: + # name value has not started being generated + return self._none_or_additional_content(additional_content) + if function_name == "": + return self._none_or_additional_content(additional_content) + else: + assert name_end_index is not None + # because the function name was successfully retrieved + + self.current_tool_name_finished = True + self.current_element_streaming = None + self.current_attribute_start_index = -1 + self.previous_attribute_end_index = name_end_index + delta = DeltaMessage( + content=additional_content, + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=MistralToolCall.generate_random_id(), + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True), + ) + ], + ) + return delta + if self.current_element_streaming == "arguments": + try: + diff, arguments_end_index = self._extract_argument_fragment( + raw_current_tool_call, + self.current_attribute_start_index, + delta_text, + ) + self.current_tool_arguments_finished = arguments_end_index != -1 + if self.current_tool_arguments_finished: + self.current_element_streaming = None + self.current_attribute_start_index = -1 + self.previous_attribute_end_index = arguments_end_index + delta = DeltaMessage( + content=additional_content, + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff).model_dump(exclude_none=True), + ) + ], + ) + return delta + except IndexError: + # arguments value has not started being generated + return self._none_or_additional_content(additional_content) From c78d1fb56998445c7181b29449aa864b71e3a5db Mon Sep 17 00:00:00 2001 From: avigny <47987522+avigny@users.noreply.github.com> Date: Fri, 4 Jul 2025 13:43:44 -0600 Subject: [PATCH 02/12] Bring in tests from #19425 --- tests/tool_use/test_mistral_tool_parser.py | 311 +++++++++++++++++++++ 1 file changed, 311 insertions(+) create mode 100644 tests/tool_use/test_mistral_tool_parser.py diff --git a/tests/tool_use/test_mistral_tool_parser.py b/tests/tool_use/test_mistral_tool_parser.py new file mode 100644 index 00000000000..b9bbdef57d9 --- /dev/null +++ b/tests/tool_use/test_mistral_tool_parser.py @@ -0,0 +1,311 @@ +# SPDX-License-Identifier: Apache-2.0 + +import json +from collections.abc import Generator +from typing import Optional + +import partial_json_parser +import pytest +from partial_json_parser.core.options import Allow + +from vllm.entrypoints.openai.protocol import (DeltaMessage, FunctionCall, + ToolCall) +from vllm.entrypoints.openai.tool_parsers import MistralToolParser +from vllm.transformers_utils.detokenizer import detokenize_incrementally +from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer + +MODEL = "mistralai/Mistral-7B-Instruct-v0.3" + + +@pytest.fixture(scope="module") +def mistral_tokenizer(): + return get_tokenizer(tokenizer_name=MODEL) + + +@pytest.fixture +def mistral_tool_parser(mistral_tokenizer): + return MistralToolParser(mistral_tokenizer) + + +def assert_tool_calls(actual_tool_calls: list[ToolCall], + expected_tool_calls: list[ToolCall]): + assert len(actual_tool_calls) == len(expected_tool_calls) + + for actual_tool_call, expected_tool_call in zip(actual_tool_calls, + expected_tool_calls): + assert isinstance(actual_tool_call.id, str) + assert len(actual_tool_call.id) == 9 + + assert actual_tool_call.type == "function" + assert actual_tool_call.function == expected_tool_call.function, ( + f'got ${actual_tool_call.function}') + + +def stream_delta_message_generator( + mistral_tool_parser: MistralToolParser, + mistral_tokenizer: AnyTokenizer, + model_output: str) -> Generator[DeltaMessage, None, None]: + all_token_ids = mistral_tokenizer.encode(model_output, + add_special_tokens=False) + + previous_text = "" + previous_tokens = None + prefix_offset = 0 + read_offset = 0 + for i, delta_token in enumerate(all_token_ids): + delta_token_ids = [delta_token] + previous_token_ids = all_token_ids[:i] + current_token_ids = all_token_ids[:i + 1] + + (new_tokens, delta_text, new_prefix_offset, + new_read_offset) = detokenize_incrementally( + tokenizer=mistral_tokenizer, + all_input_ids=current_token_ids, + prev_tokens=previous_tokens, + prefix_offset=prefix_offset, + read_offset=read_offset, + skip_special_tokens=False, + spaces_between_special_tokens=True, + ) + + current_text = previous_text + delta_text + + delta_message = mistral_tool_parser.extract_tool_calls_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + request=None, # type: ignore[arg-type] + ) + if delta_message: + yield delta_message + + previous_text = current_text + previous_tokens = previous_tokens + new_tokens if previous_tokens\ + else new_tokens + prefix_offset = new_prefix_offset + read_offset = new_read_offset + + +def test_extract_tool_calls_no_tools(mistral_tool_parser): + model_output = "This is a test" + extracted_tool_calls = mistral_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + assert not extracted_tool_calls.tools_called + assert extracted_tool_calls.tool_calls == [] + assert extracted_tool_calls.content == model_output + + +@pytest.mark.parametrize( + ids=[ + "single_tool_add", "single_tool_weather", "argument_before_name", + "argument_before_name_and_name_in_argument" + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ( + '''[TOOL_CALLS][{"name": "add", "arguments":{"a": 3.5, "b": 4}}]''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="add", + arguments=json.dumps({ + "a": 3.5, + "b": 4 + }))) + ], + None), + ( + '''[TOOL_CALLS] [{"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}}]''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="get_current_weather", + arguments=json.dumps( + { + "city": "San Francisco", + "state": "CA", + "unit": "celsius" + }))) + ], + None), + ( + '''[TOOL_CALLS] [{"arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="get_current_weather", + arguments=json.dumps( + { + "city": "San Francisco", + "state": "CA", + "unit": "celsius" + }))) + ], + None), + ( + '''[TOOL_CALLS] [{"arguments":{"name": "John Doe"}, "name": "get_age"}]''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="get_age", + arguments=json.dumps({ + "name": + "John Doe", + }))) + ], + None), + ], +) +def test_extract_tool_calls(mistral_tool_parser, model_output, + expected_tool_calls, expected_content): + extracted_tool_calls = mistral_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + assert extracted_tool_calls.tools_called + + assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) + + assert extracted_tool_calls.content == expected_content + + +@pytest.mark.parametrize( + ids=[ + "no_tools", + "single_tool_add", + "single_tool_add_strings", + "single_tool_weather", + "argument_before_name", + "argument_before_name_and_name_in_argument", + "multiple_tools", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ('''This is a test''', [], '''This is a test'''), + ( + '''[TOOL_CALLS] [ {"name":"add" , "arguments" : {"a": 3, "b": 4} } ]''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="add", + arguments=json.dumps({ + "a": 3, + "b": 4 + }))) + ], + ""), + ( + '''[TOOL_CALLS] [{"name": "add", "arguments":{"a": "3", "b": "4"}}]''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="add", + arguments=json.dumps({ + "a": "3", + "b": "4" + }))) + ], + ""), + ( + '''[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}}]''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="get_current_weather", + arguments=json.dumps( + { + "city": "San Francisco", + "state": "CA", + "unit": "celsius" + }))) + ], + ""), + ( + '''[TOOL_CALLS] [{"arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="get_current_weather", + arguments=json.dumps( + { + "city": "San Francisco", + "state": "CA", + "unit": "celsius" + }))) + ], + ''), + ( + '''[TOOL_CALLS] [{"arguments": {"name": "John Doe"}, "name": "get_age"}]''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="get_age", + arguments=json.dumps({ + "name": + "John Doe", + }))) + ], + ''), + ( + '''[TOOL_CALLS][{"name": "add", "arguments": {"a": 3.5, "b": 4}}, {"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}]''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="add", + arguments=json.dumps({ + "a": 3.5, + "b": 4 + }))), + ToolCall(function=FunctionCall(name="get_current_weather", + arguments=json.dumps( + { + "city": "San Francisco", + "state": "CA", + "unit": "celsius" + }))) + ], + ''), + ], +) +def test_extract_tool_calls_streaming(mistral_tool_parser, mistral_tokenizer, + model_output, expected_tool_calls, + expected_content): + other_content: str = '' + function_names: list[str] = [] + function_args_strs: list[str] = [] + tool_call_idx: int = -1 + tool_call_ids: list[Optional[str]] = [] + + for delta_message in stream_delta_message_generator( + mistral_tool_parser, mistral_tokenizer, model_output): + # role should never be streamed from tool parser + assert not delta_message.role + + if delta_message.content: + other_content += delta_message.content + + streamed_tool_calls = delta_message.tool_calls + + if streamed_tool_calls and len(streamed_tool_calls) > 0: + # make sure only one diff is present - correct even for parallel + assert len(streamed_tool_calls) == 1 + tool_call = streamed_tool_calls[0] + + # if a new tool is being called, set up empty arguments + if tool_call.index != tool_call_idx: + tool_call_idx = tool_call.index + function_args_strs.append("") + tool_call_ids.append(None) + + # if a tool call ID is streamed, make sure one hasn't been already + if tool_call.id and not tool_call_ids[tool_call.index]: + tool_call_ids[tool_call.index] = tool_call.id + + # if parts of the function start being streamed + if tool_call.function: + # if the function name is defined, set it. it should be streamed + # IN ENTIRETY, exactly one time. + if tool_call.function.name: + assert isinstance(tool_call.function.name, str) + function_names.append(tool_call.function.name) + + if tool_call.function.arguments: + # make sure they're a string and then add them to the list + assert isinstance(tool_call.function.arguments, str) + + function_args_strs[ + tool_call.index] += tool_call.function.arguments + + assert other_content == expected_content + + actual_tool_calls = [ + ToolCall(id=tool_call_id, + function=FunctionCall( + name=function_name, + arguments=partial_json_parser.ensure_json( + function_args_str, Allow.OBJ | Allow.STR))) + for tool_call_id, function_name, function_args_str in zip( + tool_call_ids, function_names, function_args_strs) + ] + assert_tool_calls(actual_tool_calls, expected_tool_calls) From 2dedc6e46db63265240d052a954161c06112d120 Mon Sep 17 00:00:00 2001 From: Jeff Cook Date: Sat, 12 Jul 2025 12:26:47 -0600 Subject: [PATCH 03/12] Update vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py Co-authored-by: Aaron Pham --- vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index ab652b494c5..8a8bed9e89b 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -88,7 +88,8 @@ def __init__(self, tokenizer: AnyTokenizer): # Core streaming state self.raw_tool_calls: str = "" - self.streaming_state: StreamingState = StreamingState.WAITING_FOR_TOOL_START + self.streaming_state: StreamingState = \ + StreamingState.WAITING_FOR_TOOL_START # Tool tracking self.current_tool_id: int = -1 From 5966d3746364d69adbd19ecf63c7c95870af6e2f Mon Sep 17 00:00:00 2001 From: Jeff Cook Date: Fri, 4 Jul 2025 14:43:22 -0600 Subject: [PATCH 04/12] fix: prevent infinite loop in Mistral tool parsing by removing processed tools Co-authored-by: aider (anthropic/claude-sonnet-4-20250514) --- vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index 8a8bed9e89b..018a21dc0b2 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -143,6 +143,11 @@ def _extract_tool_calls_streaming_v11( # Handle multiple tools separated by commas/whitespace if self.current_tool_name_finished and self.current_tool_arguments_finished: if self._should_advance_to_next_v11_tool(): + # Remove the completed tool from raw_tool_calls before resetting state + pattern = r'([a-zA-Z0-9_-]+\{[^}]*\})\s*[,\s]*' + match = re.search(pattern, self.raw_tool_calls) + if match: + self.raw_tool_calls = self.raw_tool_calls[match.end(1):] self._reset_v11_tool_state() logger.debug("v11 streaming: found next tool, resetting state") From 9670aee4a3df2256792b3743c9ec61976eca9bf2 Mon Sep 17 00:00:00 2001 From: Jeff Cook Date: Fri, 4 Jul 2025 14:45:05 -0600 Subject: [PATCH 05/12] refactor: improve JSON parsing for Mistral tool calls with robust regex and JSON decoding Co-authored-by: aider (anthropic/claude-sonnet-4-20250514) --- .../tool_parsers/mistral_tool_parser.py | 39 +++++++++++++++---- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index 018a21dc0b2..ae35fee34e3 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -144,10 +144,9 @@ def _extract_tool_calls_streaming_v11( if self.current_tool_name_finished and self.current_tool_arguments_finished: if self._should_advance_to_next_v11_tool(): # Remove the completed tool from raw_tool_calls before resetting state - pattern = r'([a-zA-Z0-9_-]+\{[^}]*\})\s*[,\s]*' - match = re.search(pattern, self.raw_tool_calls) - if match: - self.raw_tool_calls = self.raw_tool_calls[match.end(1):] + completed_tool_end = self._find_completed_v11_tool_end() + if completed_tool_end > 0: + self.raw_tool_calls = self.raw_tool_calls[completed_tool_end:] self._reset_v11_tool_state() logger.debug("v11 streaming: found next tool, resetting state") @@ -243,10 +242,34 @@ def _extract_tool_calls_streaming_v11( def _should_advance_to_next_v11_tool(self) -> bool: """Check if we should advance to the next tool in V11 format.""" - # Find pattern: completed_tool, next_tool or completed_tool next_tool - pattern = r'([a-zA-Z0-9_-]+\{[^}]*\})\s*[,\s]*([a-zA-Z0-9_-]+.*)' - match = re.search(pattern, self.raw_tool_calls) - return match is not None and len(match.group(2).strip()) > 0 + completed_tool_end = self._find_completed_v11_tool_end() + if completed_tool_end <= 0: + return False + + # Check if there's content after the completed tool that looks like another tool + remaining = self.raw_tool_calls[completed_tool_end:].strip() + if remaining.startswith(','): + remaining = remaining[1:].strip() + + # Look for next tool pattern: function_name{ + return bool(re.match(r'[a-zA-Z0-9_-]+\s*\{', remaining)) + + def _find_completed_v11_tool_end(self) -> int: + """Find the end position of the first completed tool in V11 format using JSON parsing.""" + # Look for function name pattern: name followed by { + brace_match = re.search(r'([a-zA-Z0-9_-]+)\s*(\{)', self.raw_tool_calls) + if not brace_match: + return -1 + + # Try to parse the JSON starting from the opening brace + json_start = brace_match.start(2) + json_part = self.raw_tool_calls[json_start:] + + try: + _, end_idx = self.json_decoder.raw_decode(json_part) + return json_start + end_idx + except json.JSONDecodeError: + return -1 def _reset_v11_tool_state(self) -> None: """Reset V11 tool parsing state for the next tool.""" From f280821df3800a2ab2210f6bb486e3b0bf9750d7 Mon Sep 17 00:00:00 2001 From: Jeff Cook Date: Fri, 4 Jul 2025 14:46:35 -0600 Subject: [PATCH 06/12] refactor: Improve quote normalization in tool call parsing to prevent JSON corruption Co-authored-by: aider (anthropic/claude-sonnet-4-20250514) --- .../tool_parsers/mistral_tool_parser.py | 42 ++++++++++++++++++- 1 file changed, 40 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index ae35fee34e3..b6d0f45fcda 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -353,11 +353,11 @@ def _process_delta_text(self, delta_text: str) -> str: if parts[0]: # Content before bot token additional_content = parts[0] # Process content after bot token - tool_content = parts[1].replace("'", '"').lstrip() + tool_content = self._normalize_quotes(parts[1]).lstrip() self.raw_tool_calls += tool_content else: # No bot token in delta, just clean and append - cleaned_delta = delta_text.replace("'", '"') + cleaned_delta = self._normalize_quotes(delta_text) self.raw_tool_calls += cleaned_delta # Remove leading spaces only if we have content if self.raw_tool_calls: @@ -365,6 +365,44 @@ def _process_delta_text(self, delta_text: str) -> str: return additional_content + def _normalize_quotes(self, text: str) -> str: + """ + Normalize quotes in tool call text, being careful not to corrupt string values. + + This method attempts to replace structural single quotes with double quotes + while preserving single quotes that are part of string values. + + Args: + text: The text to normalize + + Returns: + Text with normalized quotes + """ + if not text or "'" not in text: + return text + + # For V11 format (function_name{...}), we don't need quote normalization + # as the JSON should already be properly formatted + if self.v11_tool_format: + return text + + # Simple heuristic: if the text looks like it might be valid JSON already, + # try to parse it first before doing any replacements + stripped = text.strip() + if stripped.startswith(('[', '{')): + try: + # Try parsing as-is first + json.loads(stripped) + return text # Already valid JSON, no changes needed + except json.JSONDecodeError: + pass + + # Fallback to the original behavior with a warning + # This preserves backward compatibility but logs the risk + logger.warning("Performing quote normalization on tool call text. " + "This may corrupt string values containing single quotes.") + return text.replace("'", '"') + def _should_detect_v11_format(self) -> bool: """Check if we should attempt V11 format detection.""" return (self.fn_name_regex is not None and self.current_tool_id == -1 From ed3dc1dc47404f7f038df6b66f3d27c2a515c81f Mon Sep 17 00:00:00 2001 From: Jeff Cook Date: Fri, 4 Jul 2025 14:48:46 -0600 Subject: [PATCH 07/12] refactor: remove quote normalization from Mistral tool parser Co-authored-by: aider (anthropic/claude-sonnet-4-20250514) --- .../tool_parsers/mistral_tool_parser.py | 42 +------------------ 1 file changed, 2 insertions(+), 40 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index b6d0f45fcda..3fc65589704 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -353,55 +353,17 @@ def _process_delta_text(self, delta_text: str) -> str: if parts[0]: # Content before bot token additional_content = parts[0] # Process content after bot token - tool_content = self._normalize_quotes(parts[1]).lstrip() + tool_content = parts[1].lstrip() self.raw_tool_calls += tool_content else: # No bot token in delta, just clean and append - cleaned_delta = self._normalize_quotes(delta_text) - self.raw_tool_calls += cleaned_delta + self.raw_tool_calls += delta_text # Remove leading spaces only if we have content if self.raw_tool_calls: self.raw_tool_calls = self.raw_tool_calls.lstrip() return additional_content - def _normalize_quotes(self, text: str) -> str: - """ - Normalize quotes in tool call text, being careful not to corrupt string values. - - This method attempts to replace structural single quotes with double quotes - while preserving single quotes that are part of string values. - - Args: - text: The text to normalize - - Returns: - Text with normalized quotes - """ - if not text or "'" not in text: - return text - - # For V11 format (function_name{...}), we don't need quote normalization - # as the JSON should already be properly formatted - if self.v11_tool_format: - return text - - # Simple heuristic: if the text looks like it might be valid JSON already, - # try to parse it first before doing any replacements - stripped = text.strip() - if stripped.startswith(('[', '{')): - try: - # Try parsing as-is first - json.loads(stripped) - return text # Already valid JSON, no changes needed - except json.JSONDecodeError: - pass - - # Fallback to the original behavior with a warning - # This preserves backward compatibility but logs the risk - logger.warning("Performing quote normalization on tool call text. " - "This may corrupt string values containing single quotes.") - return text.replace("'", '"') def _should_detect_v11_format(self) -> bool: """Check if we should attempt V11 format detection.""" From d0895541bbc090c28d28cc1fa8e77a84ad0d6a06 Mon Sep 17 00:00:00 2001 From: Jeff Cook Date: Fri, 4 Jul 2025 14:50:24 -0600 Subject: [PATCH 08/12] refactor: optimize tool call parsing by removing substring operations and using offset-based parsing Co-authored-by: aider (anthropic/claude-sonnet-4-20250514) --- .../tool_parsers/mistral_tool_parser.py | 64 ++++++++----------- 1 file changed, 26 insertions(+), 38 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index 3fc65589704..17c8633345a 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -278,35 +278,30 @@ def _reset_v11_tool_state(self) -> None: self.current_tool_name_sent = False self.prev_args_sent = "" - def _determine_next_parsing_element( - self, - raw_current_tool_call: str) -> Literal["name", "arguments"] | None: + def _determine_next_parsing_element(self) -> Literal["name", "arguments"] | None: """ Determine the next element to parse based on current state. - Args: - raw_current_tool_call: The current tool call text - Returns: The next element to parse, or None if nothing is ready """ # Check for name attribute if not self.current_tool_name_finished: match_name = self.tool_call_first_attribute_name.match( - raw_current_tool_call) + self.raw_tool_calls, self.current_tool_start_index) if match_name and match_name.end( - ) > self.previous_attribute_end_index: - self.current_attribute_start_index = match_name.end() + ) > self.current_tool_start_index + self.previous_attribute_end_index: + self.current_attribute_start_index = match_name.end() - self.current_tool_start_index return "name" # Check for arguments attribute if not self.current_tool_arguments_finished: match_arguments = self.tool_call_first_attribute_arguments.match( - raw_current_tool_call) + self.raw_tool_calls, self.current_tool_start_index) if match_arguments and match_arguments.end( - ) > self.previous_attribute_end_index: + ) > self.current_tool_start_index + self.previous_attribute_end_index: # The `{` is the last character in the match - we want it as start index - self.current_attribute_start_index = match_arguments.end() - 1 + self.current_attribute_start_index = match_arguments.end() - 1 - self.current_tool_start_index return "arguments" return None @@ -402,49 +397,44 @@ def _try_parse_json_cached(self, text: str) -> tuple[bool, int]: return result def _extracted_complete_name( - self, raw_current_tool_call: str, - current_attribute_start_index: int) -> tuple[str, int | None]: + self, current_attribute_start_index: int) -> tuple[str, int | None]: """ Extract the complete function name from the current tool call. Args: - raw_current_tool_call: The raw JSON string of the current tool call current_attribute_start_index: The starting index of the - name attribute in the raw_current_tool_call string + name attribute relative to the current tool start Returns: tuple: - The function name, or "" if extraction failed - - The end index of the name in raw_current_tool_call, + - The end index of the name relative to the current tool start, or None if extraction failed """ - partial_name_value = raw_current_tool_call[ - current_attribute_start_index:] - if match := self.string_value_pattern.match(partial_name_value): - return match.group(1), match.end() + current_attribute_start_index + absolute_start = self.current_tool_start_index + current_attribute_start_index + if match := self.string_value_pattern.match(self.raw_tool_calls, absolute_start): + return match.group(1), match.end() - self.current_tool_start_index return "", None - def _extract_argument_fragment(self, raw_current_tool_call: str, - current_attribute_start_index: int, + def _extract_argument_fragment(self, current_attribute_start_index: int, delta: str) -> tuple[str, int]: """ Extract the relevant argument fragment from the current streaming delta. Args: - raw_current_tool_call: The raw JSON string of the current tool call current_attribute_start_index: The starting index - of the arguments attribute in the raw string + of the arguments attribute relative to the current tool start delta: The new text added in this streaming step Returns: tuple: - The extracted argument diff text to be sent in the streaming response - - The end index of the arguments in the raw string, + - The end index of the arguments relative to the current tool start, or -1 if not yet complete """ - partial_arguments_value = raw_current_tool_call[ - current_attribute_start_index:] + absolute_start = self.current_tool_start_index + current_attribute_start_index + partial_arguments_value = self.raw_tool_calls[absolute_start:] try: _, end_index = self.json_decoder.raw_decode( partial_arguments_value) @@ -472,11 +462,13 @@ def _next_tool_starting_position(self) -> int: or -1 if no next tool is found yet """ assert self.current_tool_start_index >= 0 - current_tool_call = self.raw_tool_calls[self.current_tool_start_index:] try: - _, end_index = self.json_decoder.raw_decode(current_tool_call) - return (self.current_tool_start_index + end_index + - current_tool_call[end_index:].find("{")) + _, end_index = self.json_decoder.raw_decode( + self.raw_tool_calls, self.current_tool_start_index) + # Look for the next opening brace after the current tool ends + search_start = self.current_tool_start_index + end_index + next_brace = self.raw_tool_calls.find("{", search_start) + return next_brace if next_brace != -1 else -1 except json.decoder.JSONDecodeError: # The current tool object is not yet closed return -1 @@ -647,13 +639,10 @@ def extract_tool_calls_streaming( if self.current_tool_start_index >= len(self.raw_tool_calls): # tool call has not started return self._none_or_additional_content(additional_content) - raw_current_tool_call = self.raw_tool_calls[self. - current_tool_start_index:] # Determine what to parse next if self.current_element_streaming is None: - next_element = self._determine_next_parsing_element( - raw_current_tool_call) + next_element = self._determine_next_parsing_element() if next_element is None: return self._none_or_additional_content(additional_content) self.current_element_streaming = next_element @@ -661,7 +650,7 @@ def extract_tool_calls_streaming( if self.current_element_streaming == "name": try: function_name, name_end_index = self._extracted_complete_name( - raw_current_tool_call, self.current_attribute_start_index) + self.current_attribute_start_index) except IndexError: # name value has not started being generated return self._none_or_additional_content(additional_content) @@ -692,7 +681,6 @@ def extract_tool_calls_streaming( if self.current_element_streaming == "arguments": try: diff, arguments_end_index = self._extract_argument_fragment( - raw_current_tool_call, self.current_attribute_start_index, delta_text, ) From dee4d4334ee90ccc030aa3160bd8ec3aa260a750 Mon Sep 17 00:00:00 2001 From: Jeff Cook Date: Fri, 4 Jul 2025 14:52:04 -0600 Subject: [PATCH 09/12] refactor: Replace `X | Y` union syntax with `Union` for Python 3.9 compatibility Co-authored-by: aider (anthropic/claude-sonnet-4-20250514) --- .../openai/tool_parsers/mistral_tool_parser.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index 17c8633345a..ec05e671835 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -6,7 +6,7 @@ from enum import Enum from random import choices from string import ascii_letters, digits -from typing import Literal +from typing import Literal, Union import regex as re from pydantic import Field @@ -98,8 +98,8 @@ def __init__(self, tokenizer: AnyTokenizer): self.previous_attribute_end_index: int = 0 # Legacy state tracking (kept for compatibility) - self.current_element_streaming: Literal["name", - "arguments"] | None = None + self.current_element_streaming: Union[Literal["name", + "arguments"], None] = None self.current_tool_name_finished: bool = False self.current_tool_arguments_finished: bool = False self.tools_parsing_finished: bool = False @@ -128,7 +128,7 @@ def __init__(self, tokenizer: AnyTokenizer): def _extract_tool_calls_streaming_v11( self, additional_content: str, - delta_text: str) -> DeltaMessage | None: + delta_text: str) -> Union[DeltaMessage, None]: """ Extract tool calls from a streaming response, specifically for the v11 MistralTokenizer format: ToolName{arguments}. This logic is a @@ -278,7 +278,7 @@ def _reset_v11_tool_state(self) -> None: self.current_tool_name_sent = False self.prev_args_sent = "" - def _determine_next_parsing_element(self) -> Literal["name", "arguments"] | None: + def _determine_next_parsing_element(self) -> Union[Literal["name", "arguments"], None]: """ Determine the next element to parse based on current state. @@ -397,7 +397,7 @@ def _try_parse_json_cached(self, text: str) -> tuple[bool, int]: return result def _extracted_complete_name( - self, current_attribute_start_index: int) -> tuple[str, int | None]: + self, current_attribute_start_index: int) -> tuple[str, Union[int, None]]: """ Extract the complete function name from the current tool call. @@ -478,7 +478,7 @@ def _next_tool_starting_position(self) -> int: return -1 def _none_or_additional_content( - self, additional_content: str) -> DeltaMessage | None: + self, additional_content: str) -> Union[DeltaMessage, None]: """ Create a DeltaMessage with additional content if present, otherwise return None. @@ -590,7 +590,7 @@ def extract_tool_calls_streaming( current_token_ids: Sequence[int], delta_token_ids: Sequence[int], request: ChatCompletionRequest, - ) -> DeltaMessage | None: + ) -> Union[DeltaMessage, None]: # Early return if no tool call token present if self.bot_token not in current_text: From b521f50f05a1c5f2f0c431dcefae1d419531f68f Mon Sep 17 00:00:00 2001 From: Jeff Cook Date: Fri, 4 Jul 2025 14:55:27 -0600 Subject: [PATCH 10/12] feat: add comprehensive tests for Mistral v11 tool format Co-authored-by: aider (claude-opus-4-20250514) --- tests/tool_use/test_mistral_tool_parser.py | 175 ++++++++++++++++++++- 1 file changed, 174 insertions(+), 1 deletion(-) diff --git a/tests/tool_use/test_mistral_tool_parser.py b/tests/tool_use/test_mistral_tool_parser.py index b9bbdef57d9..4b5a19be2eb 100644 --- a/tests/tool_use/test_mistral_tool_parser.py +++ b/tests/tool_use/test_mistral_tool_parser.py @@ -14,7 +14,7 @@ from vllm.transformers_utils.detokenizer import detokenize_incrementally from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer -MODEL = "mistralai/Mistral-7B-Instruct-v0.3" +MODEL = "jeffcookio/Mistral-Small-3.2-24B-Instruct-2506-awq-sym" @pytest.fixture(scope="module") @@ -171,6 +171,14 @@ def test_extract_tool_calls(mistral_tool_parser, model_output, "argument_before_name", "argument_before_name_and_name_in_argument", "multiple_tools", + "v11_single_tool", + "v11_multiple_tools", + "v11_nested_json", + "v11_special_chars", + "v11_empty_args", + "v11_complex_nested", + "v11_with_comma_separator", + "v11_with_whitespace_separator", ], argnames=["model_output", "expected_tool_calls", "expected_content"], argvalues=[ @@ -246,6 +254,97 @@ def test_extract_tool_calls(mistral_tool_parser, model_output, }))) ], ''), + # V11 format tests + ( + '''[TOOL_CALLS] add{"a": 3, "b": 4}''', + [ + ToolCall(function=FunctionCall(name="add", + arguments=json.dumps({ + "a": 3, + "b": 4 + }))) + ], + ""), + ( + '''[TOOL_CALLS] add{"a": 3, "b": 4}, get_weather{"city": "Paris", "unit": "celsius"}''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="add", + arguments=json.dumps({ + "a": 3, + "b": 4 + }))), + ToolCall(function=FunctionCall(name="get_weather", + arguments=json.dumps({ + "city": "Paris", + "unit": "celsius" + }))) + ], + ""), + ( + '''[TOOL_CALLS] process_data{"input": {"nested": {"value": 42, "array": [1, 2, 3]}, "flag": true}}''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="process_data", + arguments=json.dumps({ + "input": { + "nested": { + "value": 42, + "array": [1, 2, 3] + }, + "flag": True + } + }))) + ], + ""), + ( + '''[TOOL_CALLS] send_message{"text": "Hello, it's a nice day!", "recipient": "user@example.com"}''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="send_message", + arguments=json.dumps({ + "text": "Hello, it's a nice day!", + "recipient": "user@example.com" + }))) + ], + ""), + ( + '''[TOOL_CALLS] empty_function{}''', + [ + ToolCall(function=FunctionCall(name="empty_function", + arguments=json.dumps({}))) + ], + ""), + ( + '''[TOOL_CALLS] complex_tool{"data": {"items": [{"id": 1, "props": {"key": "value"}}, {"id": 2, "props": {"key": "other"}}], "meta": {"count": 2}}}''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="complex_tool", + arguments=json.dumps({ + "data": { + "items": [ + {"id": 1, "props": {"key": "value"}}, + {"id": 2, "props": {"key": "other"}} + ], + "meta": {"count": 2} + } + }))) + ], + ""), + ( + '''[TOOL_CALLS] first_tool{"x": 1}, second_tool{"y": 2}''', + [ + ToolCall(function=FunctionCall(name="first_tool", + arguments=json.dumps({"x": 1}))), + ToolCall(function=FunctionCall(name="second_tool", + arguments=json.dumps({"y": 2}))) + ], + ""), + ( + '''[TOOL_CALLS] tool_a{"param": "A"} tool_b{"param": "B"}''', + [ + ToolCall(function=FunctionCall(name="tool_a", + arguments=json.dumps({"param": "A"}))), + ToolCall(function=FunctionCall(name="tool_b", + arguments=json.dumps({"param": "B"}))) + ], + ""), ], ) def test_extract_tool_calls_streaming(mistral_tool_parser, mistral_tokenizer, @@ -309,3 +408,77 @@ def test_extract_tool_calls_streaming(mistral_tool_parser, mistral_tokenizer, tool_call_ids, function_names, function_args_strs) ] assert_tool_calls(actual_tool_calls, expected_tool_calls) + + +@pytest.mark.parametrize( + ids=[ + "v11_single_tool", + "v11_multiple_tools_comma", + "v11_nested_with_quotes", + "v11_escaped_chars", + "v11_mixed_content", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ( + '''[TOOL_CALLS] calculate_sum{"numbers": [1, 2, 3, 4, 5]}''', + [ + ToolCall(function=FunctionCall(name="calculate_sum", + arguments=json.dumps({ + "numbers": [1, 2, 3, 4, 5] + }))) + ], + None), + ( + '''[TOOL_CALLS] get_user{"id": 123}, update_profile{"name": "John", "age": 30}''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="get_user", + arguments=json.dumps({"id": 123}))), + ToolCall(function=FunctionCall(name="update_profile", + arguments=json.dumps({ + "name": "John", + "age": 30 + }))) + ], + None), + ( + '''[TOOL_CALLS] parse_json{"content": "{\\"key\\": \\"value\\", \\"nested\\": {\\"item\\": 1}}"}''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="parse_json", + arguments=json.dumps({ + "content": "{\"key\": \"value\", \"nested\": {\"item\": 1}}" + }))) + ], + None), + ( + '''[TOOL_CALLS] format_text{"template": "Hello {name}\\nWelcome!", "vars": {"name": "User"}}''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="format_text", + arguments=json.dumps({ + "template": "Hello {name}\nWelcome!", + "vars": {"name": "User"} + }))) + ], + None), + ( + '''Some content before [TOOL_CALLS] analyze_data{"dataset": "sales_2024", "metrics": ["revenue", "growth"]}''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="analyze_data", + arguments=json.dumps({ + "dataset": "sales_2024", + "metrics": ["revenue", "growth"] + }))) + ], + "Some content before "), + ], +) +def test_extract_tool_calls_v11_format(mistral_tool_parser, model_output, + expected_tool_calls, expected_content): + """Test extraction of tool calls in v11 format (non-streaming)""" + extracted_tool_calls = mistral_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + assert extracted_tool_calls.tools_called + + assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) + + assert extracted_tool_calls.content == expected_content From 2966baad912b7dc7e0d2fb2f1abaf5d6881041e4 Mon Sep 17 00:00:00 2001 From: Jeff Cook Date: Sun, 13 Jul 2025 15:43:28 -0600 Subject: [PATCH 11/12] ruff/yapf --- .../tool_parsers/mistral_tool_parser.py | 80 ++++++++++++------- 1 file changed, 49 insertions(+), 31 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index ec05e671835..09606f28e12 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -89,7 +89,7 @@ def __init__(self, tokenizer: AnyTokenizer): # Core streaming state self.raw_tool_calls: str = "" self.streaming_state: StreamingState = \ - StreamingState.WAITING_FOR_TOOL_START + StreamingState.WAITING_FOR_TOOL_START # Tool tracking self.current_tool_id: int = -1 @@ -98,8 +98,8 @@ def __init__(self, tokenizer: AnyTokenizer): self.previous_attribute_end_index: int = 0 # Legacy state tracking (kept for compatibility) - self.current_element_streaming: Union[Literal["name", - "arguments"], None] = None + self.current_element_streaming: Union[Literal["name", "arguments"], + None] = None self.current_tool_name_finished: bool = False self.current_tool_arguments_finished: bool = False self.tools_parsing_finished: bool = False @@ -141,14 +141,16 @@ def _extract_tool_calls_streaming_v11( logger.debug("v11 streaming: prev_args_sent='%s'", self.prev_args_sent) # Handle multiple tools separated by commas/whitespace - if self.current_tool_name_finished and self.current_tool_arguments_finished: - if self._should_advance_to_next_v11_tool(): - # Remove the completed tool from raw_tool_calls before resetting state - completed_tool_end = self._find_completed_v11_tool_end() - if completed_tool_end > 0: - self.raw_tool_calls = self.raw_tool_calls[completed_tool_end:] - self._reset_v11_tool_state() - logger.debug("v11 streaming: found next tool, resetting state") + if self.current_tool_name_finished \ + and self.current_tool_arguments_finished \ + and self._should_advance_to_next_v11_tool(): + # Remove the completed tool from raw_tool_calls + # before resetting state + completed_tool_end = self._find_completed_v11_tool_end() + if completed_tool_end > 0: + self.raw_tool_calls = self.raw_tool_calls[completed_tool_end:] + self._reset_v11_tool_state() + logger.debug("v11 streaming: found next tool, resetting state") # Phase 1: Extract and send function name if not self.current_tool_name_sent: @@ -186,7 +188,8 @@ def _extract_tool_calls_streaming_v11( ) # Phase 2: Extract and send argument fragments - if self.current_tool_name_sent and not self.current_tool_arguments_finished: + if self.current_tool_name_sent and \ + not self.current_tool_arguments_finished: # Find the arguments part (everything after the first {) brace_index = self.raw_tool_calls.find("{") if brace_index == -1: @@ -245,26 +248,31 @@ def _should_advance_to_next_v11_tool(self) -> bool: completed_tool_end = self._find_completed_v11_tool_end() if completed_tool_end <= 0: return False - - # Check if there's content after the completed tool that looks like another tool + + # Check if there's content after the completed tool + # that looks like another tool remaining = self.raw_tool_calls[completed_tool_end:].strip() if remaining.startswith(','): remaining = remaining[1:].strip() - + # Look for next tool pattern: function_name{ return bool(re.match(r'[a-zA-Z0-9_-]+\s*\{', remaining)) def _find_completed_v11_tool_end(self) -> int: - """Find the end position of the first completed tool in V11 format using JSON parsing.""" + """ + Find the end position of the first completed tool in V11 format using + JSON parsing. + """ # Look for function name pattern: name followed by { - brace_match = re.search(r'([a-zA-Z0-9_-]+)\s*(\{)', self.raw_tool_calls) + brace_match = re.search(r'([a-zA-Z0-9_-]+)\s*(\{)', + self.raw_tool_calls) if not brace_match: return -1 - + # Try to parse the JSON starting from the opening brace json_start = brace_match.start(2) json_part = self.raw_tool_calls[json_start:] - + try: _, end_idx = self.json_decoder.raw_decode(json_part) return json_start + end_idx @@ -278,7 +286,8 @@ def _reset_v11_tool_state(self) -> None: self.current_tool_name_sent = False self.prev_args_sent = "" - def _determine_next_parsing_element(self) -> Union[Literal["name", "arguments"], None]: + def _determine_next_parsing_element(self) \ + -> Union[Literal["name", "arguments"], None]: """ Determine the next element to parse based on current state. @@ -290,8 +299,10 @@ def _determine_next_parsing_element(self) -> Union[Literal["name", "arguments"], match_name = self.tool_call_first_attribute_name.match( self.raw_tool_calls, self.current_tool_start_index) if match_name and match_name.end( - ) > self.current_tool_start_index + self.previous_attribute_end_index: - self.current_attribute_start_index = match_name.end() - self.current_tool_start_index + ) > self.current_tool_start_index \ + + self.previous_attribute_end_index: + self.current_attribute_start_index = match_name.end() \ + - self.current_tool_start_index return "name" # Check for arguments attribute @@ -299,9 +310,12 @@ def _determine_next_parsing_element(self) -> Union[Literal["name", "arguments"], match_arguments = self.tool_call_first_attribute_arguments.match( self.raw_tool_calls, self.current_tool_start_index) if match_arguments and match_arguments.end( - ) > self.current_tool_start_index + self.previous_attribute_end_index: - # The `{` is the last character in the match - we want it as start index - self.current_attribute_start_index = match_arguments.end() - 1 - self.current_tool_start_index + ) > self.current_tool_start_index \ + + self.previous_attribute_end_index: + # The `{` is the last character in the match. + # We want it as start index. + self.current_attribute_start_index = match_arguments.end() \ + - 1 - self.current_tool_start_index return "arguments" return None @@ -331,7 +345,8 @@ def _advance_to_next_tool(self) -> bool: def _process_delta_text(self, delta_text: str) -> str: """ - Process delta text and update raw_tool_calls, returning any additional content. + Process delta text and update raw_tool_calls, returning any additional + content. Args: delta_text: The new text delta to process @@ -359,7 +374,6 @@ def _process_delta_text(self, delta_text: str) -> str: return additional_content - def _should_detect_v11_format(self) -> bool: """Check if we should attempt V11 format detection.""" return (self.fn_name_regex is not None and self.current_tool_id == -1 @@ -397,7 +411,8 @@ def _try_parse_json_cached(self, text: str) -> tuple[bool, int]: return result def _extracted_complete_name( - self, current_attribute_start_index: int) -> tuple[str, Union[int, None]]: + self, current_attribute_start_index: int) \ + -> tuple[str, Union[int, None]]: """ Extract the complete function name from the current tool call. @@ -411,8 +426,10 @@ def _extracted_complete_name( - The end index of the name relative to the current tool start, or None if extraction failed """ - absolute_start = self.current_tool_start_index + current_attribute_start_index - if match := self.string_value_pattern.match(self.raw_tool_calls, absolute_start): + absolute_start = self.current_tool_start_index \ + + current_attribute_start_index + if match := self.string_value_pattern.match(\ + self.raw_tool_calls, absolute_start): return match.group(1), match.end() - self.current_tool_start_index return "", None @@ -433,7 +450,8 @@ def _extract_argument_fragment(self, current_attribute_start_index: int, - The end index of the arguments relative to the current tool start, or -1 if not yet complete """ - absolute_start = self.current_tool_start_index + current_attribute_start_index + absolute_start = self.current_tool_start_index \ + + current_attribute_start_index partial_arguments_value = self.raw_tool_calls[absolute_start:] try: _, end_index = self.json_decoder.raw_decode( From ef4d46c75d15254c11ed8801e37bcef13b375971 Mon Sep 17 00:00:00 2001 From: Jeff Cook Date: Sun, 13 Jul 2025 16:51:57 -0600 Subject: [PATCH 12/12] Via Grok4: attempt to fix non-streaming and multiple tool calls --- .../tool_parsers/mistral_tool_parser.py | 263 +++++++++++------- 1 file changed, 157 insertions(+), 106 deletions(-) diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index 09606f28e12..ff56868f6da 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -117,7 +117,7 @@ def __init__(self, tokenizer: AnyTokenizer): self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL) if _is_fn_name_regex_support(self.model_tokenizer): self.fn_name_regex = re.compile( - r'([a-zA-Z0-9_-]+)(\{[\s\S]*?\})(?=\s*$|,|\s)', re.DOTALL) + r'([a-zA-Z0-9_-]+)\s*(\{[\s\S]*?\})', re.DOTALL) else: self.fn_name_regex = None @@ -140,106 +140,120 @@ def _extract_tool_calls_streaming_v11( self.current_tool_name_sent) logger.debug("v11 streaming: prev_args_sent='%s'", self.prev_args_sent) - # Handle multiple tools separated by commas/whitespace - if self.current_tool_name_finished \ - and self.current_tool_arguments_finished \ - and self._should_advance_to_next_v11_tool(): - # Remove the completed tool from raw_tool_calls - # before resetting state - completed_tool_end = self._find_completed_v11_tool_end() - if completed_tool_end > 0: - self.raw_tool_calls = self.raw_tool_calls[completed_tool_end:] - self._reset_v11_tool_state() - logger.debug("v11 streaming: found next tool, resetting state") - - # Phase 1: Extract and send function name - if not self.current_tool_name_sent: - # Look for function name pattern: name followed by { - brace_index = self.raw_tool_calls.find("{") - if brace_index == -1: - logger.debug("v11 streaming: no opening brace found yet") - return self._none_or_additional_content(additional_content) - - # Extract function name - func_name = self.raw_tool_calls[:brace_index].strip() - # Remove any leading separators from previous tools - func_name = re.sub(r'^[\s,]*', '', func_name) - - if not func_name: - logger.debug("v11 streaming: function name is empty") - return self._none_or_additional_content(additional_content) - - logger.debug("v11 streaming: sending function name='%s'", - func_name) - self.current_tool_name_sent = True - self.current_tool_id += 1 + result_tool_calls: list[DeltaToolCall] = [] + + while True: + advanced = False + if self.current_tool_name_finished and \ + self.current_tool_arguments_finished and \ + self._should_advance_to_next_v11_tool(): + # Remove the completed tool from raw_tool_calls + # before resetting state + completed_tool_end = self._find_completed_v11_tool_end() + if completed_tool_end > 0: + self.raw_tool_calls = self.raw_tool_calls[ + completed_tool_end:] + self._reset_v11_tool_state() + logger.debug("v11 streaming: found next tool, resetting state") + advanced = True + + sent_something = False + + # Phase 1: Extract and send function name + if not self.current_tool_name_sent: + # Look for function name pattern: name followed by { + brace_index = self.raw_tool_calls.find("{") + if brace_index == -1: + logger.debug("v11 streaming: no opening brace found yet") + break + + # Extract function name + func_name = self.raw_tool_calls[:brace_index].strip() + # Remove any leading separators from previous tools + func_name = re.sub(r'^[\s,]*', '', func_name) + + if not func_name: + logger.debug("v11 streaming: function name is empty") + break + + logger.debug("v11 streaming: sending function name='%s'", + func_name) + self.current_tool_name_sent = True + self.current_tool_name_finished = True + self.current_tool_id += 1 - return DeltaMessage( - content=additional_content, - tool_calls=[ + result_tool_calls.append( DeltaToolCall( index=self.current_tool_id, type="function", id=MistralToolCall.generate_random_id(), function=DeltaFunctionCall(name=func_name).model_dump( exclude_none=True), - ) - ], - ) - - # Phase 2: Extract and send argument fragments - if self.current_tool_name_sent and \ - not self.current_tool_arguments_finished: - # Find the arguments part (everything after the first {) - brace_index = self.raw_tool_calls.find("{") - if brace_index == -1: - logger.debug("v11 streaming: no opening brace found for args") - return self._none_or_additional_content(additional_content) - - current_args = self.raw_tool_calls[brace_index:] - logger.debug("v11 streaming: current_args='%s'", current_args) + )) + sent_something = True + + # Phase 2: Extract and send argument fragments + if self.current_tool_name_sent and \ + not self.current_tool_arguments_finished: + # Find the arguments part (everything after the first {) + brace_index = self.raw_tool_calls.find("{") + if brace_index == -1: + logger.debug( + "v11 streaming: no opening brace found for args") + break + + current_args = self.raw_tool_calls[brace_index:] + logger.debug("v11 streaming: current_args='%s'", current_args) + + actual_args = current_args + try: + parsed_obj, end_idx = self.json_decoder.raw_decode( + current_args) + # JSON is complete + self.current_tool_arguments_finished = True + actual_args = current_args[:end_idx] + logger.debug("v11 streaming: JSON complete, parsed_obj=%s", + parsed_obj) + except json.decoder.JSONDecodeError: + # JSON still incomplete + logger.debug("v11 streaming: JSON still incomplete") + pass + + # Calculate what's new since last time + new_content = "" + if actual_args != self.prev_args_sent: + if self.prev_args_sent and actual_args.startswith( + self.prev_args_sent): + # Incremental update + new_content = actual_args[len(self.prev_args_sent):] + logger.debug("v11 streaming: incremental args='%s'", + new_content) + else: + # First time or reset + new_content = actual_args + logger.debug("v11 streaming: first/reset args='%s'", + new_content) + + self.prev_args_sent = actual_args - # Check if JSON is complete - try: - parsed_obj, end_idx = self.json_decoder.raw_decode( - current_args) - # JSON is complete - self.current_tool_arguments_finished = True - logger.debug("v11 streaming: JSON complete, parsed_obj=%s", - parsed_obj) - except json.decoder.JSONDecodeError: - # JSON still incomplete - logger.debug("v11 streaming: JSON still incomplete") - pass - - # Calculate what's new since last time - if current_args != self.prev_args_sent: - if self.prev_args_sent and current_args.startswith( - self.prev_args_sent): - # Incremental update - new_content = current_args[len(self.prev_args_sent):] - logger.debug("v11 streaming: incremental args='%s'", - new_content) - else: - # First time or reset - new_content = current_args - logger.debug("v11 streaming: first/reset args='%s'", - new_content) + if new_content: + result_tool_calls.append( + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=new_content).model_dump( + exclude_none=True), + )) + sent_something = True - self.prev_args_sent = current_args + if not sent_something and not advanced: + break - if new_content: - return DeltaMessage( - content=additional_content, - tool_calls=[ - DeltaToolCall( - index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=new_content).model_dump( - exclude_none=True), - ) - ], - ) + if result_tool_calls: + return DeltaMessage( + content=additional_content, + tool_calls=result_tool_calls, + ) return self._none_or_additional_content(additional_content) @@ -260,8 +274,8 @@ def _should_advance_to_next_v11_tool(self) -> bool: def _find_completed_v11_tool_end(self) -> int: """ - Find the end position of the first completed tool in V11 format using - JSON parsing. + Find the end position of the first completed tool in V11 format + using JSON parsing. """ # Look for function name pattern: name followed by { brace_match = re.search(r'([a-zA-Z0-9_-]+)\s*(\{)', @@ -550,19 +564,55 @@ def extract_tool_calls( # jsons is difficult try: if self.fn_name_regex: - matches = self.fn_name_regex.findall(tool_content) - function_call_arr = [] - for match in matches: - fn_name = match[0] - args = match[1] - - # fn_name is encoded outside serialized json dump - # only arguments are serialized - function_call_arr.append({ - "name": fn_name, - "arguments": json.loads(args) - }) + pos = 0 + tool_str = tool_content + while pos < len(tool_str): + # skip ws + while pos < len(tool_str) and tool_str[pos].isspace(): + pos += 1 + if pos >= len(tool_str): + break + + # match name + match_name = re.match(r'([a-zA-Z0-9_-]+)', + tool_str[pos:]) + if not match_name: + break + fn_name = match_name.group(0) + pos += match_name.end() + + # skip ws + while pos < len(tool_str) and tool_str[pos].isspace(): + pos += 1 + + if pos >= len(tool_str) or tool_str[pos] != '{': + break + + pos += 1 # skip { + + # parse args + try: + args_obj, end_idx = self.json_decoder.raw_decode( + tool_str[pos:]) + function_call_arr.append({ + "name": fn_name, + "arguments": args_obj + }) + pos += end_idx + except json.JSONDecodeError: + break + + # skip ws + while pos < len(tool_str) and tool_str[pos].isspace(): + pos += 1 + + # optional comma + if pos < len(tool_str) and tool_str[pos] == ',': + pos += 1 + while pos < len( + tool_str) and tool_str[pos].isspace(): + pos += 1 else: function_call_arr = json.loads(tool_content) except json.JSONDecodeError: @@ -570,7 +620,8 @@ def extract_tool_calls( # NOTE: This use case should not happen if the model is trained # correctly. It's a easy possible fix so it's included, but # can be brittle for very complex / highly nested tool calls - raw_tool_call = self.tool_call_regex.findall(tool_content)[0] + raw_tool_call = self.tool_call_regex.search( + tool_content).group(0) function_call_arr = json.loads(raw_tool_call) # Tool Call