|
6 | 6 | import warnings
|
7 | 7 | from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence
|
8 | 8 | from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager
|
| 9 | +from contextvars import ContextVar |
9 | 10 | from copy import deepcopy
|
10 | 11 | from types import FrameType
|
11 | 12 | from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, cast, final, overload
|
@@ -157,8 +158,6 @@ class Agent(Generic[AgentDepsT, OutputDataT]):
|
157 | 158 | _mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False)
|
158 | 159 | _default_retries: int = dataclasses.field(repr=False)
|
159 | 160 | _max_result_retries: int = dataclasses.field(repr=False)
|
160 |
| - _override_deps: _utils.Option[AgentDepsT] = dataclasses.field(default=None, repr=False) |
161 |
| - _override_model: _utils.Option[models.Model] = dataclasses.field(default=None, repr=False) |
162 | 161 |
|
163 | 162 | @overload
|
164 | 163 | def __init__(
|
@@ -367,6 +366,9 @@ def __init__(
|
367 | 366 | else:
|
368 | 367 | self._register_tool(Tool(tool))
|
369 | 368 |
|
| 369 | + self._override_deps: ContextVar[_utils.Option[AgentDepsT]] = ContextVar('_override_deps', default=None) |
| 370 | + self._override_model: ContextVar[_utils.Option[models.Model]] = ContextVar('_override_model', default=None) |
| 371 | + |
370 | 372 | @staticmethod
|
371 | 373 | def instrument_all(instrument: InstrumentationSettings | bool = True) -> None:
|
372 | 374 | """Set the instrumentation options for all agents where `instrument` is not set."""
|
@@ -1113,24 +1115,22 @@ def override(
|
1113 | 1115 | model: The model to use instead of the model passed to the agent run.
|
1114 | 1116 | """
|
1115 | 1117 | if _utils.is_set(deps):
|
1116 |
| - override_deps_before = self._override_deps |
1117 |
| - self._override_deps = _utils.Some(deps) |
| 1118 | + deps_token = self._override_deps.set(_utils.Some(deps)) |
1118 | 1119 | else:
|
1119 |
| - override_deps_before = _utils.UNSET |
| 1120 | + deps_token = None |
1120 | 1121 |
|
1121 | 1122 | if _utils.is_set(model):
|
1122 |
| - override_model_before = self._override_model |
1123 |
| - self._override_model = _utils.Some(models.infer_model(model)) |
| 1123 | + model_token = self._override_model.set(_utils.Some(models.infer_model(model))) |
1124 | 1124 | else:
|
1125 |
| - override_model_before = _utils.UNSET |
| 1125 | + model_token = None |
1126 | 1126 |
|
1127 | 1127 | try:
|
1128 | 1128 | yield
|
1129 | 1129 | finally:
|
1130 |
| - if _utils.is_set(override_deps_before): |
1131 |
| - self._override_deps = override_deps_before |
1132 |
| - if _utils.is_set(override_model_before): |
1133 |
| - self._override_model = override_model_before |
| 1130 | + if deps_token is not None: |
| 1131 | + self._override_deps.reset(deps_token) |
| 1132 | + if model_token is not None: |
| 1133 | + self._override_model.reset(model_token) |
1134 | 1134 |
|
1135 | 1135 | @overload
|
1136 | 1136 | def instructions(
|
@@ -1604,7 +1604,7 @@ def _get_model(self, model: models.Model | models.KnownModelName | str | None) -
|
1604 | 1604 | The model used
|
1605 | 1605 | """
|
1606 | 1606 | model_: models.Model
|
1607 |
| - if some_model := self._override_model: |
| 1607 | + if some_model := self._override_model.get(): |
1608 | 1608 | # we don't want `override()` to cover up errors from the model not being defined, hence this check
|
1609 | 1609 | if model is None and self.model is None:
|
1610 | 1610 | raise exceptions.UserError(
|
@@ -1633,7 +1633,7 @@ def _get_deps(self: Agent[T, OutputDataT], deps: T) -> T:
|
1633 | 1633 |
|
1634 | 1634 | We could do runtime type checking of deps against `self._deps_type`, but that's a slippery slope.
|
1635 | 1635 | """
|
1636 |
| - if some_deps := self._override_deps: |
| 1636 | + if some_deps := self._override_deps.get(): |
1637 | 1637 | return some_deps.value
|
1638 | 1638 | else:
|
1639 | 1639 | return deps
|
|
0 commit comments