Skip to content

Commit ad6e826

Browse files
committed
Fix streaming tool calls
1 parent 05aa972 commit ad6e826

File tree

5 files changed

+93
-89
lines changed

5 files changed

+93
-89
lines changed

docs/output.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,6 @@ from pydantic import BaseModel
139139

140140
from pydantic_ai import Agent, ModelRetry, RunContext
141141
from pydantic_ai.exceptions import UnexpectedModelBehavior
142-
from pydantic_ai.output import ToolRetryError
143142

144143

145144
class Row(BaseModel):

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 61 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,7 @@ def multi_modal_content_identifier(identifier: str | bytes) -> str:
561561
return hashlib.sha1(identifier).hexdigest()[:6]
562562

563563

564-
async def process_function_tools(
564+
async def process_function_tools( # noqa: C901
565565
toolset: AbstractToolset[DepsT],
566566
tool_calls: list[_messages.ToolCallPart],
567567
final_result: result.FinalResult[NodeRunEndT] | None,
@@ -646,70 +646,72 @@ async def process_function_tools(
646646

647647
user_parts: list[_messages.UserPromptPart] = []
648648

649-
include_content = (
650-
ctx.deps.instrumentation_settings is not None and ctx.deps.instrumentation_settings.include_content
651-
)
649+
if calls_to_run:
650+
include_content = (
651+
ctx.deps.instrumentation_settings is not None and ctx.deps.instrumentation_settings.include_content
652+
)
652653

653-
# Run all tool tasks in parallel
654-
results_by_index: dict[int, _messages.ModelRequestPart] = {}
655-
with ctx.deps.tracer.start_as_current_span(
656-
'running tools',
657-
attributes={
658-
'tools': [call.tool_name for call in calls_to_run],
659-
'logfire.msg': f'running {len(calls_to_run)} tool{"" if len(calls_to_run) == 1 else "s"}',
660-
},
661-
):
662-
tasks = [
663-
asyncio.create_task(
664-
_call_function_tool(toolset, call, run_context, ctx.deps.tracer, include_content), name=call.tool_name
665-
)
666-
for call in calls_to_run
667-
]
668-
669-
pending = tasks
670-
while pending:
671-
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
672-
for task in done:
673-
index = tasks.index(task)
674-
tool_result = task.result()
675-
yield _messages.FunctionToolResultEvent(tool_result, tool_call_id=tool_result.tool_call_id)
676-
677-
if isinstance(tool_result, _messages.RetryPromptPart):
678-
results_by_index[index] = tool_result
679-
elif isinstance(tool_result, _messages.ToolReturnPart):
680-
681-
def process_content(content: Any) -> Any:
682-
if isinstance(content, _messages.MultiModalContentTypes):
683-
if isinstance(content, _messages.BinaryContent):
684-
identifier = multi_modal_content_identifier(content.data)
654+
# Run all tool tasks in parallel
655+
results_by_index: dict[int, _messages.ModelRequestPart] = {}
656+
with ctx.deps.tracer.start_as_current_span(
657+
'running tools',
658+
attributes={
659+
'tools': [call.tool_name for call in calls_to_run],
660+
'logfire.msg': f'running {len(calls_to_run)} tool{"" if len(calls_to_run) == 1 else "s"}',
661+
},
662+
):
663+
tasks = [
664+
asyncio.create_task(
665+
_call_function_tool(toolset, call, run_context, ctx.deps.tracer, include_content),
666+
name=call.tool_name,
667+
)
668+
for call in calls_to_run
669+
]
670+
671+
pending = tasks
672+
while pending:
673+
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
674+
for task in done:
675+
index = tasks.index(task)
676+
tool_result = task.result()
677+
yield _messages.FunctionToolResultEvent(tool_result, tool_call_id=tool_result.tool_call_id)
678+
679+
if isinstance(tool_result, _messages.RetryPromptPart):
680+
results_by_index[index] = tool_result
681+
elif isinstance(tool_result, _messages.ToolReturnPart):
682+
683+
def process_content(content: Any) -> Any:
684+
if isinstance(content, _messages.MultiModalContentTypes):
685+
if isinstance(content, _messages.BinaryContent):
686+
identifier = multi_modal_content_identifier(content.data)
687+
else:
688+
identifier = multi_modal_content_identifier(content.url)
689+
690+
user_parts.append(
691+
_messages.UserPromptPart(
692+
content=[f'This is file {identifier}:', content],
693+
timestamp=tool_result.timestamp,
694+
part_kind='user-prompt',
695+
)
696+
)
697+
return f'See file {identifier}'
685698
else:
686-
identifier = multi_modal_content_identifier(content.url)
699+
return content
687700

688-
user_parts.append(
689-
_messages.UserPromptPart(
690-
content=[f'This is file {identifier}:', content],
691-
timestamp=tool_result.timestamp,
692-
part_kind='user-prompt',
693-
)
694-
)
695-
return f'See file {identifier}'
701+
if isinstance(tool_result.content, list):
702+
contents = cast(list[Any], tool_result.content) # type: ignore
703+
tool_result.content = [process_content(content) for content in contents]
696704
else:
697-
return content
705+
tool_result.content = process_content(tool_result.content)
698706

699-
if isinstance(tool_result.content, list):
700-
contents = cast(list[Any], tool_result.content) # type: ignore
701-
tool_result.content = [process_content(content) for content in contents]
707+
results_by_index[index] = tool_result
702708
else:
703-
tool_result.content = process_content(tool_result.content)
704-
705-
results_by_index[index] = tool_result
706-
else:
707-
assert_never(tool_result)
709+
assert_never(tool_result)
708710

709-
# We append the results at the end, rather than as they are received, to retain a consistent ordering
710-
# This is mostly just to simplify testing
711-
for k in sorted(results_by_index):
712-
parts.append(results_by_index[k])
711+
# We append the results at the end, rather than as they are received, to retain a consistent ordering
712+
# This is mostly just to simplify testing
713+
for k in sorted(results_by_index):
714+
parts.append(results_by_index[k])
713715

714716
deferred_calls: list[_messages.ToolCallPart] = []
715717
for call in tool_calls_by_kind['deferred']:
@@ -739,6 +741,7 @@ def process_content(content: Any) -> Any:
739741
parts.extend(user_parts)
740742

741743
if final_result:
744+
# TODO: Use some better "box" object
742745
final_result_holder.append(final_result)
743746

744747

tests/models/test_anthropic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1700,7 +1700,7 @@ class CityLocation(BaseModel):
17001700

17011701
agent = Agent(m, output_type=NativeOutput(CityLocation))
17021702

1703-
with pytest.raises(UserError, match='Structured output is not supported by the model.'):
1703+
with pytest.raises(UserError, match='Native structured output is not supported by the model.'):
17041704
await agent.run('What is the largest city in the user country?')
17051705

17061706

tests/test_logfire.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ async def my_ret(x: int) -> str:
289289
},
290290
'outer_typed_dict_key': None,
291291
'strict': None,
292+
'kind': 'function',
292293
}
293294
],
294295
'output_mode': 'text',

tests/test_streaming.py

Lines changed: 30 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -613,18 +613,18 @@ def another_tool(y: int) -> int:
613613
timestamp=IsNow(tz=timezone.utc),
614614
tool_call_id=IsStr(),
615615
),
616-
RetryPromptPart(
617-
tool_name='unknown_tool',
618-
content="Unknown tool name: 'unknown_tool'. Available tools: regular_tool, another_tool, final_result",
619-
timestamp=IsNow(tz=timezone.utc),
620-
tool_call_id=IsStr(),
621-
),
622616
ToolReturnPart(
623617
tool_name='regular_tool', content=42, timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr()
624618
),
625619
ToolReturnPart(
626620
tool_name='another_tool', content=2, timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr()
627621
),
622+
RetryPromptPart(
623+
content="Unknown tool name: 'unknown_tool'. Available tools: final_result, regular_tool, another_tool",
624+
tool_name='unknown_tool',
625+
tool_call_id=IsStr(),
626+
timestamp=IsNow(tz=timezone.utc),
627+
),
628628
]
629629
),
630630
]
@@ -712,15 +712,15 @@ def another_tool(y: int) -> int: # pragma: no cover
712712
ModelRequest(
713713
parts=[
714714
ToolReturnPart(
715-
tool_name='regular_tool',
716-
content='Tool not executed - a final result was already processed.',
715+
tool_name='final_result',
716+
content='Final result processed.',
717717
tool_call_id=IsStr(),
718718
timestamp=IsNow(tz=datetime.timezone.utc),
719719
part_kind='tool-return',
720720
),
721721
ToolReturnPart(
722-
tool_name='final_result',
723-
content='Final result processed.',
722+
tool_name='regular_tool',
723+
content='Tool not executed - a final result was already processed.',
724724
tool_call_id=IsStr(),
725725
timestamp=IsNow(tz=datetime.timezone.utc),
726726
part_kind='tool-return',
@@ -733,10 +733,7 @@ def another_tool(y: int) -> int: # pragma: no cover
733733
part_kind='tool-return',
734734
),
735735
RetryPromptPart(
736-
content='Unknown tool name: '
737-
"'unknown_tool'. Available tools: "
738-
'regular_tool, another_tool, '
739-
'final_result',
736+
content="Unknown tool name: 'unknown_tool'. Available tools: final_result, regular_tool, another_tool",
740737
tool_name='unknown_tool',
741738
tool_call_id=IsStr(),
742739
timestamp=IsNow(tz=datetime.timezone.utc),
@@ -975,6 +972,13 @@ def known_tool(x: int) -> int:
975972

976973
assert event_parts == snapshot(
977974
[
975+
FunctionToolCallEvent(
976+
part=ToolCallPart(
977+
tool_name='known_tool',
978+
args={'x': 5},
979+
tool_call_id=IsStr(),
980+
)
981+
),
978982
FunctionToolCallEvent(
979983
part=ToolCallPart(
980984
tool_name='unknown_tool',
@@ -991,9 +995,6 @@ def known_tool(x: int) -> int:
991995
),
992996
tool_call_id=IsStr(),
993997
),
994-
FunctionToolCallEvent(
995-
part=ToolCallPart(tool_name='known_tool', args={'x': 5}, tool_call_id=IsStr()),
996-
),
997998
FunctionToolResultEvent(
998999
result=ToolReturnPart(
9991000
tool_name='known_tool',
@@ -1003,13 +1004,6 @@ def known_tool(x: int) -> int:
10031004
),
10041005
tool_call_id=IsStr(),
10051006
),
1006-
FunctionToolCallEvent(
1007-
part=ToolCallPart(
1008-
tool_name='unknown_tool',
1009-
args={'arg': 'value'},
1010-
tool_call_id=IsStr(),
1011-
),
1012-
),
10131007
]
10141008
)
10151009

@@ -1029,15 +1023,15 @@ def call_final_result_with_bad_data(messages: list[ModelMessage], info: AgentInf
10291023

10301024
agent = Agent(FunctionModel(call_final_result_with_bad_data), output_type=OutputType)
10311025

1032-
event_parts: list[Any] = []
1026+
events: list[Any] = []
10331027
async with agent.iter('test') as agent_run:
10341028
async for node in agent_run:
10351029
if Agent.is_call_tools_node(node):
10361030
async with node.stream(agent_run.ctx) as event_stream:
10371031
async for event in event_stream:
1038-
event_parts.append(event)
1032+
events.append(event)
10391033

1040-
assert event_parts == snapshot(
1034+
assert events == snapshot(
10411035
[
10421036
FunctionToolCallEvent(
10431037
part=ToolCallPart(
@@ -1047,9 +1041,16 @@ def call_final_result_with_bad_data(messages: list[ModelMessage], info: AgentInf
10471041
),
10481042
),
10491043
FunctionToolResultEvent(
1050-
result=ToolReturnPart(
1044+
result=RetryPromptPart(
1045+
content=[
1046+
{
1047+
'type': 'missing',
1048+
'loc': ('value',),
1049+
'msg': 'Field required',
1050+
'input': {'bad_value': 'invalid'},
1051+
}
1052+
],
10511053
tool_name='final_result',
1052-
content='Output tool not used - result failed validation.',
10531054
tool_call_id=IsStr(),
10541055
timestamp=IsNow(tz=timezone.utc),
10551056
),

0 commit comments

Comments
 (0)