From cd2c0c8eb32c1c2b55cb461c488009bba6a8dcf9 Mon Sep 17 00:00:00 2001 From: Koudai Aono Date: Tue, 28 Oct 2025 16:57:58 +0900 Subject: [PATCH] Support concurrent tool call streaming with ID-based tracking --- .../llm/clients/anthropic/_utils/decode.py | 6 +- .../llm/clients/google/_utils/decode.py | 21 +++++-- .../openai/completions/_utils/decode.py | 24 ++++++- .../clients/openai/responses/_utils/decode.py | 4 +- python/mirascope/llm/content/tool_call.py | 8 ++- .../llm/responses/base_stream_response.py | 37 +++++------ python/tests/e2e/conftest.py | 29 ++++++--- .../llm/responses/test_stream_response.py | 62 ++++++++++--------- 8 files changed, 121 insertions(+), 70 deletions(-) diff --git a/python/mirascope/llm/clients/anthropic/_utils/decode.py b/python/mirascope/llm/clients/anthropic/_utils/decode.py index 993eb321c..9dccbbcbf 100644 --- a/python/mirascope/llm/clients/anthropic/_utils/decode.py +++ b/python/mirascope/llm/clients/anthropic/_utils/decode.py @@ -157,7 +157,9 @@ def process_event( f"Received input_json_delta for {self.current_block_param['type']} block" ) self.accumulated_tool_json += delta.partial_json - yield ToolCallChunk(delta=delta.partial_json) + yield ToolCallChunk( + id=self.current_block_param["id"], delta=delta.partial_json + ) elif delta.type == "thinking_delta": if self.current_block_param["type"] != "thinking": # pragma: no cover raise RuntimeError( @@ -194,7 +196,7 @@ def process_event( if self.accumulated_tool_json else {} ) - yield ToolCallEndChunk() + yield ToolCallEndChunk(id=self.current_block_param["id"]) elif block_type == "thinking": yield ThoughtEndChunk() else: diff --git a/python/mirascope/llm/clients/google/_utils/decode.py b/python/mirascope/llm/clients/google/_utils/decode.py index 9cb40294a..cb0f8e0ed 100644 --- a/python/mirascope/llm/clients/google/_utils/decode.py +++ b/python/mirascope/llm/clients/google/_utils/decode.py @@ -128,6 +128,7 @@ class _GoogleChunkProcessor: def __init__(self) -> None: self.current_content_type: Literal["text", "tool_call", "thought"] | None = None + self.current_tool_id: str | None = None self.accumulated_parts: list[genai_types.Part] = [] self.reconstructed_content = genai_types.Content(parts=[]) @@ -150,11 +151,13 @@ def process_chunk( yield TextEndChunk() # pragma: no cover self.current_content_type = None # pragma: no cover elif self.current_content_type == "tool_call" and not part.function_call: - # In testing, Gemini never emits tool calls and text in the same message - # (even when specifically asked in system and user prompt), so - # the following code is uncovered but included for completeness - yield ToolCallEndChunk() # pragma: no cover + if self.current_tool_id is None: + raise RuntimeError( + "Missing tool_id when ending tool call" + ) # pragma: no cover + yield ToolCallEndChunk(id=self.current_tool_id) # pragma: no cover self.current_content_type = None # pragma: no cover + self.current_tool_id = None # pragma: no cover if part.thought: if self.current_content_type is None: @@ -179,17 +182,23 @@ def process_chunk( "Required name missing on Google function call" ) # pragma: no cover + tool_id = function_call.id or UNKNOWN_TOOL_ID + self.current_content_type = "tool_call" + self.current_tool_id = tool_id yield ToolCallStartChunk( - id=function_call.id or UNKNOWN_TOOL_ID, + id=tool_id, name=function_call.name, ) yield ToolCallChunk( + id=tool_id, delta=json.dumps(function_call.args) if function_call.args else "{}", ) - yield ToolCallEndChunk() + yield ToolCallEndChunk(id=tool_id) + self.current_content_type = None + self.current_tool_id = None if candidate.finish_reason: if self.current_content_type == "text": diff --git a/python/mirascope/llm/clients/openai/completions/_utils/decode.py b/python/mirascope/llm/clients/openai/completions/_utils/decode.py index f41aab47e..82da7d739 100644 --- a/python/mirascope/llm/clients/openai/completions/_utils/decode.py +++ b/python/mirascope/llm/clients/openai/completions/_utils/decode.py @@ -84,6 +84,7 @@ class _OpenAIChunkProcessor: def __init__(self) -> None: self.current_content_type: Literal["text", "tool_call"] | None = None self.current_tool_index: int | None = None + self.current_tool_id: str | None = None self.refusal_encountered = False def process_chunk(self, chunk: openai_types.ChatCompletionChunk) -> ChunkIterator: @@ -127,8 +128,13 @@ def process_chunk(self, chunk: openai_types.ChatCompletionChunk) -> ChunkIterato self.current_tool_index is not None and self.current_tool_index < index ): - yield ToolCallEndChunk() + if self.current_tool_id is None: + raise RuntimeError( + "Missing tool_id when ending tool call" + ) # pragma: no cover + yield ToolCallEndChunk(id=self.current_tool_id) self.current_tool_index = None + self.current_tool_id = None if self.current_tool_index is None: if not tool_call_delta.function or not ( @@ -144,19 +150,31 @@ def process_chunk(self, chunk: openai_types.ChatCompletionChunk) -> ChunkIterato f"Missing id for tool call at index {index}" ) # pragma: no cover + self.current_tool_id = tool_id yield ToolCallStartChunk( id=tool_id, name=name, ) if tool_call_delta.function and tool_call_delta.function.arguments: - yield ToolCallChunk(delta=tool_call_delta.function.arguments) + if self.current_tool_id is None: + raise RuntimeError( + "Missing tool_id when processing tool call chunk" + ) # pragma: no cover + yield ToolCallChunk( + id=self.current_tool_id, + delta=tool_call_delta.function.arguments, + ) if choice.finish_reason: if self.current_content_type == "text": yield TextEndChunk() elif self.current_content_type == "tool_call": - yield ToolCallEndChunk() + if self.current_tool_id is None: + raise RuntimeError( + "Missing tool_id when ending tool call at finish" + ) # pragma: no cover + yield ToolCallEndChunk(id=self.current_tool_id) elif self.current_content_type is not None: # pragma: no cover raise NotImplementedError() diff --git a/python/mirascope/llm/clients/openai/responses/_utils/decode.py b/python/mirascope/llm/clients/openai/responses/_utils/decode.py index 59fa763ad..de3bef5ab 100644 --- a/python/mirascope/llm/clients/openai/responses/_utils/decode.py +++ b/python/mirascope/llm/clients/openai/responses/_utils/decode.py @@ -142,9 +142,9 @@ def process_chunk(self, event: ResponseStreamEvent) -> ChunkIterator: ) self.current_content_type = "tool_call" elif event.type == "response.function_call_arguments.delta": - yield ToolCallChunk(delta=event.delta) + yield ToolCallChunk(id=self.current_tool_call_id, delta=event.delta) elif event.type == "response.function_call_arguments.done": - yield ToolCallEndChunk() + yield ToolCallEndChunk(id=self.current_tool_call_id) self.current_content_type = None elif ( event.type == "response.reasoning_text.delta" diff --git a/python/mirascope/llm/content/tool_call.py b/python/mirascope/llm/content/tool_call.py index d0c869f7e..b10caa10b 100644 --- a/python/mirascope/llm/content/tool_call.py +++ b/python/mirascope/llm/content/tool_call.py @@ -47,7 +47,9 @@ class ToolCallChunk: type: Literal["tool_call_chunk"] = "tool_call_chunk" content_type: Literal["tool_call"] = "tool_call" - """The type of content reconstructed by this chunk.""" + + id: str + """The unique identifier for the tool call this chunk belongs to.""" delta: str """The incremental json args added in this chunk.""" @@ -60,4 +62,6 @@ class ToolCallEndChunk: type: Literal["tool_call_end_chunk"] = "tool_call_end_chunk" content_type: Literal["tool_call"] = "tool_call" - """The type of content reconstructed by this chunk.""" + + id: str + """The unique identifier for the tool call being ended.""" diff --git a/python/mirascope/llm/responses/base_stream_response.py b/python/mirascope/llm/responses/base_stream_response.py index c115ff211..e0c59873f 100644 --- a/python/mirascope/llm/responses/base_stream_response.py +++ b/python/mirascope/llm/responses/base_stream_response.py @@ -215,6 +215,7 @@ def __init__( self._chunk_iterator = chunk_iterator self._current_content: Text | Thought | ToolCall | None = None + self._current_tool_calls_by_id: dict[str, ToolCall] = {} self._processing_format_tool: bool = False @@ -306,39 +307,39 @@ def _handle_tool_call_chunk( self, chunk: ToolCallStartChunk | ToolCallChunk | ToolCallEndChunk ) -> None: if chunk.type == "tool_call_start_chunk": - if self._current_content: + if self._current_content and self._current_content.type != "tool_call": raise RuntimeError( "Received tool_call_start_chunk while processing another chunk" ) - self._current_content = ToolCall( + tool_call = ToolCall( id=chunk.id, name=chunk.name, args="", ) + self._current_tool_calls_by_id[chunk.id] = tool_call + self._current_content = tool_call elif chunk.type == "tool_call_chunk": - if ( - self._current_content is None - or self._current_content.type != "tool_call" - ): + tool_call = self._current_tool_calls_by_id.get(chunk.id) + if tool_call is None: raise RuntimeError( - "Received tool_call_chunk while not processing tool call." + f"Received tool_call_chunk for unknown tool call id: {chunk.id}" ) - self._current_content.args += chunk.delta + tool_call.args += chunk.delta elif chunk.type == "tool_call_end_chunk": - if ( - self._current_content is None - or self._current_content.type != "tool_call" - ): + tool_call = self._current_tool_calls_by_id.get(chunk.id) + if tool_call is None: raise RuntimeError( - "Received tool_call_end_chunk while not processing tool call." + f"Received tool_call_end_chunk for unknown tool call id: {chunk.id}" ) - if not self._current_content.args: - self._current_content.args = "{}" - self._content.append(self._current_content) - self._tool_calls.append(self._current_content) - self._current_content = None + if not tool_call.args: + tool_call.args = "{}" + self._content.append(tool_call) + self._tool_calls.append(tool_call) + del self._current_tool_calls_by_id[chunk.id] + if self._current_content is tool_call: + self._current_content = None def _pretty_chunk(self, chunk: AssistantContentChunk, spacer: str) -> str: match chunk.type: diff --git a/python/tests/e2e/conftest.py b/python/tests/e2e/conftest.py index 6b8d97e3d..7498bb813 100644 --- a/python/tests/e2e/conftest.py +++ b/python/tests/e2e/conftest.py @@ -5,6 +5,7 @@ from __future__ import annotations +import gzip import hashlib import inspect import json @@ -179,16 +180,28 @@ def sanitize_response(response: dict[str, Any]) -> dict[str, Any]: response = deepcopy(response) if "body" in response and "string" in response["body"]: - body_str = response["body"]["string"] - if isinstance(body_str, bytes): + raw_body = response["body"]["string"] + was_gzip = False + + if isinstance(raw_body, bytes): try: - body_str = body_str.decode() + body_str = raw_body.decode() except UnicodeDecodeError: - # Body is likely compressed (gzip) or binary data - # Skip sanitization for these responses - return response + try: + decompressed = gzip.decompress(raw_body) + body_str = decompressed.decode() + was_gzip = True + except (OSError, UnicodeDecodeError): + # Binary payload we cannot sanitize + return response + else: + body_str = raw_body if "access_token" in body_str or "id_token" in body_str: + def _encode(text: str) -> bytes: + data = text.encode() + return gzip.compress(data) if was_gzip else data + try: body_json = json.loads(body_str) if "access_token" in body_json: @@ -197,7 +210,7 @@ def sanitize_response(response: dict[str, Any]) -> dict[str, Any]: body_json["id_token"] = "" if "refresh_token" in body_json: body_json["refresh_token"] = "" - response["body"]["string"] = json.dumps(body_json).encode() + response["body"]["string"] = _encode(json.dumps(body_json)) except (json.JSONDecodeError, KeyError): body_str = re.sub( r'"access_token":\s*"[^"]+"', @@ -212,7 +225,7 @@ def sanitize_response(response: dict[str, Any]) -> dict[str, Any]: '"refresh_token": ""', body_str, ) - response["body"]["string"] = body_str.encode() + response["body"]["string"] = _encode(body_str) return response diff --git a/python/tests/llm/responses/test_stream_response.py b/python/tests/llm/responses/test_stream_response.py index a2dc7b598..7c4b452a4 100644 --- a/python/tests/llm/responses/test_stream_response.py +++ b/python/tests/llm/responses/test_stream_response.py @@ -443,7 +443,7 @@ class ChunkProcessingTestCase: id="tool_123", name="empty_function", ), - llm.ToolCallEndChunk(), + llm.ToolCallEndChunk(id="tool_123"), ], expected_contents=[ [], @@ -456,9 +456,11 @@ class ChunkProcessingTestCase: id="tool_456", name="test_function", ), - llm.ToolCallChunk(type="tool_call_chunk", delta='{"key": '), - llm.ToolCallChunk(type="tool_call_chunk", delta='"value"}'), - llm.ToolCallEndChunk(type="tool_call_end_chunk", content_type="tool_call"), + llm.ToolCallChunk(id="tool_456", type="tool_call_chunk", delta='{"key": '), + llm.ToolCallChunk(id="tool_456", type="tool_call_chunk", delta='"value"}'), + llm.ToolCallEndChunk( + id="tool_456", type="tool_call_end_chunk", content_type="tool_call" + ), ], expected_contents=[ [], @@ -703,12 +705,12 @@ class InvalidChunkSequenceTestCase: expected_error="Received thought_end_chunk while not processing thought", ), "tool_call_chunk_without_start": InvalidChunkSequenceTestCase( - chunks=[llm.ToolCallChunk(delta='{"test": "value"}')], - expected_error="Received tool_call_chunk while not processing tool call", + chunks=[llm.ToolCallChunk(id="unknown_id", delta='{"test": "value"}')], + expected_error="Received tool_call_chunk for unknown tool call id", ), "tool_call_end_without_start": InvalidChunkSequenceTestCase( - chunks=[llm.ToolCallEndChunk()], - expected_error="Received tool_call_end_chunk while not processing tool call", + chunks=[llm.ToolCallEndChunk(id="unknown_id")], + expected_error="Received tool_call_end_chunk for unknown tool call id", ), "overlapping_text_then_tool_call": InvalidChunkSequenceTestCase( chunks=[ @@ -734,9 +736,11 @@ class InvalidChunkSequenceTestCase: chunks=[ llm.TextStartChunk(type="text_start_chunk"), llm.TextChunk(type="text_chunk", delta="test"), - llm.ToolCallEndChunk(type="tool_call_end_chunk", content_type="tool_call"), + llm.ToolCallEndChunk( + id="unknown_id", type="tool_call_end_chunk", content_type="tool_call" + ), ], - expected_error="Received tool_call_end_chunk while not processing tool call", + expected_error="Received tool_call_end_chunk for unknown tool call id", ), } @@ -1102,9 +1106,9 @@ def example_format_tool_chunks() -> list[llm.StreamResponseChunk]: id="call_format_123", name=FORMAT_TOOL_NAME, ), - llm.ToolCallChunk(delta='{"title": "The Hobbit"'), - llm.ToolCallChunk(delta=', "author": "Tolkien"}'), - llm.ToolCallEndChunk(), + llm.ToolCallChunk(id="call_format_123", delta='{"title": "The Hobbit"'), + llm.ToolCallChunk(id="call_format_123", delta=', "author": "Tolkien"}'), + llm.ToolCallEndChunk(id="call_format_123"), ] @@ -1122,15 +1126,15 @@ def example_format_tool_chunks_processed() -> list[llm.AssistantContentChunk]: def example_format_tool_chunks_mixed() -> list[llm.StreamResponseChunk]: return [ llm.ToolCallStartChunk(id="call_007", name="ring_tool"), - llm.ToolCallChunk(delta='{"ring_purpose": "to_rule_them_all"}'), - llm.ToolCallEndChunk(), + llm.ToolCallChunk(id="call_007", delta='{"ring_purpose": "to_rule_them_all"}'), + llm.ToolCallEndChunk(id="call_007"), llm.ToolCallStartChunk( id="call_format_123", name=FORMAT_TOOL_NAME, ), - llm.ToolCallChunk(delta='{"title": "The Hobbit"'), - llm.ToolCallChunk(delta=', "author": "Tolkien"}'), - llm.ToolCallEndChunk(), + llm.ToolCallChunk(id="call_format_123", delta='{"title": "The Hobbit"'), + llm.ToolCallChunk(id="call_format_123", delta=', "author": "Tolkien"}'), + llm.ToolCallEndChunk(id="call_format_123"), llm.TextStartChunk(), llm.TextChunk(delta="A wizard is never late."), llm.TextEndChunk(), @@ -1141,8 +1145,8 @@ def example_format_tool_chunks_mixed() -> list[llm.StreamResponseChunk]: def example_format_tool_chunks_mixed_processed() -> list[llm.AssistantContentChunk]: return [ llm.ToolCallStartChunk(id="call_007", name="ring_tool"), - llm.ToolCallChunk(delta='{"ring_purpose": "to_rule_them_all"}'), - llm.ToolCallEndChunk(), + llm.ToolCallChunk(id="call_007", delta='{"ring_purpose": "to_rule_them_all"}'), + llm.ToolCallEndChunk(id="call_007"), llm.TextStartChunk(), llm.TextChunk(delta='{"title": "The Hobbit"'), llm.TextChunk(delta=', "author": "Tolkien"}'), @@ -1160,7 +1164,7 @@ def example_format_tool_chunks_max_tokens() -> list[llm.StreamResponseChunk]: id="call_format_123", name=FORMAT_TOOL_NAME, ), - llm.ToolCallEndChunk(), + llm.ToolCallEndChunk(id="call_format_123"), llm.responses.FinishReasonChunk(finish_reason=llm.FinishReason.MAX_TOKENS), ] @@ -1314,11 +1318,11 @@ def tool_two(y: str) -> str: tool_call_chunks = [ llm.ToolCallStartChunk(id="call_1", name="tool_one"), - llm.ToolCallChunk(delta='{"x": 5}'), - llm.ToolCallEndChunk(), + llm.ToolCallChunk(id="call_1", delta='{"x": 5}'), + llm.ToolCallEndChunk(id="call_1"), llm.ToolCallStartChunk(id="call_2", name="tool_two"), - llm.ToolCallChunk(delta='{"y": "hello"}'), - llm.ToolCallEndChunk(), + llm.ToolCallChunk(id="call_2", delta='{"y": "hello"}'), + llm.ToolCallEndChunk(id="call_2"), ] stream_response = llm.StreamResponse( @@ -1352,11 +1356,11 @@ async def tool_two(y: str) -> str: tool_call_chunks = [ llm.ToolCallStartChunk(id="call_1", name="tool_one"), - llm.ToolCallChunk(delta='{"x": 5}'), - llm.ToolCallEndChunk(), + llm.ToolCallChunk(id="call_1", delta='{"x": 5}'), + llm.ToolCallEndChunk(id="call_1"), llm.ToolCallStartChunk(id="call_2", name="tool_two"), - llm.ToolCallChunk(delta='{"y": "hello"}'), - llm.ToolCallEndChunk(), + llm.ToolCallChunk(id="call_2", delta='{"y": "hello"}'), + llm.ToolCallEndChunk(id="call_2"), ] async def async_chunk_iter() -> AsyncIterator[llm.AssistantContentChunk]: