Skip to content

Commit 8039c20

Browse files
authored
Support NativeOutput and PromptedOutput modes in addition to ToolOutput (#1628)
1 parent 812a469 commit 8039c20

File tree

77 files changed

+10214
-379
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

77 files changed

+10214
-379
lines changed

docs/api/output.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# `pydantic_ai.output`
2+
3+
::: pydantic_ai.output
4+
options:
5+
inherited_members: true
6+
members:
7+
- OutputDataT
8+
- ToolOutput
9+
- NativeOutput
10+
- PromptedOutput
11+
- TextOutput

docs/api/result.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,4 @@
44
options:
55
inherited_members: true
66
members:
7-
- OutputDataT
87
- StreamedRunResult

docs/output.md

Lines changed: 168 additions & 15 deletions
Large diffs are not rendered by default.

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ nav:
6464
- api/agent.md
6565
- api/tools.md
6666
- api/common_tools.md
67+
- api/output.md
6768
- api/result.md
6869
- api/messages.md
6970
- api/exceptions.md

pydantic_ai_slim/pydantic_ai/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
)
1313
from .format_prompt import format_as_xml
1414
from .messages import AudioUrl, BinaryContent, DocumentUrl, ImageUrl, VideoUrl
15-
from .result import ToolOutput
15+
from .output import NativeOutput, PromptedOutput, TextOutput, ToolOutput
1616
from .tools import RunContext, Tool
1717

1818
__all__ = (
@@ -41,8 +41,11 @@
4141
# tools
4242
'Tool',
4343
'RunContext',
44-
# result
44+
# output
4545
'ToolOutput',
46+
'NativeOutput',
47+
'PromptedOutput',
48+
'TextOutput',
4649
# format_prompt
4750
'format_as_xml',
4851
)

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from pydantic_graph.nodes import End, NodeRunEndT
1919

2020
from . import _output, _system_prompt, exceptions, messages as _messages, models, result, usage as _usage
21-
from .result import OutputDataT
21+
from .output import OutputDataT, OutputSpec
2222
from .settings import ModelSettings, merge_model_settings
2323
from .tools import RunContext, Tool, ToolDefinition, ToolsPrepareFunc
2424

@@ -102,7 +102,7 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]):
102102
end_strategy: EndStrategy
103103
get_instructions: Callable[[RunContext[DepsT]], Awaitable[str | None]]
104104

105-
output_schema: _output.OutputSchema[OutputDataT] | None
105+
output_schema: _output.OutputSchema[OutputDataT]
106106
output_validators: list[_output.OutputValidator[DepsT, OutputDataT]]
107107

108108
history_processors: Sequence[HistoryProcessor[DepsT]]
@@ -286,10 +286,23 @@ async def add_mcp_server_tools(server: MCPServer) -> None:
286286
function_tool_defs = await ctx.deps.prepare_tools(run_context, function_tool_defs) or []
287287

288288
output_schema = ctx.deps.output_schema
289+
290+
output_tools = []
291+
output_object = None
292+
if isinstance(output_schema, _output.ToolOutputSchema):
293+
output_tools = output_schema.tool_defs()
294+
elif isinstance(output_schema, _output.NativeOutputSchema):
295+
output_object = output_schema.object_def
296+
297+
# ToolOrTextOutputSchema, NativeOutputSchema, and PromptedOutputSchema all inherit from TextOutputSchema
298+
allow_text_output = isinstance(output_schema, _output.TextOutputSchema)
299+
289300
return models.ModelRequestParameters(
290301
function_tools=function_tool_defs,
291-
allow_text_output=_output.allow_text_output(output_schema),
292-
output_tools=output_schema.tool_defs() if output_schema is not None else [],
302+
output_mode=output_schema.mode,
303+
output_tools=output_tools,
304+
output_object=output_object,
305+
allow_text_output=allow_text_output,
293306
)
294307

295308

@@ -484,7 +497,7 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
484497
# when the model has already returned text along side tool calls
485498
# in this scenario, if text responses are allowed, we return text from the most recent model
486499
# response, if any
487-
if _output.allow_text_output(ctx.deps.output_schema):
500+
if isinstance(ctx.deps.output_schema, _output.TextOutputSchema):
488501
for message in reversed(ctx.state.message_history):
489502
if isinstance(message, _messages.ModelResponse):
490503
last_texts = [p.content for p in message.parts if isinstance(p, _messages.TextPart)]
@@ -507,10 +520,11 @@ async def _handle_tool_calls(
507520
output_schema = ctx.deps.output_schema
508521
run_context = build_run_context(ctx)
509522

510-
# first, look for the output tool call
511523
final_result: result.FinalResult[NodeRunEndT] | None = None
512524
parts: list[_messages.ModelRequestPart] = []
513-
if output_schema is not None:
525+
526+
# first, look for the output tool call
527+
if isinstance(output_schema, _output.ToolOutputSchema):
514528
for call, output_tool in output_schema.find_tool(tool_calls):
515529
try:
516530
result_data = await output_tool.process(call, run_context)
@@ -568,9 +582,9 @@ async def _handle_text_response(
568582

569583
text = '\n\n'.join(texts)
570584
try:
571-
if _output.allow_text_output(output_schema):
572-
# The following cast is safe because we know `str` is an allowed result type
573-
result_data = cast(NodeRunEndT, text)
585+
if isinstance(output_schema, _output.TextOutputSchema):
586+
run_context = build_run_context(ctx)
587+
result_data = await output_schema.process(text, run_context)
574588
else:
575589
m = _messages.RetryPromptPart(
576590
content='Plain text responses are not permitted, please include your response in a tool call',
@@ -669,7 +683,7 @@ async def process_function_tools( # noqa C901
669683
yield event
670684
call_index_to_event_id[len(calls_to_run)] = event.call_id
671685
calls_to_run.append((mcp_tool, call))
672-
elif output_schema is not None and call.tool_name in output_schema.tools:
686+
elif call.tool_name in output_schema.tools:
673687
# if tool_name is in output_schema, it means we found a output tool but an error occurred in
674688
# validation, we don't add another part here
675689
if output_tool_name is not None:
@@ -809,7 +823,9 @@ def _unknown_tool(
809823
) -> _messages.RetryPromptPart:
810824
ctx.state.increment_retries(ctx.deps.max_result_retries)
811825
tool_names = list(ctx.deps.function_tools.keys())
812-
if output_schema := ctx.deps.output_schema:
826+
827+
output_schema = ctx.deps.output_schema
828+
if isinstance(output_schema, _output.ToolOutputSchema):
813829
tool_names.extend(output_schema.tool_names())
814830

815831
if tool_names:
@@ -886,7 +902,7 @@ def get_captured_run_messages() -> _RunMessages:
886902
def build_agent_graph(
887903
name: str | None,
888904
deps_type: type[DepsT],
889-
output_type: _output.OutputType[OutputT],
905+
output_type: OutputSpec[OutputT],
890906
) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, result.FinalResult[OutputT]], result.FinalResult[OutputT]]:
891907
"""Build the execution [Graph][pydantic_graph.Graph] for a given agent."""
892908
nodes = (

pydantic_ai_slim/pydantic_ai/_cli.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,13 @@
1414

1515
from typing_inspection.introspection import get_literal_values
1616

17-
from pydantic_ai.result import OutputDataT
18-
from pydantic_ai.tools import AgentDepsT
19-
2017
from . import __version__
18+
from ._run_context import AgentDepsT
2119
from .agent import Agent
2220
from .exceptions import UserError
2321
from .messages import ModelMessage
2422
from .models import KnownModelName, infer_model
23+
from .output import OutputDataT
2524

2625
try:
2726
import argcomplete

pydantic_ai_slim/pydantic_ai/_function_schema.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,8 @@
1919
from pydantic_core import SchemaValidator, core_schema
2020
from typing_extensions import Concatenate, ParamSpec, TypeIs, TypeVar, get_origin
2121

22-
from pydantic_ai.tools import RunContext
23-
2422
from ._griffe import doc_descriptions
23+
from ._run_context import RunContext
2524
from ._utils import check_object_json_schema, is_async_callable, is_model_like, run_in_executor
2625

2726
if TYPE_CHECKING:
@@ -281,6 +280,4 @@ def _build_schema(
281280

282281
def _is_call_ctx(annotation: Any) -> bool:
283282
"""Return whether the annotation is the `RunContext` class, parameterized or not."""
284-
from .tools import RunContext
285-
286283
return annotation is RunContext or get_origin(annotation) is RunContext

0 commit comments

Comments
 (0)