Skip to content

Commit 701ac1b

Browse files
proeverDouweM
andauthored
Yield events for unknown tool calls (#1960)
Co-authored-by: Douwe Maan <me@douwe.me>
1 parent 6651510 commit 701ac1b

File tree

3 files changed

+131
-7
lines changed

3 files changed

+131
-7
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,7 @@ async def process_function_tools( # noqa C901
648648
# if tool_name is in output_schema, it means we found a output tool but an error occurred in
649649
# validation, we don't add another part here
650650
if output_tool_name is not None:
651+
yield _messages.FunctionToolCallEvent(call)
651652
if found_used_output_tool:
652653
content = 'Output tool not used - a final result was already processed.'
653654
else:
@@ -658,9 +659,14 @@ async def process_function_tools( # noqa C901
658659
content=content,
659660
tool_call_id=call.tool_call_id,
660661
)
662+
yield _messages.FunctionToolResultEvent(part, tool_call_id=call.tool_call_id)
661663
output_parts.append(part)
662664
else:
663-
output_parts.append(_unknown_tool(call.tool_name, call.tool_call_id, ctx))
665+
yield _messages.FunctionToolCallEvent(call)
666+
667+
part = _unknown_tool(call.tool_name, call.tool_call_id, ctx)
668+
yield _messages.FunctionToolResultEvent(part, tool_call_id=call.tool_call_id)
669+
output_parts.append(part)
664670

665671
if not calls_to_run:
666672
return

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations as _annotations
22

33
import base64
4-
import uuid
54
from abc import ABC, abstractmethod
65
from collections.abc import Sequence
76
from dataclasses import dataclass, field, replace
@@ -888,13 +887,13 @@ class FunctionToolCallEvent:
888887

889888
part: ToolCallPart
890889
"""The (function) tool call to make."""
891-
call_id: str = field(init=False)
892-
"""An ID used for matching details about the call to its result. If present, defaults to the part's tool_call_id."""
893890
event_kind: Literal['function_tool_call'] = 'function_tool_call'
894891
"""Event type identifier, used as a discriminator."""
895892

896-
def __post_init__(self):
897-
self.call_id = self.part.tool_call_id or str(uuid.uuid4())
893+
@property
894+
def call_id(self) -> str:
895+
"""An ID used for matching details about the call to its result. If present, defaults to the part's tool_call_id."""
896+
return self.part.tool_call_id
898897

899898
__repr__ = _utils.dataclasses_no_defaults_repr
900899

tests/test_streaming.py

Lines changed: 120 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from collections.abc import AsyncIterator
77
from copy import deepcopy
88
from datetime import timezone
9-
from typing import Union
9+
from typing import Any, Union
1010

1111
import pytest
1212
from inline_snapshot import snapshot
@@ -15,6 +15,8 @@
1515
from pydantic_ai import Agent, UnexpectedModelBehavior, UserError, capture_run_messages
1616
from pydantic_ai.agent import AgentRun
1717
from pydantic_ai.messages import (
18+
FunctionToolCallEvent,
19+
FunctionToolResultEvent,
1820
ModelMessage,
1921
ModelRequest,
2022
ModelResponse,
@@ -921,3 +923,120 @@ def output_validator(data: OutputType | NotOutputType) -> OutputType | NotOutput
921923
async for output in stream.stream_output(debounce_by=None):
922924
outputs.append(output)
923925
assert outputs == [OutputType(value='a (validated)'), OutputType(value='a (validated)')]
926+
927+
928+
async def test_unknown_tool_call_events():
929+
"""Test that unknown tool calls emit both FunctionToolCallEvent and FunctionToolResultEvent during streaming."""
930+
931+
def call_mixed_tools(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
932+
"""Mock function that calls both known and unknown tools."""
933+
return ModelResponse(
934+
parts=[
935+
ToolCallPart('unknown_tool', {'arg': 'value'}),
936+
ToolCallPart('known_tool', {'x': 5}),
937+
]
938+
)
939+
940+
agent = Agent(FunctionModel(call_mixed_tools))
941+
942+
@agent.tool_plain
943+
def known_tool(x: int) -> int:
944+
return x * 2
945+
946+
event_parts: list[Any] = []
947+
948+
try:
949+
async with agent.iter('test') as agent_run:
950+
async for node in agent_run: # pragma: no branch
951+
if Agent.is_call_tools_node(node):
952+
async with node.stream(agent_run.ctx) as event_stream:
953+
async for event in event_stream:
954+
event_parts.append(event)
955+
956+
except UnexpectedModelBehavior:
957+
pass
958+
959+
assert event_parts == snapshot(
960+
[
961+
FunctionToolCallEvent(
962+
part=ToolCallPart(
963+
tool_name='unknown_tool',
964+
args={'arg': 'value'},
965+
tool_call_id=IsStr(),
966+
),
967+
),
968+
FunctionToolResultEvent(
969+
result=RetryPromptPart(
970+
content="Unknown tool name: 'unknown_tool'. Available tools: known_tool",
971+
tool_name='unknown_tool',
972+
tool_call_id=IsStr(),
973+
timestamp=IsNow(tz=timezone.utc),
974+
),
975+
tool_call_id=IsStr(),
976+
),
977+
FunctionToolCallEvent(
978+
part=ToolCallPart(tool_name='known_tool', args={'x': 5}, tool_call_id=IsStr()),
979+
),
980+
FunctionToolResultEvent(
981+
result=ToolReturnPart(
982+
tool_name='known_tool',
983+
content=10,
984+
tool_call_id=IsStr(),
985+
timestamp=IsNow(tz=timezone.utc),
986+
),
987+
tool_call_id=IsStr(),
988+
),
989+
FunctionToolCallEvent(
990+
part=ToolCallPart(
991+
tool_name='unknown_tool',
992+
args={'arg': 'value'},
993+
tool_call_id=IsStr(),
994+
),
995+
),
996+
]
997+
)
998+
999+
1000+
async def test_output_tool_validation_failure_events():
1001+
"""Test that output tools that fail validation emit events during streaming."""
1002+
1003+
def call_final_result_with_bad_data(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
1004+
"""Mock function that calls final_result tool with invalid data."""
1005+
assert info.output_tools is not None
1006+
return ModelResponse(
1007+
parts=[
1008+
ToolCallPart('final_result', {'bad_value': 'invalid'}), # Invalid field name
1009+
ToolCallPart('final_result', {'value': 'valid'}), # Valid field name
1010+
]
1011+
)
1012+
1013+
agent = Agent(FunctionModel(call_final_result_with_bad_data), output_type=OutputType)
1014+
1015+
event_parts: list[Any] = []
1016+
async with agent.iter('test') as agent_run:
1017+
async for node in agent_run:
1018+
if Agent.is_call_tools_node(node):
1019+
async with node.stream(agent_run.ctx) as event_stream:
1020+
async for event in event_stream:
1021+
event_parts.append(event)
1022+
1023+
assert event_parts == snapshot(
1024+
[
1025+
FunctionToolCallEvent(
1026+
part=ToolCallPart(
1027+
tool_name='final_result',
1028+
args={'bad_value': 'invalid'},
1029+
tool_call_id=IsStr(),
1030+
),
1031+
),
1032+
FunctionToolResultEvent(
1033+
result=ToolReturnPart(
1034+
tool_name='final_result',
1035+
content='Output tool not used - result failed validation.',
1036+
tool_call_id=IsStr(),
1037+
timestamp=IsNow(tz=timezone.utc),
1038+
),
1039+
tool_call_id=IsStr(),
1040+
),
1041+
]
1042+
)

0 commit comments

Comments
 (0)