|
7 | 7 | from dataclasses import dataclass, field, replace
|
8 | 8 | from functools import partial
|
9 | 9 | from types import TracebackType
|
10 |
| -from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Protocol, assert_never, overload |
| 10 | +from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Protocol, overload |
11 | 11 |
|
12 | 12 | from pydantic import ValidationError
|
13 | 13 | from pydantic.json_schema import GenerateJsonSchema
|
@@ -689,38 +689,33 @@ def _with_retry(self, name: str, ctx: RunContext[AgentDepsT]) -> Iterator[RunCon
|
689 | 689 | ctx = replace(ctx, tool_name=name, retry=self._retries.get(name, 0), retries={})
|
690 | 690 | yield ctx
|
691 | 691 | except (ValidationError, ModelRetry, UnexpectedModelBehavior, ToolRetryError) as e:
|
692 |
| - if isinstance(e, ToolRetryError): |
693 |
| - pass |
694 |
| - elif isinstance(e, ValidationError): |
695 |
| - if ctx.tool_call_id: |
696 |
| - m = _messages.RetryPromptPart( |
697 |
| - tool_name=name, |
698 |
| - content=e.errors(include_url=False, include_context=False), |
699 |
| - tool_call_id=ctx.tool_call_id, |
700 |
| - ) |
701 |
| - e = ToolRetryError(m) |
702 |
| - elif isinstance(e, ModelRetry): |
703 |
| - if ctx.tool_call_id: |
704 |
| - m = _messages.RetryPromptPart( |
705 |
| - tool_name=name, |
706 |
| - content=e.message, |
707 |
| - tool_call_id=ctx.tool_call_id, |
708 |
| - ) |
709 |
| - e = ToolRetryError(m) |
710 |
| - elif isinstance(e, UnexpectedModelBehavior): |
711 |
| - if e.__cause__ is not None: |
712 |
| - e = e.__cause__ |
713 |
| - else: |
714 |
| - assert_never(e) |
715 |
| - |
716 | 692 | try:
|
717 | 693 | max_retries = self._max_retries_for_tool(name)
|
718 | 694 | except Exception:
|
719 | 695 | max_retries = 1
|
720 | 696 | current_retry = self._retries.get(name, 0)
|
721 | 697 |
|
| 698 | + if isinstance(e, UnexpectedModelBehavior) and e.__cause__ is not None: |
| 699 | + e = e.__cause__ |
| 700 | + |
722 | 701 | if current_retry == max_retries:
|
723 | 702 | raise UnexpectedModelBehavior(f'Tool {name!r} exceeded max retries count of {max_retries}') from e
|
724 | 703 | else:
|
| 704 | + if ctx.tool_call_id: |
| 705 | + if isinstance(e, ValidationError): |
| 706 | + m = _messages.RetryPromptPart( |
| 707 | + tool_name=name, |
| 708 | + content=e.errors(include_url=False, include_context=False), |
| 709 | + tool_call_id=ctx.tool_call_id, |
| 710 | + ) |
| 711 | + e = ToolRetryError(m) |
| 712 | + elif isinstance(e, ModelRetry): |
| 713 | + m = _messages.RetryPromptPart( |
| 714 | + tool_name=name, |
| 715 | + content=e.message, |
| 716 | + tool_call_id=ctx.tool_call_id, |
| 717 | + ) |
| 718 | + e = ToolRetryError(m) |
| 719 | + |
725 | 720 | self._retries[name] = current_retry + 1
|
726 | 721 | raise e
|
0 commit comments