Skip to content

Commit e6575a9

Browse files
committed
Add tests
1 parent 778962c commit e6575a9

File tree

6 files changed

+588
-14
lines changed

6 files changed

+588
-14
lines changed

pydantic_ai_slim/pydantic_ai/_output.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ async def validate(
114114
content=r.message,
115115
tool_name=run_context.tool_name,
116116
)
117-
if run_context.tool_call_id:
117+
if run_context.tool_call_id: # pragma: no cover
118118
m.tool_call_id = run_context.tool_call_id
119119
raise ToolRetryError(m) from r
120120
else:

pydantic_ai_slim/pydantic_ai/mcp.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,11 @@ async def __aexit__(
218218
if self._running_count <= 0:
219219
await self._exit_stack.aclose()
220220

221+
@property
222+
def is_running(self) -> bool:
223+
"""Check if the MCP server is running."""
224+
return bool(self._running_count)
225+
221226
async def _sampling_callback(
222227
self, context: RequestContext[ClientSession, Any], params: mcp_types.CreateMessageRequestParams
223228
) -> mcp_types.CreateMessageResult | mcp_types.ErrorData:

pydantic_ai_slim/pydantic_ai/toolsets/combined.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ class CombinedToolset(AbstractToolset[AgentDepsT]):
2424
toolsets: list[AbstractToolset[AgentDepsT]]
2525
_toolset_per_tool_name: dict[str, AbstractToolset[AgentDepsT]]
2626
_exit_stack: AsyncExitStack | None
27-
_running_count: int
27+
_entered_count: int
2828

2929
def __init__(self, toolsets: Sequence[AbstractToolset[AgentDepsT]]):
3030
self._exit_stack = None
31-
self._running_count = 0
31+
self._entered_count = 0
3232
self.toolsets = list(toolsets)
3333

3434
self._toolset_per_tool_name = {}
@@ -44,18 +44,18 @@ def __init__(self, toolsets: Sequence[AbstractToolset[AgentDepsT]]):
4444
self._toolset_per_tool_name[name] = toolset
4545

4646
async def __aenter__(self) -> Self:
47-
if self._running_count == 0:
47+
if self._entered_count == 0:
4848
self._exit_stack = AsyncExitStack()
4949
for toolset in self.toolsets:
5050
await self._exit_stack.enter_async_context(toolset)
51-
self._running_count += 1
51+
self._entered_count += 1
5252
return self
5353

5454
async def __aexit__(
5555
self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None
5656
) -> bool | None:
57-
self._running_count -= 1
58-
if self._running_count <= 0 and self._exit_stack is not None:
57+
self._entered_count -= 1
58+
if self._entered_count <= 0 and self._exit_stack is not None:
5959
await self._exit_stack.aclose()
6060
self._exit_stack = None
6161
return None

tests/test_agent.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3600,11 +3600,24 @@ async def only_if_plan_presented(
36003600
)
36013601

36023602

3603-
async def test_reentrant_context_manager():
3604-
agent = Agent('test')
3603+
async def test_context_manager():
3604+
try:
3605+
from pydantic_ai.mcp import MCPServerStdio
3606+
except ImportError:
3607+
return
3608+
3609+
server1 = MCPServerStdio('python', ['-m', 'tests.mcp_server'])
3610+
server2 = MCPServerStdio('python', ['-m', 'tests.mcp_server'])
3611+
toolset = CombinedToolset([server1, PrefixedToolset(server2, 'prefix')])
3612+
agent = Agent('test', toolsets=[toolset])
3613+
36053614
async with agent:
3615+
assert server1.is_running
3616+
assert server2.is_running
3617+
36063618
async with agent:
3607-
pass
3619+
assert server1.is_running
3620+
assert server2.is_running
36083621

36093622

36103623
def test_set_mcp_sampling_model():
@@ -3616,7 +3629,7 @@ def test_set_mcp_sampling_model():
36163629
test_model = TestModel()
36173630
server1 = MCPServerStdio('python', ['-m', 'tests.mcp_server'])
36183631
server2 = MCPServerStdio('python', ['-m', 'tests.mcp_server'], sampling_model=test_model)
3619-
toolset = CombinedToolset([server1, PrefixedToolset(server2, 'prefix_')])
3632+
toolset = CombinedToolset([server1, PrefixedToolset(server2, 'prefix')])
36203633
agent = Agent(None, toolsets=[toolset])
36213634

36223635
with pytest.raises(UserError, match='No sampling model provided and no model set on the agent.'):

tests/test_logfire.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,7 @@ async def add_numbers(x: int, y: int) -> int:
554554
assert result.output == snapshot('{"add_numbers":84}')
555555
except UnexpectedModelBehavior:
556556
if not tool_error:
557-
raise
557+
raise # pragma: no cover
558558

559559
summary = get_logfire_summary()
560560

0 commit comments

Comments
 (0)