Skip to content

Commit 131a325

Browse files
committed
Improve test coverage
1 parent dea8050 commit 131a325

File tree

10 files changed

+204
-91
lines changed

10 files changed

+204
-91
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -538,16 +538,17 @@ async def _handle_text_response(
538538

539539
text = '\n\n'.join(texts)
540540
try:
541+
run_context = build_run_context(ctx)
541542
if isinstance(output_schema, _output.TextOutputSchema):
542-
run_context = build_run_context(ctx)
543543
result_data = await output_schema.process(text, run_context)
544544
else:
545545
m = _messages.RetryPromptPart(
546546
content='Plain text responses are not permitted, please include your response in a tool call',
547547
)
548548
raise ToolRetryError(m)
549549

550-
result_data = await _validate_output(result_data, ctx, None)
550+
for validator in ctx.deps.output_validators:
551+
result_data = await validator.validate(result_data, run_context)
551552
except ToolRetryError as e:
552553
ctx.state.increment_retries(ctx.deps.max_result_retries, e)
553554
return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry]))
@@ -834,17 +835,6 @@ async def _call_tool(
834835
return await toolset.call_tool(run_context, tool_call.tool_name, args_dict)
835836

836837

837-
async def _validate_output(
838-
result_data: T,
839-
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
840-
tool_call: _messages.ToolCallPart | None,
841-
) -> T:
842-
for validator in ctx.deps.output_validators:
843-
run_context = build_run_context(ctx)
844-
result_data = await validator.validate(result_data, tool_call, run_context)
845-
return result_data
846-
847-
848838
@dataclasses.dataclass
849839
class _RunMessages:
850840
messages: list[_messages.ModelMessage]

pydantic_ai_slim/pydantic_ai/_output.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import json
55
from abc import ABC, abstractmethod
66
from collections.abc import Awaitable, Sequence
7-
from dataclasses import dataclass, field, replace
7+
from dataclasses import dataclass, field
88
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast, overload
99

1010
from pydantic import TypeAdapter, ValidationError
@@ -83,28 +83,21 @@ def __post_init__(self):
8383
async def validate(
8484
self,
8585
result: T,
86-
tool_call: _messages.ToolCallPart | None,
8786
run_context: RunContext[AgentDepsT],
8887
wrap_validation_errors: bool = True,
8988
) -> T:
9089
"""Validate a result but calling the function.
9190
9291
Args:
9392
result: The result data after Pydantic validation the message content.
94-
tool_call: The original tool call message, `None` if there was no tool call.
9593
run_context: The current run context.
9694
wrap_validation_errors: If true, wrap the validation errors in a retry message.
9795
9896
Returns:
9997
Result of either the validated result data (ok) or a retry message (Err).
10098
"""
10199
if self._takes_ctx:
102-
ctx = (
103-
replace(run_context, tool_name=tool_call.tool_name, tool_call_id=tool_call.tool_call_id)
104-
if tool_call
105-
else run_context
106-
)
107-
args = ctx, result
100+
args = run_context, result
108101
else:
109102
args = (result,)
110103

@@ -117,10 +110,12 @@ async def validate(
117110
result_data = await _utils.run_in_executor(function, *args)
118111
except ModelRetry as r:
119112
if wrap_validation_errors:
120-
m = _messages.RetryPromptPart(content=r.message)
121-
if tool_call is not None:
122-
m.tool_name = tool_call.tool_name
123-
m.tool_call_id = tool_call.tool_call_id
113+
m = _messages.RetryPromptPart(
114+
content=r.message,
115+
tool_name=run_context.tool_name,
116+
)
117+
if run_context.tool_call_id:
118+
m.tool_call_id = run_context.tool_call_id
124119
raise ToolRetryError(m) from r
125120
else:
126121
raise r
@@ -190,7 +185,7 @@ def build( # noqa: C901
190185

191186
if output := next((output for output in outputs if isinstance(output, NativeOutput)), None):
192187
if len(outputs) > 1:
193-
raise UserError('`NativeOutput` must be the only output type.')
188+
raise UserError('`NativeOutput` must be the only output type.') # pragma: no cover
194189

195190
return NativeOutputSchema(
196191
processor=cls._build_processor(
@@ -203,7 +198,7 @@ def build( # noqa: C901
203198
)
204199
elif output := next((output for output in outputs if isinstance(output, PromptedOutput)), None):
205200
if len(outputs) > 1:
206-
raise UserError('`PromptedOutput` must be the only output type.')
201+
raise UserError('`PromptedOutput` must be the only output type.') # pragma: no cover
207202

208203
return PromptedOutputSchema(
209204
processor=cls._build_processor(
@@ -940,7 +935,7 @@ async def call_tool(
940935
) -> Any:
941936
output = await self.processors[name].call(tool_args, ctx)
942937
for validator in self.output_validators:
943-
output = await validator.validate(output, None, ctx, wrap_validation_errors=False)
938+
output = await validator.validate(output, ctx, wrap_validation_errors=False)
944939
return output
945940

946941

@@ -965,7 +960,7 @@ def _flatten_output_spec(output_spec: OutputSpec[T]) -> Sequence[_OutputSpecItem
965960
for output in outputs:
966961
if isinstance(output, Sequence):
967962
outputs_flat.extend(_flatten_output_spec(cast(OutputSpec[T], output)))
968-
if union_types := _utils.get_union_args(output):
963+
elif union_types := _utils.get_union_args(output):
969964
outputs_flat.extend(union_types)
970965
else:
971966
outputs_flat.append(cast(_OutputSpecItem[T], output))

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -370,11 +370,10 @@ def __init__(
370370
output_retries = result_retries
371371

372372
if mcp_servers := _deprecated_kwargs.pop('mcp_servers', None):
373+
if toolsets is not None: # pragma: no cover
374+
raise TypeError('`mcp_servers` and `toolsets` cannot be set at the same time.')
373375
warnings.warn('`mcp_servers` is deprecated, use `toolsets` instead', DeprecationWarning)
374-
if toolsets is None:
375-
toolsets = mcp_servers
376-
else:
377-
toolsets = [*toolsets, *mcp_servers]
376+
toolsets = mcp_servers
378377

379378
_utils.validate_empty_kwargs(_deprecated_kwargs)
380379

pydantic_ai_slim/pydantic_ai/mcp.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class MCPServer(AbstractToolset[Any], ABC):
5959
process_tool_call: ToolProcessFunc[Any] | None = None
6060
allow_sampling: bool = True
6161
max_retries: int = 1
62+
sampling_model: models.Model | None = None
6263
# } end of "abstract fields"
6364

6465
_running_count: int = 0
@@ -67,7 +68,6 @@ class MCPServer(AbstractToolset[Any], ABC):
6768
_read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
6869
_write_stream: MemoryObjectSendStream[SessionMessage]
6970
_exit_stack: AsyncExitStack
70-
sampling_model: models.Model | None = None
7171

7272
@abstractmethod
7373
@asynccontextmanager
@@ -83,11 +83,6 @@ async def client_streams(
8383
raise NotImplementedError('MCP Server subclasses must implement this method.')
8484
yield
8585

86-
@property
87-
def is_running(self) -> bool:
88-
"""Check if the MCP server is running."""
89-
return bool(self._running_count)
90-
9186
@property
9287
def name(self) -> str:
9388
return repr(self)
@@ -373,6 +368,9 @@ async def main():
373368
max_retries: int = 1
374369
"""The maximum number of times to retry a tool call."""
375370

371+
sampling_model: models.Model | None = None
372+
"""The model to use for sampling."""
373+
376374
@asynccontextmanager
377375
async def client_streams(
378376
self,
@@ -471,6 +469,9 @@ class _MCPServerHTTP(MCPServer):
471469
max_retries: int = 1
472470
"""The maximum number of times to retry a tool call."""
473471

472+
sampling_model: models.Model | None = None
473+
"""The model to use for sampling."""
474+
474475
@property
475476
@abstractmethod
476477
def _transport_client(

pydantic_ai_slim/pydantic_ai/result.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,8 @@ async def _validate_response(
113113
return await self._toolset.call_tool(run_context, tool_call.tool_name, args_dict)
114114
elif deferred_tool_calls := self._toolset.get_deferred_tool_calls(message.parts):
115115
if not self._output_schema.allows_deferred_tool_calls:
116-
raise exceptions.UserError(
117-
'There are deferred tool calls but DeferredToolCalls is not among output types.'
116+
raise exceptions.UserError( # pragma: no cover
117+
'A deferred tool call was present, but `DeferredToolCalls` is not among output types. To resolve this, add `DeferredToolCalls` to the list of output types for this agent.'
118118
)
119119
return cast(OutputDataT, deferred_tool_calls)
120120
elif isinstance(self._output_schema, TextOutputSchema):
@@ -124,7 +124,7 @@ async def _validate_response(
124124
text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False
125125
)
126126
for validator in self._output_validators:
127-
result_data = await validator.validate(result_data, None, self._run_ctx)
127+
result_data = await validator.validate(result_data, self._run_ctx)
128128
return result_data
129129
else:
130130
raise exceptions.UnexpectedModelBehavior( # pragma: no cover
@@ -450,7 +450,7 @@ async def validate_structured_output(
450450
elif deferred_tool_calls := self._toolset.get_deferred_tool_calls(message.parts):
451451
if not self._output_schema.allows_deferred_tool_calls:
452452
raise exceptions.UserError(
453-
'There are deferred tool calls but DeferredToolCalls is not among output types.'
453+
'A deferred tool call was present, but `DeferredToolCalls` is not among output types. To resolve this, add `DeferredToolCalls` to the list of output types for this agent.'
454454
)
455455
return cast(OutputDataT, deferred_tool_calls)
456456
elif isinstance(self._output_schema, TextOutputSchema):
@@ -460,7 +460,7 @@ async def validate_structured_output(
460460
text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False
461461
)
462462
for validator in self._output_validators:
463-
result_data = await validator.validate(result_data, None, self._run_ctx) # pragma: no cover
463+
result_data = await validator.validate(result_data, self._run_ctx) # pragma: no cover
464464
return result_data
465465
else:
466466
raise exceptions.UnexpectedModelBehavior( # pragma: no cover
@@ -469,7 +469,7 @@ async def validate_structured_output(
469469

470470
async def _validate_text_output(self, text: str) -> str:
471471
for validator in self._output_validators:
472-
text = await validator.validate(text, None, self._run_ctx) # pragma: no cover
472+
text = await validator.validate(text, self._run_ctx) # pragma: no cover
473473
return text
474474

475475
async def _marked_completed(self, message: _messages.ModelResponse) -> None:

pydantic_ai_slim/pydantic_ai/toolsets/_run.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -122,21 +122,22 @@ def _with_retry(self, name: str, ctx: RunContext[AgentDepsT]) -> Iterator[RunCon
122122
if current_retry == max_retries:
123123
raise UnexpectedModelBehavior(f'Tool {name!r} exceeded max retries count of {max_retries}') from e
124124
else:
125-
if ctx.tool_call_id:
126-
if isinstance(e, ValidationError):
127-
m = _messages.RetryPromptPart(
128-
tool_name=name,
129-
content=e.errors(include_url=False, include_context=False),
130-
tool_call_id=ctx.tool_call_id,
131-
)
132-
e = ToolRetryError(m)
133-
elif isinstance(e, ModelRetry):
134-
m = _messages.RetryPromptPart(
135-
tool_name=name,
136-
content=e.message,
137-
tool_call_id=ctx.tool_call_id,
138-
)
139-
e = ToolRetryError(m)
125+
if isinstance(e, ValidationError):
126+
m = _messages.RetryPromptPart(
127+
tool_name=name,
128+
content=e.errors(include_url=False, include_context=False),
129+
)
130+
if ctx.tool_call_id: # pragma: no branch
131+
m.tool_call_id = ctx.tool_call_id
132+
e = ToolRetryError(m)
133+
elif isinstance(e, ModelRetry):
134+
m = _messages.RetryPromptPart(
135+
tool_name=name,
136+
content=e.message,
137+
)
138+
if ctx.tool_call_id: # pragma: no branch
139+
m.tool_call_id = ctx.tool_call_id
140+
e = ToolRetryError(m)
140141

141142
self._retries[name] = current_retry + 1
142143
raise e

tests/test_agent.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
ToolOutputSchema,
2525
)
2626
from pydantic_ai.agent import AgentRunResult
27+
from pydantic_ai.mcp import MCPServerStdio
2728
from pydantic_ai.messages import (
2829
BinaryContent,
2930
ImageUrl,
@@ -46,7 +47,9 @@
4647
from pydantic_ai.profiles import ModelProfile
4748
from pydantic_ai.result import Usage
4849
from pydantic_ai.tools import ToolDefinition
50+
from pydantic_ai.toolsets.combined import CombinedToolset
4951
from pydantic_ai.toolsets.function import FunctionToolset
52+
from pydantic_ai.toolsets.prefixed import PrefixedToolset
5053

5154
from .conftest import IsDatetime, IsNow, IsStr, TestEnv
5255

@@ -3439,6 +3442,14 @@ def test_deprecated_kwargs_still_work():
34393442
assert issubclass(w[0].category, DeprecationWarning)
34403443
assert '`result_retries` is deprecated' in str(w[0].message)
34413444

3445+
with warnings.catch_warnings(record=True) as w:
3446+
warnings.simplefilter('always')
3447+
3448+
agent = Agent('test', mcp_servers=[MCPServerStdio('python', ['-m', 'tests.mcp_server'])]) # type: ignore[call-arg]
3449+
assert len(w) == 1
3450+
assert issubclass(w[0].category, DeprecationWarning)
3451+
assert '`mcp_servers` is deprecated' in str(w[0].message)
3452+
34423453

34433454
def test_deprecated_kwargs_mixed_valid_invalid():
34443455
"""Test that mix of valid deprecated and invalid kwargs raises error for invalid ones."""
@@ -3583,3 +3594,39 @@ async def only_if_plan_presented(
35833594
),
35843595
]
35853596
)
3597+
3598+
3599+
async def test_reentrant_context_manager():
3600+
agent = Agent('test')
3601+
async with agent:
3602+
async with agent:
3603+
pass
3604+
3605+
3606+
def test_set_mcp_sampling_model():
3607+
test_model = TestModel()
3608+
server1 = MCPServerStdio('python', ['-m', 'tests.mcp_server'])
3609+
server2 = MCPServerStdio('python', ['-m', 'tests.mcp_server'], sampling_model=test_model)
3610+
toolset = CombinedToolset([server1, PrefixedToolset(server2, 'prefix_')])
3611+
agent = Agent(None, toolsets=[toolset])
3612+
3613+
with pytest.raises(UserError, match='No sampling model provided and no model set on the agent.'):
3614+
agent.set_mcp_sampling_model()
3615+
assert server1.sampling_model is None
3616+
assert server2.sampling_model is test_model
3617+
3618+
agent.model = test_model
3619+
agent.set_mcp_sampling_model()
3620+
assert server1.sampling_model is test_model
3621+
assert server2.sampling_model is test_model
3622+
3623+
function_model = FunctionModel(lambda messages, info: ModelResponse(parts=[TextPart('Hello')]))
3624+
with agent.override(model=function_model):
3625+
agent.set_mcp_sampling_model()
3626+
assert server1.sampling_model is function_model
3627+
assert server2.sampling_model is function_model
3628+
3629+
function_model2 = FunctionModel(lambda messages, info: ModelResponse(parts=[TextPart('Goodbye')]))
3630+
agent.set_mcp_sampling_model(function_model2)
3631+
assert server1.sampling_model is function_model2
3632+
assert server2.sampling_model is function_model2

tests/test_examples.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,8 +262,6 @@ def rich_prompt_ask(prompt: str, *_args: Any, **_kwargs: Any) -> str:
262262

263263

264264
class MockMCPServer(AbstractToolset[Any]):
265-
is_running = True
266-
267265
async def __aenter__(self) -> MockMCPServer:
268266
return self
269267

0 commit comments

Comments
 (0)