Skip to content

Commit 8745a7a

Browse files
committed
Pass just one toolset into the run
1 parent 735df29 commit 8745a7a

File tree

7 files changed

+28
-24
lines changed

7 files changed

+28
-24
lines changed

docs/agents.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -808,7 +808,7 @@ with capture_run_messages() as messages: # (2)!
808808
result = agent.run_sync('Please get me the volume of a box with size 6.')
809809
except UnexpectedModelBehavior as e:
810810
print('An error occurred:', e)
811-
#> An error occurred: Tool exceeded max retries count of 1
811+
#> An error occurred: Tool 'calc_volume' exceeded max retries count of 1
812812
print('cause:', repr(e.__cause__))
813813
#> cause: ModelRetry('Please try again.')
814814
print('messages:', messages)

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from pydantic_ai._function_schema import _takes_ctx as is_takes_ctx # type: ignore
1919
from pydantic_ai._utils import is_async_callable, run_in_executor
20-
from pydantic_ai.toolset import AbstractToolset, CombinedToolset, RunToolset
20+
from pydantic_ai.toolset import AbstractToolset, RunToolset
2121
from pydantic_graph import BaseNode, Graph, GraphRunContext
2222
from pydantic_graph.nodes import End, NodeRunEndT
2323

@@ -107,7 +107,6 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]):
107107
get_instructions: Callable[[RunContext[DepsT]], Awaitable[str | None]]
108108

109109
output_schema: _output.OutputSchema[OutputDataT]
110-
output_toolset: RunToolset[DepsT]
111110
output_validators: list[_output.OutputValidator[DepsT, OutputDataT]]
112111

113112
history_processors: Sequence[HistoryProcessor[DepsT]]
@@ -249,11 +248,7 @@ async def _prepare_request_parameters(
249248
) -> models.ModelRequestParameters:
250249
"""Build tools and create an agent model."""
251250
run_context = build_run_context(ctx)
252-
ctx.deps.toolset = toolset = await ctx.deps.toolset.prepare_for_run(run_context)
253-
ctx.deps.output_toolset = output_toolset = await ctx.deps.output_toolset.prepare_for_run(run_context)
254-
255-
# This will raise errors for any name conflicts
256-
CombinedToolset[DepsT]([output_toolset, toolset])
251+
ctx.deps.toolset = await ctx.deps.toolset.prepare_for_run(run_context)
257252

258253
output_schema = ctx.deps.output_schema
259254
output_object = None
@@ -263,10 +258,18 @@ async def _prepare_request_parameters(
263258
# ToolOrTextOutputSchema, NativeOutputSchema, and PromptedOutputSchema all inherit from TextOutputSchema
264259
allow_text_output = isinstance(output_schema, _output.TextOutputSchema)
265260

261+
function_tools: list[ToolDefinition] = []
262+
output_tools: list[ToolDefinition] = []
263+
for tool_def in ctx.deps.toolset.tool_defs:
264+
if tool_def.kind == 'output':
265+
output_tools.append(tool_def)
266+
else:
267+
function_tools.append(tool_def)
268+
266269
return models.ModelRequestParameters(
267-
function_tools=toolset.tool_defs,
270+
function_tools=function_tools,
268271
output_mode=output_schema.mode,
269-
output_tools=output_toolset.tool_defs,
272+
output_tools=output_tools,
270273
output_object=output_object,
271274
allow_text_output=allow_text_output,
272275
)
@@ -487,7 +490,7 @@ async def _handle_tool_calls( # noqa: C901
487490
final_result: result.FinalResult[NodeRunEndT] | None = None
488491
parts: list[_messages.ModelRequestPart] = []
489492

490-
toolset = CombinedToolset([ctx.deps.toolset, ctx.deps.output_toolset])
493+
toolset = ctx.deps.toolset
491494

492495
unknown_calls: list[_messages.ToolCallPart] = []
493496
tool_calls_by_kind: dict[ToolKind, list[_messages.ToolCallPart]] = defaultdict(list)

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -686,11 +686,8 @@ async def main():
686686
run_step=state.run_step,
687687
)
688688

689-
run_toolset = await self._toolset.prepare_for_run(run_context)
690-
run_output_toolset = await output_toolset.prepare_for_run(run_context)
691-
692-
# This will raise errors for any name conflicts
693-
CombinedToolset([run_output_toolset, run_toolset])
689+
toolset = CombinedToolset([output_toolset, self._toolset])
690+
run_toolset = await toolset.prepare_for_run(run_context)
694691

695692
model_settings = merge_model_settings(self.model_settings, model_settings)
696693
usage_limits = usage_limits or _usage.UsageLimits()
@@ -738,7 +735,6 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
738735
end_strategy=self.end_strategy,
739736
output_schema=output_schema,
740737
output_validators=output_validators,
741-
output_toolset=run_output_toolset,
742738
history_processors=self.history_processors,
743739
toolset=run_toolset,
744740
tracer=tracer,

pydantic_ai_slim/pydantic_ai/toolset.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,11 @@ def tool_names(self) -> list[str]:
368368
def get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator:
369369
return self._toolset_for_tool_name(name).get_tool_args_validator(ctx, name)
370370

371+
def validate_tool_args(
372+
self, ctx: RunContext[AgentDepsT], name: str, args: str | dict[str, Any] | None, allow_partial: bool = False
373+
) -> dict[str, Any]:
374+
return self._toolset_for_tool_name(name).validate_tool_args(ctx, name, args, allow_partial)
375+
371376
def max_retries_for_tool(self, name: str) -> int:
372377
return self._toolset_for_tool_name(name).max_retries_for_tool(name)
373378

@@ -678,9 +683,9 @@ def _on_error(self, name: str, e: Exception) -> Never:
678683
max_retries = self.max_retries_for_tool(name)
679684
current_retry = self._retries.get(name, 0)
680685
if current_retry == max_retries:
681-
raise UnexpectedModelBehavior(f'Tool exceeded max retries count of {max_retries}') from e
686+
raise UnexpectedModelBehavior(f'Tool {name!r} exceeded max retries count of {max_retries}') from e
682687
else:
683-
self._retries[name] = current_retry + 1
688+
self._retries[name] = current_retry + 1 # TODO: Reset on successful call!
684689
raise e
685690

686691
def _validate_tool_name(self, name: str) -> None:

tests/models/test_model_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def validate_output(ctx: RunContext[None], output: OutputModel) -> OutputModel:
157157
call_count += 1
158158
raise ModelRetry('Fail')
159159

160-
with pytest.raises(UnexpectedModelBehavior, match='Exceeded maximum retries'):
160+
with pytest.raises(UnexpectedModelBehavior, match="Tool 'final_result' exceeded max retries count of 2"):
161161
agent.run_sync('Hello', model=TestModel())
162162

163163
assert call_count == 3
@@ -200,7 +200,7 @@ class ResultModel(BaseModel):
200200

201201
agent = Agent('test', output_type=ResultModel, retries=2)
202202

203-
with pytest.raises(UnexpectedModelBehavior, match='Exceeded maximum retries'):
203+
with pytest.raises(UnexpectedModelBehavior, match="Tool 'final_result' exceeded max retries count of 2"):
204204
agent.run_sync('Hello', model=TestModel(custom_output_args={'foo': 'a', 'bar': 1}))
205205

206206

tests/test_agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2293,7 +2293,7 @@ def another_tool(y: int) -> int:
22932293
tool_name='another_tool', content=2, tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc)
22942294
),
22952295
RetryPromptPart(
2296-
content="Unknown tool name: 'unknown_tool'. Available tools: regular_tool, another_tool, final_result",
2296+
content="Unknown tool name: 'unknown_tool'. Available tools: final_result, regular_tool, another_tool",
22972297
tool_name='unknown_tool',
22982298
tool_call_id=IsStr(),
22992299
timestamp=IsDatetime(),
@@ -2380,7 +2380,7 @@ def another_tool(y: int) -> int: # pragma: no cover
23802380
),
23812381
RetryPromptPart(
23822382
tool_name='unknown_tool',
2383-
content="Unknown tool name: 'unknown_tool'. Available tools: regular_tool, another_tool, final_result",
2383+
content="Unknown tool name: 'unknown_tool'. Available tools: final_result, regular_tool, another_tool",
23842384
timestamp=IsNow(tz=timezone.utc),
23852385
tool_call_id=IsStr(),
23862386
),

tests/test_tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1171,7 +1171,7 @@ def infinite_retry_tool(ctx: RunContext[None]) -> int:
11711171
call_retries.append(ctx.retry)
11721172
raise ModelRetry('Please try again.')
11731173

1174-
with pytest.raises(UnexpectedModelBehavior, match='Tool exceeded max retries count of 5'):
1174+
with pytest.raises(UnexpectedModelBehavior, match="Tool 'infinite_retry_tool' exceeded max retries count of 5"):
11751175
agent.run_sync('Begin infinite retry loop!')
11761176

11771177
# There are extra 0s here because the toolset is prepared once ahead of the graph run, before the user prompt part is added in.

0 commit comments

Comments
 (0)