Skip to content

Commit a6eba43

Browse files
committed
Add Agent.prepare_output_tools
1 parent 1cb7f32 commit a6eba43

File tree

3 files changed

+104
-2
lines changed

3 files changed

+104
-2
lines changed

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
159159
_mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False)
160160
_toolset: AbstractToolset[AgentDepsT] = dataclasses.field(repr=False)
161161
_prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False)
162+
_prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False)
162163
_max_result_retries: int = dataclasses.field(repr=False)
163164
_override_deps: _utils.Option[AgentDepsT] = dataclasses.field(default=None, repr=False)
164165
_override_model: _utils.Option[models.Model] = dataclasses.field(default=None, repr=False)
@@ -181,6 +182,7 @@ def __init__(
181182
output_retries: int | None = None,
182183
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
183184
prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
185+
prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
184186
mcp_servers: Sequence[MCPServer] = (),
185187
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
186188
defer_model_check: bool = False,
@@ -212,6 +214,7 @@ def __init__(
212214
result_retries: int | None = None,
213215
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
214216
prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
217+
prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
215218
mcp_servers: Sequence[MCPServer] = (),
216219
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
217220
defer_model_check: bool = False,
@@ -238,6 +241,7 @@ def __init__(
238241
output_retries: int | None = None,
239242
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
240243
prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
244+
prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
241245
mcp_servers: Sequence[
242246
MCPServer
243247
] = (), # TODO: Deprecate argument, MCPServers can be passed directly to toolsets
@@ -270,9 +274,12 @@ def __init__(
270274
output_retries: The maximum number of retries to allow for result validation, defaults to `retries`.
271275
tools: Tools to register with the agent, you can also register tools via the decorators
272276
[`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain].
273-
prepare_tools: custom method to prepare the tool definition of all tools for each step.
277+
prepare_tools: Custom function to prepare the tool definition of all tools for each step, except output tools.
274278
This is useful if you want to customize the definition of multiple tools or you want to register
275279
a subset of tools for a given step. See [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc]
280+
prepare_output_tools: Custom function to prepare the tool definition of all output tools for each step.
281+
This is useful if you want to customize the definition of multiple output tools or you want to register
282+
a subset of output tools for a given step. See [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc]
276283
mcp_servers: MCP servers to register with the agent. You should register a [`MCPServer`][pydantic_ai.mcp.MCPServer]
277284
for each server you want the agent to connect to.
278285
toolsets: Toolsets to register with the agent.
@@ -365,6 +372,7 @@ def __init__(
365372

366373
self._max_result_retries = output_retries if output_retries is not None else retries
367374
self._prepare_tools = prepare_tools
375+
self._prepare_output_tools = prepare_output_tools
368376

369377
self._output_toolset = OutputToolset(self._output_schema, max_retries=self._max_result_retries)
370378
self._function_toolset = FunctionToolset(tools, max_retries=retries)
@@ -682,6 +690,8 @@ async def main():
682690
output_toolset = OutputToolset[AgentDepsT](
683691
output_schema, max_retries=self._max_result_retries, output_validators=output_validators
684692
)
693+
if self._prepare_output_tools:
694+
output_toolset = PreparedToolset(output_toolset, self._prepare_output_tools)
685695

686696
# Build the graph
687697
graph: Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[Any]] = (

tests/test_agent.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import re
33
import sys
4+
from dataclasses import dataclass
45
from datetime import timezone
56
from typing import Any, Callable, Union
67

@@ -3491,3 +3492,94 @@ def baz() -> str:
34913492
result = agent.run_sync('Hello', toolsets=[])
34923493
assert available_tools[-1] == snapshot(['baz'])
34933494
assert result.output == snapshot('{"baz":"Hello from baz"}')
3495+
3496+
3497+
def test_prepare_output_tools():
3498+
@dataclass
3499+
class AgentDeps:
3500+
plan_presented: bool = False
3501+
3502+
async def present_plan(ctx: RunContext[AgentDeps], plan: str) -> str:
3503+
"""
3504+
Present the plan to the user.
3505+
"""
3506+
ctx.deps.plan_presented = True
3507+
return plan
3508+
3509+
async def run_sql(ctx: RunContext[AgentDeps], purpose: str, query: str) -> str:
3510+
"""
3511+
Run an SQL query.
3512+
"""
3513+
return 'SQL query executed successfully'
3514+
3515+
async def only_if_plan_presented(
3516+
ctx: RunContext[AgentDeps], tool_defs: list[ToolDefinition]
3517+
) -> list[ToolDefinition]:
3518+
return tool_defs if ctx.deps.plan_presented else []
3519+
3520+
agent = Agent(
3521+
model='test',
3522+
deps_type=AgentDeps,
3523+
tools=[present_plan],
3524+
output_type=[ToolOutput(run_sql, name='run_sql')],
3525+
prepare_output_tools=only_if_plan_presented,
3526+
)
3527+
3528+
result = agent.run_sync('Hello', deps=AgentDeps())
3529+
assert result.output == snapshot('SQL query executed successfully')
3530+
assert result.all_messages() == snapshot(
3531+
[
3532+
ModelRequest(
3533+
parts=[
3534+
UserPromptPart(
3535+
content='Hello',
3536+
timestamp=IsDatetime(),
3537+
)
3538+
]
3539+
),
3540+
ModelResponse(
3541+
parts=[
3542+
ToolCallPart(
3543+
tool_name='present_plan',
3544+
args={'plan': 'a'},
3545+
tool_call_id=IsStr(),
3546+
)
3547+
],
3548+
usage=Usage(requests=1, request_tokens=51, response_tokens=5, total_tokens=56),
3549+
model_name='test',
3550+
timestamp=IsDatetime(),
3551+
),
3552+
ModelRequest(
3553+
parts=[
3554+
ToolReturnPart(
3555+
tool_name='present_plan',
3556+
content='a',
3557+
tool_call_id=IsStr(),
3558+
timestamp=IsDatetime(),
3559+
)
3560+
]
3561+
),
3562+
ModelResponse(
3563+
parts=[
3564+
ToolCallPart(
3565+
tool_name='run_sql',
3566+
args={'purpose': 'a', 'query': 'a'},
3567+
tool_call_id=IsStr(),
3568+
)
3569+
],
3570+
usage=Usage(requests=1, request_tokens=52, response_tokens=12, total_tokens=64),
3571+
model_name='test',
3572+
timestamp=IsDatetime(),
3573+
),
3574+
ModelRequest(
3575+
parts=[
3576+
ToolReturnPart(
3577+
tool_name='run_sql',
3578+
content='Final result processed.',
3579+
tool_call_id=IsStr(),
3580+
timestamp=IsDatetime(),
3581+
)
3582+
]
3583+
),
3584+
]
3585+
)

tests/test_tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,7 @@ def test_tool_return_conflict():
589589
# this raises an error
590590
with pytest.raises(
591591
UserError,
592-
match="Function toolset defines a tool whose name conflicts with existing tool from OutputToolset: 'ctx_tool'. Consider renaming the tool or wrapping the toolset in a `PrefixedToolset` to avoid name conflicts.",
592+
match="Function toolset defines a tool whose name conflicts with existing tool from Output toolset: 'ctx_tool'. Consider renaming the tool or wrapping the toolset in a `PrefixedToolset` to avoid name conflicts.",
593593
):
594594
Agent('test', tools=[ctx_tool], deps_type=int, output_type=ToolOutput(int, name='ctx_tool'))
595595

0 commit comments

Comments
 (0)