Skip to content

Commit 7800990

Browse files
authored
Use contextvars for agent overriding, rather than a local attribute (#2118)
1 parent 5b94841 commit 7800990

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import warnings
77
from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
88
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager
9+
from contextvars import ContextVar
910
from copy import deepcopy
1011
from types import FrameType
1112
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, cast, final, overload
@@ -157,8 +158,6 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
157158
_mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False)
158159
_default_retries: int = dataclasses.field(repr=False)
159160
_max_result_retries: int = dataclasses.field(repr=False)
160-
_override_deps: _utils.Option[AgentDepsT] = dataclasses.field(default=None, repr=False)
161-
_override_model: _utils.Option[models.Model] = dataclasses.field(default=None, repr=False)
162161

163162
@overload
164163
def __init__(
@@ -367,6 +366,9 @@ def __init__(
367366
else:
368367
self._register_tool(Tool(tool))
369368

369+
self._override_deps: ContextVar[_utils.Option[AgentDepsT]] = ContextVar('_override_deps', default=None)
370+
self._override_model: ContextVar[_utils.Option[models.Model]] = ContextVar('_override_model', default=None)
371+
370372
@staticmethod
371373
def instrument_all(instrument: InstrumentationSettings | bool = True) -> None:
372374
"""Set the instrumentation options for all agents where `instrument` is not set."""
@@ -1113,24 +1115,22 @@ def override(
11131115
model: The model to use instead of the model passed to the agent run.
11141116
"""
11151117
if _utils.is_set(deps):
1116-
override_deps_before = self._override_deps
1117-
self._override_deps = _utils.Some(deps)
1118+
deps_token = self._override_deps.set(_utils.Some(deps))
11181119
else:
1119-
override_deps_before = _utils.UNSET
1120+
deps_token = None
11201121

11211122
if _utils.is_set(model):
1122-
override_model_before = self._override_model
1123-
self._override_model = _utils.Some(models.infer_model(model))
1123+
model_token = self._override_model.set(_utils.Some(models.infer_model(model)))
11241124
else:
1125-
override_model_before = _utils.UNSET
1125+
model_token = None
11261126

11271127
try:
11281128
yield
11291129
finally:
1130-
if _utils.is_set(override_deps_before):
1131-
self._override_deps = override_deps_before
1132-
if _utils.is_set(override_model_before):
1133-
self._override_model = override_model_before
1130+
if deps_token is not None:
1131+
self._override_deps.reset(deps_token)
1132+
if model_token is not None:
1133+
self._override_model.reset(model_token)
11341134

11351135
@overload
11361136
def instructions(
@@ -1604,7 +1604,7 @@ def _get_model(self, model: models.Model | models.KnownModelName | str | None) -
16041604
The model used
16051605
"""
16061606
model_: models.Model
1607-
if some_model := self._override_model:
1607+
if some_model := self._override_model.get():
16081608
# we don't want `override()` to cover up errors from the model not being defined, hence this check
16091609
if model is None and self.model is None:
16101610
raise exceptions.UserError(
@@ -1633,7 +1633,7 @@ def _get_deps(self: Agent[T, OutputDataT], deps: T) -> T:
16331633
16341634
We could do runtime type checking of deps against `self._deps_type`, but that's a slippery slope.
16351635
"""
1636-
if some_deps := self._override_deps:
1636+
if some_deps := self._override_deps.get():
16371637
return some_deps.value
16381638
else:
16391639
return deps

0 commit comments

Comments
 (0)