Skip to content

Commit 8a3febb

Browse files
committed
Let toolsets be overridden in run/iter/run_stream/run_sync
1 parent 6607b00 commit 8a3febb

File tree

2 files changed

+83
-14
lines changed

2 files changed

+83
-14
lines changed

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,10 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
155155
)
156156
_function_toolset: FunctionToolset[AgentDepsT] = dataclasses.field(repr=False)
157157
_output_toolset: OutputToolset[AgentDepsT] = dataclasses.field(repr=False)
158+
_user_toolsets: Sequence[AbstractToolset[AgentDepsT]] = dataclasses.field(repr=False)
159+
_mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False)
158160
_toolset: AbstractToolset[AgentDepsT] = dataclasses.field(repr=False)
161+
_prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False)
159162
_max_result_retries: int = dataclasses.field(repr=False)
160163
_override_deps: _utils.Option[AgentDepsT] = dataclasses.field(default=None, repr=False)
161164
_override_model: _utils.Option[models.Model] = dataclasses.field(default=None, repr=False)
@@ -179,7 +182,7 @@ def __init__(
179182
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
180183
prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
181184
mcp_servers: Sequence[MCPServer] = (),
182-
toolsets: Sequence[AbstractToolset[AgentDepsT]] = (),
185+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
183186
defer_model_check: bool = False,
184187
end_strategy: EndStrategy = 'early',
185188
instrument: InstrumentationSettings | bool | None = None,
@@ -210,7 +213,7 @@ def __init__(
210213
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (),
211214
prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None,
212215
mcp_servers: Sequence[MCPServer] = (),
213-
toolsets: Sequence[AbstractToolset[AgentDepsT]] = (),
216+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
214217
defer_model_check: bool = False,
215218
end_strategy: EndStrategy = 'early',
216219
instrument: InstrumentationSettings | bool | None = None,
@@ -238,7 +241,7 @@ def __init__(
238241
mcp_servers: Sequence[
239242
MCPServer
240243
] = (), # TODO: Deprecate argument, MCPServers can be passed directly to toolsets
241-
toolsets: Sequence[AbstractToolset[AgentDepsT]] = (),
244+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
242245
defer_model_check: bool = False,
243246
end_strategy: EndStrategy = 'early',
244247
instrument: InstrumentationSettings | bool | None = None,
@@ -361,19 +364,18 @@ def __init__(
361364
self._system_prompt_dynamic_functions = {}
362365

363366
self._max_result_retries = output_retries if output_retries is not None else retries
367+
self._prepare_tools = prepare_tools
364368

365-
self._output_toolset = OutputToolset[AgentDepsT](self._output_schema, max_retries=self._max_result_retries)
366-
self._function_toolset = FunctionToolset[AgentDepsT](tools, max_retries=retries)
369+
self._output_toolset = OutputToolset(self._output_schema, max_retries=self._max_result_retries)
370+
self._function_toolset = FunctionToolset(tools, max_retries=retries)
371+
self._user_toolsets = toolsets or ()
372+
# TODO: Set max_retries on MCPServer
373+
self._mcp_servers = mcp_servers
367374

368375
# This will raise errors for any name conflicts
369-
# TODO: Also include toolsets (not mcp_serves as we won't have tool defs yet)
370-
CombinedToolset[AgentDepsT]([self._output_toolset, self._function_toolset])
371-
372-
# TODO: Set max_retries on MCPServer
373-
toolset = CombinedToolset[AgentDepsT]([self._function_toolset, *toolsets, *mcp_servers])
374-
if prepare_tools:
375-
toolset = PreparedToolset[AgentDepsT](toolset, prepare_tools)
376-
self._toolset = toolset
376+
self._toolset = CombinedToolset(
377+
[self._output_toolset, self._function_toolset, *self._user_toolsets, *self._mcp_servers]
378+
)
377379

378380
self.history_processors = history_processors or []
379381

@@ -395,6 +397,7 @@ async def run(
395397
usage_limits: _usage.UsageLimits | None = None,
396398
usage: _usage.Usage | None = None,
397399
infer_name: bool = True,
400+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
398401
) -> AgentRunResult[OutputDataT]: ...
399402

400403
@overload
@@ -410,6 +413,7 @@ async def run(
410413
usage_limits: _usage.UsageLimits | None = None,
411414
usage: _usage.Usage | None = None,
412415
infer_name: bool = True,
416+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
413417
) -> AgentRunResult[RunOutputDataT]: ...
414418

415419
@overload
@@ -426,6 +430,7 @@ async def run(
426430
usage_limits: _usage.UsageLimits | None = None,
427431
usage: _usage.Usage | None = None,
428432
infer_name: bool = True,
433+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
429434
) -> AgentRunResult[RunOutputDataT]: ...
430435

431436
async def run(
@@ -440,6 +445,7 @@ async def run(
440445
usage_limits: _usage.UsageLimits | None = None,
441446
usage: _usage.Usage | None = None,
442447
infer_name: bool = True,
448+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
443449
**_deprecated_kwargs: Never,
444450
) -> AgentRunResult[Any]:
445451
"""Run the agent with a user prompt in async mode.
@@ -470,6 +476,7 @@ async def main():
470476
usage_limits: Optional limits on model request count or token usage.
471477
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
472478
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
479+
toolsets: Optional toolsets to use for this run instead of the `toolsets` set when creating the agent.
473480
474481
Returns:
475482
The result of the run.
@@ -494,6 +501,7 @@ async def main():
494501
model_settings=model_settings,
495502
usage_limits=usage_limits,
496503
usage=usage,
504+
toolsets=toolsets,
497505
) as agent_run:
498506
async for _ in agent_run:
499507
pass
@@ -514,6 +522,7 @@ def iter(
514522
usage_limits: _usage.UsageLimits | None = None,
515523
usage: _usage.Usage | None = None,
516524
infer_name: bool = True,
525+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
517526
**_deprecated_kwargs: Never,
518527
) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, OutputDataT]]: ...
519528

@@ -530,6 +539,7 @@ def iter(
530539
usage_limits: _usage.UsageLimits | None = None,
531540
usage: _usage.Usage | None = None,
532541
infer_name: bool = True,
542+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
533543
**_deprecated_kwargs: Never,
534544
) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ...
535545

@@ -547,6 +557,7 @@ def iter(
547557
usage_limits: _usage.UsageLimits | None = None,
548558
usage: _usage.Usage | None = None,
549559
infer_name: bool = True,
560+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
550561
) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, Any]]: ...
551562

552563
@asynccontextmanager
@@ -562,6 +573,7 @@ async def iter(
562573
usage_limits: _usage.UsageLimits | None = None,
563574
usage: _usage.Usage | None = None,
564575
infer_name: bool = True,
576+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
565577
**_deprecated_kwargs: Never,
566578
) -> AsyncIterator[AgentRun[AgentDepsT, Any]]:
567579
"""A contextmanager which can be used to iterate over the agent graph's nodes as they are executed.
@@ -636,6 +648,7 @@ async def main():
636648
usage_limits: Optional limits on model request count or token usage.
637649
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
638650
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
651+
toolsets: Optional toolsets to use for this run instead of the `toolsets` set when creating the agent.
639652
640653
Returns:
641654
The result of the run.
@@ -693,7 +706,11 @@ async def main():
693706
run_step=state.run_step,
694707
)
695708

696-
toolset = CombinedToolset([output_toolset, self._toolset])
709+
user_toolsets = self._user_toolsets if toolsets is None else toolsets
710+
toolset = CombinedToolset([self._function_toolset, *user_toolsets, *self._mcp_servers])
711+
if self._prepare_tools:
712+
toolset = PreparedToolset(toolset, self._prepare_tools)
713+
toolset = CombinedToolset([output_toolset, toolset])
697714
run_toolset = await toolset.prepare_for_run(run_context)
698715

699716
model_settings = merge_model_settings(self.model_settings, model_settings)
@@ -814,6 +831,7 @@ def run_sync(
814831
usage_limits: _usage.UsageLimits | None = None,
815832
usage: _usage.Usage | None = None,
816833
infer_name: bool = True,
834+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
817835
) -> AgentRunResult[OutputDataT]: ...
818836

819837
@overload
@@ -829,6 +847,7 @@ def run_sync(
829847
usage_limits: _usage.UsageLimits | None = None,
830848
usage: _usage.Usage | None = None,
831849
infer_name: bool = True,
850+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
832851
) -> AgentRunResult[RunOutputDataT]: ...
833852

834853
@overload
@@ -845,6 +864,7 @@ def run_sync(
845864
usage_limits: _usage.UsageLimits | None = None,
846865
usage: _usage.Usage | None = None,
847866
infer_name: bool = True,
867+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
848868
) -> AgentRunResult[RunOutputDataT]: ...
849869

850870
def run_sync(
@@ -859,6 +879,7 @@ def run_sync(
859879
usage_limits: _usage.UsageLimits | None = None,
860880
usage: _usage.Usage | None = None,
861881
infer_name: bool = True,
882+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
862883
**_deprecated_kwargs: Never,
863884
) -> AgentRunResult[Any]:
864885
"""Synchronously run the agent with a user prompt.
@@ -888,6 +909,7 @@ def run_sync(
888909
usage_limits: Optional limits on model request count or token usage.
889910
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
890911
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
912+
toolsets: Optional toolsets to use for this run instead of the `toolsets` set when creating the agent.
891913
892914
Returns:
893915
The result of the run.
@@ -914,6 +936,7 @@ def run_sync(
914936
usage_limits=usage_limits,
915937
usage=usage,
916938
infer_name=False,
939+
toolsets=toolsets,
917940
)
918941
)
919942

@@ -929,6 +952,7 @@ def run_stream(
929952
usage_limits: _usage.UsageLimits | None = None,
930953
usage: _usage.Usage | None = None,
931954
infer_name: bool = True,
955+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
932956
) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, OutputDataT]]: ...
933957

934958
@overload
@@ -944,6 +968,7 @@ def run_stream(
944968
usage_limits: _usage.UsageLimits | None = None,
945969
usage: _usage.Usage | None = None,
946970
infer_name: bool = True,
971+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
947972
) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunOutputDataT]]: ...
948973

949974
@overload
@@ -960,6 +985,7 @@ def run_stream(
960985
usage_limits: _usage.UsageLimits | None = None,
961986
usage: _usage.Usage | None = None,
962987
infer_name: bool = True,
988+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
963989
) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunOutputDataT]]: ...
964990

965991
@asynccontextmanager
@@ -975,6 +1001,7 @@ async def run_stream( # noqa C901
9751001
usage_limits: _usage.UsageLimits | None = None,
9761002
usage: _usage.Usage | None = None,
9771003
infer_name: bool = True,
1004+
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
9781005
**_deprecated_kwargs: Never,
9791006
) -> AsyncIterator[result.StreamedRunResult[AgentDepsT, Any]]:
9801007
"""Run the agent with a user prompt in async mode, returning a streamed response.
@@ -1002,6 +1029,7 @@ async def main():
10021029
usage_limits: Optional limits on model request count or token usage.
10031030
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
10041031
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
1032+
toolsets: Optional toolsets to use for this run instead of the `toolsets` set when creating the agent.
10051033
10061034
Returns:
10071035
The result of the run.
@@ -1032,6 +1060,7 @@ async def main():
10321060
usage_limits=usage_limits,
10331061
usage=usage,
10341062
infer_name=False,
1063+
toolsets=toolsets,
10351064
) as agent_run:
10361065
first_node = agent_run.next_node # start with the first node
10371066
assert isinstance(first_node, _agent_graph.UserPromptNode) # the first node should be a user prompt node

tests/test_agent.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from pydantic_ai.profiles import ModelProfile
4646
from pydantic_ai.result import Usage
4747
from pydantic_ai.tools import ToolDefinition
48+
from pydantic_ai.toolset import FunctionToolset
4849

4950
from .conftest import IsDatetime, IsNow, IsStr, TestEnv
5051

@@ -3451,3 +3452,42 @@ def test_deprecated_kwargs_mixed_valid_invalid():
34513452
with warnings.catch_warnings():
34523453
warnings.simplefilter('ignore', DeprecationWarning) # Ignore the deprecation warning for result_tool_name
34533454
Agent('test', result_tool_name='test', foo='value1', bar='value2') # type: ignore[call-arg]
3455+
3456+
3457+
def test_override_toolsets():
3458+
foo_toolset = FunctionToolset()
3459+
3460+
@foo_toolset.tool
3461+
def foo() -> str:
3462+
return 'Hello from foo'
3463+
3464+
bar_toolset = FunctionToolset()
3465+
3466+
@bar_toolset.tool
3467+
def bar() -> str:
3468+
return 'Hello from bar'
3469+
3470+
available_tools: list[list[str]] = []
3471+
3472+
async def prepare_tools(ctx: RunContext[None], tool_defs: list[ToolDefinition]) -> list[ToolDefinition]:
3473+
nonlocal available_tools
3474+
available_tools.append([tool_def.name for tool_def in tool_defs])
3475+
return tool_defs
3476+
3477+
agent = Agent('test', toolsets=[foo_toolset], prepare_tools=prepare_tools)
3478+
3479+
@agent.tool_plain
3480+
def baz() -> str:
3481+
return 'Hello from baz'
3482+
3483+
result = agent.run_sync('Hello')
3484+
assert available_tools[-1] == snapshot(['baz', 'foo'])
3485+
assert result.output == snapshot('{"baz":"Hello from baz","foo":"Hello from foo"}')
3486+
3487+
result = agent.run_sync('Hello', toolsets=[bar_toolset])
3488+
assert available_tools[-1] == snapshot(['baz', 'bar'])
3489+
assert result.output == snapshot('{"baz":"Hello from baz","bar":"Hello from bar"}')
3490+
3491+
result = agent.run_sync('Hello', toolsets=[])
3492+
assert available_tools[-1] == snapshot(['baz'])
3493+
assert result.output == snapshot('{"baz":"Hello from baz"}')

0 commit comments

Comments
 (0)