Skip to content

Commit baa82ae

Browse files
committed
Support concurrent tool call streaming with ID-based tracking
1 parent 2401400 commit baa82ae

File tree

7 files changed

+100
-62
lines changed

7 files changed

+100
-62
lines changed

python/mirascope/llm/clients/anthropic/_utils/decode.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,9 @@ def process_event(
157157
f"Received input_json_delta for {self.current_block_param['type']} block"
158158
)
159159
self.accumulated_tool_json += delta.partial_json
160-
yield ToolCallChunk(delta=delta.partial_json)
160+
yield ToolCallChunk(
161+
id=self.current_block_param["id"], delta=delta.partial_json
162+
)
161163
elif delta.type == "thinking_delta":
162164
if self.current_block_param["type"] != "thinking": # pragma: no cover
163165
raise RuntimeError(
@@ -194,7 +196,7 @@ def process_event(
194196
if self.accumulated_tool_json
195197
else {}
196198
)
197-
yield ToolCallEndChunk()
199+
yield ToolCallEndChunk(id=self.current_block_param["id"])
198200
elif block_type == "thinking":
199201
yield ThoughtEndChunk()
200202
else:

python/mirascope/llm/clients/google/_utils/decode.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ class _GoogleChunkProcessor:
128128

129129
def __init__(self) -> None:
130130
self.current_content_type: Literal["text", "tool_call", "thought"] | None = None
131+
self.current_tool_id: str | None = None
131132
self.accumulated_parts: list[genai_types.Part] = []
132133
self.reconstructed_content = genai_types.Content(parts=[])
133134

@@ -150,11 +151,13 @@ def process_chunk(
150151
yield TextEndChunk() # pragma: no cover
151152
self.current_content_type = None # pragma: no cover
152153
elif self.current_content_type == "tool_call" and not part.function_call:
153-
# In testing, Gemini never emits tool calls and text in the same message
154-
# (even when specifically asked in system and user prompt), so
155-
# the following code is uncovered but included for completeness
156-
yield ToolCallEndChunk() # pragma: no cover
154+
if self.current_tool_id is None:
155+
raise RuntimeError(
156+
"Missing tool_id when ending tool call"
157+
) # pragma: no cover
158+
yield ToolCallEndChunk(id=self.current_tool_id) # pragma: no cover
157159
self.current_content_type = None # pragma: no cover
160+
self.current_tool_id = None # pragma: no cover
158161

159162
if part.thought:
160163
if self.current_content_type is None:
@@ -179,17 +182,23 @@ def process_chunk(
179182
"Required name missing on Google function call"
180183
) # pragma: no cover
181184

185+
tool_id = function_call.id or UNKNOWN_TOOL_ID
186+
self.current_content_type = "tool_call"
187+
self.current_tool_id = tool_id
182188
yield ToolCallStartChunk(
183-
id=function_call.id or UNKNOWN_TOOL_ID,
189+
id=tool_id,
184190
name=function_call.name,
185191
)
186192

187193
yield ToolCallChunk(
194+
id=tool_id,
188195
delta=json.dumps(function_call.args)
189196
if function_call.args
190197
else "{}",
191198
)
192-
yield ToolCallEndChunk()
199+
yield ToolCallEndChunk(id=tool_id)
200+
self.current_content_type = None
201+
self.current_tool_id = None
193202

194203
if candidate.finish_reason:
195204
if self.current_content_type == "text":

python/mirascope/llm/clients/openai/completions/_utils/decode.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ class _OpenAIChunkProcessor:
8484
def __init__(self) -> None:
8585
self.current_content_type: Literal["text", "tool_call"] | None = None
8686
self.current_tool_index: int | None = None
87+
self.current_tool_id: str | None = None
8788
self.refusal_encountered = False
8889

8990
def process_chunk(self, chunk: openai_types.ChatCompletionChunk) -> ChunkIterator:
@@ -127,8 +128,13 @@ def process_chunk(self, chunk: openai_types.ChatCompletionChunk) -> ChunkIterato
127128
self.current_tool_index is not None
128129
and self.current_tool_index < index
129130
):
130-
yield ToolCallEndChunk()
131+
if self.current_tool_id is None:
132+
raise RuntimeError(
133+
"Missing tool_id when ending tool call"
134+
) # pragma: no cover
135+
yield ToolCallEndChunk(id=self.current_tool_id)
131136
self.current_tool_index = None
137+
self.current_tool_id = None
132138

133139
if self.current_tool_index is None:
134140
if not tool_call_delta.function or not (
@@ -144,19 +150,31 @@ def process_chunk(self, chunk: openai_types.ChatCompletionChunk) -> ChunkIterato
144150
f"Missing id for tool call at index {index}"
145151
) # pragma: no cover
146152

153+
self.current_tool_id = tool_id
147154
yield ToolCallStartChunk(
148155
id=tool_id,
149156
name=name,
150157
)
151158

152159
if tool_call_delta.function and tool_call_delta.function.arguments:
153-
yield ToolCallChunk(delta=tool_call_delta.function.arguments)
160+
if self.current_tool_id is None:
161+
raise RuntimeError(
162+
"Missing tool_id when processing tool call chunk"
163+
) # pragma: no cover
164+
yield ToolCallChunk(
165+
id=self.current_tool_id,
166+
delta=tool_call_delta.function.arguments,
167+
)
154168

155169
if choice.finish_reason:
156170
if self.current_content_type == "text":
157171
yield TextEndChunk()
158172
elif self.current_content_type == "tool_call":
159-
yield ToolCallEndChunk()
173+
if self.current_tool_id is None:
174+
raise RuntimeError(
175+
"Missing tool_id when ending tool call at finish"
176+
) # pragma: no cover
177+
yield ToolCallEndChunk(id=self.current_tool_id)
160178
elif self.current_content_type is not None: # pragma: no cover
161179
raise NotImplementedError()
162180

python/mirascope/llm/clients/openai/responses/_utils/decode.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,9 @@ def process_chunk(self, event: ResponseStreamEvent) -> ChunkIterator:
142142
)
143143
self.current_content_type = "tool_call"
144144
elif event.type == "response.function_call_arguments.delta":
145-
yield ToolCallChunk(delta=event.delta)
145+
yield ToolCallChunk(id=self.current_tool_call_id, delta=event.delta)
146146
elif event.type == "response.function_call_arguments.done":
147-
yield ToolCallEndChunk()
147+
yield ToolCallEndChunk(id=self.current_tool_call_id)
148148
self.current_content_type = None
149149
elif (
150150
event.type == "response.reasoning_text.delta"

python/mirascope/llm/content/tool_call.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ class ToolCallChunk:
4747
type: Literal["tool_call_chunk"] = "tool_call_chunk"
4848

4949
content_type: Literal["tool_call"] = "tool_call"
50-
"""The type of content reconstructed by this chunk."""
50+
51+
id: str
52+
"""The unique identifier for the tool call this chunk belongs to."""
5153

5254
delta: str
5355
"""The incremental json args added in this chunk."""
@@ -60,4 +62,6 @@ class ToolCallEndChunk:
6062
type: Literal["tool_call_end_chunk"] = "tool_call_end_chunk"
6163

6264
content_type: Literal["tool_call"] = "tool_call"
63-
"""The type of content reconstructed by this chunk."""
65+
66+
id: str
67+
"""The unique identifier for the tool call being ended."""

python/mirascope/llm/responses/base_stream_response.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def __init__(
215215

216216
self._chunk_iterator = chunk_iterator
217217
self._current_content: Text | Thought | ToolCall | None = None
218+
self._current_tool_calls_by_id: dict[str, ToolCall] = {}
218219

219220
self._processing_format_tool: bool = False
220221

@@ -306,39 +307,39 @@ def _handle_tool_call_chunk(
306307
self, chunk: ToolCallStartChunk | ToolCallChunk | ToolCallEndChunk
307308
) -> None:
308309
if chunk.type == "tool_call_start_chunk":
309-
if self._current_content:
310+
if self._current_content and self._current_content.type != "tool_call":
310311
raise RuntimeError(
311312
"Received tool_call_start_chunk while processing another chunk"
312313
)
313-
self._current_content = ToolCall(
314+
tool_call = ToolCall(
314315
id=chunk.id,
315316
name=chunk.name,
316317
args="",
317318
)
319+
self._current_tool_calls_by_id[chunk.id] = tool_call
320+
self._current_content = tool_call
318321

319322
elif chunk.type == "tool_call_chunk":
320-
if (
321-
self._current_content is None
322-
or self._current_content.type != "tool_call"
323-
):
323+
tool_call = self._current_tool_calls_by_id.get(chunk.id)
324+
if tool_call is None:
324325
raise RuntimeError(
325-
"Received tool_call_chunk while not processing tool call."
326+
f"Received tool_call_chunk for unknown tool call id: {chunk.id}"
326327
)
327-
self._current_content.args += chunk.delta
328+
tool_call.args += chunk.delta
328329

329330
elif chunk.type == "tool_call_end_chunk":
330-
if (
331-
self._current_content is None
332-
or self._current_content.type != "tool_call"
333-
):
331+
tool_call = self._current_tool_calls_by_id.get(chunk.id)
332+
if tool_call is None:
334333
raise RuntimeError(
335-
"Received tool_call_end_chunk while not processing tool call."
334+
f"Received tool_call_end_chunk for unknown tool call id: {chunk.id}"
336335
)
337-
if not self._current_content.args:
338-
self._current_content.args = "{}"
339-
self._content.append(self._current_content)
340-
self._tool_calls.append(self._current_content)
341-
self._current_content = None
336+
if not tool_call.args:
337+
tool_call.args = "{}"
338+
self._content.append(tool_call)
339+
self._tool_calls.append(tool_call)
340+
del self._current_tool_calls_by_id[chunk.id]
341+
if self._current_content is tool_call:
342+
self._current_content = None
342343

343344
def _pretty_chunk(self, chunk: AssistantContentChunk, spacer: str) -> str:
344345
match chunk.type:

python/tests/llm/responses/test_stream_response.py

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ class ChunkProcessingTestCase:
443443
id="tool_123",
444444
name="empty_function",
445445
),
446-
llm.ToolCallEndChunk(),
446+
llm.ToolCallEndChunk(id="tool_123"),
447447
],
448448
expected_contents=[
449449
[],
@@ -456,9 +456,11 @@ class ChunkProcessingTestCase:
456456
id="tool_456",
457457
name="test_function",
458458
),
459-
llm.ToolCallChunk(type="tool_call_chunk", delta='{"key": '),
460-
llm.ToolCallChunk(type="tool_call_chunk", delta='"value"}'),
461-
llm.ToolCallEndChunk(type="tool_call_end_chunk", content_type="tool_call"),
459+
llm.ToolCallChunk(id="tool_456", type="tool_call_chunk", delta='{"key": '),
460+
llm.ToolCallChunk(id="tool_456", type="tool_call_chunk", delta='"value"}'),
461+
llm.ToolCallEndChunk(
462+
id="tool_456", type="tool_call_end_chunk", content_type="tool_call"
463+
),
462464
],
463465
expected_contents=[
464466
[],
@@ -703,12 +705,12 @@ class InvalidChunkSequenceTestCase:
703705
expected_error="Received thought_end_chunk while not processing thought",
704706
),
705707
"tool_call_chunk_without_start": InvalidChunkSequenceTestCase(
706-
chunks=[llm.ToolCallChunk(delta='{"test": "value"}')],
707-
expected_error="Received tool_call_chunk while not processing tool call",
708+
chunks=[llm.ToolCallChunk(id="unknown_id", delta='{"test": "value"}')],
709+
expected_error="Received tool_call_chunk for unknown tool call id",
708710
),
709711
"tool_call_end_without_start": InvalidChunkSequenceTestCase(
710-
chunks=[llm.ToolCallEndChunk()],
711-
expected_error="Received tool_call_end_chunk while not processing tool call",
712+
chunks=[llm.ToolCallEndChunk(id="unknown_id")],
713+
expected_error="Received tool_call_end_chunk for unknown tool call id",
712714
),
713715
"overlapping_text_then_tool_call": InvalidChunkSequenceTestCase(
714716
chunks=[
@@ -734,9 +736,11 @@ class InvalidChunkSequenceTestCase:
734736
chunks=[
735737
llm.TextStartChunk(type="text_start_chunk"),
736738
llm.TextChunk(type="text_chunk", delta="test"),
737-
llm.ToolCallEndChunk(type="tool_call_end_chunk", content_type="tool_call"),
739+
llm.ToolCallEndChunk(
740+
id="unknown_id", type="tool_call_end_chunk", content_type="tool_call"
741+
),
738742
],
739-
expected_error="Received tool_call_end_chunk while not processing tool call",
743+
expected_error="Received tool_call_end_chunk for unknown tool call id",
740744
),
741745
}
742746

@@ -1102,9 +1106,9 @@ def example_format_tool_chunks() -> list[llm.StreamResponseChunk]:
11021106
id="call_format_123",
11031107
name=FORMAT_TOOL_NAME,
11041108
),
1105-
llm.ToolCallChunk(delta='{"title": "The Hobbit"'),
1106-
llm.ToolCallChunk(delta=', "author": "Tolkien"}'),
1107-
llm.ToolCallEndChunk(),
1109+
llm.ToolCallChunk(id="call_format_123", delta='{"title": "The Hobbit"'),
1110+
llm.ToolCallChunk(id="call_format_123", delta=', "author": "Tolkien"}'),
1111+
llm.ToolCallEndChunk(id="call_format_123"),
11081112
]
11091113

11101114

@@ -1122,15 +1126,15 @@ def example_format_tool_chunks_processed() -> list[llm.AssistantContentChunk]:
11221126
def example_format_tool_chunks_mixed() -> list[llm.StreamResponseChunk]:
11231127
return [
11241128
llm.ToolCallStartChunk(id="call_007", name="ring_tool"),
1125-
llm.ToolCallChunk(delta='{"ring_purpose": "to_rule_them_all"}'),
1126-
llm.ToolCallEndChunk(),
1129+
llm.ToolCallChunk(id="call_007", delta='{"ring_purpose": "to_rule_them_all"}'),
1130+
llm.ToolCallEndChunk(id="call_007"),
11271131
llm.ToolCallStartChunk(
11281132
id="call_format_123",
11291133
name=FORMAT_TOOL_NAME,
11301134
),
1131-
llm.ToolCallChunk(delta='{"title": "The Hobbit"'),
1132-
llm.ToolCallChunk(delta=', "author": "Tolkien"}'),
1133-
llm.ToolCallEndChunk(),
1135+
llm.ToolCallChunk(id="call_format_123", delta='{"title": "The Hobbit"'),
1136+
llm.ToolCallChunk(id="call_format_123", delta=', "author": "Tolkien"}'),
1137+
llm.ToolCallEndChunk(id="call_format_123"),
11341138
llm.TextStartChunk(),
11351139
llm.TextChunk(delta="A wizard is never late."),
11361140
llm.TextEndChunk(),
@@ -1141,8 +1145,8 @@ def example_format_tool_chunks_mixed() -> list[llm.StreamResponseChunk]:
11411145
def example_format_tool_chunks_mixed_processed() -> list[llm.AssistantContentChunk]:
11421146
return [
11431147
llm.ToolCallStartChunk(id="call_007", name="ring_tool"),
1144-
llm.ToolCallChunk(delta='{"ring_purpose": "to_rule_them_all"}'),
1145-
llm.ToolCallEndChunk(),
1148+
llm.ToolCallChunk(id="call_007", delta='{"ring_purpose": "to_rule_them_all"}'),
1149+
llm.ToolCallEndChunk(id="call_007"),
11461150
llm.TextStartChunk(),
11471151
llm.TextChunk(delta='{"title": "The Hobbit"'),
11481152
llm.TextChunk(delta=', "author": "Tolkien"}'),
@@ -1160,7 +1164,7 @@ def example_format_tool_chunks_max_tokens() -> list[llm.StreamResponseChunk]:
11601164
id="call_format_123",
11611165
name=FORMAT_TOOL_NAME,
11621166
),
1163-
llm.ToolCallEndChunk(),
1167+
llm.ToolCallEndChunk(id="call_format_123"),
11641168
llm.responses.FinishReasonChunk(finish_reason=llm.FinishReason.MAX_TOKENS),
11651169
]
11661170

@@ -1314,11 +1318,11 @@ def tool_two(y: str) -> str:
13141318

13151319
tool_call_chunks = [
13161320
llm.ToolCallStartChunk(id="call_1", name="tool_one"),
1317-
llm.ToolCallChunk(delta='{"x": 5}'),
1318-
llm.ToolCallEndChunk(),
1321+
llm.ToolCallChunk(id="call_1", delta='{"x": 5}'),
1322+
llm.ToolCallEndChunk(id="call_1"),
13191323
llm.ToolCallStartChunk(id="call_2", name="tool_two"),
1320-
llm.ToolCallChunk(delta='{"y": "hello"}'),
1321-
llm.ToolCallEndChunk(),
1324+
llm.ToolCallChunk(id="call_2", delta='{"y": "hello"}'),
1325+
llm.ToolCallEndChunk(id="call_2"),
13221326
]
13231327

13241328
stream_response = llm.StreamResponse(
@@ -1352,11 +1356,11 @@ async def tool_two(y: str) -> str:
13521356

13531357
tool_call_chunks = [
13541358
llm.ToolCallStartChunk(id="call_1", name="tool_one"),
1355-
llm.ToolCallChunk(delta='{"x": 5}'),
1356-
llm.ToolCallEndChunk(),
1359+
llm.ToolCallChunk(id="call_1", delta='{"x": 5}'),
1360+
llm.ToolCallEndChunk(id="call_1"),
13571361
llm.ToolCallStartChunk(id="call_2", name="tool_two"),
1358-
llm.ToolCallChunk(delta='{"y": "hello"}'),
1359-
llm.ToolCallEndChunk(),
1362+
llm.ToolCallChunk(id="call_2", delta='{"y": "hello"}'),
1363+
llm.ToolCallEndChunk(id="call_2"),
13601364
]
13611365

13621366
async def async_chunk_iter() -> AsyncIterator[llm.AssistantContentChunk]:

0 commit comments

Comments
 (0)