Skip to content

Commit 49658f3

Browse files
authored
Use contextvars for tracking the MCP sampling model (#2117)
1 parent 7800990 commit 49658f3

File tree

3 files changed

+26
-6
lines changed

3 files changed

+26
-6
lines changed

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1740,7 +1740,7 @@ async def run_mcp_servers(
17401740
try:
17411741
for mcp_server in self._mcp_servers:
17421742
if sampling_model is not None: # pragma: no branch
1743-
mcp_server.sampling_model = sampling_model
1743+
exit_stack.enter_context(mcp_server.override_sampling_model(sampling_model))
17441744
await exit_stack.enter_async_context(mcp_server)
17451745
yield
17461746
finally:

pydantic_ai_slim/pydantic_ai/mcp.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
import base64
44
import functools
55
from abc import ABC, abstractmethod
6-
from collections.abc import AsyncIterator, Awaitable, Sequence
7-
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
6+
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
7+
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager
8+
from contextvars import ContextVar
89
from dataclasses import dataclass
910
from pathlib import Path
1011
from types import TracebackType
@@ -60,6 +61,22 @@ class MCPServer(ABC):
6061
_exit_stack: AsyncExitStack
6162
sampling_model: models.Model | None = None
6263

64+
def __post_init__(self):
65+
self._override_sampling_model: ContextVar[models.Model | None] = ContextVar(
66+
'_override_sampling_model', default=None
67+
)
68+
69+
@contextmanager
70+
def override_sampling_model(
71+
self,
72+
model: models.Model,
73+
) -> Iterator[None]:
74+
token = self._override_sampling_model.set(model)
75+
try:
76+
yield
77+
finally:
78+
self._override_sampling_model.reset(token)
79+
6380
@abstractmethod
6481
@asynccontextmanager
6582
async def client_streams(
@@ -184,7 +201,8 @@ async def _sampling_callback(
184201
self, context: RequestContext[ClientSession, Any], params: mcp_types.CreateMessageRequestParams
185202
) -> mcp_types.CreateMessageResult | mcp_types.ErrorData:
186203
"""MCP sampling callback."""
187-
if self.sampling_model is None:
204+
sampling_model = self._override_sampling_model.get() or self.sampling_model
205+
if sampling_model is None:
188206
raise ValueError('Sampling model is not set') # pragma: no cover
189207

190208
pai_messages = _mcp.map_from_mcp_params(params)
@@ -196,15 +214,15 @@ async def _sampling_callback(
196214
if stop_sequences := params.stopSequences: # pragma: no branch
197215
model_settings['stop_sequences'] = stop_sequences
198216

199-
model_response = await self.sampling_model.request(
217+
model_response = await sampling_model.request(
200218
pai_messages,
201219
model_settings,
202220
models.ModelRequestParameters(),
203221
)
204222
return mcp_types.CreateMessageResult(
205223
role='assistant',
206224
content=_mcp.map_from_model_response(model_response),
207-
model=self.sampling_model.model_name,
225+
model=sampling_model.model_name,
208226
)
209227

210228
def _map_tool_result_part(

tests/test_examples.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import shutil
77
import sys
88
from collections.abc import AsyncIterator, Iterable, Sequence
9+
from contextlib import nullcontext
910
from dataclasses import dataclass
1011
from inspect import FrameInfo
1112
from io import StringIO
@@ -258,6 +259,7 @@ def rich_prompt_ask(prompt: str, *_args: Any, **_kwargs: Any) -> str:
258259

259260
class MockMCPServer:
260261
is_running = True
262+
override_sampling_model = nullcontext
261263

262264
async def __aenter__(self) -> MockMCPServer:
263265
return self

0 commit comments

Comments
 (0)