diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index a04ae8646..c6bac1059 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -2052,6 +2052,69 @@ def usage(self) -> _usage.Usage: """Get usage statistics for the run so far, including token usage, model requests, and so on.""" return self._graph_run.state.usage + async def validate_structured_output( + self, response: _messages.ModelResponse, *, tool_name: str | None = None, allow_partial: bool = False + ) -> FinalResult[OutputDataT] | None: + """Validate a structured result message. + + This is analogous to `pydantic_ai.result.StreamedRunResult.validate_structured_output`, but can be used + when iterating over an agent run even when you aren't on the final response. + + Args: + response: The structured result message. + tool_name: If provided, this should be the name of the tool that will produce the output. + This is only included so that you can skip the tool-lookup step if you already know which tool + produced the output, which may be the case if calling this method in a loop while streaming. + (You can get the `tool_name` from the return value of this method when it is not `None`.) + allow_partial: If true, allow partial validation. + + Returns: + A `FinalResult` where the `output` field is the output that would be associated with this response, + or `None` if no final output was produced. Note that if you are validating an intermediate output + from a streaming response, this may not actually be the "final" output of the agent run, despite + the name of the return type. + """ + output_schema = self.ctx.deps.output_schema + output_validators = self.ctx.deps.output_validators + run_context = _agent_graph.build_run_context(self.ctx) + + call = None + tool_call_id: str | None = None + if isinstance(output_schema, _output.ToolOutputSchema): + if tool_name is None: + # infer the tool name from the response parts + for part, _ in output_schema.find_tool(response.parts): + tool_name = part.tool_name + tool_call_id = part.tool_call_id + break + if tool_name is None: + return None # no output tool was present, so there's no final output + + match = output_schema.find_named_tool(response.parts, tool_name) + if match is None: + raise exceptions.UnexpectedModelBehavior( # pragma: no cover + f'Invalid response, unable to find tool {tool_name!r}; expected one of {output_schema.tool_names()}' + ) + call, output_tool = match + result_data = await output_tool.process( + call, run_context, allow_partial=allow_partial, wrap_validation_errors=False + ) + elif isinstance(output_schema, _output.TextOutputSchema): + text = '\n\n'.join(x.content for x in response.parts if isinstance(x, _messages.TextPart)) + + result_data = await output_schema.process( + text, run_context, allow_partial=allow_partial, wrap_validation_errors=False + ) + else: + raise ValueError( + f'Unexpected output schema type: {type(output_schema)}; expected ToolOutputSchema or TextOutputSchema' + ) + + result_data = result_data.value + for validator in output_validators: + result_data = await validator.validate(result_data, call, run_context) # pragma: no cover + return FinalResult(result_data, tool_name=tool_name, tool_call_id=tool_call_id) + def __repr__(self) -> str: # pragma: no cover result = self._graph_run.result result_repr = '' if result is None else repr(result.output) diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index a8b46f029..45285302b 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -95,7 +95,7 @@ async def _validate_response( match = self._output_schema.find_named_tool(message.parts, output_tool_name) if match is None: raise exceptions.UnexpectedModelBehavior( # pragma: no cover - f'Invalid response, unable to find tool: {self._output_schema.tool_names()}' + f'Invalid response, unable to find tool {output_tool_name!r}; expected one of {self._output_schema.tool_names()}' ) call, output_tool = match @@ -413,7 +413,7 @@ async def validate_structured_output( match = self._output_schema.find_named_tool(message.parts, self._output_tool_name) if match is None: raise exceptions.UnexpectedModelBehavior( # pragma: no cover - f'Invalid response, unable to find tool: {self._output_schema.tool_names()}' + f'Invalid response, unable to find tool {self._output_tool_name!r}; expected one of {self._output_schema.tool_names()}' ) call, output_tool = match