Skip to content

Commit b9233de

Browse files
stevenhhayk-corpusantDouweM
authored
Add process_tool_call hook to MCP servers to modify tool args, metadata, and return value (#2000)
Co-authored-by: Hayk Martiros <hayk@corpusant.ai> Co-authored-by: Douwe Maan <douwe@pydantic.dev>
1 parent 388ecc2 commit b9233de

File tree

5 files changed

+159
-11
lines changed

5 files changed

+159
-11
lines changed

docs/mcp/client.md

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,48 @@ async def main():
166166

167167
1. See [MCP Run Python](run-python.md) for more information.
168168

169+
## Tool call customisation
170+
171+
The MCP servers provide the ability to set a `process_tool_call` which allows
172+
the customisation of tool call requests and their responses.
173+
174+
A common use case for this is to inject metadata to the requests which the server
175+
call needs.
176+
177+
```python {title="mcp_process_tool_call.py" py="3.10"}
178+
from typing import Any
179+
180+
from pydantic_ai import Agent
181+
from pydantic_ai.mcp import CallToolFunc, MCPServerStdio, ToolResult
182+
from pydantic_ai.models.test import TestModel
183+
from pydantic_ai.tools import RunContext
184+
185+
186+
async def process_tool_call(
187+
ctx: RunContext[int],
188+
call_tool: CallToolFunc,
189+
tool_name: str,
190+
args: dict[str, Any],
191+
) -> ToolResult:
192+
"""A tool call processor that passes along the deps."""
193+
return await call_tool(tool_name, args, metadata={'deps': ctx.deps})
194+
195+
196+
server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], process_tool_call=process_tool_call)
197+
agent = Agent(
198+
model=TestModel(call_tools=['echo_deps']),
199+
deps_type=int,
200+
mcp_servers=[server]
201+
)
202+
203+
204+
async def main():
205+
async with agent.run_mcp_servers():
206+
result = await agent.run('Echo with deps set to 42', deps=42)
207+
print(result.output)
208+
#> {"echo_deps":{"echo":"This is an echo message","deps":42}}
209+
```
210+
169211
## Using Tool Prefixes to Avoid Naming Conflicts
170212

171213
When connecting to multiple MCP servers that might provide tools with the same name, you can use the `tool_prefix` parameter to avoid naming conflicts. This parameter adds a prefix to all tool names from a specific server.

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -762,7 +762,12 @@ async def run_tool(ctx: RunContext[DepsT], **args: Any) -> Any:
762762
# some weird edge case occurs.
763763
if not server.is_running: # pragma: no cover
764764
raise exceptions.UserError(f'MCP server is not running: {server}')
765-
result = await server.call_tool(tool_name, args)
765+
766+
if server.process_tool_call is not None:
767+
result = await server.process_tool_call(ctx, server.call_tool, tool_name, args)
768+
else:
769+
result = await server.call_tool(tool_name, args)
770+
766771
return result
767772

768773
for server in ctx.deps.mcp_servers:

pydantic_ai_slim/pydantic_ai/mcp.py

Lines changed: 62 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import functools
55
import json
66
from abc import ABC, abstractmethod
7-
from collections.abc import AsyncIterator, Sequence
7+
from collections.abc import AsyncIterator, Awaitable, Sequence
88
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
99
from dataclasses import dataclass
1010
from pathlib import Path
@@ -20,18 +20,23 @@
2020
from mcp.types import (
2121
AudioContent,
2222
BlobResourceContents,
23+
CallToolRequest,
24+
CallToolRequestParams,
25+
CallToolResult,
26+
ClientRequest,
2327
Content,
2428
EmbeddedResource,
2529
ImageContent,
2630
LoggingLevel,
31+
RequestParams,
2732
TextContent,
2833
TextResourceContents,
2934
)
3035
from typing_extensions import Self, assert_never, deprecated
3136

3237
from pydantic_ai.exceptions import ModelRetry
3338
from pydantic_ai.messages import BinaryContent
34-
from pydantic_ai.tools import ToolDefinition
39+
from pydantic_ai.tools import RunContext, ToolDefinition
3540

3641
try:
3742
from mcp.client.session import ClientSession
@@ -61,6 +66,9 @@ class MCPServer(ABC):
6166
e.g. if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar`
6267
"""
6368

69+
process_tool_call: ProcessToolCallback | None = None
70+
"""Hook to customize tool calling and optionally pass extra metadata."""
71+
6472
_client: ClientSession
6573
_read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
6674
_write_stream: MemoryObjectSendStream[SessionMessage]
@@ -114,13 +122,17 @@ async def list_tools(self) -> list[ToolDefinition]:
114122
]
115123

116124
async def call_tool(
117-
self, tool_name: str, arguments: dict[str, Any]
118-
) -> str | BinaryContent | dict[str, Any] | list[Any] | Sequence[str | BinaryContent | dict[str, Any] | list[Any]]:
125+
self,
126+
tool_name: str,
127+
arguments: dict[str, Any],
128+
metadata: dict[str, Any] | None = None,
129+
) -> ToolResult:
119130
"""Call a tool on the server.
120131
121132
Args:
122133
tool_name: The name of the tool to call.
123134
arguments: The arguments to pass to the tool.
135+
metadata: Request-level metadata (optional)
124136
125137
Returns:
126138
The result of the tool call.
@@ -129,7 +141,20 @@ async def call_tool(
129141
ModelRetry: If the tool call fails.
130142
"""
131143
try:
132-
result = await self._client.call_tool(self.get_unprefixed_tool_name(tool_name), arguments)
144+
# meta param is not provided by session yet, so build and can send_request directly.
145+
result = await self._client.send_request(
146+
ClientRequest(
147+
CallToolRequest(
148+
method='tools/call',
149+
params=CallToolRequestParams(
150+
name=self.get_unprefixed_tool_name(tool_name),
151+
arguments=arguments,
152+
_meta=RequestParams.Meta(**metadata) if metadata else None,
153+
),
154+
)
155+
),
156+
CallToolResult,
157+
)
133158
except McpError as e:
134159
raise ModelRetry(e.error.message)
135160

@@ -269,6 +294,9 @@ async def main():
269294
e.g. if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar`
270295
"""
271296

297+
process_tool_call: ProcessToolCallback | None = None
298+
"""Hook to customize tool calling and optionally pass extra metadata."""
299+
272300
timeout: float = 5
273301
""" The timeout in seconds to wait for the client to initialize."""
274302

@@ -363,6 +391,9 @@ class _MCPServerHTTP(MCPServer):
363391
For example, if `tool_prefix='foo'`, then a tool named `bar` will be registered as `foo_bar`
364392
"""
365393

394+
process_tool_call: ProcessToolCallback | None = None
395+
"""Hook to customize tool calling and optionally pass extra metadata."""
396+
366397
@property
367398
@abstractmethod
368399
def _transport_client(
@@ -521,3 +552,29 @@ async def main():
521552
@property
522553
def _transport_client(self):
523554
return streamablehttp_client # pragma: no cover
555+
556+
557+
ToolResult = (
558+
str | BinaryContent | dict[str, Any] | list[Any] | Sequence[str | BinaryContent | dict[str, Any] | list[Any]]
559+
)
560+
"""The result type of a tool call."""
561+
562+
CallToolFunc = Callable[[str, dict[str, Any], dict[str, Any] | None], Awaitable[ToolResult]]
563+
"""A function type that represents a tool call."""
564+
565+
ProcessToolCallback = Callable[
566+
[
567+
RunContext[Any],
568+
CallToolFunc,
569+
str,
570+
dict[str, Any],
571+
],
572+
Awaitable[ToolResult],
573+
]
574+
"""A process tool callback.
575+
576+
It accepts a run context, the original tool call function, a tool name, and arguments.
577+
578+
Allows wrapping an MCP server tool call to customize it, including adding extra request
579+
metadata.
580+
"""

tests/mcp_server.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from typing import Any
44

55
from mcp.server.fastmcp import Context, FastMCP, Image
6+
from mcp.server.session import ServerSessionT
7+
from mcp.shared.context import LifespanContextT, RequestT
68
from mcp.types import BlobResourceContents, EmbeddedResource, TextResourceContents
79
from pydantic import AnyUrl
810

@@ -118,6 +120,22 @@ async def get_log_level(ctx: Context) -> str: # type: ignore
118120
return log_level
119121

120122

123+
@mcp.tool()
124+
async def echo_deps(ctx: Context[ServerSessionT, LifespanContextT, RequestT]) -> dict[str, Any]:
125+
"""Echo the run context.
126+
127+
Args:
128+
ctx: Context object containing request and session information.
129+
130+
Returns:
131+
Dictionary with an echo message and the deps.
132+
"""
133+
await ctx.info('This is an info message')
134+
135+
deps: Any = getattr(ctx.request_context.meta, 'deps')
136+
return {'echo': 'This is an echo message', 'deps': deps}
137+
138+
121139
@mcp._mcp_server.set_logging_level() # pyright: ignore[reportPrivateUsage]
122140
async def set_logging_level(level: str) -> None:
123141
global log_level

tests/test_mcp.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import re
44
from pathlib import Path
5+
from typing import Any, Final
56
from unittest.mock import AsyncMock, patch
67

78
import pytest
@@ -19,19 +20,22 @@
1920
ToolReturnPart,
2021
UserPromptPart,
2122
)
23+
from pydantic_ai.models.test import TestModel
24+
from pydantic_ai.tools import RunContext
2225
from pydantic_ai.usage import Usage
2326

2427
from .conftest import IsDatetime, IsStr, try_import
2528

2629
with try_import() as imports_successful:
2730
from mcp import ErrorData, McpError
2831

29-
from pydantic_ai.mcp import MCPServerSSE, MCPServerStdio
32+
from pydantic_ai.mcp import CallToolFunc, MCPServerSSE, MCPServerStdio, ToolResult
3033
from pydantic_ai.models.google import GoogleModel
3134
from pydantic_ai.models.openai import OpenAIModel
3235
from pydantic_ai.providers.google import GoogleProvider
3336
from pydantic_ai.providers.openai import OpenAIProvider
3437

38+
TOOL_COUNT: Final[int] = 12
3539

3640
pytestmark = [
3741
pytest.mark.skipif(not imports_successful(), reason='mcp and openai not installed'),
@@ -51,7 +55,7 @@ async def test_stdio_server():
5155
server = MCPServerStdio('python', ['-m', 'tests.mcp_server'])
5256
async with server:
5357
tools = await server.list_tools()
54-
assert len(tools) == 11
58+
assert len(tools) == TOOL_COUNT
5559
assert tools[0].name == 'celsius_to_fahrenheit'
5660
assert tools[0].description.startswith('Convert Celsius to Fahrenheit.')
5761

@@ -72,7 +76,29 @@ async def test_stdio_server_with_cwd():
7276
server = MCPServerStdio('python', ['mcp_server.py'], cwd=test_dir)
7377
async with server:
7478
tools = await server.list_tools()
75-
assert len(tools) == snapshot(11)
79+
assert len(tools) == TOOL_COUNT
80+
81+
82+
async def test_process_tool_call() -> None:
83+
called: bool = False
84+
85+
async def process_tool_call(
86+
ctx: RunContext[int],
87+
call_tool: CallToolFunc,
88+
tool_name: str,
89+
args: dict[str, Any],
90+
) -> ToolResult:
91+
"""A process_tool_call that sets a flag and sends deps as metadata."""
92+
nonlocal called
93+
called = True
94+
return await call_tool(tool_name, args, {'deps': ctx.deps})
95+
96+
server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], process_tool_call=process_tool_call)
97+
async with server:
98+
agent = Agent(deps_type=int, model=TestModel(call_tools=['echo_deps']), mcp_servers=[server])
99+
result = await agent.run('Echo with deps set to 42', deps=42)
100+
assert result.output == snapshot('{"echo_deps":{"echo":"This is an echo message","deps":42}}')
101+
assert called, 'process_tool_call should have been called'
76102

77103

78104
def test_sse_server():
@@ -217,7 +243,7 @@ async def test_log_level_unset():
217243
assert server._get_log_level() is None # pyright: ignore[reportPrivateUsage]
218244
async with server:
219245
tools = await server.list_tools()
220-
assert len(tools) == snapshot(11)
246+
assert len(tools) == TOOL_COUNT
221247
assert tools[10].name == 'get_log_level'
222248

223249
result = await server.call_tool('get_log_level', {})
@@ -945,7 +971,7 @@ async def test_mcp_server_raises_mcp_error(allow_model_requests: None, agent: Ag
945971
async with agent.run_mcp_servers():
946972
with patch.object(
947973
server._client, # pyright: ignore[reportPrivateUsage]
948-
'call_tool',
974+
'send_request',
949975
new=AsyncMock(side_effect=mcp_error),
950976
):
951977
with pytest.raises(ModelRetry, match='Test MCP error conversion'):

0 commit comments

Comments
 (0)