Skip to content

Commit cefe6a3

Browse files
authored
Make AgentStream.stream_output (available inside agent.iter) stream validated output data instead of raising validation errors (#2134)
1 parent d10f036 commit cefe6a3

File tree

3 files changed

+33
-17
lines changed

3 files changed

+33
-17
lines changed

examples/pydantic_ai_examples/stream_whales.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from typing import Annotated
1212

1313
import logfire
14-
from pydantic import Field, ValidationError
14+
from pydantic import Field
1515
from rich.console import Console
1616
from rich.live import Live
1717
from rich.table import Table
@@ -51,20 +51,7 @@ async def main():
5151
) as result:
5252
console.print('Response:', style='green')
5353

54-
async for message, last in result.stream_structured(debounce_by=0.01):
55-
try:
56-
whales = await result.validate_structured_output(
57-
message, allow_partial=not last
58-
)
59-
except ValidationError as exc:
60-
if all(
61-
e['type'] == 'missing' and e['loc'] == ('response',)
62-
for e in exc.errors()
63-
):
64-
continue
65-
else:
66-
raise
67-
54+
async for whales in result.stream(debounce_by=0.01):
6855
table = Table(
6956
title='Species of Whale',
7057
caption='Streaming Structured responses from GPT-4',

pydantic_ai_slim/pydantic_ai/result.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,12 @@ async def stream_output(self, *, debounce_by: float | None = 0.1) -> AsyncIterat
5959
"""Asynchronously stream the (validated) agent outputs."""
6060
async for response in self.stream_responses(debounce_by=debounce_by):
6161
if self._final_result_event is not None:
62-
yield await self._validate_response(response, self._final_result_event.tool_name, allow_partial=True)
62+
try:
63+
yield await self._validate_response(
64+
response, self._final_result_event.tool_name, allow_partial=True
65+
)
66+
except ValidationError:
67+
pass
6368
if self._final_result_event is not None: # pragma: no branch
6469
yield await self._validate_response(
6570
self._raw_stream_response.get(), self._final_result_event.tool_name, allow_partial=False

tests/test_streaming.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1056,7 +1056,7 @@ def call_final_result_with_bad_data(messages: list[ModelMessage], info: AgentInf
10561056
)
10571057

10581058

1059-
async def test_stream_prompted_output():
1059+
async def test_stream_structured_output():
10601060
class CityLocation(BaseModel):
10611061
city: str
10621062
country: str | None = None
@@ -1079,6 +1079,30 @@ class CityLocation(BaseModel):
10791079
assert result.is_complete
10801080

10811081

1082+
async def test_iter_stream_structured_output():
1083+
class CityLocation(BaseModel):
1084+
city: str
1085+
country: str | None = None
1086+
1087+
m = TestModel(custom_output_text='{"city": "Mexico City", "country": "Mexico"}')
1088+
1089+
agent = Agent(m, output_type=PromptedOutput(CityLocation))
1090+
1091+
async with agent.iter('') as run:
1092+
async for node in run:
1093+
if agent.is_model_request_node(node):
1094+
async with node.stream(run.ctx) as stream:
1095+
assert [c async for c in stream.stream_output(debounce_by=None)] == snapshot(
1096+
[
1097+
CityLocation(city='Mexico '),
1098+
CityLocation(city='Mexico City'),
1099+
CityLocation(city='Mexico City'),
1100+
CityLocation(city='Mexico City', country='Mexico'),
1101+
CityLocation(city='Mexico City', country='Mexico'),
1102+
]
1103+
)
1104+
1105+
10821106
def test_function_tool_event_tool_call_id_properties():
10831107
"""Ensure that the `tool_call_id` property on function tool events mirrors the underlying part's ID."""
10841108
# Prepare a ToolCallPart with a fixed ID

0 commit comments

Comments
 (0)