@@ -159,6 +159,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
159
159
_mcp_servers : Sequence [MCPServer ] = dataclasses .field (repr = False )
160
160
_toolset : AbstractToolset [AgentDepsT ] = dataclasses .field (repr = False )
161
161
_prepare_tools : ToolsPrepareFunc [AgentDepsT ] | None = dataclasses .field (repr = False )
162
+ _prepare_output_tools : ToolsPrepareFunc [AgentDepsT ] | None = dataclasses .field (repr = False )
162
163
_max_result_retries : int = dataclasses .field (repr = False )
163
164
_override_deps : _utils .Option [AgentDepsT ] = dataclasses .field (default = None , repr = False )
164
165
_override_model : _utils .Option [models .Model ] = dataclasses .field (default = None , repr = False )
@@ -181,6 +182,7 @@ def __init__(
181
182
output_retries : int | None = None ,
182
183
tools : Sequence [Tool [AgentDepsT ] | ToolFuncEither [AgentDepsT , ...]] = (),
183
184
prepare_tools : ToolsPrepareFunc [AgentDepsT ] | None = None ,
185
+ prepare_output_tools : ToolsPrepareFunc [AgentDepsT ] | None = None ,
184
186
mcp_servers : Sequence [MCPServer ] = (),
185
187
toolsets : Sequence [AbstractToolset [AgentDepsT ]] | None = None ,
186
188
defer_model_check : bool = False ,
@@ -212,6 +214,7 @@ def __init__(
212
214
result_retries : int | None = None ,
213
215
tools : Sequence [Tool [AgentDepsT ] | ToolFuncEither [AgentDepsT , ...]] = (),
214
216
prepare_tools : ToolsPrepareFunc [AgentDepsT ] | None = None ,
217
+ prepare_output_tools : ToolsPrepareFunc [AgentDepsT ] | None = None ,
215
218
mcp_servers : Sequence [MCPServer ] = (),
216
219
toolsets : Sequence [AbstractToolset [AgentDepsT ]] | None = None ,
217
220
defer_model_check : bool = False ,
@@ -238,6 +241,7 @@ def __init__(
238
241
output_retries : int | None = None ,
239
242
tools : Sequence [Tool [AgentDepsT ] | ToolFuncEither [AgentDepsT , ...]] = (),
240
243
prepare_tools : ToolsPrepareFunc [AgentDepsT ] | None = None ,
244
+ prepare_output_tools : ToolsPrepareFunc [AgentDepsT ] | None = None ,
241
245
mcp_servers : Sequence [
242
246
MCPServer
243
247
] = (), # TODO: Deprecate argument, MCPServers can be passed directly to toolsets
@@ -270,9 +274,12 @@ def __init__(
270
274
output_retries: The maximum number of retries to allow for result validation, defaults to `retries`.
271
275
tools: Tools to register with the agent, you can also register tools via the decorators
272
276
[`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain].
273
- prepare_tools: custom method to prepare the tool definition of all tools for each step.
277
+ prepare_tools: Custom function to prepare the tool definition of all tools for each step, except output tools .
274
278
This is useful if you want to customize the definition of multiple tools or you want to register
275
279
a subset of tools for a given step. See [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc]
280
+ prepare_output_tools: Custom function to prepare the tool definition of all output tools for each step.
281
+ This is useful if you want to customize the definition of multiple output tools or you want to register
282
+ a subset of output tools for a given step. See [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc]
276
283
mcp_servers: MCP servers to register with the agent. You should register a [`MCPServer`][pydantic_ai.mcp.MCPServer]
277
284
for each server you want the agent to connect to.
278
285
toolsets: Toolsets to register with the agent.
@@ -365,6 +372,7 @@ def __init__(
365
372
366
373
self ._max_result_retries = output_retries if output_retries is not None else retries
367
374
self ._prepare_tools = prepare_tools
375
+ self ._prepare_output_tools = prepare_output_tools
368
376
369
377
self ._output_toolset = OutputToolset (self ._output_schema , max_retries = self ._max_result_retries )
370
378
self ._function_toolset = FunctionToolset (tools , max_retries = retries )
@@ -682,6 +690,8 @@ async def main():
682
690
output_toolset = OutputToolset [AgentDepsT ](
683
691
output_schema , max_retries = self ._max_result_retries , output_validators = output_validators
684
692
)
693
+ if self ._prepare_output_tools :
694
+ output_toolset = PreparedToolset (output_toolset , self ._prepare_output_tools )
685
695
686
696
# Build the graph
687
697
graph : Graph [_agent_graph .GraphAgentState , _agent_graph .GraphAgentDeps [AgentDepsT , Any ], FinalResult [Any ]] = (
0 commit comments