23
23
Mapping ,
24
24
Optional ,
25
25
Sequence ,
26
+ Type ,
26
27
Union ,
28
+ cast ,
27
29
)
28
30
29
31
from aiohttp import ClientSession
32
+ from pydantic import BaseModel , Field , create_model
30
33
31
34
from toolbox_core .protocol import ParameterSchema
32
35
@@ -78,6 +81,8 @@ def __init__(
78
81
self .__url = f"{ base_url } /api/tool/{ name } /invoke"
79
82
self .__description = description
80
83
self .__params = params
84
+ self .__pydantic_model = params_to_pydantic_model (name , self .__params )
85
+
81
86
inspect_type_params = [param .to_param () for param in self .__params ]
82
87
83
88
# the following properties are set to help anyone that might inspect it determine usage
@@ -86,6 +91,7 @@ def __init__(
86
91
self .__signature__ = Signature (
87
92
parameters = inspect_type_params , return_annotation = str
88
93
)
94
+
89
95
self .__annotations__ = {p .name : p .annotation for p in inspect_type_params }
90
96
# TODO: self.__qualname__ ??
91
97
@@ -170,6 +176,9 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
170
176
all_args .apply_defaults () # Include default values if not provided
171
177
payload = all_args .arguments
172
178
179
+ # Perform argument type validations using pydantic
180
+ self .__pydantic_model .model_validate (payload )
181
+
173
182
# apply bounded parameters
174
183
for param , value in self .__bound_parameters .items ():
175
184
if asyncio .iscoroutinefunction (value ):
@@ -305,3 +314,19 @@ def identify_required_authn_params(
305
314
if required :
306
315
required_params [param ] = services
307
316
return required_params
317
+
318
+
319
+ def params_to_pydantic_model (
320
+ tool_name : str , params : Sequence [ParameterSchema ]
321
+ ) -> Type [BaseModel ]:
322
+ """Converts the given parameters to a Pydantic BaseModel class."""
323
+ field_definitions = {}
324
+ for field in params :
325
+ field_definitions [field .name ] = cast (
326
+ Any ,
327
+ (
328
+ field .to_param ().annotation ,
329
+ Field (description = field .description ),
330
+ ),
331
+ )
332
+ return create_model (tool_name , ** field_definitions )
0 commit comments