Skip to content

Commit 2e200ac

Browse files
committed
Add DeferredToolset
1 parent 8a3febb commit 2e200ac

File tree

2 files changed

+36
-10
lines changed

2 files changed

+36
-10
lines changed

pydantic_ai_slim/pydantic_ai/toolset.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,30 @@ async def call_tool(
319319
return output
320320

321321

322+
class DeferredToolset(AbstractToolset[AgentDepsT]):
323+
"""A toolset that holds deferred tool."""
324+
325+
_tool_defs: list[ToolDefinition]
326+
327+
def __init__(self, tool_defs: list[ToolDefinition]):
328+
self._tool_defs = tool_defs
329+
330+
@property
331+
def tool_defs(self) -> list[ToolDefinition]:
332+
return [replace(tool_def, kind='deferred') for tool_def in self._tool_defs]
333+
334+
def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator:
335+
raise NotImplementedError('Deferred tools cannot be validated')
336+
337+
def _max_retries_for_tool(self, name: str) -> int:
338+
raise NotImplementedError('Deferred tools cannot be retried')
339+
340+
async def call_tool(
341+
self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any
342+
) -> Any:
343+
raise NotImplementedError('Deferred tools cannot be called')
344+
345+
322346
@dataclass(init=False)
323347
class CombinedToolset(AbstractToolset[AgentDepsT]):
324348
"""A toolset that combines multiple toolsets."""

tests/test_tools.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from pydantic_ai.models.test import TestModel
1919
from pydantic_ai.output import DeferredToolCalls, ToolOutput
2020
from pydantic_ai.tools import ToolDefinition
21+
from pydantic_ai.toolset import DeferredToolset
2122

2223
from .conftest import IsStr
2324

@@ -1181,14 +1182,16 @@ def infinite_retry_tool(ctx: RunContext[None]) -> int:
11811182

11821183

11831184
def test_deferred_tool():
1184-
agent = Agent(TestModel(), output_type=[str, DeferredToolCalls])
1185-
1186-
async def prepare_tool(ctx: RunContext[None], tool_def: ToolDefinition) -> ToolDefinition:
1187-
return replace(tool_def, kind='deferred')
1188-
1189-
@agent.tool_plain(prepare=prepare_tool)
1190-
def my_tool(x: int) -> int:
1191-
return x + 1
1185+
deferred_toolset = DeferredToolset(
1186+
[
1187+
ToolDefinition(
1188+
name='my_tool',
1189+
description='',
1190+
parameters_json_schema={'type': 'object', 'properties': {'x': {'type': 'integer'}}, 'required': ['x']},
1191+
),
1192+
]
1193+
)
1194+
agent = Agent(TestModel(), output_type=[str, DeferredToolCalls], toolsets=[deferred_toolset])
11921195

11931196
result = agent.run_sync('Hello')
11941197
assert result.output == snapshot(
@@ -1199,10 +1202,9 @@ def my_tool(x: int) -> int:
11991202
name='my_tool',
12001203
description='',
12011204
parameters_json_schema={
1202-
'additionalProperties': False,
1205+
'type': 'object',
12031206
'properties': {'x': {'type': 'integer'}},
12041207
'required': ['x'],
1205-
'type': 'object',
12061208
},
12071209
kind='deferred',
12081210
)

0 commit comments

Comments
 (0)