From 215b39f0a16fe37ea15665f21ef107253e4222b5 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Wed, 25 Jun 2025 15:49:16 -0600 Subject: [PATCH 1/2] Make it possible to stream structured while using .iter --- pydantic_ai_slim/pydantic_ai/agent.py | 60 ++++++++++++++++++++++++++ pydantic_ai_slim/pydantic_ai/result.py | 4 +- 2 files changed, 62 insertions(+), 2 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index a04ae8646..e28a0b8b7 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -2052,6 +2052,66 @@ def usage(self) -> _usage.Usage: """Get usage statistics for the run so far, including token usage, model requests, and so on.""" return self._graph_run.state.usage + async def validate_structured_output( + self, response: _messages.ModelResponse, *, allow_partial: bool = False + ) -> FinalResult[OutputDataT] | None: + """Validate a structured result message. + + This is analogous to `pydantic_ai.result.StreamedRunResult.validate_structured_output`, but can be used + when iterating over an agent run even when you aren't on the final response. + + Args: + response: The structured result message. + output_tool_name: If provided, this must be the name of the output tool to use for validation. + allow_partial: If true, allow partial validation. + + Returns: + A `FinalResult` where the `output` field is the output that would be associated with this response, + or `None` if no final output was produced. Note that if you are validating an intermediate output + from a streaming response, this may not actually be the "final" output of the agent run, despite + the name of the return type. + """ + output_schema = self.ctx.deps.output_schema + output_validators = self.ctx.deps.output_validators + run_context = _agent_graph.build_run_context(self.ctx) + + call = None + tool_name: str | None = None + tool_call_id: str | None = None + if isinstance(output_schema, _output.ToolOutputSchema): + tool_name: str | None = None + for part, _ in output_schema.find_tool(response.parts): + tool_name = part.tool_name + tool_call_id = part.tool_call_id + break + if tool_name is None: + return None # no output tool was present, so there's no final output + + match = output_schema.find_named_tool(response.parts, tool_name) + if match is None: + raise exceptions.UnexpectedModelBehavior( # pragma: no cover + f'Invalid response, unable to find tool {tool_name!r}; expected one of {output_schema.tool_names()}' + ) + call, output_tool = match + result_data = await output_tool.process( + call, run_context, allow_partial=allow_partial, wrap_validation_errors=False + ) + elif isinstance(output_schema, _output.TextOutputSchema): + text = '\n\n'.join(x.content for x in response.parts if isinstance(x, _messages.TextPart)) + + result_data = await output_schema.process( + text, run_context, allow_partial=allow_partial, wrap_validation_errors=False + ) + else: + raise ValueError( + f'Unexpected output schema type: {type(output_schema)}; expected ToolOutputSchema or TextOutputSchema' + ) + + result_data = result_data.value + for validator in output_validators: + result_data = await validator.validate(result_data, call, run_context) # pragma: no cover + return FinalResult(result_data, tool_name=tool_name, tool_call_id=tool_call_id) + def __repr__(self) -> str: # pragma: no cover result = self._graph_run.result result_repr = '' if result is None else repr(result.output) diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index a8b46f029..45285302b 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -95,7 +95,7 @@ async def _validate_response( match = self._output_schema.find_named_tool(message.parts, output_tool_name) if match is None: raise exceptions.UnexpectedModelBehavior( # pragma: no cover - f'Invalid response, unable to find tool: {self._output_schema.tool_names()}' + f'Invalid response, unable to find tool {output_tool_name!r}; expected one of {self._output_schema.tool_names()}' ) call, output_tool = match @@ -413,7 +413,7 @@ async def validate_structured_output( match = self._output_schema.find_named_tool(message.parts, self._output_tool_name) if match is None: raise exceptions.UnexpectedModelBehavior( # pragma: no cover - f'Invalid response, unable to find tool: {self._output_schema.tool_names()}' + f'Invalid response, unable to find tool {self._output_tool_name!r}; expected one of {self._output_schema.tool_names()}' ) call, output_tool = match From 3ebc6a20d465ee17b35186fa94134464fceb87ce Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Wed, 25 Jun 2025 15:59:31 -0600 Subject: [PATCH 2/2] Rework the tool_name stuff --- pydantic_ai_slim/pydantic_ai/agent.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index e28a0b8b7..c6bac1059 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -2053,7 +2053,7 @@ def usage(self) -> _usage.Usage: return self._graph_run.state.usage async def validate_structured_output( - self, response: _messages.ModelResponse, *, allow_partial: bool = False + self, response: _messages.ModelResponse, *, tool_name: str | None = None, allow_partial: bool = False ) -> FinalResult[OutputDataT] | None: """Validate a structured result message. @@ -2062,7 +2062,10 @@ async def validate_structured_output( Args: response: The structured result message. - output_tool_name: If provided, this must be the name of the output tool to use for validation. + tool_name: If provided, this should be the name of the tool that will produce the output. + This is only included so that you can skip the tool-lookup step if you already know which tool + produced the output, which may be the case if calling this method in a loop while streaming. + (You can get the `tool_name` from the return value of this method when it is not `None`.) allow_partial: If true, allow partial validation. Returns: @@ -2076,14 +2079,14 @@ async def validate_structured_output( run_context = _agent_graph.build_run_context(self.ctx) call = None - tool_name: str | None = None tool_call_id: str | None = None if isinstance(output_schema, _output.ToolOutputSchema): - tool_name: str | None = None - for part, _ in output_schema.find_tool(response.parts): - tool_name = part.tool_name - tool_call_id = part.tool_call_id - break + if tool_name is None: + # infer the tool name from the response parts + for part, _ in output_schema.find_tool(response.parts): + tool_name = part.tool_name + tool_call_id = part.tool_call_id + break if tool_name is None: return None # no output tool was present, so there's no final output