Skip to content

Commit badbe23

Browse files
committed
Merge branch 'main' into toolsets
# Conflicts: # pydantic_ai_slim/pydantic_ai/agent.py # pydantic_ai_slim/pydantic_ai/mcp.py
2 parents c5ef5f6 + 49658f3 commit badbe23

File tree

6 files changed

+61
-37
lines changed

6 files changed

+61
-37
lines changed

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import json
66
import warnings
77
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
8-
from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager
8+
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager
9+
from contextvars import ContextVar
910
from copy import deepcopy
1011
from types import FrameType
1112
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, cast, final, overload
@@ -164,8 +165,6 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
164165
_prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False)
165166
_prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False)
166167
_max_result_retries: int = dataclasses.field(repr=False)
167-
_override_deps: _utils.Option[AgentDepsT] = dataclasses.field(default=None, repr=False)
168-
_override_model: _utils.Option[models.Model] = dataclasses.field(default=None, repr=False)
169168

170169
@overload
171170
def __init__(
@@ -425,6 +424,9 @@ def __init__(
425424

426425
self.history_processors = history_processors or []
427426

427+
self._override_deps: ContextVar[_utils.Option[AgentDepsT]] = ContextVar('_override_deps', default=None)
428+
self._override_model: ContextVar[_utils.Option[models.Model]] = ContextVar('_override_model', default=None)
429+
428430
@staticmethod
429431
def instrument_all(instrument: InstrumentationSettings | bool = True) -> None:
430432
"""Set the instrumentation options for all agents where `instrument` is not set."""
@@ -1212,24 +1214,22 @@ def override(
12121214
model: The model to use instead of the model passed to the agent run.
12131215
"""
12141216
if _utils.is_set(deps):
1215-
override_deps_before = self._override_deps
1216-
self._override_deps = _utils.Some(deps)
1217+
deps_token = self._override_deps.set(_utils.Some(deps))
12171218
else:
1218-
override_deps_before = _utils.UNSET
1219+
deps_token = None
12191220

12201221
if _utils.is_set(model):
1221-
override_model_before = self._override_model
1222-
self._override_model = _utils.Some(models.infer_model(model))
1222+
model_token = self._override_model.set(_utils.Some(models.infer_model(model)))
12231223
else:
1224-
override_model_before = _utils.UNSET
1224+
model_token = None
12251225

12261226
try:
12271227
yield
12281228
finally:
1229-
if _utils.is_set(override_deps_before):
1230-
self._override_deps = override_deps_before
1231-
if _utils.is_set(override_model_before):
1232-
self._override_model = override_model_before
1229+
if deps_token is not None:
1230+
self._override_deps.reset(deps_token)
1231+
if model_token is not None:
1232+
self._override_model.reset(model_token)
12331233

12341234
@overload
12351235
def instructions(
@@ -1662,7 +1662,7 @@ def _get_model(self, model: models.Model | models.KnownModelName | str | None) -
16621662
The model used
16631663
"""
16641664
model_: models.Model
1665-
if some_model := self._override_model:
1665+
if some_model := self._override_model.get():
16661666
# we don't want `override()` to cover up errors from the model not being defined, hence this check
16671667
if model is None and self.model is None:
16681668
raise exceptions.UserError(
@@ -1691,7 +1691,7 @@ def _get_deps(self: Agent[T, OutputDataT], deps: T) -> T:
16911691
16921692
We could do runtime type checking of deps against `self._deps_type`, but that's a slippery slope.
16931693
"""
1694-
if some_deps := self._override_deps:
1694+
if some_deps := self._override_deps.get():
16951695
return some_deps.value
16961696
else:
16971697
return deps
@@ -1793,10 +1793,11 @@ async def run_toolsets(
17931793
model: models.Model | None = self._get_model(sampling_model)
17941794
except exceptions.UserError: # pragma: no cover
17951795
model = None
1796-
if model is not None: # pragma: no branch
1797-
self._toolset._set_mcp_sampling_model(model) # type: ignore[reportPrivateUsage]
17981796

1799-
async with self._toolset:
1797+
async with AsyncExitStack() as exit_stack:
1798+
if model is not None: # pragma: no branch
1799+
exit_stack.enter_context(self._toolset.override_sampling_model(model))
1800+
await exit_stack.enter_async_context(self._toolset)
18001801
yield
18011802

18021803
@asynccontextmanager

pydantic_ai_slim/pydantic_ai/mcp.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
import base64
44
import functools
55
from abc import ABC, abstractmethod
6-
from collections.abc import AsyncIterator, Sequence
7-
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
6+
from collections.abc import AsyncIterator, Iterator, Sequence
7+
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager
8+
from contextvars import ContextVar
89
from dataclasses import dataclass
910
from pathlib import Path
1011
from types import TracebackType
@@ -69,6 +70,22 @@ class MCPServer(AbstractToolset[Any], ABC):
6970
_exit_stack: AsyncExitStack
7071
sampling_model: models.Model | None = None
7172

73+
def __post_init__(self):
74+
self._override_sampling_model: ContextVar[models.Model | None] = ContextVar(
75+
'_override_sampling_model', default=None
76+
)
77+
78+
@contextmanager
79+
def override_sampling_model(
80+
self,
81+
model: models.Model,
82+
) -> Iterator[None]:
83+
token = self._override_sampling_model.set(model)
84+
try:
85+
yield
86+
finally:
87+
self._override_sampling_model.reset(token)
88+
7289
@abstractmethod
7390
@asynccontextmanager
7491
async def client_streams(
@@ -193,9 +210,6 @@ def _get_tool_args_validator(self, ctx: RunContext[Any], name: str) -> pydantic_
193210
def _max_retries_for_tool(self, name: str) -> int:
194211
return self.max_retries
195212

196-
def _set_mcp_sampling_model(self, model: models.Model) -> None:
197-
self.sampling_model = model
198-
199213
async def __aenter__(self) -> Self:
200214
if self._running_count == 0:
201215
self._exit_stack = AsyncExitStack()
@@ -231,7 +245,8 @@ async def _sampling_callback(
231245
self, context: RequestContext[ClientSession, Any], params: mcp_types.CreateMessageRequestParams
232246
) -> mcp_types.CreateMessageResult | mcp_types.ErrorData:
233247
"""MCP sampling callback."""
234-
if self.sampling_model is None:
248+
sampling_model = self._override_sampling_model.get() or self.sampling_model
249+
if sampling_model is None:
235250
raise ValueError('Sampling model is not set') # pragma: no cover
236251

237252
pai_messages = _mcp.map_from_mcp_params(params)
@@ -243,15 +258,15 @@ async def _sampling_callback(
243258
if stop_sequences := params.stopSequences: # pragma: no branch
244259
model_settings['stop_sequences'] = stop_sequences
245260

246-
model_response = await self.sampling_model.request(
261+
model_response = await sampling_model.request(
247262
pai_messages,
248263
model_settings,
249264
models.ModelRequestParameters(),
250265
)
251266
return mcp_types.CreateMessageResult(
252267
role='assistant',
253268
content=_mcp.map_from_model_response(model_response),
254-
model=self.sampling_model.model_name,
269+
model=sampling_model.model_name,
255270
)
256271

257272
def _map_tool_result_part(

pydantic_ai_slim/pydantic_ai/toolsets/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from __future__ import annotations
22

33
from abc import ABC, abstractmethod
4+
from collections.abc import Iterator
5+
from contextlib import contextmanager
46
from types import TracebackType
57
from typing import TYPE_CHECKING, Any, Generic, Literal
68

@@ -80,5 +82,6 @@ async def call_tool(
8082
) -> Any:
8183
raise NotImplementedError()
8284

83-
def _set_mcp_sampling_model(self, model: Model) -> None:
84-
pass
85+
@contextmanager
86+
def override_sampling_model(self, model: Model) -> Iterator[None]:
87+
yield

pydantic_ai_slim/pydantic_ai/toolsets/combined.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from __future__ import annotations
22

33
import asyncio
4-
from collections.abc import Sequence
5-
from contextlib import AsyncExitStack
4+
from collections.abc import Iterator, Sequence
5+
from contextlib import AsyncExitStack, ExitStack, contextmanager
66
from dataclasses import dataclass
77
from types import TracebackType
88
from typing import TYPE_CHECKING, Any
@@ -92,9 +92,12 @@ async def call_tool(
9292
) -> Any:
9393
return await self._toolset_for_tool_name(name).call_tool(ctx, name, tool_args, *args, **kwargs)
9494

95-
def _set_mcp_sampling_model(self, model: Model) -> None:
96-
for toolset in self.toolsets:
97-
toolset._set_mcp_sampling_model(model)
95+
@contextmanager
96+
def override_sampling_model(self, model: Model) -> Iterator[None]:
97+
with ExitStack() as exit_stack:
98+
for toolset in self.toolsets:
99+
exit_stack.enter_context(toolset.override_sampling_model(model))
100+
yield
98101

99102
def _toolset_for_tool_name(self, name: str) -> AbstractToolset[AgentDepsT]:
100103
try:

pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from __future__ import annotations
22

33
from abc import ABC, abstractmethod
4+
from collections.abc import Iterator
5+
from contextlib import contextmanager
46
from dataclasses import dataclass
57
from types import TracebackType
68
from typing import TYPE_CHECKING, Any
@@ -59,8 +61,10 @@ async def call_tool(
5961
) -> Any:
6062
return await self.wrapped.call_tool(ctx, name, tool_args, *args, **kwargs)
6163

62-
def _set_mcp_sampling_model(self, model: Model) -> None:
63-
self.wrapped._set_mcp_sampling_model(model)
64+
@contextmanager
65+
def override_sampling_model(self, model: Model) -> Iterator[None]:
66+
with self.wrapped.override_sampling_model(model):
67+
yield
6468

6569
def __getattr__(self, item: str):
6670
return getattr(self.wrapped, item) # pragma: no cover

tests/test_examples.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import shutil
77
import sys
88
from collections.abc import AsyncIterator, Iterable, Sequence
9-
from contextlib import nullcontext
109
from dataclasses import dataclass
1110
from inspect import FrameInfo
1211
from io import StringIO
@@ -264,7 +263,6 @@ def rich_prompt_ask(prompt: str, *_args: Any, **_kwargs: Any) -> str:
264263

265264
class MockMCPServer(AbstractToolset[Any]):
266265
is_running = True
267-
override_sampling_model = nullcontext
268266

269267
async def __aenter__(self) -> MockMCPServer:
270268
return self

0 commit comments

Comments
 (0)