Skip to content

Commit 89fc266

Browse files
committed
Turn RunContext.retries from a defaultdict into a dict again as the 0 being stored on read broke a test
1 parent acddb8d commit 89fc266

File tree

4 files changed

+6
-8
lines changed

4 files changed

+6
-8
lines changed

pydantic_ai_slim/pydantic_ai/_run_context.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations as _annotations
22

33
import dataclasses
4-
from collections import defaultdict
54
from collections.abc import Sequence
65
from dataclasses import field
76
from typing import TYPE_CHECKING, Generic
@@ -34,7 +33,7 @@ class RunContext(Generic[AgentDepsT]):
3433
"""The original user prompt passed to the run."""
3534
messages: list[_messages.ModelMessage] = field(default_factory=list)
3635
"""Messages exchanged in the conversation so far."""
37-
retries: defaultdict[str, int] = field(default_factory=lambda: defaultdict(int))
36+
retries: dict[str, int] = field(default_factory=dict)
3837
"""Number of retries for each tool so far."""
3938
tool_call_id: str | None = None
4039
"""The ID of the tool call."""

pydantic_ai_slim/pydantic_ai/toolsets/_run.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
from collections import defaultdict
43
from collections.abc import Iterable, Iterator
54
from contextlib import contextmanager
65
from dataclasses import dataclass, replace
@@ -25,7 +24,7 @@ class RunToolset(WrapperToolset[AgentDepsT]):
2524
ctx: RunContext[AgentDepsT]
2625
_tool_defs: list[ToolDefinition]
2726
_tool_names: list[str]
28-
_retries: defaultdict[str, int]
27+
_retries: dict[str, int]
2928
_original: AbstractToolset[AgentDepsT]
3029

3130
def __init__(
@@ -108,14 +107,14 @@ def _with_retry(self, name: str, ctx: RunContext[AgentDepsT]) -> Iterator[RunCon
108107
msg = 'No tools available.'
109108
raise ModelRetry(f'Unknown tool name: {name!r}. {msg}')
110109

111-
ctx = replace(ctx, tool_name=name, retry=self._retries[name], retries={})
110+
ctx = replace(ctx, tool_name=name, retry=self._retries.get(name, 0), retries={})
112111
yield ctx
113112
except (ValidationError, ModelRetry, UnexpectedModelBehavior, ToolRetryError) as e:
114113
try:
115114
max_retries = self._max_retries_for_tool(name)
116115
except Exception:
117116
max_retries = 1
118-
current_retry = self._retries[name]
117+
current_retry = self._retries.get(name, 0)
119118

120119
if isinstance(e, UnexpectedModelBehavior) and e.__cause__ is not None:
121120
e = e.__cause__

pydantic_ai_slim/pydantic_ai/toolsets/function.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[Agent
188188

189189
async def _prepare_tool_def(self, ctx: RunContext[AgentDepsT], tool_def: ToolDefinition) -> ToolDefinition | None:
190190
tool_name = tool_def.name
191-
ctx = replace(ctx, tool_name=tool_name, retry=ctx.retries[tool_name])
191+
ctx = replace(ctx, tool_name=tool_name, retry=ctx.retries.get(tool_name, 0))
192192
return await self.tools[tool_name].prepare_tool_def(ctx)
193193

194194
@property

tests/test_tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1155,7 +1155,7 @@ async def prepare_tool_defs(
11551155
ctx: RunContext[None], tool_defs: list[ToolDefinition]
11561156
) -> Union[list[ToolDefinition], None]:
11571157
nonlocal prepare_tools_retries
1158-
retry = ctx.retries['infinite_retry_tool']
1158+
retry = ctx.retries.get('infinite_retry_tool', 0)
11591159
prepare_tools_retries.append(retry)
11601160
return tool_defs
11611161

0 commit comments

Comments
 (0)