2
2
3
3
import asyncio
4
4
from abc import ABC , abstractmethod
5
- from collections .abc import Awaitable , Sequence
6
- from contextlib import AsyncExitStack
5
+ from collections .abc import Awaitable , Iterator , Sequence
6
+ from contextlib import AsyncExitStack , contextmanager
7
7
from dataclasses import dataclass , field , replace
8
8
from functools import partial
9
9
from types import TracebackType
10
- from typing import TYPE_CHECKING , Any , Callable , Generic , Literal , Protocol , overload
10
+ from typing import TYPE_CHECKING , Any , Callable , Generic , Literal , Protocol , assert_never , overload
11
11
12
12
from pydantic import ValidationError
13
13
from pydantic .json_schema import GenerateJsonSchema
14
14
from pydantic_core import SchemaValidator
15
- from typing_extensions import Never , Self
15
+ from typing_extensions import Self
16
16
17
- from ._output import BaseOutputSchema , OutputValidator
17
+ from . import messages as _messages
18
+ from ._output import BaseOutputSchema , OutputValidator , ToolRetryError
18
19
from ._run_context import AgentDepsT , RunContext
19
20
from .exceptions import ModelRetry , UnexpectedModelBehavior , UserError
20
21
from .tools import (
@@ -70,21 +71,21 @@ def tool_names(self) -> list[str]:
70
71
return [tool_def .name for tool_def in self .tool_defs ]
71
72
72
73
@abstractmethod
73
- def get_tool_args_validator (self , ctx : RunContext [AgentDepsT ], name : str ) -> SchemaValidator :
74
+ def _get_tool_args_validator (self , ctx : RunContext [AgentDepsT ], name : str ) -> SchemaValidator :
74
75
raise NotImplementedError ()
75
76
76
77
def validate_tool_args (
77
78
self , ctx : RunContext [AgentDepsT ], name : str , args : str | dict [str , Any ] | None , allow_partial : bool = False
78
79
) -> dict [str , Any ]:
79
80
pyd_allow_partial : Literal ['off' , 'trailing-strings' ] = 'trailing-strings' if allow_partial else 'off'
80
- validator = self .get_tool_args_validator (ctx , name )
81
+ validator = self ._get_tool_args_validator (ctx , name )
81
82
if isinstance (args , str ):
82
83
return validator .validate_json (args or '{}' , allow_partial = pyd_allow_partial )
83
84
else :
84
85
return validator .validate_python (args or {}, allow_partial = pyd_allow_partial )
85
86
86
87
@abstractmethod
87
- def max_retries_for_tool (self , name : str ) -> int :
88
+ def _max_retries_for_tool (self , name : str ) -> int :
88
89
raise NotImplementedError ()
89
90
90
91
@abstractmethod
@@ -273,10 +274,10 @@ async def _prepare_tool_def(self, ctx: RunContext[AgentDepsT], tool_def: ToolDef
273
274
def tool_defs (self ) -> list [ToolDefinition ]:
274
275
return [tool .tool_def for tool in self .tools .values ()]
275
276
276
- def get_tool_args_validator (self , ctx : RunContext [AgentDepsT ], name : str ) -> SchemaValidator :
277
+ def _get_tool_args_validator (self , ctx : RunContext [AgentDepsT ], name : str ) -> SchemaValidator :
277
278
return self .tools [name ].function_schema .validator
278
279
279
- def max_retries_for_tool (self , name : str ) -> int :
280
+ def _max_retries_for_tool (self , name : str ) -> int :
280
281
tool = self .tools [name ]
281
282
return tool .max_retries if tool .max_retries is not None else self .max_retries
282
283
@@ -298,10 +299,10 @@ class OutputToolset(AbstractToolset[AgentDepsT]):
298
299
def tool_defs (self ) -> list [ToolDefinition ]:
299
300
return [tool .tool_def for tool in self .output_schema .tools .values ()]
300
301
301
- def get_tool_args_validator (self , ctx : RunContext [AgentDepsT ], name : str ) -> SchemaValidator :
302
+ def _get_tool_args_validator (self , ctx : RunContext [AgentDepsT ], name : str ) -> SchemaValidator :
302
303
return self .output_schema .tools [name ].processor .validator
303
304
304
- def max_retries_for_tool (self , name : str ) -> int :
305
+ def _max_retries_for_tool (self , name : str ) -> int :
305
306
return self .max_retries
306
307
307
308
async def call_tool (
@@ -365,16 +366,16 @@ def tool_defs(self) -> list[ToolDefinition]:
365
366
def tool_names (self ) -> list [str ]:
366
367
return list (self ._toolset_per_tool_name .keys ())
367
368
368
- def get_tool_args_validator (self , ctx : RunContext [AgentDepsT ], name : str ) -> SchemaValidator :
369
- return self ._toolset_for_tool_name (name ).get_tool_args_validator (ctx , name )
369
+ def _get_tool_args_validator (self , ctx : RunContext [AgentDepsT ], name : str ) -> SchemaValidator :
370
+ return self ._toolset_for_tool_name (name )._get_tool_args_validator (ctx , name )
370
371
371
372
def validate_tool_args (
372
373
self , ctx : RunContext [AgentDepsT ], name : str , args : str | dict [str , Any ] | None , allow_partial : bool = False
373
374
) -> dict [str , Any ]:
374
375
return self ._toolset_for_tool_name (name ).validate_tool_args (ctx , name , args , allow_partial )
375
376
376
- def max_retries_for_tool (self , name : str ) -> int :
377
- return self ._toolset_for_tool_name (name ).max_retries_for_tool (name )
377
+ def _max_retries_for_tool (self , name : str ) -> int :
378
+ return self ._toolset_for_tool_name (name )._max_retries_for_tool (name )
378
379
379
380
async def call_tool (
380
381
self , ctx : RunContext [AgentDepsT ], name : str , tool_args : dict [str , Any ], * args : Any , ** kwargs : Any
@@ -419,11 +420,11 @@ async def __aexit__(
419
420
def tool_defs (self ) -> list [ToolDefinition ]:
420
421
return self .wrapped .tool_defs
421
422
422
- def get_tool_args_validator (self , ctx : RunContext [AgentDepsT ], name : str ) -> SchemaValidator :
423
- return self .wrapped .get_tool_args_validator (ctx , name )
423
+ def _get_tool_args_validator (self , ctx : RunContext [AgentDepsT ], name : str ) -> SchemaValidator :
424
+ return self .wrapped ._get_tool_args_validator (ctx , name )
424
425
425
- def max_retries_for_tool (self , name : str ) -> int :
426
- return self .wrapped .max_retries_for_tool (name )
426
+ def _max_retries_for_tool (self , name : str ) -> int :
427
+ return self .wrapped ._max_retries_for_tool (name )
427
428
428
429
async def call_tool (
429
430
self , ctx : RunContext [AgentDepsT ], name : str , tool_args : dict [str , Any ], * args : Any , ** kwargs : Any
@@ -452,11 +453,11 @@ async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[Agent
452
453
def tool_defs (self ) -> list [ToolDefinition ]:
453
454
return [replace (tool_def , name = self ._prefixed_tool_name (tool_def .name )) for tool_def in super ().tool_defs ]
454
455
455
- def get_tool_args_validator (self , ctx : RunContext [AgentDepsT ], name : str ) -> SchemaValidator :
456
- return super ().get_tool_args_validator (ctx , self ._unprefixed_tool_name (name ))
456
+ def _get_tool_args_validator (self , ctx : RunContext [AgentDepsT ], name : str ) -> SchemaValidator :
457
+ return super ()._get_tool_args_validator (ctx , self ._unprefixed_tool_name (name ))
457
458
458
- def max_retries_for_tool (self , name : str ) -> int :
459
- return super ().max_retries_for_tool (self ._unprefixed_tool_name (name ))
459
+ def _max_retries_for_tool (self , name : str ) -> int :
460
+ return super ()._max_retries_for_tool (self ._unprefixed_tool_name (name ))
460
461
461
462
async def call_tool (
462
463
self , ctx : RunContext [AgentDepsT ], name : str , tool_args : dict [str , Any ], * args : Any , ** kwargs : Any
@@ -519,11 +520,11 @@ async def prepare_for_run(self, ctx: RunContext[AgentDepsT]) -> RunToolset[Agent
519
520
def tool_defs (self ) -> list [ToolDefinition ]:
520
521
return self ._tool_defs
521
522
522
- def get_tool_args_validator (self , ctx : RunContext [AgentDepsT ], name : str ) -> SchemaValidator :
523
- return super ().get_tool_args_validator (ctx , self ._map_name (name ))
523
+ def _get_tool_args_validator (self , ctx : RunContext [AgentDepsT ], name : str ) -> SchemaValidator :
524
+ return super ()._get_tool_args_validator (ctx , self ._map_name (name ))
524
525
525
- def max_retries_for_tool (self , name : str ) -> int :
526
- return super ().max_retries_for_tool (self ._map_name (name ))
526
+ def _max_retries_for_tool (self , name : str ) -> int :
527
+ return super ()._max_retries_for_tool (self ._map_name (name ))
527
528
528
529
async def call_tool (
529
530
self , ctx : RunContext [AgentDepsT ], name : str , tool_args : dict [str , Any ], * args : Any , ** kwargs : Any
@@ -660,40 +661,66 @@ def tool_names(self) -> list[str]:
660
661
def validate_tool_args (
661
662
self , ctx : RunContext [AgentDepsT ], name : str , args : str | dict [str , Any ] | None , allow_partial : bool = False
662
663
) -> dict [str , Any ]:
663
- try :
664
- self ._validate_tool_name (name )
665
-
666
- ctx = replace (ctx , tool_name = name , retry = self ._retries .get (name , 0 ))
664
+ with self ._with_retry (name , ctx ) as ctx :
667
665
return super ().validate_tool_args (ctx , name , args , allow_partial )
668
- except ValidationError as e :
669
- return self ._on_error (name , e )
670
666
671
667
async def call_tool (
672
668
self , ctx : RunContext [AgentDepsT ], name : str , tool_args : dict [str , Any ], * args : Any , ** kwargs : Any
673
669
) -> Any :
670
+ with self ._with_retry (name , ctx ) as ctx :
671
+ try :
672
+ output = await super ().call_tool (ctx , name , tool_args , * args , ** kwargs )
673
+ except Exception as e :
674
+ raise e
675
+ else :
676
+ self ._retries .pop (name , None )
677
+ return output
678
+
679
+ @contextmanager
680
+ def _with_retry (self , name : str , ctx : RunContext [AgentDepsT ]) -> Iterator [RunContext [AgentDepsT ]]:
674
681
try :
675
- self ._validate_tool_name (name )
676
-
677
- ctx = replace (ctx , tool_name = name , retry = self ._retries .get (name , 0 ))
678
- return await super ().call_tool (ctx , name , tool_args , * args , ** kwargs )
679
- except ModelRetry as e :
680
- return self ._on_error (name , e )
681
-
682
- def _on_error (self , name : str , e : Exception ) -> Never :
683
- max_retries = self .max_retries_for_tool (name )
684
- current_retry = self ._retries .get (name , 0 )
685
- if current_retry == max_retries :
686
- raise UnexpectedModelBehavior (f'Tool { name !r} exceeded max retries count of { max_retries } ' ) from e
687
- else :
688
- self ._retries [name ] = current_retry + 1 # TODO: Reset on successful call!
689
- raise e
682
+ if name not in self .tool_names :
683
+ if self .tool_names :
684
+ msg = f'Available tools: { ", " .join (self .tool_names )} '
685
+ else :
686
+ msg = 'No tools available.'
687
+ raise ModelRetry (f'Unknown tool name: { name !r} . { msg } ' )
688
+
689
+ ctx = replace (ctx , tool_name = name , retry = self ._retries .get (name , 0 ), retries = {})
690
+ yield ctx
691
+ except (ValidationError , ModelRetry , UnexpectedModelBehavior , ToolRetryError ) as e :
692
+ if isinstance (e , ToolRetryError ):
693
+ pass
694
+ elif isinstance (e , ValidationError ):
695
+ if ctx .tool_call_id :
696
+ m = _messages .RetryPromptPart (
697
+ tool_name = name ,
698
+ content = e .errors (include_url = False , include_context = False ),
699
+ tool_call_id = ctx .tool_call_id ,
700
+ )
701
+ e = ToolRetryError (m )
702
+ elif isinstance (e , ModelRetry ):
703
+ if ctx .tool_call_id :
704
+ m = _messages .RetryPromptPart (
705
+ tool_name = name ,
706
+ content = e .message ,
707
+ tool_call_id = ctx .tool_call_id ,
708
+ )
709
+ e = ToolRetryError (m )
710
+ elif isinstance (e , UnexpectedModelBehavior ):
711
+ if e .__cause__ is not None :
712
+ e = e .__cause__
713
+ else :
714
+ assert_never (e )
690
715
691
- def _validate_tool_name (self , name : str ) -> None :
692
- if name in self .tool_names :
693
- return
716
+ try :
717
+ max_retries = self ._max_retries_for_tool (name )
718
+ except Exception :
719
+ max_retries = 1
720
+ current_retry = self ._retries .get (name , 0 )
694
721
695
- if self . tool_names :
696
- msg = f'Available tools: { ", " . join ( self . tool_names ) } '
697
- else :
698
- msg = 'No tools available.'
699
- raise ModelRetry ( f'Unknown tool name: { name !r } . { msg } ' )
722
+ if current_retry == max_retries :
723
+ raise UnexpectedModelBehavior ( f'Tool { name !r } exceeded max retries count of { max_retries } ' ) from e
724
+ else :
725
+ self . _retries [ name ] = current_retry + 1
726
+ raise e
0 commit comments