Skip to content

Commit 84cd954

Browse files
committed
Stop double counting retries and reset on success
1 parent ad6e826 commit 84cd954

File tree

6 files changed

+98
-84
lines changed

6 files changed

+98
-84
lines changed

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast
1313

1414
from opentelemetry.trace import Tracer
15-
from pydantic import ValidationError
1615
from typing_extensions import TypeGuard, TypeVar, assert_never
1716

1817
from pydantic_ai._function_schema import _takes_ctx as is_takes_ctx # type: ignore
@@ -610,7 +609,11 @@ async def process_function_tools( # noqa: C901
610609
else:
611610
try:
612611
result_data = await _call_tool(toolset, call, run_context)
612+
except exceptions.UnexpectedModelBehavior as e:
613+
ctx.state.increment_retries(ctx.deps.max_result_retries, e)
614+
raise e
613615
except _output.ToolRetryError as e:
616+
ctx.state.increment_retries(ctx.deps.max_result_retries, e)
614617
yield _messages.FunctionToolCallEvent(call)
615618
parts.append(e.tool_retry)
616619
yield _messages.FunctionToolResultEvent(e.tool_retry, tool_call_id=call.tool_call_id)
@@ -792,26 +795,8 @@ async def _call_tool(
792795
toolset: AbstractToolset[DepsT], tool_call: _messages.ToolCallPart, run_context: RunContext[DepsT]
793796
) -> Any:
794797
run_context = dataclasses.replace(run_context, tool_call_id=tool_call.tool_call_id)
795-
796-
try:
797-
args_dict = toolset.validate_tool_args(run_context, tool_call.tool_name, tool_call.args)
798-
response_content = await toolset.call_tool(run_context, tool_call.tool_name, args_dict)
799-
except (ValidationError, exceptions.ModelRetry) as e:
800-
if isinstance(e, ValidationError):
801-
m = _messages.RetryPromptPart(
802-
tool_name=tool_call.tool_name,
803-
content=e.errors(include_url=False, include_context=False),
804-
tool_call_id=tool_call.tool_call_id,
805-
)
806-
else:
807-
m = _messages.RetryPromptPart(
808-
tool_name=tool_call.tool_name,
809-
content=e.message,
810-
tool_call_id=tool_call.tool_call_id,
811-
)
812-
raise _output.ToolRetryError(m)
813-
814-
return response_content
798+
args_dict = toolset.validate_tool_args(run_context, tool_call.tool_name, tool_call.args)
799+
return await toolset.call_tool(run_context, tool_call.tool_name, args_dict)
815800

816801

817802
async def _validate_output(

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,7 @@ def __init__(
364364
self._function_toolset = FunctionToolset[AgentDepsT](tools, max_retries=retries)
365365

366366
# This will raise errors for any name conflicts
367+
# TODO: Also include toolsets (not mcp_serves as we won't have tool defs yet)
367368
CombinedToolset[AgentDepsT]([self._output_toolset, self._function_toolset])
368369

369370
# TODO: Set max_retries on MCPServer

pydantic_ai_slim/pydantic_ai/mcp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,14 +182,14 @@ async def list_tool_defs(self) -> list[ToolDefinition]:
182182
for mcp_tool in mcp_tools
183183
]
184184

185-
def get_tool_args_validator(self, ctx: RunContext[Any], name: str) -> pydantic_core.SchemaValidator:
185+
def _get_tool_args_validator(self, ctx: RunContext[Any], name: str) -> pydantic_core.SchemaValidator:
186186
return pydantic_core.SchemaValidator(
187187
schema=pydantic_core.core_schema.dict_schema(
188188
pydantic_core.core_schema.str_schema(), pydantic_core.core_schema.any_schema()
189189
)
190190
)
191191

192-
def max_retries_for_tool(self, name: str) -> int:
192+
def _max_retries_for_tool(self, name: str) -> int:
193193
return 1
194194

195195
def set_mcp_sampling_model(self, model: models.Model) -> None:

pydantic_ai_slim/pydantic_ai/toolset.py

Lines changed: 84 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,20 @@
22

33
import asyncio
44
from abc import ABC, abstractmethod
5-
from collections.abc import Awaitable, Sequence
6-
from contextlib import AsyncExitStack
5+
from collections.abc import Awaitable, Iterator, Sequence
6+
from contextlib import AsyncExitStack, contextmanager
77
from dataclasses import dataclass, field, replace
88
from functools import partial
99
from types import TracebackType
10-
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Protocol, overload
10+
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Protocol, assert_never, overload
1111

1212
from pydantic import ValidationError
1313
from pydantic.json_schema import GenerateJsonSchema
1414
from pydantic_core import SchemaValidator
15-
from typing_extensions import Never, Self
15+
from typing_extensions import Self
1616

17-
from ._output import BaseOutputSchema, OutputValidator
17+
from . import messages as _messages
18+
from ._output import BaseOutputSchema, OutputValidator, ToolRetryError
1819
from ._run_context import AgentDepsT, RunContext
1920
from .exceptions import ModelRetry, UnexpectedModelBehavior, UserError
2021
from .tools import (
@@ -70,21 +71,21 @@ def tool_names(self) -> list[str]:
7071
return [tool_def.name for tool_def in self.tool_defs]
7172

7273
@abstractmethod
73-
def get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator:
74+
def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator:
7475
raise NotImplementedError()
7576

7677
def validate_tool_args(
7778
self, ctx: RunContext[AgentDepsT], name: str, args: str | dict[str, Any] | None, allow_partial: bool = False
7879
) -> dict[str, Any]:
7980
pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off'
80-
validator = self.get_tool_args_validator(ctx, name)
81+
validator = self._get_tool_args_validator(ctx, name)
8182
if isinstance(args, str):
8283
return validator.validate_json(args or '{}', allow_partial=pyd_allow_partial)
8384
else:
8485
return validator.validate_python(args or {}, allow_partial=pyd_allow_partial)
8586

8687
@abstractmethod
87-
def max_retries_for_tool(self, name: str) -> int:
88+
def _max_retries_for_tool(self, name: str) -> int:
8889
raise NotImplementedError()
8990

9091
@abstractmethod
@@ -273,10 +274,10 @@ async def _prepare_tool_def(self, ctx: RunContext[AgentDepsT], tool_def: ToolDef
273274
def tool_defs(self) -> list[ToolDefinition]:
274275
return [tool.tool_def for tool in self.tools.values()]
275276

276-
def get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator:
277+
def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator:
277278
return self.tools[name].function_schema.validator
278279

279-
def max_retries_for_tool(self, name: str) -> int:
280+
def _max_retries_for_tool(self, name: str) -> int:
280281
tool = self.tools[name]
281282
return tool.max_retries if tool.max_retries is not None else self.max_retries
282283

@@ -298,10 +299,10 @@ class OutputToolset(AbstractToolset[AgentDepsT]):
298299
def tool_defs(self) -> list[ToolDefinition]:
299300
return [tool.tool_def for tool in self.output_schema.tools.values()]
300301

301-
def get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator:
302+
def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator:
302303
return self.output_schema.tools[name].processor.validator
303304

304-
def max_retries_for_tool(self, name: str) -> int:
305+
def _max_retries_for_tool(self, name: str) -> int:
305306
return self.max_retries
306307

307308
async def call_tool(
@@ -365,16 +366,16 @@ def tool_defs(self) -> list[ToolDefinition]:
365366
def tool_names(self) -> list[str]:
366367
return list(self._toolset_per_tool_name.keys())
367368

368-
def get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator:
369-
return self._toolset_for_tool_name(name).get_tool_args_validator(ctx, name)
369+
def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator:
370+
return self._toolset_for_tool_name(name)._get_tool_args_validator(ctx, name)
370371

371372
def validate_tool_args(
372373
self, ctx: RunContext[AgentDepsT], name: str, args: str | dict[str, Any] | None, allow_partial: bool = False
373374
) -> dict[str, Any]:
374375
return self._toolset_for_tool_name(name).validate_tool_args(ctx, name, args, allow_partial)
375376

376-
def max_retries_for_tool(self, name: str) -> int:
377-
return self._toolset_for_tool_name(name).max_retries_for_tool(name)
377+
def _max_retries_for_tool(self, name: str) -> int:
378+
return self._toolset_for_tool_name(name)._max_retries_for_tool(name)
378379

379380
async def call_tool(
380381
self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any
@@ -419,11 +420,11 @@ async def __aexit__(
419420
def tool_defs(self) -> list[ToolDefinition]:
420421
return self.wrapped.tool_defs
421422

422-
def get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator:
423-
return self.wrapped.get_tool_args_validator(ctx, name)
423+
def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator:
424+
return self.wrapped._get_tool_args_validator(ctx, name)
424425

425-
def max_retries_for_tool(self, name: str) -> int:
426-
return self.wrapped.max_retries_for_tool(name)
426+
def _max_retries_for_tool(self, name: str) -> int:
427+
return self.wrapped._max_retries_for_tool(name)
427428

428429
async def call_tool(
429430
self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any
@@ -452,11 +453,11 @@ async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[Agent
452453
def tool_defs(self) -> list[ToolDefinition]:
453454
return [replace(tool_def, name=self._prefixed_tool_name(tool_def.name)) for tool_def in super().tool_defs]
454455

455-
def get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator:
456-
return super().get_tool_args_validator(ctx, self._unprefixed_tool_name(name))
456+
def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator:
457+
return super()._get_tool_args_validator(ctx, self._unprefixed_tool_name(name))
457458

458-
def max_retries_for_tool(self, name: str) -> int:
459-
return super().max_retries_for_tool(self._unprefixed_tool_name(name))
459+
def _max_retries_for_tool(self, name: str) -> int:
460+
return super()._max_retries_for_tool(self._unprefixed_tool_name(name))
460461

461462
async def call_tool(
462463
self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any
@@ -519,11 +520,11 @@ async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[Agent
519520
def tool_defs(self) -> list[ToolDefinition]:
520521
return self._tool_defs
521522

522-
def get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator:
523-
return super().get_tool_args_validator(ctx, self._map_name(name))
523+
def _get_tool_args_validator(self, ctx: RunContext[AgentDepsT], name: str) -> SchemaValidator:
524+
return super()._get_tool_args_validator(ctx, self._map_name(name))
524525

525-
def max_retries_for_tool(self, name: str) -> int:
526-
return super().max_retries_for_tool(self._map_name(name))
526+
def _max_retries_for_tool(self, name: str) -> int:
527+
return super()._max_retries_for_tool(self._map_name(name))
527528

528529
async def call_tool(
529530
self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any
@@ -660,40 +661,66 @@ def tool_names(self) -> list[str]:
660661
def validate_tool_args(
661662
self, ctx: RunContext[AgentDepsT], name: str, args: str | dict[str, Any] | None, allow_partial: bool = False
662663
) -> dict[str, Any]:
663-
try:
664-
self._validate_tool_name(name)
665-
666-
ctx = replace(ctx, tool_name=name, retry=self._retries.get(name, 0))
664+
with self._with_retry(name, ctx) as ctx:
667665
return super().validate_tool_args(ctx, name, args, allow_partial)
668-
except ValidationError as e:
669-
return self._on_error(name, e)
670666

671667
async def call_tool(
672668
self, ctx: RunContext[AgentDepsT], name: str, tool_args: dict[str, Any], *args: Any, **kwargs: Any
673669
) -> Any:
670+
with self._with_retry(name, ctx) as ctx:
671+
try:
672+
output = await super().call_tool(ctx, name, tool_args, *args, **kwargs)
673+
except Exception as e:
674+
raise e
675+
else:
676+
self._retries.pop(name, None)
677+
return output
678+
679+
@contextmanager
680+
def _with_retry(self, name: str, ctx: RunContext[AgentDepsT]) -> Iterator[RunContext[AgentDepsT]]:
674681
try:
675-
self._validate_tool_name(name)
676-
677-
ctx = replace(ctx, tool_name=name, retry=self._retries.get(name, 0))
678-
return await super().call_tool(ctx, name, tool_args, *args, **kwargs)
679-
except ModelRetry as e:
680-
return self._on_error(name, e)
681-
682-
def _on_error(self, name: str, e: Exception) -> Never:
683-
max_retries = self.max_retries_for_tool(name)
684-
current_retry = self._retries.get(name, 0)
685-
if current_retry == max_retries:
686-
raise UnexpectedModelBehavior(f'Tool {name!r} exceeded max retries count of {max_retries}') from e
687-
else:
688-
self._retries[name] = current_retry + 1 # TODO: Reset on successful call!
689-
raise e
682+
if name not in self.tool_names:
683+
if self.tool_names:
684+
msg = f'Available tools: {", ".join(self.tool_names)}'
685+
else:
686+
msg = 'No tools available.'
687+
raise ModelRetry(f'Unknown tool name: {name!r}. {msg}')
688+
689+
ctx = replace(ctx, tool_name=name, retry=self._retries.get(name, 0), retries={})
690+
yield ctx
691+
except (ValidationError, ModelRetry, UnexpectedModelBehavior, ToolRetryError) as e:
692+
if isinstance(e, ToolRetryError):
693+
pass
694+
elif isinstance(e, ValidationError):
695+
if ctx.tool_call_id:
696+
m = _messages.RetryPromptPart(
697+
tool_name=name,
698+
content=e.errors(include_url=False, include_context=False),
699+
tool_call_id=ctx.tool_call_id,
700+
)
701+
e = ToolRetryError(m)
702+
elif isinstance(e, ModelRetry):
703+
if ctx.tool_call_id:
704+
m = _messages.RetryPromptPart(
705+
tool_name=name,
706+
content=e.message,
707+
tool_call_id=ctx.tool_call_id,
708+
)
709+
e = ToolRetryError(m)
710+
elif isinstance(e, UnexpectedModelBehavior):
711+
if e.__cause__ is not None:
712+
e = e.__cause__
713+
else:
714+
assert_never(e)
690715

691-
def _validate_tool_name(self, name: str) -> None:
692-
if name in self.tool_names:
693-
return
716+
try:
717+
max_retries = self._max_retries_for_tool(name)
718+
except Exception:
719+
max_retries = 1
720+
current_retry = self._retries.get(name, 0)
694721

695-
if self.tool_names:
696-
msg = f'Available tools: {", ".join(self.tool_names)}'
697-
else:
698-
msg = 'No tools available.'
699-
raise ModelRetry(f'Unknown tool name: {name!r}. {msg}')
722+
if current_retry == max_retries:
723+
raise UnexpectedModelBehavior(f'Tool {name!r} exceeded max retries count of {max_retries}') from e
724+
else:
725+
self._retries[name] = current_retry + 1
726+
raise e

tests/models/test_model_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import asyncio
66
import dataclasses
7+
import re
78
from datetime import timezone
89
from typing import Annotated, Any, Literal
910

@@ -157,7 +158,7 @@ def validate_output(ctx: RunContext[None], output: OutputModel) -> OutputModel:
157158
call_count += 1
158159
raise ModelRetry('Fail')
159160

160-
with pytest.raises(UnexpectedModelBehavior, match="Tool 'final_result' exceeded max retries count of 2"):
161+
with pytest.raises(UnexpectedModelBehavior, match=re.escape('Exceeded maximum retries (2) for result validation')):
161162
agent.run_sync('Hello', model=TestModel())
162163

163164
assert call_count == 3
@@ -200,7 +201,7 @@ class ResultModel(BaseModel):
200201

201202
agent = Agent('test', output_type=ResultModel, retries=2)
202203

203-
with pytest.raises(UnexpectedModelBehavior, match="Tool 'final_result' exceeded max retries count of 2"):
204+
with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(2\) for result validation'):
204205
agent.run_sync('Hello', model=TestModel(custom_output_args={'foo': 'a', 'bar': 1}))
205206

206207

tests/test_examples.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,10 +270,10 @@ async def __aexit__(self, *args: Any) -> None:
270270
def tool_defs(self) -> list[ToolDefinition]:
271271
return []
272272

273-
def get_tool_args_validator(self, ctx: RunContext[Any], name: str) -> SchemaValidator:
273+
def _get_tool_args_validator(self, ctx: RunContext[Any], name: str) -> SchemaValidator:
274274
return SchemaValidator(core_schema.any_schema()) # pragma: lax no cover
275275

276-
def max_retries_for_tool(self, name: str) -> int:
276+
def _max_retries_for_tool(self, name: str) -> int:
277277
return 0 # pragma: lax no cover
278278

279279
async def call_tool(

0 commit comments

Comments
 (0)