Skip to content

Commit 9b55b99

Browse files
committed
Make it possible to stream structured while using .iter
1 parent a341e56 commit 9b55b99

File tree

2 files changed

+62
-2
lines changed

2 files changed

+62
-2
lines changed

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2052,6 +2052,66 @@ def usage(self) -> _usage.Usage:
20522052
"""Get usage statistics for the run so far, including token usage, model requests, and so on."""
20532053
return self._graph_run.state.usage
20542054

2055+
async def validate_structured_output(
2056+
self, response: _messages.ModelResponse, *, allow_partial: bool = False
2057+
) -> FinalResult[OutputDataT] | None:
2058+
"""Validate a structured result message.
2059+
2060+
This is analogous to `pydantic_ai.result.StreamedRunResult.validate_structured_output`, but can be used
2061+
when iterating over an agent run even when you aren't on the final response.
2062+
2063+
Args:
2064+
response: The structured result message.
2065+
output_tool_name: If provided, this must be the name of the output tool to use for validation.
2066+
allow_partial: If true, allow partial validation.
2067+
2068+
Returns:
2069+
A `FinalResult` where the `output` field is the output that would be associated with this response,
2070+
or `None` if no final output was produced. Note that if you are validating an intermediate output
2071+
from a streaming response, this may not actually be the "final" output of the agent run, despite
2072+
the name of the return type.
2073+
"""
2074+
output_schema = self.ctx.deps.output_schema
2075+
output_validators = self.ctx.deps.output_validators
2076+
run_context = _agent_graph.build_run_context(self.ctx)
2077+
2078+
call = None
2079+
tool_name: str | None = None
2080+
tool_call_id: str | None = None
2081+
if isinstance(output_schema, _output.ToolOutputSchema):
2082+
tool_name: str | None = None
2083+
for part, _ in output_schema.find_tool(response.parts):
2084+
tool_name = part.tool_name
2085+
tool_call_id = part.tool_call_id
2086+
break
2087+
if tool_name is None:
2088+
return None # no output tool was present, so there's no final output
2089+
2090+
match = output_schema.find_named_tool(response.parts, tool_name)
2091+
if match is None:
2092+
raise exceptions.UnexpectedModelBehavior( # pragma: no cover
2093+
f'Invalid response, unable to find tool {tool_name!r}; expected one of {output_schema.tool_names()}'
2094+
)
2095+
call, output_tool = match
2096+
result_data = await output_tool.process(
2097+
call, run_context, allow_partial=allow_partial, wrap_validation_errors=False
2098+
)
2099+
elif isinstance(output_schema, _output.TextOutputSchema):
2100+
text = '\n\n'.join(x.content for x in response.parts if isinstance(x, _messages.TextPart))
2101+
2102+
result_data = await output_schema.process(
2103+
text, run_context, allow_partial=allow_partial, wrap_validation_errors=False
2104+
)
2105+
else:
2106+
raise exceptions.UserError(
2107+
f'Unexpected output schema type: {type(output_schema)}; expected ToolOutputSchema or TextOutputSchema'
2108+
)
2109+
2110+
result_data = result_data.value
2111+
for validator in output_validators:
2112+
result_data = await validator.validate(result_data, call, run_context) # pragma: no cover
2113+
return FinalResult(result_data, tool_name=tool_name, tool_call_id=tool_call_id)
2114+
20552115
def __repr__(self) -> str: # pragma: no cover
20562116
result = self._graph_run.result
20572117
result_repr = '<run not finished>' if result is None else repr(result.output)

pydantic_ai_slim/pydantic_ai/result.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ async def _validate_response(
9595
match = self._output_schema.find_named_tool(message.parts, output_tool_name)
9696
if match is None:
9797
raise exceptions.UnexpectedModelBehavior( # pragma: no cover
98-
f'Invalid response, unable to find tool: {self._output_schema.tool_names()}'
98+
f'Invalid response, unable to find tool {output_tool_name!r}; expected one of {self._output_schema.tool_names()}'
9999
)
100100

101101
call, output_tool = match
@@ -413,7 +413,7 @@ async def validate_structured_output(
413413
match = self._output_schema.find_named_tool(message.parts, self._output_tool_name)
414414
if match is None:
415415
raise exceptions.UnexpectedModelBehavior( # pragma: no cover
416-
f'Invalid response, unable to find tool: {self._output_schema.tool_names()}'
416+
f'Invalid response, unable to find tool {self._output_tool_name!r}; expected one of {self._output_schema.tool_names()}'
417417
)
418418

419419
call, output_tool = match

0 commit comments

Comments
 (0)