Skip to content

Commit 1c2d221

Browse files
committed
Address feedback
1 parent ecf6f75 commit 1c2d221

File tree

13 files changed

+101
-72
lines changed

13 files changed

+101
-72
lines changed

docs/mcp/client.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,9 @@ _(This example is complete, it can be run "as is" with Python 3.10+ — you'll n
137137
The other transport offered by MCP is the [stdio transport](https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio) where the server is run as a subprocess and communicates with the client over `stdin` and `stdout`. In this case, you'd use the [`MCPServerStdio`][pydantic_ai.mcp.MCPServerStdio] class.
138138

139139
!!! note
140-
When using [`MCPServerStdio`][pydantic_ai.mcp.MCPServerStdio] servers, the [`async with agent`][pydantic_ai.Agent.__aenter__] context manager is responsible for starting and stopping the server.
140+
When using [`MCPServerStdio`][pydantic_ai.mcp.MCPServerStdio] servers as `toolsets` on an [`Agent`][pydantic_ai.Agent], you can use the [`async with agent`][pydantic_ai.Agent.__aenter__] context manager to start and stop the server around the context where it'll be used. You can also use [`async with server`][pydantic_ai.mcp.MCPServerStdio.__aenter__] to manage the starting and stopping of a specific server, for example if you'd like to use it with multiple agents.
141+
142+
If you don't explicitly start the server using one of these context managers, it will automatically be started when it's needed (e.g. to list the available tools or call a specific tool), but it's more efficient to do so around the entire context where you expect it to be used.
141143

142144
```python {title="mcp_stdio_client.py" py="3.10"}
143145
from pydantic_ai import Agent

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -639,14 +639,16 @@ async def process_function_tools( # noqa: C901
639639
# Then, we handle function tool calls
640640
calls_to_run: list[_messages.ToolCallPart] = []
641641
if final_result and ctx.deps.end_strategy == 'early':
642-
for call in tool_calls_by_kind['function']:
643-
output_parts.append(
642+
output_parts.extend(
643+
[
644644
_messages.ToolReturnPart(
645645
tool_name=call.tool_name,
646646
content='Tool not executed - a final result was already processed.',
647647
tool_call_id=call.tool_call_id,
648648
)
649-
)
649+
for call in tool_calls_by_kind['function']
650+
]
651+
)
650652
else:
651653
calls_to_run.extend(tool_calls_by_kind['function'])
652654

@@ -776,8 +778,8 @@ async def _call_function_tool(
776778
def process_content(content: Any) -> Any:
777779
if isinstance(content, _messages.ToolReturn):
778780
raise exceptions.UserError(
779-
f"{tool_call.tool_name}'s return contains invalid nested ToolReturn objects. "
780-
f'ToolReturn should be used directly.'
781+
f'The return value of tool {tool_call.tool_name!r} contains invalid nested `ToolReturn` objects. '
782+
f'`ToolReturn` should be used directly.'
781783
)
782784
elif isinstance(content, _messages.MultiModalContentTypes):
783785
if isinstance(content, _messages.BinaryContent):
@@ -792,8 +794,8 @@ def process_content(content: Any) -> Any:
792794
)
793795
)
794796
return f'See file {identifier}'
795-
else:
796-
return content
797+
798+
return content
797799

798800
if isinstance(tool_result, _messages.ToolReturn):
799801
if (
@@ -805,7 +807,7 @@ def process_content(content: Any) -> Any:
805807
)
806808
):
807809
raise exceptions.UserError(
808-
f"{tool_call.tool_name}'s `return_value` contains invalid nested MultiModalContentTypes objects. "
810+
f'The `return_value` of tool {tool_call.tool_name!r} contains invalid nested `MultiModalContentTypes` objects. '
809811
f'Please use `content` instead.'
810812
)
811813

pydantic_ai_slim/pydantic_ai/_output.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDa
294294
@dataclass(init=False)
295295
class OutputSchemaWithoutMode(BaseOutputSchema[OutputDataT]):
296296
processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT]
297-
_toolset: OutputToolset[Any] | None = None
297+
_toolset: OutputToolset[Any] | None
298298

299299
def __init__(
300300
self,
@@ -477,7 +477,7 @@ async def process(
477477

478478
@dataclass(init=False)
479479
class ToolOutputSchema(OutputSchema[OutputDataT]):
480-
_toolset: OutputToolset[Any] | None = None
480+
_toolset: OutputToolset[Any] | None
481481

482482
def __init__(self, toolset: OutputToolset[Any] | None, allows_deferred_tool_calls: bool):
483483
super().__init__(allows_deferred_tool_calls)
@@ -834,8 +834,8 @@ class OutputToolset(CallableToolset[AgentDepsT]):
834834

835835
_tool_defs: list[ToolDefinition]
836836
processors: dict[str, ObjectOutputProcessor[Any]]
837-
max_retries: int = field(default=1)
838-
output_validators: list[OutputValidator[AgentDepsT, Any]] = field(default_factory=list)
837+
max_retries: int
838+
output_validators: list[OutputValidator[AgentDepsT, Any]]
839839

840840
@classmethod
841841
def build(
@@ -910,12 +910,12 @@ def __init__(
910910
tool_defs: list[ToolDefinition],
911911
processors: dict[str, ObjectOutputProcessor[Any]],
912912
max_retries: int = 1,
913-
output_validators: list[OutputValidator[AgentDepsT, Any]] = [],
913+
output_validators: list[OutputValidator[AgentDepsT, Any]] | None = None,
914914
):
915915
self.processors = processors
916916
self._tool_defs = tool_defs
917917
self.max_retries = max_retries
918-
self.output_validators = output_validators
918+
self.output_validators = output_validators or []
919919

920920
async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]:
921921
return RunToolset(self, ctx)

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import inspect
55
import json
66
import warnings
7+
from asyncio import Lock
78
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
89
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager
910
from contextvars import ContextVar
@@ -166,6 +167,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
166167
_prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False)
167168
_max_result_retries: int = dataclasses.field(repr=False)
168169

170+
_enter_lock: Lock = dataclasses.field(repr=False)
169171
_entered_count: int = dataclasses.field(repr=False)
170172
_exit_stack: AsyncExitStack | None = dataclasses.field(repr=False)
171173

@@ -433,8 +435,9 @@ def __init__(
433435
self._override_deps: ContextVar[_utils.Option[AgentDepsT]] = ContextVar('_override_deps', default=None)
434436
self._override_model: ContextVar[_utils.Option[models.Model]] = ContextVar('_override_model', default=None)
435437

436-
self._exit_stack = None
438+
self._enter_lock = Lock()
437439
self._entered_count = 0
440+
self._exit_stack = None
438441

439442
@staticmethod
440443
def instrument_all(instrument: InstrumentationSettings | bool = True) -> None:
@@ -1795,18 +1798,19 @@ def is_end_node(
17951798
return isinstance(node, End)
17961799

17971800
async def __aenter__(self) -> Self:
1798-
"""Enter the agent. This will start all [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] registered with the agent so they can be used in a run."""
1799-
if self._entered_count == 0:
1800-
self._exit_stack = AsyncExitStack()
1801-
await self._exit_stack.enter_async_context(self._toolset)
1802-
self._entered_count += 1
1801+
async with self._enter_lock:
1802+
if self._entered_count == 0:
1803+
self._exit_stack = AsyncExitStack()
1804+
await self._exit_stack.enter_async_context(self._toolset)
1805+
self._entered_count += 1
18031806
return self
18041807

18051808
async def __aexit__(self, *args: Any) -> bool | None:
1806-
self._entered_count -= 1
1807-
if self._entered_count <= 0 and self._exit_stack is not None:
1808-
await self._exit_stack.aclose()
1809-
self._exit_stack = None
1809+
async with self._enter_lock:
1810+
self._entered_count -= 1
1811+
if self._entered_count == 0 and self._exit_stack is not None:
1812+
await self._exit_stack.aclose()
1813+
self._exit_stack = None
18101814

18111815
def set_mcp_sampling_model(self, model: models.Model | models.KnownModelName | str | None = None) -> None:
18121816
"""Set the sampling model on all MCP servers registered with the agent.

pydantic_ai_slim/pydantic_ai/mcp.py

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
import base64
44
import functools
55
from abc import ABC, abstractmethod
6+
from asyncio import Lock
67
from collections.abc import AsyncIterator, Awaitable, Sequence
78
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
8-
from dataclasses import dataclass
9+
from dataclasses import dataclass, field
910
from pathlib import Path
1011
from typing import Any, Callable
1112

@@ -60,12 +61,18 @@ class MCPServer(CallableToolset[Any], ABC):
6061
sampling_model: models.Model | None = None
6162
# } end of "abstract fields"
6263

63-
_running_count: int = 0
64+
_enter_lock: Lock = field(compare=False)
65+
_running_count: int
66+
_exit_stack: AsyncExitStack | None
6467

6568
_client: ClientSession
6669
_read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
6770
_write_stream: MemoryObjectSendStream[SessionMessage]
68-
_exit_stack: AsyncExitStack
71+
72+
def __post_init__(self):
73+
self._enter_lock = Lock()
74+
self._running_count = 0
75+
self._exit_stack = None
6976

7077
@abstractmethod
7178
@asynccontextmanager
@@ -86,7 +93,7 @@ def name(self) -> str:
8693
return repr(self)
8794

8895
@property
89-
def tool_name_conflict_hint(self) -> str:
96+
def _tool_name_conflict_hint(self) -> str:
9097
return 'Consider setting `tool_prefix` to avoid name conflicts.'
9198

9299
async def list_tools(self) -> list[mcp_types.Tool]:
@@ -188,30 +195,35 @@ def _max_retries_for_tool(self, name: str) -> int:
188195
return self.max_retries
189196

190197
async def __aenter__(self) -> Self:
191-
if self._running_count == 0:
192-
self._exit_stack = AsyncExitStack()
193-
194-
self._read_stream, self._write_stream = await self._exit_stack.enter_async_context(self.client_streams())
195-
client = ClientSession(
196-
read_stream=self._read_stream,
197-
write_stream=self._write_stream,
198-
sampling_callback=self._sampling_callback if self.allow_sampling else None,
199-
logging_callback=self.log_handler,
200-
)
201-
self._client = await self._exit_stack.enter_async_context(client)
198+
async with self._enter_lock:
199+
if self._running_count == 0:
200+
self._exit_stack = AsyncExitStack()
201+
202+
self._read_stream, self._write_stream = await self._exit_stack.enter_async_context(
203+
self.client_streams()
204+
)
205+
client = ClientSession(
206+
read_stream=self._read_stream,
207+
write_stream=self._write_stream,
208+
sampling_callback=self._sampling_callback if self.allow_sampling else None,
209+
logging_callback=self.log_handler,
210+
)
211+
self._client = await self._exit_stack.enter_async_context(client)
202212

203-
with anyio.fail_after(self.timeout):
204-
await self._client.initialize()
213+
with anyio.fail_after(self.timeout):
214+
await self._client.initialize()
205215

206-
if log_level := self.log_level:
207-
await self._client.set_logging_level(log_level)
208-
self._running_count += 1
216+
if log_level := self.log_level:
217+
await self._client.set_logging_level(log_level)
218+
self._running_count += 1
209219
return self
210220

211221
async def __aexit__(self, *args: Any) -> bool | None:
212-
self._running_count -= 1
213-
if self._running_count <= 0:
214-
await self._exit_stack.aclose()
222+
async with self._enter_lock:
223+
self._running_count -= 1
224+
if self._running_count <= 0 and self._exit_stack is not None:
225+
await self._exit_stack.aclose()
226+
self._exit_stack = None
215227

216228
@property
217229
def is_running(self) -> bool:

pydantic_ai_slim/pydantic_ai/tools.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,7 @@ class ToolDefinition:
364364

365365
kind: ToolKind = field(default='function')
366366
"""The kind of tool:
367+
367368
- `'function'`: a tool that can be executed by Pydantic AI and has its result returned to the model
368369
- `'output'`: a tool that passes through an output value that ends the run
369370
- `'deferred'`: a tool that will be executed not by Pydantic AI, but by the upstream service that called the agent, such as a web application that supports frontend-defined tools provided to Pydantic AI via e.g. [AG-UI](https://docs.ag-ui.com/concepts/tools#frontend-defined-tools).

pydantic_ai_slim/pydantic_ai/toolsets/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def name(self) -> str:
2727
return self.__class__.__name__.replace('Toolset', ' toolset')
2828

2929
@property
30-
def tool_name_conflict_hint(self) -> str:
30+
def _tool_name_conflict_hint(self) -> str:
3131
return 'Consider renaming the tool or wrapping the toolset in a `PrefixedToolset` to avoid name conflicts.'
3232

3333
async def __aenter__(self) -> Self:

pydantic_ai_slim/pydantic_ai/toolsets/_callable.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from abc import ABC, abstractmethod
44
from dataclasses import replace
5-
from typing import TYPE_CHECKING, Any, Literal
5+
from typing import TYPE_CHECKING, Any
66

77
from pydantic_core import SchemaValidator
88

@@ -28,7 +28,7 @@ async def _call_tool(self, ctx: RunContext[AgentDepsT], name: str, tool_args: di
2828
async def call_tool(self, call: ToolCallPart, ctx: RunContext[AgentDepsT], allow_partial: bool = False) -> Any:
2929
ctx = replace(ctx, tool_name=call.tool_name, tool_call_id=call.tool_call_id)
3030

31-
pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off'
31+
pyd_allow_partial = 'trailing-strings' if allow_partial else 'off'
3232
validator = self._get_tool_args_validator(ctx, call.tool_name)
3333
if isinstance(call.args, str):
3434
args_dict = validator.validate_json(call.args or '{}', allow_partial=pyd_allow_partial)

pydantic_ai_slim/pydantic_ai/toolsets/combined.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import asyncio
44
from collections.abc import Sequence
55
from contextlib import AsyncExitStack
6-
from dataclasses import dataclass
6+
from dataclasses import dataclass, field
77
from typing import Any, Callable
88

99
from typing_extensions import Self
@@ -22,12 +22,16 @@ class CombinedToolset(AbstractToolset[AgentDepsT]):
2222

2323
toolsets: list[AbstractToolset[AgentDepsT]]
2424
_toolset_per_tool_name: dict[str, AbstractToolset[AgentDepsT]]
25-
_exit_stack: AsyncExitStack | None
25+
26+
_enter_lock: asyncio.Lock = field(compare=False)
2627
_entered_count: int
28+
_exit_stack: AsyncExitStack | None
2729

2830
def __init__(self, toolsets: Sequence[AbstractToolset[AgentDepsT]]):
29-
self._exit_stack = None
31+
self._enter_lock = asyncio.Lock()
3032
self._entered_count = 0
33+
self._exit_stack = None
34+
3135
self.toolsets = list(toolsets)
3236

3337
self._toolset_per_tool_name = {}
@@ -36,28 +40,30 @@ def __init__(self, toolsets: Sequence[AbstractToolset[AgentDepsT]]):
3640
try:
3741
existing_toolset = self._toolset_per_tool_name[name]
3842
raise UserError(
39-
f'{toolset.name} defines a tool whose name conflicts with existing tool from {existing_toolset.name}: {name!r}. {toolset.tool_name_conflict_hint}'
43+
f'{toolset.name} defines a tool whose name conflicts with existing tool from {existing_toolset.name}: {name!r}. {toolset._tool_name_conflict_hint}'
4044
)
4145
except KeyError:
4246
pass
4347
self._toolset_per_tool_name[name] = toolset
4448

4549
async def __aenter__(self) -> Self:
46-
if self._entered_count == 0:
47-
self._exit_stack = AsyncExitStack()
48-
for toolset in self.toolsets:
49-
await self._exit_stack.enter_async_context(toolset)
50-
self._entered_count += 1
50+
async with self._enter_lock:
51+
if self._entered_count == 0:
52+
self._exit_stack = AsyncExitStack()
53+
for toolset in self.toolsets:
54+
await self._exit_stack.enter_async_context(toolset)
55+
self._entered_count += 1
5156
return self
5257

5358
async def __aexit__(self, *args: Any) -> bool | None:
54-
self._entered_count -= 1
55-
if self._entered_count <= 0 and self._exit_stack is not None:
56-
await self._exit_stack.aclose()
57-
self._exit_stack = None
59+
async with self._enter_lock:
60+
self._entered_count -= 1
61+
if self._entered_count == 0 and self._exit_stack is not None:
62+
await self._exit_stack.aclose()
63+
self._exit_stack = None
5864

5965
async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[AgentDepsT]:
60-
toolsets_for_run = await asyncio.gather(*[toolset.prepare_for_run(ctx) for toolset in self.toolsets])
66+
toolsets_for_run = await asyncio.gather(*(toolset.prepare_for_run(ctx) for toolset in self.toolsets))
6167
combined_for_run = CombinedToolset(toolsets_for_run)
6268
return RunToolset(combined_for_run, ctx)
6369

pydantic_ai_slim/pydantic_ai/toolsets/deferred.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111

1212

1313
class DeferredToolset(AbstractToolset[AgentDepsT]):
14-
"""A toolset that holds deferred tool."""
14+
"""A toolset that holds deferred tools.
15+
16+
See [`ToolDefinition.kind`][pydantic_ai.tools.ToolDefinition.kind] for more information about deferred tools.
17+
"""
1518

1619
_tool_defs: list[ToolDefinition]
1720

0 commit comments

Comments
 (0)