1
1
from abc import ABC
2
2
from enum import Enum
3
- from typing import Any , Dict , List , Callable , Optional
3
+ from typing import Any , Dict , List , Callable , Optional , Union , ClassVar , Type
4
+ from pydantic import BaseModel , Field , model_validator
4
5
5
6
6
7
class ParameterType (str , Enum ):
@@ -14,56 +15,61 @@ class ParameterType(str, Enum):
14
15
ARRAY = "array"
15
16
16
17
17
- class ToolParameter :
18
- """Base class for all tool parameters."""
19
-
20
- def __init__ (self , description : str , required : bool = False ):
21
- self .description = description
22
- self .required = required
23
-
24
- def to_dict (self ) -> Dict [str , Any ]:
25
- """Convert the parameter to a dictionary format."""
26
- raise NotImplementedError ("Subclasses must implement to_dict" )
18
+ class ToolParameter (BaseModel ):
19
+ """Base class for all tool parameters using Pydantic."""
20
+ description : str
21
+ required : bool = False
22
+ type : ClassVar [ParameterType ]
23
+
24
+ def model_dump_tool (self ) -> Dict [str , Any ]:
25
+ """Convert the parameter to a dictionary format for tool usage."""
26
+ result = {"type" : self .type , "description" : self .description }
27
+ return result
28
+
29
+ @classmethod
30
+ def from_dict (cls , data : Dict [str , Any ]) -> "ToolParameter" :
31
+ """Create a parameter from a dictionary."""
32
+ param_type = data .get ("type" )
33
+ if not param_type :
34
+ raise ValueError ("Parameter type is required" )
35
+
36
+ # Find the appropriate class based on the type
37
+ param_classes = {
38
+ ParameterType .STRING : StringParameter ,
39
+ ParameterType .INTEGER : IntegerParameter ,
40
+ ParameterType .NUMBER : NumberParameter ,
41
+ ParameterType .BOOLEAN : BooleanParameter ,
42
+ ParameterType .OBJECT : ObjectParameter ,
43
+ ParameterType .ARRAY : ArrayParameter ,
44
+ }
45
+
46
+ param_class = param_classes .get (param_type )
47
+ if not param_class :
48
+ raise ValueError (f"Unknown parameter type: { param_type } " )
49
+
50
+ return param_class .model_validate (data )
27
51
28
52
29
53
class StringParameter (ToolParameter ):
30
54
"""String parameter for tools."""
31
-
32
- def __init__ (
33
- self , description : str , required : bool = False , enum : Optional [List [str ]] = None
34
- ):
35
- super ().__init__ (description , required )
36
- self .enum = enum
37
-
38
- def to_dict (self ) -> Dict [str , Any ]:
39
- result : Dict [str , Any ] = {
40
- "type" : ParameterType .STRING ,
41
- "description" : self .description ,
42
- }
55
+ type : ClassVar [ParameterType ] = ParameterType .STRING
56
+ enum : Optional [List [str ]] = None
57
+
58
+ def model_dump_tool (self ) -> Dict [str , Any ]:
59
+ result = super ().model_dump_tool ()
43
60
if self .enum :
44
61
result ["enum" ] = self .enum
45
62
return result
46
63
47
64
48
65
class IntegerParameter (ToolParameter ):
49
66
"""Integer parameter for tools."""
50
-
51
- def __init__ (
52
- self ,
53
- description : str ,
54
- required : bool = False ,
55
- minimum : Optional [int ] = None ,
56
- maximum : Optional [int ] = None ,
57
- ):
58
- super ().__init__ (description , required )
59
- self .minimum = minimum
60
- self .maximum = maximum
61
-
62
- def to_dict (self ) -> Dict [str , Any ]:
63
- result : Dict [str , Any ] = {
64
- "type" : ParameterType .INTEGER ,
65
- "description" : self .description ,
66
- }
67
+ type : ClassVar [ParameterType ] = ParameterType .INTEGER
68
+ minimum : Optional [int ] = None
69
+ maximum : Optional [int ] = None
70
+
71
+ def model_dump_tool (self ) -> Dict [str , Any ]:
72
+ result = super ().model_dump_tool ()
67
73
if self .minimum is not None :
68
74
result ["minimum" ] = self .minimum
69
75
if self .maximum is not None :
@@ -73,23 +79,12 @@ def to_dict(self) -> Dict[str, Any]:
73
79
74
80
class NumberParameter (ToolParameter ):
75
81
"""Number parameter for tools."""
76
-
77
- def __init__ (
78
- self ,
79
- description : str ,
80
- required : bool = False ,
81
- minimum : Optional [float ] = None ,
82
- maximum : Optional [float ] = None ,
83
- ):
84
- super ().__init__ (description , required )
85
- self .minimum = minimum
86
- self .maximum = maximum
87
-
88
- def to_dict (self ) -> Dict [str , Any ]:
89
- result : Dict [str , Any ] = {
90
- "type" : ParameterType .NUMBER ,
91
- "description" : self .description ,
92
- }
82
+ type : ClassVar [ParameterType ] = ParameterType .NUMBER
83
+ minimum : Optional [float ] = None
84
+ maximum : Optional [float ] = None
85
+
86
+ def model_dump_tool (self ) -> Dict [str , Any ]:
87
+ result = super ().model_dump_tool ()
93
88
if self .minimum is not None :
94
89
result ["minimum" ] = self .minimum
95
90
if self .maximum is not None :
@@ -99,77 +94,71 @@ def to_dict(self) -> Dict[str, Any]:
99
94
100
95
class BooleanParameter (ToolParameter ):
101
96
"""Boolean parameter for tools."""
97
+ type : ClassVar [ParameterType ] = ParameterType .BOOLEAN
98
+
102
99
103
- def to_dict (self ) -> Dict [str , Any ]:
104
- return {"type" : ParameterType .BOOLEAN , "description" : self .description }
100
+ class ArrayParameter (ToolParameter ):
101
+ """Array parameter for tools."""
102
+ type : ClassVar [ParameterType ] = ParameterType .ARRAY
103
+ items : "ToolParameter"
104
+ min_items : Optional [int ] = None
105
+ max_items : Optional [int ] = None
106
+
107
+ def model_dump_tool (self ) -> Dict [str , Any ]:
108
+ result = super ().model_dump_tool ()
109
+ result ["items" ] = self .items .model_dump_tool ()
110
+ if self .min_items is not None :
111
+ result ["minItems" ] = self .min_items
112
+ if self .max_items is not None :
113
+ result ["maxItems" ] = self .max_items
114
+ return result
115
+
116
+ @model_validator (mode = "after" )
117
+ def validate_items (self ) -> "ArrayParameter" :
118
+ if not isinstance (self .items , ToolParameter ):
119
+ if isinstance (self .items , dict ):
120
+ self .items = ToolParameter .from_dict (self .items )
121
+ else :
122
+ raise ValueError (f"Items must be a ToolParameter or dict, got { type (self .items )} " )
123
+ return self
105
124
106
125
107
126
class ObjectParameter (ToolParameter ):
108
127
"""Object parameter for tools."""
109
-
110
- def __init__ (
111
- self ,
112
- description : str ,
113
- properties : Dict [str , ToolParameter ],
114
- required : bool = False ,
115
- required_properties : Optional [List [str ]] = None ,
116
- additional_properties : bool = True ,
117
- ):
118
- super ().__init__ (description , required )
119
- self .properties = properties
120
- self .required_properties = required_properties or []
121
- self .additional_properties = additional_properties
122
-
123
- def to_dict (self ) -> Dict [str , Any ]:
128
+ type : ClassVar [ParameterType ] = ParameterType .OBJECT
129
+ properties : Dict [str , ToolParameter ]
130
+ required_properties : List [str ] = Field (default_factory = list )
131
+ additional_properties : bool = True
132
+
133
+ def model_dump_tool (self ) -> Dict [str , Any ]:
124
134
properties_dict : Dict [str , Any ] = {}
125
135
for name , param in self .properties .items ():
126
- properties_dict [name ] = param .to_dict ()
127
-
128
- result : Dict [str , Any ] = {
129
- "type" : ParameterType .OBJECT ,
130
- "description" : self .description ,
131
- "properties" : properties_dict ,
132
- }
136
+ properties_dict [name ] = param .model_dump_tool ()
133
137
138
+ result = super ().model_dump_tool ()
139
+ result ["properties" ] = properties_dict
140
+
134
141
if self .required_properties :
135
142
result ["required" ] = self .required_properties
136
-
143
+
137
144
if not self .additional_properties :
138
145
result ["additionalProperties" ] = False
139
-
140
- return result
141
-
142
-
143
- class ArrayParameter (ToolParameter ):
144
- """Array parameter for tools."""
145
-
146
- def __init__ (
147
- self ,
148
- description : str ,
149
- items : ToolParameter ,
150
- required : bool = False ,
151
- min_items : Optional [int ] = None ,
152
- max_items : Optional [int ] = None ,
153
- ):
154
- super ().__init__ (description , required )
155
- self .items = items
156
- self .min_items = min_items
157
- self .max_items = max_items
158
-
159
- def to_dict (self ) -> Dict [str , Any ]:
160
- result : Dict [str , Any ] = {
161
- "type" : ParameterType .ARRAY ,
162
- "description" : self .description ,
163
- "items" : self .items .to_dict (),
164
- }
165
-
166
- if self .min_items is not None :
167
- result ["minItems" ] = self .min_items
168
-
169
- if self .max_items is not None :
170
- result ["maxItems" ] = self .max_items
171
-
146
+
172
147
return result
148
+
149
+ @model_validator (mode = "after" )
150
+ def validate_properties (self ) -> "ObjectParameter" :
151
+ validated_properties = {}
152
+ for name , param in self .properties .items ():
153
+ if not isinstance (param , ToolParameter ):
154
+ if isinstance (param , dict ):
155
+ validated_properties [name ] = ToolParameter .from_dict (param )
156
+ else :
157
+ raise ValueError (f"Property { name } must be a ToolParameter or dict, got { type (param )} " )
158
+ else :
159
+ validated_properties [name ] = param
160
+ self .properties = validated_properties
161
+ return self
173
162
174
163
175
164
class Tool (ABC ):
@@ -179,12 +168,18 @@ def __init__(
179
168
self ,
180
169
name : str ,
181
170
description : str ,
182
- parameters : ObjectParameter ,
171
+ parameters : Union [ ObjectParameter , Dict [ str , Any ]] ,
183
172
execute_func : Callable [..., Any ],
184
173
):
185
174
self ._name = name
186
175
self ._description = description
187
- self ._parameters = parameters
176
+
177
+ # Allow parameters to be provided as a dictionary
178
+ if isinstance (parameters , dict ):
179
+ self ._parameters = ObjectParameter .model_validate (parameters )
180
+ else :
181
+ self ._parameters = parameters
182
+
188
183
self ._execute_func = execute_func
189
184
190
185
def get_name (self ) -> str :
@@ -209,7 +204,7 @@ def get_parameters(self) -> Dict[str, Any]:
209
204
Returns:
210
205
Dict[str, Any]: Dictionary containing parameter schema information.
211
206
"""
212
- return self ._parameters .to_dict ()
207
+ return self ._parameters .model_dump_tool ()
213
208
214
209
def execute (self , query : str , ** kwargs : Any ) -> Any :
215
210
"""Execute the tool with the given query and additional parameters.
0 commit comments