3
3
import base64
4
4
import functools
5
5
from abc import ABC , abstractmethod
6
- from collections .abc import AsyncIterator , Awaitable , Sequence
7
- from contextlib import AbstractAsyncContextManager , AsyncExitStack , asynccontextmanager
6
+ from collections .abc import AsyncIterator , Awaitable , Iterator , Sequence
7
+ from contextlib import AbstractAsyncContextManager , AsyncExitStack , asynccontextmanager , contextmanager
8
+ from contextvars import ContextVar
8
9
from dataclasses import dataclass
9
10
from pathlib import Path
10
11
from types import TracebackType
@@ -60,6 +61,22 @@ class MCPServer(ABC):
60
61
_exit_stack : AsyncExitStack
61
62
sampling_model : models .Model | None = None
62
63
64
+ def __post_init__ (self ):
65
+ self ._override_sampling_model : ContextVar [models .Model | None ] = ContextVar (
66
+ '_override_sampling_model' , default = None
67
+ )
68
+
69
+ @contextmanager
70
+ def override_sampling_model (
71
+ self ,
72
+ model : models .Model ,
73
+ ) -> Iterator [None ]:
74
+ token = self ._override_sampling_model .set (model )
75
+ try :
76
+ yield
77
+ finally :
78
+ self ._override_sampling_model .reset (token )
79
+
63
80
@abstractmethod
64
81
@asynccontextmanager
65
82
async def client_streams (
@@ -184,7 +201,8 @@ async def _sampling_callback(
184
201
self , context : RequestContext [ClientSession , Any ], params : mcp_types .CreateMessageRequestParams
185
202
) -> mcp_types .CreateMessageResult | mcp_types .ErrorData :
186
203
"""MCP sampling callback."""
187
- if self .sampling_model is None :
204
+ sampling_model = self ._override_sampling_model .get () or self .sampling_model
205
+ if sampling_model is None :
188
206
raise ValueError ('Sampling model is not set' ) # pragma: no cover
189
207
190
208
pai_messages = _mcp .map_from_mcp_params (params )
@@ -196,15 +214,15 @@ async def _sampling_callback(
196
214
if stop_sequences := params .stopSequences : # pragma: no branch
197
215
model_settings ['stop_sequences' ] = stop_sequences
198
216
199
- model_response = await self . sampling_model .request (
217
+ model_response = await sampling_model .request (
200
218
pai_messages ,
201
219
model_settings ,
202
220
models .ModelRequestParameters (),
203
221
)
204
222
return mcp_types .CreateMessageResult (
205
223
role = 'assistant' ,
206
224
content = _mcp .map_from_model_response (model_response ),
207
- model = self . sampling_model .model_name ,
225
+ model = sampling_model .model_name ,
208
226
)
209
227
210
228
def _map_tool_result_part (
0 commit comments