diff --git a/tests/tool_use/test_mistral_tool_parser.py b/tests/tool_use/test_mistral_tool_parser.py new file mode 100644 index 000000000000..8628e577ffd0 --- /dev/null +++ b/tests/tool_use/test_mistral_tool_parser.py @@ -0,0 +1,528 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +from collections.abc import Generator +from typing import Optional + +import partial_json_parser +import pytest +from mistral_common.protocol.instruct.messages import AssistantMessage +from mistral_common.protocol.instruct.tool_calls import FunctionCall, ToolCall +from mistral_common.tokens.instruct.request import InstructRequest +from partial_json_parser.core.options import Allow + +from vllm.entrypoints.openai.protocol import DeltaMessage +from vllm.entrypoints.openai.tool_parsers import MistralToolParser +from vllm.transformers_utils.detokenizer import detokenize_incrementally +from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, + get_tokenizer) + + +@pytest.fixture(scope="module") +def mistral_tokenizer(): + MODEL = "mistralai/Mistral-7B-Instruct-v0.3" + return get_tokenizer(tokenizer_name=MODEL) + + +@pytest.fixture(scope="module") +def mistral_v11_tokenizer(): + MODEL = "mistralai/Mistral-Small-3.2-24B-Instruct-2506" + return get_tokenizer(tokenizer_name=MODEL, tokenizer_mode="mistral") + + +@pytest.fixture +def mistral_tool_parser(mistral_tokenizer): + return MistralToolParser(mistral_tokenizer) + + +@pytest.fixture +def mistral_v11_tool_parser(mistral_v11_tokenizer): + return MistralToolParser(mistral_v11_tokenizer) + + +def assert_tool_calls(actual_tool_calls: list[ToolCall], + expected_tool_calls: list[ToolCall]): + assert len(actual_tool_calls) == len(expected_tool_calls) + + for actual_tool_call, expected_tool_call in zip(actual_tool_calls, + expected_tool_calls): + assert isinstance(actual_tool_call.id, str) + assert len(actual_tool_call.id) == 9 + + assert actual_tool_call.type == "function" + assert ( + actual_tool_call.function.name == expected_tool_call.function.name + ), f"got wrong function name:${actual_tool_call.function.name}" + assert ( + actual_tool_call.function.arguments == + expected_tool_call.function.arguments + ), f"got wrong function argument:${actual_tool_call.function.arguments}" + + +def stream_delta_message_generator( + mistral_tool_parser: MistralToolParser, + mistral_tokenizer: AnyTokenizer, + model_output: Optional[str], + tools: Optional[list[tuple[str, str]]], +) -> Generator[DeltaMessage, None, None]: + if isinstance(mistral_tokenizer, MistralTokenizer): + assert tools is not None + assistant_msg = AssistantMessage(tool_calls=[ + ToolCall(function=FunctionCall( + name=name, + arguments=arg, + )) for (name, arg) in tools + ], ) + request = InstructRequest(messages=[assistant_msg], ) + all_token_ids = mistral_tokenizer.instruct.encode_instruct( + request).tokens + else: + assert model_output is not None + all_token_ids = mistral_tokenizer.encode(model_output, + add_special_tokens=False) + + previous_text = "" + previous_tokens = None + prefix_offset = 0 + read_offset = 0 + for i, delta_token in enumerate(all_token_ids): + delta_token_ids = [delta_token] + previous_token_ids = all_token_ids[:i] + current_token_ids = all_token_ids[:i + 1] + + (new_tokens, delta_text, new_prefix_offset, + new_read_offset) = (detokenize_incrementally( + tokenizer=mistral_tokenizer, + all_input_ids=current_token_ids, + prev_tokens=previous_tokens, + prefix_offset=prefix_offset, + read_offset=read_offset, + skip_special_tokens=isinstance(mistral_tokenizer, + MistralTokenizer), + spaces_between_special_tokens=True, + )) + + current_text = previous_text + delta_text + + delta_message = mistral_tool_parser.extract_tool_calls_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + request=None, # type: ignore[arg-type] + ) + if delta_message: + yield delta_message + + previous_text = current_text + previous_tokens = (previous_tokens + + new_tokens if previous_tokens else new_tokens) + prefix_offset = new_prefix_offset + read_offset = new_read_offset + + +def test_extract_tool_calls_no_tools(mistral_tool_parser): + model_output = "This is a test" + extracted_tool_calls = mistral_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + assert not extracted_tool_calls.tools_called + assert extracted_tool_calls.tool_calls == [] + assert extracted_tool_calls.content == model_output + + +@pytest.mark.parametrize( + ids=[ + "single_tool_add", + "single_tool_weather", + "argument_before_name", + "argument_before_name_and_name_in_argument", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ( + """[TOOL_CALLS][{"name": "add", "arguments":{"a": 3.5, "b": 4}}]""", # noqa: E501 + [ + ToolCall(function=FunctionCall(name="add", + arguments=json.dumps({ + "a": 3.5, + "b": 4 + }))) + ], + None, + ), + ( + """[TOOL_CALLS] [{"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501 + [ + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "San Francisco", + "state": "CA", + "unit": "celsius" + }), + )) + ], + None, + ), + ( + """[TOOL_CALLS] [{"arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501 + [ + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "San Francisco", + "state": "CA", + "unit": "celsius" + }), + )) + ], + None, + ), + ( + """[TOOL_CALLS] [{"arguments":{"name": "John Doe"}, "name": "get_age"}]""", # noqa: E501 + [ + ToolCall(function=FunctionCall( + name="get_age", + arguments=json.dumps({ + "name": "John Doe", + }), + )) + ], + None, + ), + ], +) +def test_extract_tool_callspre_v11_tokenizer(mistral_tool_parser, model_output, + expected_tool_calls, + expected_content): + extracted_tool_calls = mistral_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + assert extracted_tool_calls.tools_called + + assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) + + assert extracted_tool_calls.content == expected_content + + +@pytest.mark.parametrize( + ids=[ + "single_tool_add", + "single_tool_weather", + "multiple_tool_calls", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ( + """[TOOL_CALLS]add_this_and_that{"a": 3.5, "b": 4}""", # noqa: E501 + [ + ToolCall(function=FunctionCall( + name="add_this_and_that", + arguments=json.dumps({ + "a": 3.5, + "b": 4 + }), + )) + ], + None, + ), + ( + """[TOOL_CALLS]get_current_weather{"city": "San Francisco", "state": "CA", "unit": "celsius"}""", # noqa: E501 + [ + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "San Francisco", + "state": "CA", + "unit": "celsius" + }), + )) + ], + None, + ), + ( + '''[TOOL_CALLS]add{"a": 3.5, "b": 4}[TOOL_CALLS]multiply{"a": 3, "b": 6}''', # noqa: E501 + [ + ToolCall(function=FunctionCall(name="add", + arguments=json.dumps({ + "a": 3.5, + "b": 4 + }))), + ToolCall(function=FunctionCall(name="multiply", + arguments=json.dumps({ + "a": 3, + "b": 6 + }))) + ], + None) + ], +) +def test_extract_tool_calls(mistral_v11_tool_parser, model_output, + expected_tool_calls, expected_content): + extracted_tool_calls = mistral_v11_tool_parser.extract_tool_calls( + model_output, request=None) # type: ignore[arg-type] + assert extracted_tool_calls.tools_called + + assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls) + + assert extracted_tool_calls.content == expected_content + + +def _test_extract_tool_calls_streaming(tool_parser, tokenizer, model_output, + tools, expected_tool_calls, + expected_content): + other_content: str = "" + function_names: list[str] = [] + function_args_strs: list[str] = [] + tool_call_idx: int = -1 + tool_call_ids: list[Optional[str]] = [] + + for delta_message in stream_delta_message_generator( + tool_parser, tokenizer, model_output, tools): + # role should never be streamed from tool parser + assert not delta_message.role + + if delta_message.content: + other_content += delta_message.content + + streamed_tool_calls = delta_message.tool_calls + + if streamed_tool_calls and len(streamed_tool_calls) > 0: + # make sure only one diff is present - correct even for parallel + assert len(streamed_tool_calls) == 1 + tool_call = streamed_tool_calls[0] + + # if a new tool is being called, set up empty arguments + if tool_call.index != tool_call_idx: + tool_call_idx = tool_call.index + function_args_strs.append("") + function_names.append("") + tool_call_ids.append(None) + + # if a tool call ID is streamed, make sure one hasn't been already + if tool_call.id and not tool_call_ids[tool_call.index]: + tool_call_ids[tool_call.index] = tool_call.id + + # if parts of the function start being streamed + if tool_call.function: + # if the function name is defined, set it. it should be streamed + # IN ENTIRETY, exactly one time. + if tool_call.function.name: + assert isinstance(tool_call.function.name, str) + function_names[tool_call.index] += tool_call.function.name + + if tool_call.function.arguments: + # make sure they're a string and then add them to the list + assert isinstance(tool_call.function.arguments, str) + + function_args_strs[ + tool_call.index] += tool_call.function.arguments + + assert other_content == expected_content + + actual_tool_calls = [ + ToolCall( + id=tool_call_id, + function=FunctionCall( + name=function_name, + arguments=partial_json_parser.ensure_json( + function_args_str, Allow.OBJ | Allow.STR), + ), + ) for tool_call_id, function_name, function_args_str in zip( + tool_call_ids, function_names, function_args_strs) + ] + assert_tool_calls(actual_tool_calls, expected_tool_calls) + + +@pytest.mark.parametrize( + ids=[ + "no_tools", + "single_tool_add", + "single_tool_add_strings", + "single_tool_weather", + "argument_before_name", + "argument_before_name_and_name_in_argument", + "multiple_tools", + ], + argnames=["model_output", "expected_tool_calls", "expected_content"], + argvalues=[ + ("""This is a test""", [], """This is a test"""), + ( + """[TOOL_CALLS] [ {"name":"add" , "arguments" : {"a": 3, "b": 4} } ]""", # noqa: E501 + [ + ToolCall(function=FunctionCall(name="add", + arguments=json.dumps({ + "a": 3, + "b": 4 + }))) + ], + "", + ), + ( + """[TOOL_CALLS] [{"name": "add", "arguments":{"a": "3", "b": "4"}}]""", # noqa: E501 + [ + ToolCall(function=FunctionCall(name="add", + arguments=json.dumps({ + "a": "3", + "b": "4" + }))) + ], + "", + ), + ( + """[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501 + [ + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "San Francisco", + "state": "CA", + "unit": "celsius" + }), + )) + ], + "", + ), + ( + """[TOOL_CALLS] [{"arguments": {"city": "San Francisco", "state": "CA", "unit": "celsius"}, "name": "get_current_weather"}]""", # noqa: E501 + [ + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "San Francisco", + "state": "CA", + "unit": "celsius" + }), + )) + ], + "", + ), + ( + """[TOOL_CALLS] [{"arguments": {"name": "John Doe"}, "name": "get_age"}]""", # noqa: E501 + [ + ToolCall(function=FunctionCall( + name="get_age", + arguments=json.dumps({ + "name": "John Doe", + }), + )) + ], + "", + ), + ( + """[TOOL_CALLS][{"name": "add", "arguments": {"a": 3.5, "b": 4}}, {"name": "get_current_weather", "arguments":{"city": "San Francisco", "state": "CA", "unit": "celsius"}}]""", # noqa: E501 + [ + ToolCall(function=FunctionCall(name="add", + arguments=json.dumps({ + "a": 3.5, + "b": 4 + }))), + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "San Francisco", + "state": "CA", + "unit": "celsius" + }), + )), + ], + "", + ), + ], +) +def test_extract_tool_calls_streaming_pre_v11_tokenizer( + mistral_tool_parser, + mistral_tokenizer, + model_output, + expected_tool_calls, + expected_content, +): + _test_extract_tool_calls_streaming( + mistral_tool_parser, + mistral_tokenizer, + model_output, + None, + expected_tool_calls, + expected_content, + ) + + +@pytest.mark.parametrize( + ids=[ + "single_tool_add", + "single_tool_add_strings", + "multiple_tools", + ], + argnames=["tools", "expected_tool_calls", "expected_content"], + argvalues=[ + ( + [("add", '{"a": 3, "b": 4}')], + # [TOOL_CALLS]add{"a": 3, "b": 4} + [ + ToolCall(function=FunctionCall(name="add", + arguments=json.dumps({ + "a": 3, + "b": 4 + }))) + ], + "", + ), + ( + [("add_two_strings", '{"a": "3", "b": "4"}')], + # [TOOL_CALLS]add_two_strings{"a": "3", "b": "4"} + [ + ToolCall(function=FunctionCall( + name="add_two_strings", + arguments=json.dumps({ + "a": "3", + "b": "4" + }), + )) + ], + "", + ), + ( + [ + ("add", '{"a": 3.5, "b": 4}'), + ( + "get_current_weather", + '{"city": "San Francisco", "state": "CA", "unit": "celsius"}', # noqa: E501 + ), + ], + # [TOOL_CALLS]add{"a": 3.5, "b": 4}[TOOL_CALLS]get_current_weather{"city": "San Francisco", "state": "CA", "unit": "celsius"} # noqa: E501 + [ + ToolCall(function=FunctionCall(name="add", + arguments=json.dumps({ + "a": 3.5, + "b": 4 + }))), + ToolCall(function=FunctionCall( + name="get_current_weather", + arguments=json.dumps({ + "city": "San Francisco", + "state": "CA", + "unit": "celsius" + }), + )), + ], + "", + ), + ], +) +def test_extract_tool_calls_streaming( + mistral_v11_tool_parser, + mistral_v11_tokenizer, + tools, + expected_tool_calls, + expected_content, +): + _test_extract_tool_calls_streaming( + mistral_v11_tool_parser, + mistral_v11_tokenizer, + None, + tools, + expected_tool_calls, + expected_content, + ) diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index c0691f122904..626ba7b20973 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -5,11 +5,9 @@ from collections.abc import Sequence from random import choices from string import ascii_letters, digits -from typing import Union +from typing import Literal, Union -import partial_json_parser import regex as re -from partial_json_parser.core.options import Allow from pydantic import Field from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, @@ -19,8 +17,6 @@ FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( ToolParser, ToolParserManager) -from vllm.entrypoints.openai.tool_parsers.utils import ( - extract_intermediate_diff) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer @@ -68,11 +64,33 @@ def __init__(self, tokenizer: AnyTokenizer): # initialize properties used for state when parsing tool calls in # streaming mode - self.prev_tool_call_arr: list[dict] = [] + self.json_decoder: json.JSONDecoder = json.JSONDecoder() + self.tool_call_first_attribute_name: re.Pattern[str] = re.compile( + r'.*\s*"name"\s*:\s*') + self.string_value_pattern: re.Pattern[str] = re.compile( + r'\s*"(.*?)(? Union[DeltaMessage, None]: - # if the tool call token is not in the tokens generated so far, append - # output to contents since it's not a tool if self.bot_token not in current_text: + # if the tool call token is not in the tokens generated so far, + # append output to contents since it's not a tool return DeltaMessage(content=delta_text) # if the tool call token ID IS in the tokens generated so far, that # means we're parsing as tool calls now + if _is_fn_name_regex_support(self.model_tokenizer): + return self._extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=delta_token_ids, + request=request, + ) + else: + return self._extract_tool_calls_streaming_pre_v11_tokenizer( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=delta_token_ids, + request=request, + ) + + def _extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + additional_content: str = "" + if self.current_tool_id == -1: + # this is the first tool call + assert self.bot_token in delta_text + if not delta_text.startswith(self.bot_token): + additional_content += delta_text.split(self.bot_token)[0] + delta_text = self.bot_token + "".join( + delta_text.split(self.bot_token)[1:]) + + delta_tool_calls = self._generate_delta_tool_call(delta_text) + delta = DeltaMessage( + content=additional_content, + tool_calls=delta_tool_calls, + ) + return delta + + def _generate_delta_tool_call(self, + delta_text: str) -> list[DeltaToolCall]: + if delta_text == "" or delta_text is None: + return [] + delta_function_name = None + tool_id = None + if self.current_element_streaming is None and delta_text.startswith( + self.bot_token): + self.current_tool_id += 1 + tool_id = MistralToolCall.generate_random_id() + self.current_element_streaming = 'name' + delta_text = delta_text.replace(self.bot_token, "", 1) + if self.current_element_streaming == 'name': + if "{" in delta_text: + delta_function_name = delta_text.split("{")[0] + delta_text = delta_text[len(delta_function_name):] + self.current_element_streaming = 'arguments' + else: + delta_function_name = delta_text + return [ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=tool_id, + function=DeltaFunctionCall( + name=delta_function_name).model_dump( + exclude_none=True), + ) + ] + if self.current_element_streaming == 'arguments': + next_function_text = None + if self.bot_token in delta_text: + # current tool call is over + if delta_text.startswith(self.bot_token): + delta_arguments = "" + else: + delta_arguments = delta_text.split(self.bot_token)[0] + next_function_text = delta_text[len(delta_arguments):] + self.current_element_streaming = None + else: + delta_arguments = delta_text + ret = [] + if delta_function_name or delta_arguments: + ret += [ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=tool_id, + function=DeltaFunctionCall( + name=delta_function_name, + arguments=delta_arguments).model_dump( + exclude_none=True), + ) + ] + if next_function_text: + ret += self._generate_delta_tool_call(next_function_text) + return ret + # Should not happen + return [] + + def _extract_tool_calls_streaming_pre_v11_tokenizer( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + additional_content: str = "" + if self.bot_token in delta_text: + self.raw_tool_calls += (delta_text.split( + self.bot_token)[-1].replace("'", '"').lstrip()) + if not delta_text.startswith(self.bot_token): + # delta contains some text before the bot token + additional_content = delta_text.split(self.bot_token)[0] + else: + self.raw_tool_calls += delta_text.replace("\'", "\"") + self.raw_tool_calls = ( + self.raw_tool_calls.lstrip() + ) # leading spaces prevent us from raw_decoding + + if self.current_tool_start_index < 0: + if "[" in self.raw_tool_calls: + self.current_tool_start_index = self.raw_tool_calls.find( + "[") + 1 + self.current_tool_id += 1 + else: + # tool calls not started + return self._none_or_additional_content(additional_content) - # handle if we detected the BOT token which means the start of tool - # calling - if (self.bot_token_id in delta_token_ids - and len(delta_token_ids) == 1): - # if it's the only token, return None, so we don't send a chat - # completion any don't send a control token - return None - - # bit mask flags for partial JSON parsing. If the name hasn't been - # sent yet, don't allow sending - # an incomplete string since OpenAI only ever (as far as I have - # seen) allows sending the entire tool/ function name at once. - flags = Allow.ALL if self.current_tool_name_sent \ - else Allow.ALL & ~Allow.STR try: + _, end_index = self.json_decoder.raw_decode(self.raw_tool_calls) + self.tools_parsing_finished = True + if len(self.raw_tool_calls) > end_index: + additional_content = self.raw_tool_calls[end_index:] + except json.decoder.JSONDecodeError: + # we are in tool calls + pass + + if (self.current_tool_name_finished + and self.current_tool_arguments_finished): + if self.tools_parsing_finished: + return self._none_or_additional_content(additional_content) + # let's find the next tool starting position + next_tool_start_index = self._next_tool_starting_position() + if next_tool_start_index > 0: + self.current_tool_id += 1 + self.current_tool_start_index = next_tool_start_index + self.current_attribute_start_index = -1 + self.previous_attribute_end_index = 0 + self.current_tool_name_finished = False + self.current_tool_arguments_finished = False + + if self.current_tool_start_index >= len(self.raw_tool_calls): + # tool call has not started + return self._none_or_additional_content(additional_content) + raw_current_tool_call = self.raw_tool_calls[self. + current_tool_start_index:] + + if self.current_element_streaming is None: + # we are waiting for the next argument to be parsed + match_name = self.tool_call_first_attribute_name.match( + raw_current_tool_call) + match_arguments = self.tool_call_first_attribute_arguments.match( + raw_current_tool_call) + if not self.current_tool_name_finished and match_name: + if match_name.end() <= self.previous_attribute_end_index: + return self._none_or_additional_content(additional_content) + self.current_element_streaming = "name" + self.current_attribute_start_index = match_name.end() + elif not self.current_tool_arguments_finished and match_arguments: + if match_arguments.end() <= self.previous_attribute_end_index: + return self._none_or_additional_content(additional_content) + self.current_element_streaming = "arguments" + self.current_attribute_start_index = match_arguments.end() - 1 + # the `{` is the last IN the match part. + # We want it as the start index element + else: + # let's wait for more deltas + return self._none_or_additional_content(additional_content) - # replace BOT token with empty string, and convert single quotes - # to double to allow parsing as JSON since mistral uses single - # quotes instead of double for tool calls - parsable_arr = current_text.split(self.bot_token)[-1] - - # tool calls are generated in an array, so do partial JSON - # parsing on the entire array + if self.current_element_streaming == "name": try: - tool_call_arr: list[dict] = partial_json_parser.loads( - parsable_arr, flags) - except partial_json_parser.core.exceptions.MalformedJSON: - logger.debug('not enough tokens to parse into JSON yet') - return None - - # select as the current tool call the one we're on the state at - - current_tool_call: dict = tool_call_arr[self.current_tool_id] \ - if len(tool_call_arr) > 0 else {} - - # case -- if no tokens have been streamed for the tool, e.g. - # only the array brackets, stream nothing - if len(tool_call_arr) == 0: - return None - - # case: we are starting a new tool in the array - # -> array has > 0 length AND length has moved past cursor - elif (len(tool_call_arr) > 0 - and len(tool_call_arr) > self.current_tool_id + 1): - - # if we're moving on to a new call, first make sure we - # haven't missed anything in the previous one that was - # auto-generated due to JSON completions, but wasn't - # streamed to the client yet. - if self.current_tool_id >= 0: - diff: Union[str, None] = current_tool_call.get("arguments") - - if diff: - diff = json.dumps(diff, ensure_ascii=False).replace( - self.streamed_args_for_tool[self.current_tool_id], - "") - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=diff).model_dump( - exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += diff - else: - delta = None - else: - delta = None - # re-set stuff pertaining to progress in the current tool - self.current_tool_id = len(tool_call_arr) - 1 - self.current_tool_name_sent = False - self.streamed_args_for_tool.append("") - logger.debug("starting on new tool %d", self.current_tool_id) - return delta - - # case: update an existing tool - this is handled below - - # if the current tool name hasn't been sent, send if available - # - otherwise send nothing - if not self.current_tool_name_sent: - function_name = current_tool_call.get("name") - if function_name: - - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - type="function", - id=MistralToolCall.generate_random_id(), - function=DeltaFunctionCall( - name=function_name).model_dump( - exclude_none=True)) - ]) - self.current_tool_name_sent = True - else: - delta = None - - # now we know we're on the same tool call and we're streaming - # arguments + function_name, name_end_index = self._extracted_complete_name( + raw_current_tool_call, self.current_attribute_start_index) + except IndexError: + # name value has not started being generated + return self._none_or_additional_content(additional_content) + if function_name == "": + return self._none_or_additional_content(additional_content) else: + assert name_end_index is not None + # because the function name was successfully retrieved + + self.current_tool_name_finished = True + self.current_element_streaming = None + self.current_attribute_start_index = -1 + self.previous_attribute_end_index = name_end_index + delta = DeltaMessage( + content=additional_content, + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=MistralToolCall.generate_random_id(), + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True), + ) + ], + ) + return delta + if self.current_element_streaming == "arguments": + try: + diff, arguments_end_index = self._extract_argument_fragment( + raw_current_tool_call, + self.current_attribute_start_index, + delta_text, + ) + self.current_tool_arguments_finished = arguments_end_index != -1 + if self.current_tool_arguments_finished: + self.current_element_streaming = None + self.current_attribute_start_index = -1 + self.previous_attribute_end_index = arguments_end_index + delta = DeltaMessage( + content=additional_content, + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff).model_dump(exclude_none=True), + ) + ], + ) + return delta + except IndexError: + # arguments value has not started being generated + return self._none_or_additional_content(additional_content) + + def _extracted_complete_name( + self, raw_current_tool_call: str, + current_attribute_start_index: int + ) -> tuple[str, Union[int, None]]: + """ + Extract the complete function name from the current tool call. + + Args: + raw_current_tool_call: The raw JSON string of the current tool call + current_attribute_start_index: The starting index of the + name attribute in the raw_current_tool_call string + + Returns: + tuple: + - The function name, or "" if extraction failed + - The end index of the name in raw_current_tool_call, + or None if extraction failed + """ + partial_name_value = raw_current_tool_call[ + current_attribute_start_index:] + if match := self.string_value_pattern.match(partial_name_value): + return match.group(1), match.end() + current_attribute_start_index + return "", None + + def _extract_argument_fragment(self, raw_current_tool_call: str, + current_attribute_start_index: int, + delta: str) -> tuple[str, int]: + """ + Extract the relevant argument fragment from the current streaming delta. + + Args: + raw_current_tool_call: The raw JSON string of the current tool call + current_attribute_start_index: The starting index + of the arguments attribute in the raw string + delta: The new text added in this streaming step + + Returns: + tuple: + - The extracted argument diff text + to be sent in the streaming response + - The end index of the arguments in the raw string, + or -1 if not yet complete + """ + partial_arguments_value = raw_current_tool_call[ + current_attribute_start_index:] + try: + _, end_index = self.json_decoder.raw_decode( + partial_arguments_value) + return ( + delta[:len(delta) + end_index - len(partial_arguments_value)], + current_attribute_start_index + end_index, + ) + except json.decoder.JSONDecodeError: + # The arguments object is not complete + + # delta contains data from before the argument start + if len(delta) > len(partial_arguments_value): + return delta[-len(partial_arguments_value):], -1 + + # We can send the whole delta + return delta, -1 + + def _next_tool_starting_position(self) -> int: + """ + Find the starting position of the next tool + in the raw tool calls string. - prev_arguments = self.prev_tool_call_arr[ - self.current_tool_id].get("arguments") - cur_arguments = current_tool_call.get("arguments") - - new_text = delta_text.replace("\'", "\"") - if ('"}' in new_text): - new_text = new_text[:new_text.rindex('"}')] - - if not cur_arguments and not prev_arguments: - - delta = None - elif not cur_arguments and prev_arguments: - logger.error( - "INVARIANT - impossible to have arguments reset " - "mid-arguments") - delta = None - elif cur_arguments and not prev_arguments: - cur_arguments_json = json.dumps(cur_arguments, - ensure_ascii=False)[:-2] - logger.debug("finding %s in %s", new_text, - cur_arguments_json) - - if (new_text not in cur_arguments_json): - return None - arguments_delta = cur_arguments_json[:cur_arguments_json. - rindex(new_text) + - len(new_text)] - logger.debug("First tokens in arguments received: %s", - arguments_delta) - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=arguments_delta). - model_dump(exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += arguments_delta - - elif cur_arguments and prev_arguments: - cur_args_json = json.dumps(cur_arguments, - ensure_ascii=False) - prev_args_json = json.dumps(prev_arguments, - ensure_ascii=False) - logger.debug("Searching for diff between \n%s\n%s", - cur_args_json, prev_args_json) - - argument_diff = extract_intermediate_diff( - cur_args_json, prev_args_json) - logger.debug("got arguments diff: %s", argument_diff) - delta = DeltaMessage(tool_calls=[ - DeltaToolCall(index=self.current_tool_id, - function=DeltaFunctionCall( - arguments=argument_diff).model_dump( - exclude_none=True)) - ]) - self.streamed_args_for_tool[ - self.current_tool_id] += argument_diff - else: - # try parsing it with regular JSON - if it works we're - # at the end, and we need to send the difference between - # tokens streamed so far and the valid JSON - delta = None + Returns: + The index position where the next tool starts, + or -1 if no next tool is found yet + """ + assert self.current_tool_start_index >= 0 + current_tool_call = self.raw_tool_calls[self.current_tool_start_index:] + try: + _, end_index = self.json_decoder.raw_decode(current_tool_call) + return (self.current_tool_start_index + end_index + + current_tool_call[end_index:].find("{")) + except json.decoder.JSONDecodeError: + # The current tool object is not yet closed + return -1 + except IndexError: + # The next tool has not started yet + # and the delta just closes the current tool call + return -1 + + def _none_or_additional_content( + self, additional_content: str) -> Union[DeltaMessage, None]: + """ + Create a DeltaMessage with additional content if present, + otherwise return None. - # check to see if the name is defined and has been sent. if so, - # stream the name - otherwise keep waiting - # finish by setting old and returning None as base case - self.prev_tool_call_arr = tool_call_arr - return delta + Args: + additional_content: The text content to include in the message - except Exception: - logger.exception("Error trying to handle streaming tool call.") - logger.debug( - "Skipping chunk as a result of tool streaming extraction " - "error") - return None + Returns: + A DeltaMessage with the additional content, + or None if no content is provided + """ + if additional_content: + return DeltaMessage(content=additional_content) + return None