18
18
from pydantic_graph .nodes import End , NodeRunEndT
19
19
20
20
from . import _output , _system_prompt , exceptions , messages as _messages , models , result , usage as _usage
21
- from .result import OutputDataT
21
+ from .output import OutputDataT , OutputSpec
22
22
from .settings import ModelSettings , merge_model_settings
23
23
from .tools import RunContext , Tool , ToolDefinition , ToolsPrepareFunc
24
24
@@ -102,7 +102,7 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]):
102
102
end_strategy : EndStrategy
103
103
get_instructions : Callable [[RunContext [DepsT ]], Awaitable [str | None ]]
104
104
105
- output_schema : _output .OutputSchema [OutputDataT ] | None
105
+ output_schema : _output .OutputSchema [OutputDataT ]
106
106
output_validators : list [_output .OutputValidator [DepsT , OutputDataT ]]
107
107
108
108
history_processors : Sequence [HistoryProcessor [DepsT ]]
@@ -286,10 +286,23 @@ async def add_mcp_server_tools(server: MCPServer) -> None:
286
286
function_tool_defs = await ctx .deps .prepare_tools (run_context , function_tool_defs ) or []
287
287
288
288
output_schema = ctx .deps .output_schema
289
+
290
+ output_tools = []
291
+ output_object = None
292
+ if isinstance (output_schema , _output .ToolOutputSchema ):
293
+ output_tools = output_schema .tool_defs ()
294
+ elif isinstance (output_schema , _output .NativeOutputSchema ):
295
+ output_object = output_schema .object_def
296
+
297
+ # ToolOrTextOutputSchema, NativeOutputSchema, and PromptedOutputSchema all inherit from TextOutputSchema
298
+ allow_text_output = isinstance (output_schema , _output .TextOutputSchema )
299
+
289
300
return models .ModelRequestParameters (
290
301
function_tools = function_tool_defs ,
291
- allow_text_output = _output .allow_text_output (output_schema ),
292
- output_tools = output_schema .tool_defs () if output_schema is not None else [],
302
+ output_mode = output_schema .mode ,
303
+ output_tools = output_tools ,
304
+ output_object = output_object ,
305
+ allow_text_output = allow_text_output ,
293
306
)
294
307
295
308
@@ -484,7 +497,7 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
484
497
# when the model has already returned text along side tool calls
485
498
# in this scenario, if text responses are allowed, we return text from the most recent model
486
499
# response, if any
487
- if _output . allow_text_output (ctx .deps .output_schema ):
500
+ if isinstance (ctx .deps .output_schema , _output . TextOutputSchema ):
488
501
for message in reversed (ctx .state .message_history ):
489
502
if isinstance (message , _messages .ModelResponse ):
490
503
last_texts = [p .content for p in message .parts if isinstance (p , _messages .TextPart )]
@@ -507,10 +520,11 @@ async def _handle_tool_calls(
507
520
output_schema = ctx .deps .output_schema
508
521
run_context = build_run_context (ctx )
509
522
510
- # first, look for the output tool call
511
523
final_result : result .FinalResult [NodeRunEndT ] | None = None
512
524
parts : list [_messages .ModelRequestPart ] = []
513
- if output_schema is not None :
525
+
526
+ # first, look for the output tool call
527
+ if isinstance (output_schema , _output .ToolOutputSchema ):
514
528
for call , output_tool in output_schema .find_tool (tool_calls ):
515
529
try :
516
530
result_data = await output_tool .process (call , run_context )
@@ -568,9 +582,9 @@ async def _handle_text_response(
568
582
569
583
text = '\n \n ' .join (texts )
570
584
try :
571
- if _output . allow_text_output (output_schema ):
572
- # The following cast is safe because we know `str` is an allowed result type
573
- result_data = cast ( NodeRunEndT , text )
585
+ if isinstance (output_schema , _output . TextOutputSchema ):
586
+ run_context = build_run_context ( ctx )
587
+ result_data = await output_schema . process ( text , run_context )
574
588
else :
575
589
m = _messages .RetryPromptPart (
576
590
content = 'Plain text responses are not permitted, please include your response in a tool call' ,
@@ -669,7 +683,7 @@ async def process_function_tools( # noqa C901
669
683
yield event
670
684
call_index_to_event_id [len (calls_to_run )] = event .call_id
671
685
calls_to_run .append ((mcp_tool , call ))
672
- elif output_schema is not None and call .tool_name in output_schema .tools :
686
+ elif call .tool_name in output_schema .tools :
673
687
# if tool_name is in output_schema, it means we found a output tool but an error occurred in
674
688
# validation, we don't add another part here
675
689
if output_tool_name is not None :
@@ -809,7 +823,9 @@ def _unknown_tool(
809
823
) -> _messages .RetryPromptPart :
810
824
ctx .state .increment_retries (ctx .deps .max_result_retries )
811
825
tool_names = list (ctx .deps .function_tools .keys ())
812
- if output_schema := ctx .deps .output_schema :
826
+
827
+ output_schema = ctx .deps .output_schema
828
+ if isinstance (output_schema , _output .ToolOutputSchema ):
813
829
tool_names .extend (output_schema .tool_names ())
814
830
815
831
if tool_names :
@@ -886,7 +902,7 @@ def get_captured_run_messages() -> _RunMessages:
886
902
def build_agent_graph (
887
903
name : str | None ,
888
904
deps_type : type [DepsT ],
889
- output_type : _output . OutputType [OutputT ],
905
+ output_type : OutputSpec [OutputT ],
890
906
) -> Graph [GraphAgentState , GraphAgentDeps [DepsT , result .FinalResult [OutputT ]], result .FinalResult [OutputT ]]:
891
907
"""Build the execution [Graph][pydantic_graph.Graph] for a given agent."""
892
908
nodes = (
0 commit comments