Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions python/mirascope/llm/clients/anthropic/_utils/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@ def process_event(
f"Received input_json_delta for {self.current_block_param['type']} block"
)
self.accumulated_tool_json += delta.partial_json
yield ToolCallChunk(delta=delta.partial_json)
yield ToolCallChunk(
id=self.current_block_param["id"], delta=delta.partial_json
)
elif delta.type == "thinking_delta":
if self.current_block_param["type"] != "thinking": # pragma: no cover
raise RuntimeError(
Expand Down Expand Up @@ -194,7 +196,7 @@ def process_event(
if self.accumulated_tool_json
else {}
)
yield ToolCallEndChunk()
yield ToolCallEndChunk(id=self.current_block_param["id"])
elif block_type == "thinking":
yield ThoughtEndChunk()
else:
Expand Down
21 changes: 15 additions & 6 deletions python/mirascope/llm/clients/google/_utils/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class _GoogleChunkProcessor:

def __init__(self) -> None:
self.current_content_type: Literal["text", "tool_call", "thought"] | None = None
self.current_tool_id: str | None = None
self.accumulated_parts: list[genai_types.Part] = []
self.reconstructed_content = genai_types.Content(parts=[])

Expand All @@ -150,11 +151,13 @@ def process_chunk(
yield TextEndChunk() # pragma: no cover
self.current_content_type = None # pragma: no cover
elif self.current_content_type == "tool_call" and not part.function_call:
# In testing, Gemini never emits tool calls and text in the same message
# (even when specifically asked in system and user prompt), so
# the following code is uncovered but included for completeness
yield ToolCallEndChunk() # pragma: no cover
if self.current_tool_id is None:
raise RuntimeError(
"Missing tool_id when ending tool call"
) # pragma: no cover
yield ToolCallEndChunk(id=self.current_tool_id) # pragma: no cover
self.current_content_type = None # pragma: no cover
self.current_tool_id = None # pragma: no cover

if part.thought:
if self.current_content_type is None:
Expand All @@ -179,17 +182,23 @@ def process_chunk(
"Required name missing on Google function call"
) # pragma: no cover

tool_id = function_call.id or UNKNOWN_TOOL_ID
self.current_content_type = "tool_call"
self.current_tool_id = tool_id
yield ToolCallStartChunk(
id=function_call.id or UNKNOWN_TOOL_ID,
id=tool_id,
name=function_call.name,
)

yield ToolCallChunk(
id=tool_id,
delta=json.dumps(function_call.args)
if function_call.args
else "{}",
)
yield ToolCallEndChunk()
yield ToolCallEndChunk(id=tool_id)
self.current_content_type = None
self.current_tool_id = None

if candidate.finish_reason:
if self.current_content_type == "text":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class _OpenAIChunkProcessor:
def __init__(self) -> None:
self.current_content_type: Literal["text", "tool_call"] | None = None
self.current_tool_index: int | None = None
self.current_tool_id: str | None = None
self.refusal_encountered = False

def process_chunk(self, chunk: openai_types.ChatCompletionChunk) -> ChunkIterator:
Expand Down Expand Up @@ -127,8 +128,13 @@ def process_chunk(self, chunk: openai_types.ChatCompletionChunk) -> ChunkIterato
self.current_tool_index is not None
and self.current_tool_index < index
):
yield ToolCallEndChunk()
if self.current_tool_id is None:
raise RuntimeError(
"Missing tool_id when ending tool call"
) # pragma: no cover
yield ToolCallEndChunk(id=self.current_tool_id)
self.current_tool_index = None
self.current_tool_id = None

if self.current_tool_index is None:
if not tool_call_delta.function or not (
Expand All @@ -144,19 +150,31 @@ def process_chunk(self, chunk: openai_types.ChatCompletionChunk) -> ChunkIterato
f"Missing id for tool call at index {index}"
) # pragma: no cover

self.current_tool_id = tool_id
yield ToolCallStartChunk(
id=tool_id,
name=name,
)

if tool_call_delta.function and tool_call_delta.function.arguments:
yield ToolCallChunk(delta=tool_call_delta.function.arguments)
if self.current_tool_id is None:
raise RuntimeError(
"Missing tool_id when processing tool call chunk"
) # pragma: no cover
yield ToolCallChunk(
id=self.current_tool_id,
delta=tool_call_delta.function.arguments,
)

if choice.finish_reason:
if self.current_content_type == "text":
yield TextEndChunk()
elif self.current_content_type == "tool_call":
yield ToolCallEndChunk()
if self.current_tool_id is None:
raise RuntimeError(
"Missing tool_id when ending tool call at finish"
) # pragma: no cover
yield ToolCallEndChunk(id=self.current_tool_id)
elif self.current_content_type is not None: # pragma: no cover
raise NotImplementedError()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,9 @@ def process_chunk(self, event: ResponseStreamEvent) -> ChunkIterator:
)
self.current_content_type = "tool_call"
elif event.type == "response.function_call_arguments.delta":
yield ToolCallChunk(delta=event.delta)
yield ToolCallChunk(id=self.current_tool_call_id, delta=event.delta)
elif event.type == "response.function_call_arguments.done":
yield ToolCallEndChunk()
yield ToolCallEndChunk(id=self.current_tool_call_id)
self.current_content_type = None
elif (
event.type == "response.reasoning_text.delta"
Expand Down
8 changes: 6 additions & 2 deletions python/mirascope/llm/content/tool_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ class ToolCallChunk:
type: Literal["tool_call_chunk"] = "tool_call_chunk"

content_type: Literal["tool_call"] = "tool_call"
"""The type of content reconstructed by this chunk."""

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason we don't have id on ToolCallChunk or ToolCallEndChunk is because (in my observation/understanding) the providers always stream tool calls one at a time, so we can put the id on the ToolCallStartChunk, and then safely infer that every subsequent ToolCallChunk and ToolCallEndChunk is part of the tool call with the same id. This matches with my understanding of LLMs as fundamentally serial generators that produce output one-token-at-a-time (ie. when tool calls are parallel from being included in the same response, they were still generated serially).

Are you seeing a provider that streams tool call chunks interleaved? If so which one?

(Aside, I'm guessing these PR descriptions are all auto-generated by Graphite? It has that smooth AI vibe that feels like its coming up with a plausible explanation based on the code change, but doesn't truly understand why the change is being made. It would be nice if the descriptions had some more of the specific human intent in them)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the question. Grok streams tool calls in parallel. We often see call_a begin, then call_b begin, then more deltas from both IDs mixed together. We rely on the id on every ToolCallChunk and ToolCallEndChunk to keep each call stitched to the right arguments. If we drop the id, the chunks from different calls become ambiguous.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

id: str
"""The unique identifier for the tool call this chunk belongs to."""

delta: str
"""The incremental json args added in this chunk."""
Expand All @@ -60,4 +62,6 @@ class ToolCallEndChunk:
type: Literal["tool_call_end_chunk"] = "tool_call_end_chunk"

content_type: Literal["tool_call"] = "tool_call"
"""The type of content reconstructed by this chunk."""

id: str
"""The unique identifier for the tool call being ended."""
37 changes: 19 additions & 18 deletions python/mirascope/llm/responses/base_stream_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def __init__(

self._chunk_iterator = chunk_iterator
self._current_content: Text | Thought | ToolCall | None = None
self._current_tool_calls_by_id: dict[str, ToolCall] = {}

self._processing_format_tool: bool = False

Expand Down Expand Up @@ -306,39 +307,39 @@ def _handle_tool_call_chunk(
self, chunk: ToolCallStartChunk | ToolCallChunk | ToolCallEndChunk
) -> None:
if chunk.type == "tool_call_start_chunk":
if self._current_content:
if self._current_content and self._current_content.type != "tool_call":
raise RuntimeError(
"Received tool_call_start_chunk while processing another chunk"
)
self._current_content = ToolCall(
tool_call = ToolCall(
id=chunk.id,
name=chunk.name,
args="",
)
self._current_tool_calls_by_id[chunk.id] = tool_call
self._current_content = tool_call

elif chunk.type == "tool_call_chunk":
if (
self._current_content is None
or self._current_content.type != "tool_call"
):
tool_call = self._current_tool_calls_by_id.get(chunk.id)
if tool_call is None:
raise RuntimeError(
"Received tool_call_chunk while not processing tool call."
f"Received tool_call_chunk for unknown tool call id: {chunk.id}"
)
self._current_content.args += chunk.delta
tool_call.args += chunk.delta

elif chunk.type == "tool_call_end_chunk":
if (
self._current_content is None
or self._current_content.type != "tool_call"
):
tool_call = self._current_tool_calls_by_id.get(chunk.id)
if tool_call is None:
raise RuntimeError(
"Received tool_call_end_chunk while not processing tool call."
f"Received tool_call_end_chunk for unknown tool call id: {chunk.id}"
)
if not self._current_content.args:
self._current_content.args = "{}"
self._content.append(self._current_content)
self._tool_calls.append(self._current_content)
self._current_content = None
if not tool_call.args:
tool_call.args = "{}"
self._content.append(tool_call)
self._tool_calls.append(tool_call)
del self._current_tool_calls_by_id[chunk.id]
if self._current_content is tool_call:
self._current_content = None

def _pretty_chunk(self, chunk: AssistantContentChunk, spacer: str) -> str:
match chunk.type:
Expand Down
29 changes: 21 additions & 8 deletions python/tests/e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations

import gzip
import hashlib
import inspect
import json
Expand Down Expand Up @@ -179,16 +180,28 @@ def sanitize_response(response: dict[str, Any]) -> dict[str, Any]:
response = deepcopy(response)

if "body" in response and "string" in response["body"]:
body_str = response["body"]["string"]
if isinstance(body_str, bytes):
raw_body = response["body"]["string"]
was_gzip = False

if isinstance(raw_body, bytes):
try:
body_str = body_str.decode()
body_str = raw_body.decode()
except UnicodeDecodeError:
# Body is likely compressed (gzip) or binary data
# Skip sanitization for these responses
return response
try:
decompressed = gzip.decompress(raw_body)
body_str = decompressed.decode()
was_gzip = True
except (OSError, UnicodeDecodeError):
# Binary payload we cannot sanitize
return response
else:
body_str = raw_body

if "access_token" in body_str or "id_token" in body_str:
def _encode(text: str) -> bytes:
data = text.encode()
return gzip.compress(data) if was_gzip else data

try:
body_json = json.loads(body_str)
if "access_token" in body_json:
Expand All @@ -197,7 +210,7 @@ def sanitize_response(response: dict[str, Any]) -> dict[str, Any]:
body_json["id_token"] = "<filtered>"
if "refresh_token" in body_json:
body_json["refresh_token"] = "<filtered>"
response["body"]["string"] = json.dumps(body_json).encode()
response["body"]["string"] = _encode(json.dumps(body_json))
except (json.JSONDecodeError, KeyError):
body_str = re.sub(
r'"access_token":\s*"[^"]+"',
Expand All @@ -212,7 +225,7 @@ def sanitize_response(response: dict[str, Any]) -> dict[str, Any]:
'"refresh_token": "<filtered>"',
body_str,
)
response["body"]["string"] = body_str.encode()
response["body"]["string"] = _encode(body_str)

return response

Expand Down
Loading
Loading