|
6 | 6 | from collections.abc import AsyncIterator
|
7 | 7 | from copy import deepcopy
|
8 | 8 | from datetime import timezone
|
9 |
| -from typing import Union |
| 9 | +from typing import Any, Union |
10 | 10 |
|
11 | 11 | import pytest
|
12 | 12 | from inline_snapshot import snapshot
|
|
15 | 15 | from pydantic_ai import Agent, UnexpectedModelBehavior, UserError, capture_run_messages
|
16 | 16 | from pydantic_ai.agent import AgentRun
|
17 | 17 | from pydantic_ai.messages import (
|
| 18 | + FunctionToolCallEvent, |
| 19 | + FunctionToolResultEvent, |
18 | 20 | ModelMessage,
|
19 | 21 | ModelRequest,
|
20 | 22 | ModelResponse,
|
@@ -921,3 +923,120 @@ def output_validator(data: OutputType | NotOutputType) -> OutputType | NotOutput
|
921 | 923 | async for output in stream.stream_output(debounce_by=None):
|
922 | 924 | outputs.append(output)
|
923 | 925 | assert outputs == [OutputType(value='a (validated)'), OutputType(value='a (validated)')]
|
| 926 | + |
| 927 | + |
| 928 | +async def test_unknown_tool_call_events(): |
| 929 | + """Test that unknown tool calls emit both FunctionToolCallEvent and FunctionToolResultEvent during streaming.""" |
| 930 | + |
| 931 | + def call_mixed_tools(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: |
| 932 | + """Mock function that calls both known and unknown tools.""" |
| 933 | + return ModelResponse( |
| 934 | + parts=[ |
| 935 | + ToolCallPart('unknown_tool', {'arg': 'value'}), |
| 936 | + ToolCallPart('known_tool', {'x': 5}), |
| 937 | + ] |
| 938 | + ) |
| 939 | + |
| 940 | + agent = Agent(FunctionModel(call_mixed_tools)) |
| 941 | + |
| 942 | + @agent.tool_plain |
| 943 | + def known_tool(x: int) -> int: |
| 944 | + return x * 2 |
| 945 | + |
| 946 | + event_parts: list[Any] = [] |
| 947 | + |
| 948 | + try: |
| 949 | + async with agent.iter('test') as agent_run: |
| 950 | + async for node in agent_run: # pragma: no branch |
| 951 | + if Agent.is_call_tools_node(node): |
| 952 | + async with node.stream(agent_run.ctx) as event_stream: |
| 953 | + async for event in event_stream: |
| 954 | + event_parts.append(event) |
| 955 | + |
| 956 | + except UnexpectedModelBehavior: |
| 957 | + pass |
| 958 | + |
| 959 | + assert event_parts == snapshot( |
| 960 | + [ |
| 961 | + FunctionToolCallEvent( |
| 962 | + part=ToolCallPart( |
| 963 | + tool_name='unknown_tool', |
| 964 | + args={'arg': 'value'}, |
| 965 | + tool_call_id=IsStr(), |
| 966 | + ), |
| 967 | + ), |
| 968 | + FunctionToolResultEvent( |
| 969 | + result=RetryPromptPart( |
| 970 | + content="Unknown tool name: 'unknown_tool'. Available tools: known_tool", |
| 971 | + tool_name='unknown_tool', |
| 972 | + tool_call_id=IsStr(), |
| 973 | + timestamp=IsNow(tz=timezone.utc), |
| 974 | + ), |
| 975 | + tool_call_id=IsStr(), |
| 976 | + ), |
| 977 | + FunctionToolCallEvent( |
| 978 | + part=ToolCallPart(tool_name='known_tool', args={'x': 5}, tool_call_id=IsStr()), |
| 979 | + ), |
| 980 | + FunctionToolResultEvent( |
| 981 | + result=ToolReturnPart( |
| 982 | + tool_name='known_tool', |
| 983 | + content=10, |
| 984 | + tool_call_id=IsStr(), |
| 985 | + timestamp=IsNow(tz=timezone.utc), |
| 986 | + ), |
| 987 | + tool_call_id=IsStr(), |
| 988 | + ), |
| 989 | + FunctionToolCallEvent( |
| 990 | + part=ToolCallPart( |
| 991 | + tool_name='unknown_tool', |
| 992 | + args={'arg': 'value'}, |
| 993 | + tool_call_id=IsStr(), |
| 994 | + ), |
| 995 | + ), |
| 996 | + ] |
| 997 | + ) |
| 998 | + |
| 999 | + |
| 1000 | +async def test_output_tool_validation_failure_events(): |
| 1001 | + """Test that output tools that fail validation emit events during streaming.""" |
| 1002 | + |
| 1003 | + def call_final_result_with_bad_data(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: |
| 1004 | + """Mock function that calls final_result tool with invalid data.""" |
| 1005 | + assert info.output_tools is not None |
| 1006 | + return ModelResponse( |
| 1007 | + parts=[ |
| 1008 | + ToolCallPart('final_result', {'bad_value': 'invalid'}), # Invalid field name |
| 1009 | + ToolCallPart('final_result', {'value': 'valid'}), # Valid field name |
| 1010 | + ] |
| 1011 | + ) |
| 1012 | + |
| 1013 | + agent = Agent(FunctionModel(call_final_result_with_bad_data), output_type=OutputType) |
| 1014 | + |
| 1015 | + event_parts: list[Any] = [] |
| 1016 | + async with agent.iter('test') as agent_run: |
| 1017 | + async for node in agent_run: |
| 1018 | + if Agent.is_call_tools_node(node): |
| 1019 | + async with node.stream(agent_run.ctx) as event_stream: |
| 1020 | + async for event in event_stream: |
| 1021 | + event_parts.append(event) |
| 1022 | + |
| 1023 | + assert event_parts == snapshot( |
| 1024 | + [ |
| 1025 | + FunctionToolCallEvent( |
| 1026 | + part=ToolCallPart( |
| 1027 | + tool_name='final_result', |
| 1028 | + args={'bad_value': 'invalid'}, |
| 1029 | + tool_call_id=IsStr(), |
| 1030 | + ), |
| 1031 | + ), |
| 1032 | + FunctionToolResultEvent( |
| 1033 | + result=ToolReturnPart( |
| 1034 | + tool_name='final_result', |
| 1035 | + content='Output tool not used - result failed validation.', |
| 1036 | + tool_call_id=IsStr(), |
| 1037 | + timestamp=IsNow(tz=timezone.utc), |
| 1038 | + ), |
| 1039 | + tool_call_id=IsStr(), |
| 1040 | + ), |
| 1041 | + ] |
| 1042 | + ) |
0 commit comments