Skip to content

Commit 74a56ae

Browse files
committed
Fix retry error wrapping
1 parent 84cd954 commit 74a56ae

File tree

3 files changed

+25
-28
lines changed

3 files changed

+25
-28
lines changed

docs/output.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,8 @@ async def hand_off_to_sql_agent(ctx: RunContext, query: str) -> list[Row]:
200200
return output
201201
except UnexpectedModelBehavior as e:
202202
# Bubble up potentially retryable errors to the router agent
203-
if (cause := e.__cause__) and hasattr(cause, 'tool_retry'):
204-
raise ModelRetry(f'SQL agent failed: {cause.tool_retry.content}') from e
203+
if (cause := e.__cause__) and isinstance(cause, ModelRetry):
204+
raise ModelRetry(f'SQL agent failed: {cause.message}') from e
205205
else:
206206
raise
207207

pydantic_ai_slim/pydantic_ai/_agent_graph.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,13 @@ class GraphAgentState:
7979
retries: int
8080
run_step: int
8181

82-
def increment_retries(self, max_result_retries: int, error: Exception | None = None) -> None:
82+
def increment_retries(self, max_result_retries: int, error: BaseException | None = None) -> None:
8383
self.retries += 1
8484
if self.retries > max_result_retries:
8585
message = f'Exceeded maximum retries ({max_result_retries}) for result validation'
8686
if error:
87+
if isinstance(error, exceptions.UnexpectedModelBehavior) and error.__cause__ is not None:
88+
error = error.__cause__
8789
raise exceptions.UnexpectedModelBehavior(message) from error
8890
else:
8991
raise exceptions.UnexpectedModelBehavior(message)

pydantic_ai_slim/pydantic_ai/toolset.py

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from dataclasses import dataclass, field, replace
88
from functools import partial
99
from types import TracebackType
10-
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Protocol, assert_never, overload
10+
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Protocol, overload
1111

1212
from pydantic import ValidationError
1313
from pydantic.json_schema import GenerateJsonSchema
@@ -689,38 +689,33 @@ def _with_retry(self, name: str, ctx: RunContext[AgentDepsT]) -> Iterator[RunCon
689689
ctx = replace(ctx, tool_name=name, retry=self._retries.get(name, 0), retries={})
690690
yield ctx
691691
except (ValidationError, ModelRetry, UnexpectedModelBehavior, ToolRetryError) as e:
692-
if isinstance(e, ToolRetryError):
693-
pass
694-
elif isinstance(e, ValidationError):
695-
if ctx.tool_call_id:
696-
m = _messages.RetryPromptPart(
697-
tool_name=name,
698-
content=e.errors(include_url=False, include_context=False),
699-
tool_call_id=ctx.tool_call_id,
700-
)
701-
e = ToolRetryError(m)
702-
elif isinstance(e, ModelRetry):
703-
if ctx.tool_call_id:
704-
m = _messages.RetryPromptPart(
705-
tool_name=name,
706-
content=e.message,
707-
tool_call_id=ctx.tool_call_id,
708-
)
709-
e = ToolRetryError(m)
710-
elif isinstance(e, UnexpectedModelBehavior):
711-
if e.__cause__ is not None:
712-
e = e.__cause__
713-
else:
714-
assert_never(e)
715-
716692
try:
717693
max_retries = self._max_retries_for_tool(name)
718694
except Exception:
719695
max_retries = 1
720696
current_retry = self._retries.get(name, 0)
721697

698+
if isinstance(e, UnexpectedModelBehavior) and e.__cause__ is not None:
699+
e = e.__cause__
700+
722701
if current_retry == max_retries:
723702
raise UnexpectedModelBehavior(f'Tool {name!r} exceeded max retries count of {max_retries}') from e
724703
else:
704+
if ctx.tool_call_id:
705+
if isinstance(e, ValidationError):
706+
m = _messages.RetryPromptPart(
707+
tool_name=name,
708+
content=e.errors(include_url=False, include_context=False),
709+
tool_call_id=ctx.tool_call_id,
710+
)
711+
e = ToolRetryError(m)
712+
elif isinstance(e, ModelRetry):
713+
m = _messages.RetryPromptPart(
714+
tool_name=name,
715+
content=e.message,
716+
tool_call_id=ctx.tool_call_id,
717+
)
718+
e = ToolRetryError(m)
719+
725720
self._retries[name] = current_retry + 1
726721
raise e

0 commit comments

Comments
 (0)