Skip to content

Commit ebf6f40

Browse files
committed
Remove Agent sampling_model field (and method argument) in favor of Agent.set_mcp_sampling_model
1 parent 7e3331b commit ebf6f40

File tree

11 files changed

+68
-124
lines changed

11 files changed

+68
-124
lines changed

docs/mcp/client.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,7 @@ agent = Agent('openai:gpt-4o', toolsets=[server])
364364

365365
async def main():
366366
async with agent:
367+
agent.set_mcp_sampling_model()
367368
result = await agent.run('Create an image of a robot in a punk style.')
368369
print(result.output)
369370
#> Image file written to robot_punk.svg.

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,6 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]):
115115
history_processors: Sequence[HistoryProcessor[DepsT]]
116116

117117
toolset: RunToolset[DepsT]
118-
sampling_model: models.Model
119118

120119
tracer: Tracer
121120
instrumentation_settings: InstrumentationSettings | None = None
@@ -562,7 +561,6 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT
562561
deps=ctx.deps.user_deps,
563562
model=ctx.deps.model,
564563
usage=ctx.state.usage,
565-
sampling_model=ctx.deps.sampling_model,
566564
prompt=ctx.deps.prompt,
567565
messages=ctx.state.message_history,
568566
run_step=ctx.state.run_step,

pydantic_ai_slim/pydantic_ai/_run_context.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@ class RunContext(Generic[AgentDepsT]):
2727
"""The model used in this run."""
2828
usage: Usage
2929
"""LLM usage associated with the run."""
30-
sampling_model: Model
31-
"""The model used for MCP sampling."""
3230
prompt: str | Sequence[_messages.UserContent] | None = None
3331
"""The original user prompt passed to the run."""
3432
messages: list[_messages.ModelMessage] = field(default_factory=list)

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 31 additions & 65 deletions
Large diffs are not rendered by default.

pydantic_ai_slim/pydantic_ai/mcp.py

Lines changed: 22 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@
33
import base64
44
import functools
55
from abc import ABC, abstractmethod
6-
from collections.abc import AsyncIterator, Iterator, Sequence
7-
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager, nullcontext
8-
from contextvars import ContextVar
6+
from collections.abc import AsyncIterator, Sequence
7+
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
98
from dataclasses import dataclass
109
from pathlib import Path
1110
from types import TracebackType
@@ -70,22 +69,6 @@ class MCPServer(AbstractToolset[Any], ABC):
7069
_exit_stack: AsyncExitStack
7170
sampling_model: models.Model | None = None
7271

73-
def __post_init__(self):
74-
self._override_sampling_model: ContextVar[models.Model | None] = ContextVar(
75-
'_override_sampling_model', default=None
76-
)
77-
78-
@contextmanager
79-
def override_sampling_model(
80-
self,
81-
model: models.Model,
82-
) -> Iterator[None]:
83-
token = self._override_sampling_model.set(model)
84-
try:
85-
yield
86-
finally:
87-
self._override_sampling_model.reset(token)
88-
8972
@abstractmethod
9073
@asynccontextmanager
9174
async def client_streams(
@@ -149,28 +132,23 @@ async def call_tool(
149132
Raises:
150133
ModelRetry: If the tool call fails.
151134
"""
152-
sampling_contextmanager = (
153-
nullcontext() if self._get_sampling_model() else self.override_sampling_model(ctx.sampling_model)
154-
)
155-
with sampling_contextmanager:
156-
async with self: # Ensure server is running
157-
try:
158-
# meta param is not provided by session yet, so build and can send_request directly.
159-
result = await self._client.send_request(
160-
mcp_types.ClientRequest(
161-
mcp_types.CallToolRequest(
162-
method='tools/call',
163-
params=mcp_types.CallToolRequestParams(
164-
name=name,
165-
arguments=tool_args,
166-
_meta=mcp_types.RequestParams.Meta(**metadata) if metadata else None,
167-
),
168-
)
169-
),
170-
mcp_types.CallToolResult,
171-
)
172-
except McpError as e:
173-
raise exceptions.ModelRetry(e.error.message)
135+
async with self: # Ensure server is running
136+
try:
137+
result = await self._client.send_request(
138+
mcp_types.ClientRequest(
139+
mcp_types.CallToolRequest(
140+
method='tools/call',
141+
params=mcp_types.CallToolRequestParams(
142+
name=name,
143+
arguments=tool_args,
144+
_meta=mcp_types.RequestParams.Meta(**metadata) if metadata else None,
145+
),
146+
)
147+
),
148+
mcp_types.CallToolResult,
149+
)
150+
except McpError as e:
151+
raise exceptions.ModelRetry(e.error.message)
174152

175153
content = [self._map_tool_result_part(part) for part in result.content]
176154

@@ -245,15 +223,11 @@ async def __aexit__(
245223
if self._running_count <= 0:
246224
await self._exit_stack.aclose()
247225

248-
def _get_sampling_model(self) -> models.Model | None:
249-
return self._override_sampling_model.get() or self.sampling_model
250-
251226
async def _sampling_callback(
252227
self, context: RequestContext[ClientSession, Any], params: mcp_types.CreateMessageRequestParams
253228
) -> mcp_types.CreateMessageResult | mcp_types.ErrorData:
254229
"""MCP sampling callback."""
255-
sampling_model = self._get_sampling_model()
256-
if sampling_model is None:
230+
if self.sampling_model is None:
257231
raise ValueError('Sampling model is not set') # pragma: no cover
258232

259233
pai_messages = _mcp.map_from_mcp_params(params)
@@ -265,15 +239,15 @@ async def _sampling_callback(
265239
if stop_sequences := params.stopSequences: # pragma: no branch
266240
model_settings['stop_sequences'] = stop_sequences
267241

268-
model_response = await sampling_model.request(
242+
model_response = await self.sampling_model.request(
269243
pai_messages,
270244
model_settings,
271245
models.ModelRequestParameters(),
272246
)
273247
return mcp_types.CreateMessageResult(
274248
role='assistant',
275249
content=_mcp.map_from_model_response(model_response),
276-
model=sampling_model.model_name,
250+
model=self.sampling_model.model_name,
277251
)
278252

279253
def _map_tool_result_part(

pydantic_ai_slim/pydantic_ai/toolsets/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from abc import ABC, abstractmethod
44
from types import TracebackType
5-
from typing import TYPE_CHECKING, Any, Generic, Literal
5+
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal
66

77
from pydantic_core import SchemaValidator
88
from typing_extensions import Self
@@ -78,3 +78,6 @@ async def call_tool(
7878
self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any
7979
) -> Any:
8080
raise NotImplementedError()
81+
82+
def accept(self, visitor: Callable[[AbstractToolset[AgentDepsT]], Any]) -> Any:
83+
return visitor(self)

pydantic_ai_slim/pydantic_ai/toolsets/combined.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from contextlib import AsyncExitStack
66
from dataclasses import dataclass
77
from types import TracebackType
8-
from typing import Any
8+
from typing import Any, Callable
99

1010
from pydantic_core import SchemaValidator
1111
from typing_extensions import Self
@@ -89,6 +89,10 @@ async def call_tool(
8989
) -> Any:
9090
return await self._toolset_for_tool_name(name).call_tool(ctx, name, tool_args, *args, **kwargs)
9191

92+
def accept(self, visitor: Callable[[AbstractToolset[AgentDepsT]], Any]) -> Any:
93+
for toolset in self.toolsets:
94+
toolset.accept(visitor)
95+
9296
def _toolset_for_tool_name(self, name: str) -> AbstractToolset[AgentDepsT]:
9397
try:
9498
return self._toolset_per_tool_name[name]

pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from abc import ABC, abstractmethod
44
from dataclasses import dataclass
55
from types import TracebackType
6-
from typing import TYPE_CHECKING, Any
6+
from typing import TYPE_CHECKING, Any, Callable
77

88
from pydantic_core import SchemaValidator
99
from typing_extensions import Self
@@ -58,5 +58,8 @@ async def call_tool(
5858
) -> Any:
5959
return await self.wrapped.call_tool(ctx, name, tool_args, *args, **kwargs)
6060

61+
def accept(self, visitor: Callable[[AbstractToolset[AgentDepsT]], Any]) -> Any:
62+
return self.wrapped.accept(visitor)
63+
6164
def __getattr__(self, item: str):
6265
return getattr(self.wrapped, item) # pragma: no cover

tests/test_examples.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import shutil
77
import sys
88
from collections.abc import AsyncIterator, Iterable, Sequence
9-
from contextlib import nullcontext
109
from dataclasses import dataclass
1110
from inspect import FrameInfo
1211
from io import StringIO
@@ -264,7 +263,6 @@ def rich_prompt_ask(prompt: str, *_args: Any, **_kwargs: Any) -> str:
264263

265264
class MockMCPServer(AbstractToolset[Any]):
266265
is_running = True
267-
override_sampling_model = nullcontext
268266

269267
async def __aenter__(self) -> MockMCPServer:
270268
return self

tests/test_mcp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def agent(model: Model, mcp_server: MCPServerStdio) -> Agent:
6868

6969
@pytest.fixture
7070
def run_context(model: Model) -> RunContext[int]:
71-
return RunContext(deps=0, model=model, usage=Usage(), sampling_model=model)
71+
return RunContext(deps=0, model=model, usage=Usage())
7272

7373

7474
async def test_stdio_server(run_context: RunContext[int]):

0 commit comments

Comments
 (0)