Skip to content

Commit 41130b5

Browse files
authored
Stop sharing tool retry count across all runs of the same agent (#1918)
1 parent ea837b9 commit 41130b5

File tree

5 files changed

+43
-13
lines changed

5 files changed

+43
-13
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,10 +151,6 @@ async def _get_first_message(
151151
ctx.state.message_history = history
152152
run_context.messages = history
153153

154-
# TODO: We need to make it so that function_tools are not shared between runs
155-
# See comment on the current_retry field of `Tool` for more details.
156-
for tool in ctx.deps.function_tools.values():
157-
tool.current_retry = 0
158154
return next_message
159155

160156
async def _prepare_messages(

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -646,11 +646,6 @@ async def main():
646646
# typecast reasonable, even though it is possible to violate it with otherwise-type-checked code.
647647
output_validators = cast(list[_output.OutputValidator[AgentDepsT, RunOutputDataT]], self._output_validators)
648648

649-
# TODO: Instead of this, copy the function tools to ensure they don't share current_retry state between agent
650-
# runs. Requires some changes to `Tool` to make them copyable though.
651-
for v in self._function_tools.values():
652-
v.current_retry = 0
653-
654649
model_settings = merge_model_settings(self.model_settings, model_settings)
655650
usage_limits = usage_limits or _usage.UsageLimits()
656651

@@ -679,6 +674,10 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
679674
instructions += '\n' + await instructions_runner.run(run_context)
680675
return instructions.strip()
681676

677+
# Copy the function tools so that retry state is agent-run-specific
678+
# Note that the retry count is reset to 0 when this happens due to the `default=0` and `init=False`.
679+
run_function_tools = {k: dataclasses.replace(v) for k, v in self._function_tools.items()}
680+
682681
graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT](
683682
user_deps=deps,
684683
prompt=user_prompt,
@@ -690,7 +689,7 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
690689
end_strategy=self.end_strategy,
691690
output_schema=output_schema,
692691
output_validators=output_validators,
693-
function_tools=self._function_tools,
692+
function_tools=run_function_tools,
694693
mcp_servers=self._mcp_servers,
695694
default_retries=self._default_retries,
696695
tracer=tracer,

pydantic_ai_slim/pydantic_ai/tools.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,10 @@ class Tool(Generic[AgentDepsT]):
215215
This schema may be modified by the `prepare` function or by the Model class prior to including it in an API request.
216216
"""
217217

218-
# TODO: Move this state off the Tool class, which is otherwise stateless.
219-
# This should be tracked inside a specific agent run, not the tool.
218+
# TODO: Consider moving this current_retry state to live on something other than the tool.
219+
# We've worked around this for now by copying instances of the tool when creating new runs,
220+
# but this is a bit fragile. Moving the tool retry counts to live on the agent run state would likely clean things
221+
# up, though is also likely a larger effort to refactor.
220222
current_retry: int = field(default=0, init=False)
221223

222224
def __init__(

tests/models/test_model_test.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22

33
from __future__ import annotations as _annotations
44

5+
import asyncio
6+
import dataclasses
57
from datetime import timezone
68
from typing import Annotated, Any, Literal
79

810
import pytest
911
from annotated_types import Ge, Gt, Le, Lt, MaxLen, MinLen
12+
from anyio import Event
1013
from inline_snapshot import snapshot
1114
from pydantic import BaseModel, Field
1215

@@ -160,6 +163,36 @@ def validate_output(ctx: RunContext[None], output: OutputModel) -> OutputModel:
160163
assert call_count == 3
161164

162165

166+
@dataclasses.dataclass
167+
class AgentRunDeps:
168+
run_id: int
169+
170+
171+
@pytest.mark.anyio
172+
async def test_multiple_concurrent_tool_retries():
173+
class OutputModel(BaseModel):
174+
x: int
175+
y: str
176+
177+
agent = Agent('test', deps_type=AgentRunDeps, output_type=OutputModel, retries=2)
178+
retried_run_ids = set[int]()
179+
event = Event()
180+
181+
run_ids = list(range(5)) # fire off 5 run ids that will all retry the tool before they finish
182+
183+
@agent.tool
184+
async def tool_that_must_be_retried(ctx: RunContext[AgentRunDeps]) -> None:
185+
if ctx.deps.run_id not in retried_run_ids:
186+
retried_run_ids.add(ctx.deps.run_id)
187+
raise ModelRetry('Fail')
188+
if len(retried_run_ids) == len(run_ids): # pragma: no branch # won't branch if all runs happen very quickly
189+
event.set()
190+
await event.wait() # ensure a retry is done by all runs before any of them finish their flow
191+
return None
192+
193+
await asyncio.gather(*[agent.run('Hello', model=TestModel(), deps=AgentRunDeps(run_id)) for run_id in run_ids])
194+
195+
163196
def test_output_tool_retry_error_handled_with_custom_args(set_event_loop: None):
164197
class ResultModel(BaseModel):
165198
x: int

tests/test_live.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def cohere(http_client: httpx.AsyncClient, _tmp_path: Path) -> Model:
9393
pytest.param(anthropic, id='anthropic'),
9494
pytest.param(ollama, id='ollama'),
9595
pytest.param(mistral, id='mistral'),
96-
pytest.param(cohere, id='cohere'),
96+
pytest.param(cohere, id='cohere', marks=pytest.mark.skip(reason='Might be causing hangs in CI')),
9797
]
9898
GetModel = Callable[[httpx.AsyncClient, Path], Model]
9999

0 commit comments

Comments
 (0)