Skip to content

Commit d896b01

Browse files
rmaceissoftDouweMKludex
authored
Add prepare_tools param to Agent class (#1474)
Co-authored-by: Douwe Maan <me@douwe.me> Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
1 parent c6a0e97 commit d896b01

File tree

5 files changed

+183
-2
lines changed

5 files changed

+183
-2
lines changed

docs/tools.md

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,100 @@ print(test_model.last_model_request_parameters.function_tools)
553553

554554
_(This example is complete, it can be run "as is")_
555555

556+
## Agent-wide Dynamic Tool Preparation {#prepare-tools}
557+
558+
In addition to per-tool `prepare` methods, you can also define an agent-wide `prepare_tools` function. This function is called at each step of a run and allows you to filter or modify the list of all tool definitions available to the agent for that step. This is especially useful if you want to enable or disable multiple tools at once, or apply global logic based on the current context.
559+
560+
The `prepare_tools` function should be of type [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc], which takes the [`RunContext`][pydantic_ai.tools.RunContext] and a list of [`ToolDefinition`][pydantic_ai.tools.ToolDefinition], and returns a new list of tool definitions (or `None` to disable all tools for that step).
561+
562+
!!! note
563+
The list of tool definitions passed to `prepare_tools` includes both regular tools and tools from any MCP servers attached to the agent.
564+
565+
Here's an example that makes all tools strict if the model is an OpenAI model:
566+
567+
```python {title="agent_prepare_tools_customize.py" noqa="I001"}
568+
from dataclasses import replace
569+
from typing import Union
570+
571+
from pydantic_ai import Agent, RunContext
572+
from pydantic_ai.tools import ToolDefinition
573+
from pydantic_ai.models.test import TestModel
574+
575+
576+
async def turn_on_strict_if_openai(
577+
ctx: RunContext[None], tool_defs: list[ToolDefinition]
578+
) -> Union[list[ToolDefinition], None]:
579+
if ctx.model.system == 'openai':
580+
return [replace(tool_def, strict=True) for tool_def in tool_defs]
581+
return tool_defs
582+
583+
584+
test_model = TestModel()
585+
agent = Agent(test_model, prepare_tools=turn_on_strict_if_openai)
586+
587+
588+
@agent.tool_plain
589+
def echo(message: str) -> str:
590+
return message
591+
592+
593+
agent.run_sync('testing...')
594+
assert test_model.last_model_request_parameters.function_tools[0].strict is None
595+
596+
# Set the system attribute of the test_model to 'openai'
597+
test_model._system = 'openai'
598+
599+
agent.run_sync('testing with openai...')
600+
assert test_model.last_model_request_parameters.function_tools[0].strict
601+
```
602+
603+
_(This example is complete, it can be run "as is")_
604+
605+
Here's another example that conditionally filters out the tools by name if the dependency (`ctx.deps`) is `True`:
606+
607+
```python {title="agent_prepare_tools_filter_out.py" noqa="I001"}
608+
from typing import Union
609+
610+
from pydantic_ai import Agent, RunContext
611+
from pydantic_ai.tools import Tool, ToolDefinition
612+
613+
614+
def launch_potato(target: str) -> str:
615+
return f'Potato launched at {target}!'
616+
617+
618+
async def filter_out_tools_by_name(
619+
ctx: RunContext[bool], tool_defs: list[ToolDefinition]
620+
) -> Union[list[ToolDefinition], None]:
621+
if ctx.deps:
622+
return [tool_def for tool_def in tool_defs if tool_def.name != 'launch_potato']
623+
return tool_defs
624+
625+
626+
agent = Agent(
627+
'test',
628+
tools=[Tool(launch_potato)],
629+
prepare_tools=filter_out_tools_by_name,
630+
deps_type=bool,
631+
)
632+
633+
result = agent.run_sync('testing...', deps=False)
634+
print(result.output)
635+
#> {"launch_potato":"Potato launched at a!"}
636+
result = agent.run_sync('testing...', deps=True)
637+
print(result.output)
638+
#> success (no tool calls)
639+
```
640+
641+
_(This example is complete, it can be run "as is")_
642+
643+
You can use `prepare_tools` to:
644+
645+
- Dynamically enable or disable tools based on the current model, dependencies, or other context
646+
- Modify tool definitions globally (e.g., set all tools to strict mode, change descriptions, etc.)
647+
648+
If both per-tool `prepare` and agent-wide `prepare_tools` are used, the per-tool `prepare` is applied first to each tool, and then `prepare_tools` is called with the resulting list of tool definitions.
649+
556650

557651
## Tool Execution and Retries {#tool-retries}
558652

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
)
2727
from .result import OutputDataT, ToolOutput
2828
from .settings import ModelSettings, merge_model_settings
29-
from .tools import RunContext, Tool, ToolDefinition
29+
from .tools import RunContext, Tool, ToolDefinition, ToolsPrepareFunc
3030

3131
if TYPE_CHECKING:
3232
from .mcp import MCPServer
@@ -97,6 +97,8 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]):
9797

9898
tracer: Tracer
9999

100+
prepare_tools: ToolsPrepareFunc[DepsT] | None = None
101+
100102

101103
class AgentNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]):
102104
"""The base class for all agent nodes.
@@ -241,6 +243,11 @@ async def add_mcp_server_tools(server: MCPServer) -> None:
241243
*map(add_mcp_server_tools, ctx.deps.mcp_servers),
242244
)
243245

246+
if ctx.deps.prepare_tools:
247+
# Prepare the tools using the provided function
248+
# This also acts over tool definitions pulled from MCP servers
249+
function_tool_defs = await ctx.deps.prepare_tools(run_context, function_tool_defs) or []
250+
244251
output_schema = ctx.deps.output_schema
245252
return models.ModelRequestParameters(
246253
function_tools=function_tool_defs,

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
ToolFuncPlain,
4343
ToolParams,
4444
ToolPrepareFunc,
45+
ToolsPrepareFunc,
4546
)
4647

4748
# Re-exporting like this improves auto-import behavior in PyCharm
@@ -131,6 +132,11 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
131132
The type of data output by agent runs, used to validate the data returned by the model, defaults to `str`.
132133
"""
133134

135+
prepare_tools: ToolsPrepareFunc[AgentDepsT] | None
136+
"""
137+
Function invoked on each step, allowing the tools to be modified and filtered out as needed.
138+
"""
139+
134140
instrument: InstrumentationSettings | bool | None
135141
"""Options to automatically instrument with OpenTelemetry."""
136142

@@ -172,6 +178,7 @@ def __init__(
172178
retries: int = 1,
173179
output_retries: int | None = None,
174180
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
181+
prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
175182
mcp_servers: Sequence[MCPServer] = (),
176183
defer_model_check: bool = False,
177184
end_strategy: EndStrategy = 'early',
@@ -200,6 +207,7 @@ def __init__(
200207
result_tool_description: str | None = None,
201208
result_retries: int | None = None,
202209
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
210+
prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
203211
mcp_servers: Sequence[MCPServer] = (),
204212
defer_model_check: bool = False,
205213
end_strategy: EndStrategy = 'early',
@@ -223,6 +231,7 @@ def __init__(
223231
retries: int = 1,
224232
output_retries: int | None = None,
225233
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
234+
prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
226235
mcp_servers: Sequence[MCPServer] = (),
227236
defer_model_check: bool = False,
228237
end_strategy: EndStrategy = 'early',
@@ -251,6 +260,9 @@ def __init__(
251260
output_retries: The maximum number of retries to allow for result validation, defaults to `retries`.
252261
tools: Tools to register with the agent, you can also register tools via the decorators
253262
[`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain].
263+
prepare_tools: custom method to prepare the tool definition of all tools for each step.
264+
This is useful if you want to customize the definition of multiple tools or you want to register
265+
a subset of tools for a given step. See [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc]
254266
mcp_servers: MCP servers to register with the agent. You should register a [`MCPServer`][pydantic_ai.mcp.MCPServer]
255267
for each server you want the agent to connect to.
256268
defer_model_check: by default, if you provide a [named][pydantic_ai.models.KnownModelName] model,
@@ -334,6 +346,7 @@ def __init__(
334346
self._default_retries = retries
335347
self._max_result_retries = output_retries if output_retries is not None else retries
336348
self._mcp_servers = mcp_servers
349+
self._prepare_tools = prepare_tools
337350
for tool in tools:
338351
if isinstance(tool, Tool):
339352
self._register_tool(tool)
@@ -694,6 +707,7 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
694707
mcp_servers=self._mcp_servers,
695708
default_retries=self._default_retries,
696709
tracer=tracer,
710+
prepare_tools=self._prepare_tools,
697711
get_instructions=get_instructions,
698712
)
699713
start_node = _agent_graph.UserPromptNode[AgentDepsT](

pydantic_ai_slim/pydantic_ai/tools.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
'ToolFuncEither',
3030
'ToolParams',
3131
'ToolPrepareFunc',
32+
'ToolsPrepareFunc',
3233
'Tool',
3334
'ObjectJsonSchema',
3435
'ToolDefinition',
@@ -133,6 +134,37 @@ def hitchhiker(ctx: RunContext[int], answer: str) -> str:
133134
Usage `ToolPrepareFunc[AgentDepsT]`.
134135
"""
135136

137+
ToolsPrepareFunc: TypeAlias = (
138+
'Callable[[RunContext[AgentDepsT], list[ToolDefinition]], Awaitable[list[ToolDefinition] | None]]'
139+
)
140+
"""Definition of a function that can prepare the tool definition of all tools for each step.
141+
This is useful if you want to customize the definition of multiple tools or you want to register
142+
a subset of tools for a given step.
143+
144+
Example — here `turn_on_strict_if_openai` is valid as a `ToolsPrepareFunc`:
145+
146+
```python {noqa="I001"}
147+
from dataclasses import replace
148+
from typing import Union
149+
150+
from pydantic_ai import Agent, RunContext
151+
from pydantic_ai.tools import ToolDefinition
152+
153+
154+
async def turn_on_strict_if_openai(
155+
ctx: RunContext[None], tool_defs: list[ToolDefinition]
156+
) -> Union[list[ToolDefinition], None]:
157+
if ctx.model.system == 'openai':
158+
return [replace(tool_def, strict=True) for tool_def in tool_defs]
159+
return tool_defs
160+
161+
agent = Agent('openai:gpt-4o', prepare_tools=turn_on_strict_if_openai)
162+
```
163+
164+
Usage `ToolsPrepareFunc[AgentDepsT]`.
165+
"""
166+
167+
136168
DocstringFormat = Literal['google', 'numpy', 'sphinx', 'auto']
137169
"""Supported docstring formats.
138170

tests/test_tools.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from dataclasses import dataclass
2+
from dataclasses import dataclass, replace
33
from typing import Annotated, Any, Callable, Literal, Union
44

55
import pydantic_core
@@ -954,3 +954,37 @@ def get_score(data: Data) -> int: ... # pragma: no branch
954954
'strict': None,
955955
}
956956
)
957+
958+
959+
def test_dynamic_tools_agent_wide():
960+
async def prepare_tool_defs(
961+
ctx: RunContext[int], tool_defs: list[ToolDefinition]
962+
) -> Union[list[ToolDefinition], None]:
963+
if ctx.deps == 42:
964+
return []
965+
elif ctx.deps == 43:
966+
return None
967+
elif ctx.deps == 21:
968+
return [replace(tool_def, strict=True) for tool_def in tool_defs]
969+
return tool_defs
970+
971+
agent = Agent('test', deps_type=int, prepare_tools=prepare_tool_defs)
972+
973+
@agent.tool
974+
def foobar(ctx: RunContext[int], x: int, y: str) -> str:
975+
return f'{ctx.deps} {x} {y}'
976+
977+
result = agent.run_sync('', deps=42)
978+
assert result.output == snapshot('success (no tool calls)')
979+
980+
result = agent.run_sync('', deps=43)
981+
assert result.output == snapshot('success (no tool calls)')
982+
983+
with agent.override(model=FunctionModel(get_json_schema)):
984+
result = agent.run_sync('', deps=21)
985+
json_schema = json.loads(result.output)
986+
assert agent._function_tools['foobar'].strict is None
987+
assert json_schema['strict'] is True
988+
989+
result = agent.run_sync('', deps=1)
990+
assert result.output == snapshot('{"foobar":"1 0 a"}')

0 commit comments

Comments
 (0)