Skip to content

Commit 2b3a9e5

Browse files
committed
Merge branch 'main' into toolsets
2 parents 6eae653 + cefe6a3 commit 2b3a9e5

File tree

3 files changed

+33
-17
lines changed

3 files changed

+33
-17
lines changed

examples/pydantic_ai_examples/stream_whales.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from typing import Annotated
1212

1313
import logfire
14-
from pydantic import Field, ValidationError
14+
from pydantic import Field
1515
from rich.console import Console
1616
from rich.live import Live
1717
from rich.table import Table
@@ -51,20 +51,7 @@ async def main():
5151
) as result:
5252
console.print('Response:', style='green')
5353

54-
async for message, last in result.stream_structured(debounce_by=0.01):
55-
try:
56-
whales = await result.validate_structured_output(
57-
message, allow_partial=not last
58-
)
59-
except ValidationError as exc:
60-
if all(
61-
e['type'] == 'missing' and e['loc'] == ('response',)
62-
for e in exc.errors()
63-
):
64-
continue
65-
else:
66-
raise
67-
54+
async for whales in result.stream(debounce_by=0.01):
6855
table = Table(
6956
title='Species of Whale',
7057
caption='Streaming Structured responses from GPT-4',

pydantic_ai_slim/pydantic_ai/result.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,12 @@ async def stream_output(self, *, debounce_by: float | None = 0.1) -> AsyncIterat
6262
"""Asynchronously stream the (validated) agent outputs."""
6363
async for response in self.stream_responses(debounce_by=debounce_by):
6464
if self._final_result_event is not None:
65-
yield await self._validate_response(response, self._final_result_event.tool_name, allow_partial=True)
65+
try:
66+
yield await self._validate_response(
67+
response, self._final_result_event.tool_name, allow_partial=True
68+
)
69+
except ValidationError:
70+
pass
6671
if self._final_result_event is not None: # pragma: no branch
6772
yield await self._validate_response(
6873
self._raw_stream_response.get(), self._final_result_event.tool_name, allow_partial=False

tests/test_streaming.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1061,7 +1061,7 @@ def call_final_result_with_bad_data(messages: list[ModelMessage], info: AgentInf
10611061
)
10621062

10631063

1064-
async def test_stream_prompted_output():
1064+
async def test_stream_structured_output():
10651065
class CityLocation(BaseModel):
10661066
city: str
10671067
country: str | None = None
@@ -1084,6 +1084,30 @@ class CityLocation(BaseModel):
10841084
assert result.is_complete
10851085

10861086

1087+
async def test_iter_stream_structured_output():
1088+
class CityLocation(BaseModel):
1089+
city: str
1090+
country: str | None = None
1091+
1092+
m = TestModel(custom_output_text='{"city": "Mexico City", "country": "Mexico"}')
1093+
1094+
agent = Agent(m, output_type=PromptedOutput(CityLocation))
1095+
1096+
async with agent.iter('') as run:
1097+
async for node in run:
1098+
if agent.is_model_request_node(node):
1099+
async with node.stream(run.ctx) as stream:
1100+
assert [c async for c in stream.stream_output(debounce_by=None)] == snapshot(
1101+
[
1102+
CityLocation(city='Mexico '),
1103+
CityLocation(city='Mexico City'),
1104+
CityLocation(city='Mexico City'),
1105+
CityLocation(city='Mexico City', country='Mexico'),
1106+
CityLocation(city='Mexico City', country='Mexico'),
1107+
]
1108+
)
1109+
1110+
10871111
def test_function_tool_event_tool_call_id_properties():
10881112
"""Ensure that the `tool_call_id` property on function tool events mirrors the underlying part's ID."""
10891113
# Prepare a ToolCallPart with a fixed ID

0 commit comments

Comments
 (0)