Skip to content

Commit 9f9ee55

Browse files
committed
AbstractToolset.call_tool now takes a ToolCallPart
1 parent e6575a9 commit 9f9ee55

20 files changed

+228
-347
lines changed

docs/mcp/client.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,10 +189,12 @@ async def process_tool_call(
189189
ctx: RunContext[int],
190190
call_tool: CallToolFunc,
191191
name: str,
192-
tool_args: dict[str, Any],
192+
tool_args: str | dict[str, Any] None,
193+
*args: Any,
194+
**kwargs: Any
193195
) -> ToolResult:
194196
"""A tool call processor that passes along the deps."""
195-
return await call_tool(name, tool_args, metadata={'deps': ctx.deps})
197+
return await call_tool(name, tool_args, *args, metadata={'deps': ctx.deps}, **kwargs)
196198

197199

198200
server = MCPServerStdio('python', ['mcp_server.py'], process_tool_call=process_tool_call)

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -618,7 +618,7 @@ async def process_function_tools( # noqa: C901
618618
output_parts.append(part)
619619
else:
620620
try:
621-
result_data = await _call_tool(toolset, call, run_context)
621+
result_data = await toolset.call_tool(call, run_context)
622622
except exceptions.UnexpectedModelBehavior as e:
623623
ctx.state.increment_retries(ctx.deps.max_result_retries, e)
624624
raise e
@@ -755,7 +755,7 @@ async def _call_function_tool(
755755

756756
with tracer.start_as_current_span('running tool', attributes=span_attributes) as span:
757757
try:
758-
tool_result = await _call_tool(toolset, tool_call, run_context)
758+
tool_result = await toolset.call_tool(tool_call, run_context)
759759
except ToolRetryError as e:
760760
part = e.tool_retry
761761
if include_content and span.is_recording():
@@ -827,14 +827,6 @@ def process_content(content: Any) -> Any:
827827
return (part, extra_parts)
828828

829829

830-
async def _call_tool(
831-
toolset: AbstractToolset[DepsT], tool_call: _messages.ToolCallPart, run_context: RunContext[DepsT]
832-
) -> Any:
833-
run_context = dataclasses.replace(run_context, tool_call_id=tool_call.tool_call_id)
834-
args_dict = toolset.validate_tool_args(run_context, tool_call.tool_name, tool_call.args)
835-
return await toolset.call_tool(run_context, tool_call.tool_name, args_dict)
836-
837-
838830
@dataclasses.dataclass
839831
class _RunMessages:
840832
messages: list[_messages.ModelMessage]

pydantic_ai_slim/pydantic_ai/_output.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
_OutputSpecItem, # type: ignore[reportPrivateUsage]
3030
)
3131
from .tools import GenerateToolJsonSchema, ObjectJsonSchema, ToolDefinition
32-
from .toolsets import AbstractToolset
32+
from .toolsets._callable import CallableToolset
3333
from .toolsets._run import RunToolset
3434

3535
if TYPE_CHECKING:
@@ -829,7 +829,7 @@ async def process(
829829

830830

831831
@dataclass(init=False)
832-
class OutputToolset(AbstractToolset[AgentDepsT]):
832+
class OutputToolset(CallableToolset[AgentDepsT]):
833833
"""A toolset that contains output tools."""
834834

835835
_tool_defs: list[ToolDefinition]
@@ -930,9 +930,7 @@ def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> Sc
930930
def _max_retries_for_tool(self, name: str) -> int:
931931
return self.max_retries
932932

933-
async def call_tool(
934-
self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any
935-
) -> Any:
933+
async def _call_tool(self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any]) -> Any:
936934
output = await self.processors[name].call(tool_args, ctx)
937935
for validator in self.output_validators:
938936
output = await validator.validate(output, ctx, wrap_validation_errors=False)

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
166166
_prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False)
167167
_max_result_retries: int = dataclasses.field(repr=False)
168168

169-
_running_count: int = dataclasses.field(repr=False)
169+
_entered_count: int = dataclasses.field(repr=False)
170170
_exit_stack: AsyncExitStack | None = dataclasses.field(repr=False)
171171

172172
@overload
@@ -430,7 +430,7 @@ def __init__(
430430
self._override_model: ContextVar[_utils.Option[models.Model]] = ContextVar('_override_model', default=None)
431431

432432
self._exit_stack = None
433-
self._running_count = 0
433+
self._entered_count = 0
434434

435435
@staticmethod
436436
def instrument_all(instrument: InstrumentationSettings | bool = True) -> None:
@@ -1788,18 +1788,17 @@ def is_end_node(
17881788

17891789
async def __aenter__(self) -> Self:
17901790
"""Enter the agent. This will start all [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] registered with the agent so they can be used in a run."""
1791-
if self._running_count == 0:
1791+
if self._entered_count == 0:
17921792
self._exit_stack = AsyncExitStack()
17931793
await self._exit_stack.enter_async_context(self._toolset)
1794-
self._running_count += 1
1794+
self._entered_count += 1
17951795
return self
17961796

17971797
async def __aexit__(self, *args: Any) -> bool | None:
1798-
self._running_count -= 1
1799-
if self._running_count <= 0 and self._exit_stack is not None:
1798+
self._entered_count -= 1
1799+
if self._entered_count <= 0 and self._exit_stack is not None:
18001800
await self._exit_stack.aclose()
18011801
self._exit_stack = None
1802-
return None
18031802

18041803
def set_mcp_sampling_model(self, model: models.Model | models.KnownModelName | str | None = None) -> None:
18051804
"""Set the sampling model on all MCP servers registered with the agent.

pydantic_ai_slim/pydantic_ai/mcp.py

Lines changed: 57 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@
33
import base64
44
import functools
55
from abc import ABC, abstractmethod
6-
from collections.abc import AsyncIterator, Sequence
6+
from collections.abc import AsyncIterator, Awaitable, Sequence
77
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
88
from dataclasses import dataclass
99
from pathlib import Path
10-
from types import TracebackType
1110
from typing import Any, Callable
1211

1312
import anyio
@@ -19,10 +18,9 @@
1918
from pydantic_ai._run_context import RunContext
2019
from pydantic_ai.tools import ToolDefinition
2120

22-
from .toolsets import AbstractToolset
21+
from .toolsets._callable import CallableToolset
2322
from .toolsets._run import RunToolset
2423
from .toolsets.prefixed import PrefixedToolset
25-
from .toolsets.processed import ProcessedToolset, ToolProcessFunc
2624

2725
try:
2826
from mcp import types as mcp_types
@@ -45,7 +43,7 @@
4543
__all__ = 'MCPServer', 'MCPServerStdio', 'MCPServerHTTP', 'MCPServerSSE', 'MCPServerStreamableHTTP'
4644

4745

48-
class MCPServer(AbstractToolset[Any], ABC):
46+
class MCPServer(CallableToolset[Any], ABC):
4947
"""Base class for attaching agents to MCP servers.
5048
5149
See <https://modelcontextprotocol.io> for more information.
@@ -56,7 +54,7 @@ class MCPServer(AbstractToolset[Any], ABC):
5654
log_level: mcp_types.LoggingLevel | None = None
5755
log_handler: LoggingFnT | None = None
5856
timeout: float = 5
59-
process_tool_call: ToolProcessFunc[Any] | None = None
57+
process_tool_call: ProcessToolCallback | None = None
6058
allow_sampling: bool = True
6159
max_retries: int = 1
6260
sampling_model: models.Model | None = None
@@ -102,14 +100,11 @@ async def list_tools(self) -> list[mcp_types.Tool]:
102100
result = await self._client.list_tools()
103101
return result.tools
104102

105-
async def call_tool(
103+
async def _call_tool(
106104
self,
107105
ctx: RunContext[Any],
108106
name: str,
109107
tool_args: dict[str, Any],
110-
*args: Any,
111-
metadata: dict[str, Any] | None = None,
112-
**kwargs: Any,
113108
) -> ToolResult:
114109
"""Call a tool on the server.
115110
@@ -127,36 +122,41 @@ async def call_tool(
127122
Raises:
128123
ModelRetry: If the tool call fails.
129124
"""
130-
async with self: # Ensure server is running
131-
try:
132-
result = await self._client.send_request(
133-
mcp_types.ClientRequest(
134-
mcp_types.CallToolRequest(
135-
method='tools/call',
136-
params=mcp_types.CallToolRequestParams(
137-
name=name,
138-
arguments=tool_args,
139-
_meta=mcp_types.RequestParams.Meta(**metadata) if metadata else None,
140-
),
141-
)
142-
),
143-
mcp_types.CallToolResult,
144-
)
145-
except McpError as e:
146-
raise exceptions.ModelRetry(e.error.message)
147125

148-
content = [self._map_tool_result_part(part) for part in result.content]
126+
async def _call(name: str, args: dict[str, Any], metadata: dict[str, Any] | None = None) -> ToolResult:
127+
async with self: # Ensure server is running
128+
try:
129+
result = await self._client.send_request(
130+
mcp_types.ClientRequest(
131+
mcp_types.CallToolRequest(
132+
method='tools/call',
133+
params=mcp_types.CallToolRequestParams(
134+
name=name,
135+
arguments=args,
136+
_meta=mcp_types.RequestParams.Meta(**metadata) if metadata else None,
137+
),
138+
)
139+
),
140+
mcp_types.CallToolResult,
141+
)
142+
except McpError as e:
143+
raise exceptions.ModelRetry(e.error.message)
144+
145+
content = [self._map_tool_result_part(part) for part in result.content]
146+
147+
if result.isError:
148+
text = '\n'.join(str(part) for part in content)
149+
raise exceptions.ModelRetry(text)
150+
else:
151+
return content[0] if len(content) == 1 else content
149152

150-
if result.isError:
151-
text = '\n'.join(str(part) for part in content)
152-
raise exceptions.ModelRetry(text)
153+
if self.process_tool_call is not None:
154+
return await self.process_tool_call(ctx, _call, name, tool_args)
153155
else:
154-
return content[0] if len(content) == 1 else content
156+
return await _call(name, tool_args)
155157

156158
async def prepare_for_run(self, ctx: RunContext[Any]) -> RunToolset[Any]:
157159
frozen_toolset = RunToolset(self, ctx, await self.list_tool_defs())
158-
if self.process_tool_call:
159-
frozen_toolset = await ProcessedToolset(frozen_toolset, self.process_tool_call).prepare_for_run(ctx)
160160
if self.tool_prefix:
161161
frozen_toolset = await PrefixedToolset(frozen_toolset, self.tool_prefix).prepare_for_run(ctx)
162162
return RunToolset(frozen_toolset, ctx, original=self)
@@ -208,12 +208,7 @@ async def __aenter__(self) -> Self:
208208
self._running_count += 1
209209
return self
210210

211-
async def __aexit__(
212-
self,
213-
exc_type: type[BaseException] | None,
214-
exc_value: BaseException | None,
215-
traceback: TracebackType | None,
216-
) -> bool | None:
211+
async def __aexit__(self, *args: Any) -> bool | None:
217212
self._running_count -= 1
218213
if self._running_count <= 0:
219214
await self._exit_stack.aclose()
@@ -364,7 +359,7 @@ async def main():
364359
timeout: float = 5
365360
"""The timeout in seconds to wait for the client to initialize."""
366361

367-
process_tool_call: ToolProcessFunc[Any] | None = None
362+
process_tool_call: ProcessToolCallback | None = None
368363
"""Hook to customize tool calling and optionally pass extra metadata."""
369364

370365
allow_sampling: bool = True
@@ -465,7 +460,7 @@ class _MCPServerHTTP(MCPServer):
465460
If the connection cannot be established within this time, the operation will fail.
466461
"""
467462

468-
process_tool_call: ToolProcessFunc[Any] | None = None
463+
process_tool_call: ProcessToolCallback | None = None
469464
"""Hook to customize tool calling and optionally pass extra metadata."""
470465

471466
allow_sampling: bool = True
@@ -642,3 +637,23 @@ def _transport_client(self):
642637
| Sequence[str | messages.BinaryContent | dict[str, Any] | list[Any]]
643638
)
644639
"""The result type of an MCP tool call."""
640+
641+
CallToolFunc = Callable[[str, dict[str, Any], dict[str, Any] | None], Awaitable[ToolResult]]
642+
"""A function type that represents a tool call."""
643+
644+
ProcessToolCallback = Callable[
645+
[
646+
RunContext[Any],
647+
CallToolFunc,
648+
str,
649+
dict[str, Any],
650+
],
651+
Awaitable[ToolResult],
652+
]
653+
"""A process tool callback.
654+
655+
It accepts a run context, the original tool call function, a tool name, and arguments.
656+
657+
Allows wrapping an MCP server tool call to customize it, including adding extra request
658+
metadata.
659+
"""

pydantic_ai_slim/pydantic_ai/result.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import warnings
44
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
55
from copy import copy
6-
from dataclasses import dataclass, field, replace
6+
from dataclasses import dataclass, field
77
from datetime import datetime
88
from typing import Generic, cast
99

@@ -106,11 +106,7 @@ async def _validate_response(
106106
raise exceptions.UnexpectedModelBehavior( # pragma: no cover
107107
f'Invalid response, unable to find tool call for {output_tool_name!r}'
108108
)
109-
run_context = replace(self._run_ctx, tool_call_id=tool_call.tool_call_id)
110-
args_dict = self._toolset.validate_tool_args(
111-
run_context, tool_call.tool_name, tool_call.args, allow_partial=allow_partial
112-
)
113-
return await self._toolset.call_tool(run_context, tool_call.tool_name, args_dict)
109+
return await self._toolset.call_tool(tool_call, self._run_ctx, allow_partial=allow_partial)
114110
elif deferred_tool_calls := self._toolset.get_deferred_tool_calls(message.parts):
115111
if not self._output_schema.allows_deferred_tool_calls:
116112
raise exceptions.UserError( # pragma: no cover
@@ -442,11 +438,7 @@ async def validate_structured_output(
442438
raise exceptions.UnexpectedModelBehavior( # pragma: no cover
443439
f'Invalid response, unable to find tool call for {self._output_tool_name!r}'
444440
)
445-
run_context = replace(self._run_ctx, tool_call_id=tool_call.tool_call_id)
446-
args_dict = self._toolset.validate_tool_args(
447-
run_context, tool_call.tool_name, tool_call.args, allow_partial=allow_partial
448-
)
449-
return await self._toolset.call_tool(run_context, tool_call.tool_name, args_dict)
441+
return await self._toolset.call_tool(tool_call, self._run_ctx, allow_partial=allow_partial)
450442
elif deferred_tool_calls := self._toolset.get_deferred_tool_calls(message.parts):
451443
if not self._output_schema.allows_deferred_tool_calls:
452444
raise exceptions.UserError(

pydantic_ai_slim/pydantic_ai/toolsets/__init__.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
from __future__ import annotations
22

33
from abc import ABC, abstractmethod
4-
from types import TracebackType
5-
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal
4+
from typing import TYPE_CHECKING, Any, Callable, Generic
65

7-
from pydantic_core import SchemaValidator
86
from typing_extensions import Self
97

108
from .._run_context import AgentDepsT, RunContext
9+
from ..messages import ToolCallPart
1110
from ..tools import ToolDefinition
1211

1312
if TYPE_CHECKING:
@@ -34,9 +33,7 @@ def tool_name_conflict_hint(self) -> str:
3433
async def __aenter__(self) -> Self:
3534
return self
3635

37-
async def __aexit__(
38-
self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None
39-
) -> bool | None:
36+
async def __aexit__(self, *args: Any) -> bool | None:
4037
return None
4138

4239
@abstractmethod
@@ -55,28 +52,12 @@ def tool_names(self) -> list[str]:
5552
def get_tool_def(self, name: str) -> ToolDefinition | None:
5653
return next((tool_def for tool_def in self.tool_defs if tool_def.name == name), None)
5754

58-
@abstractmethod
59-
def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator:
60-
raise NotImplementedError()
61-
62-
def validate_tool_args(
63-
self, ctx: RunContext[AgentDepsT], name: str, args: str | dict[str, Any] | None, allow_partial: bool = False
64-
) -> dict[str, Any]:
65-
pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off'
66-
validator = self._get_tool_args_validator(ctx, name)
67-
if isinstance(args, str):
68-
return validator.validate_json(args or '{}', allow_partial=pyd_allow_partial)
69-
else:
70-
return validator.validate_python(args or {}, allow_partial=pyd_allow_partial)
71-
7255
@abstractmethod
7356
def _max_retries_for_tool(self, name: str) -> int:
7457
raise NotImplementedError()
7558

7659
@abstractmethod
77-
async def call_tool(
78-
self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any
79-
) -> Any:
60+
async def call_tool(self, call: ToolCallPart, ctx: RunContext[AgentDepsT], allow_partial: bool = False) -> Any:
8061
raise NotImplementedError()
8162

8263
def accept(self, visitor: Callable[[AbstractToolset[AgentDepsT]], Any]) -> Any:

0 commit comments

Comments
 (0)