Skip to content

Commit 31614cd

Browse files
committed
Use pydantic model for the tool parameters
1 parent 31c36dd commit 31614cd

File tree

1 file changed

+113
-118
lines changed

1 file changed

+113
-118
lines changed

src/neo4j_graphrag/tool.py

Lines changed: 113 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from abc import ABC
22
from enum import Enum
3-
from typing import Any, Dict, List, Callable, Optional
3+
from typing import Any, Dict, List, Callable, Optional, Union, ClassVar, Type
4+
from pydantic import BaseModel, Field, model_validator
45

56

67
class ParameterType(str, Enum):
@@ -14,56 +15,61 @@ class ParameterType(str, Enum):
1415
ARRAY = "array"
1516

1617

17-
class ToolParameter:
18-
"""Base class for all tool parameters."""
19-
20-
def __init__(self, description: str, required: bool = False):
21-
self.description = description
22-
self.required = required
23-
24-
def to_dict(self) -> Dict[str, Any]:
25-
"""Convert the parameter to a dictionary format."""
26-
raise NotImplementedError("Subclasses must implement to_dict")
18+
class ToolParameter(BaseModel):
19+
"""Base class for all tool parameters using Pydantic."""
20+
description: str
21+
required: bool = False
22+
type: ClassVar[ParameterType]
23+
24+
def model_dump_tool(self) -> Dict[str, Any]:
25+
"""Convert the parameter to a dictionary format for tool usage."""
26+
result = {"type": self.type, "description": self.description}
27+
return result
28+
29+
@classmethod
30+
def from_dict(cls, data: Dict[str, Any]) -> "ToolParameter":
31+
"""Create a parameter from a dictionary."""
32+
param_type = data.get("type")
33+
if not param_type:
34+
raise ValueError("Parameter type is required")
35+
36+
# Find the appropriate class based on the type
37+
param_classes = {
38+
ParameterType.STRING: StringParameter,
39+
ParameterType.INTEGER: IntegerParameter,
40+
ParameterType.NUMBER: NumberParameter,
41+
ParameterType.BOOLEAN: BooleanParameter,
42+
ParameterType.OBJECT: ObjectParameter,
43+
ParameterType.ARRAY: ArrayParameter,
44+
}
45+
46+
param_class = param_classes.get(param_type)
47+
if not param_class:
48+
raise ValueError(f"Unknown parameter type: {param_type}")
49+
50+
return param_class.model_validate(data)
2751

2852

2953
class StringParameter(ToolParameter):
3054
"""String parameter for tools."""
31-
32-
def __init__(
33-
self, description: str, required: bool = False, enum: Optional[List[str]] = None
34-
):
35-
super().__init__(description, required)
36-
self.enum = enum
37-
38-
def to_dict(self) -> Dict[str, Any]:
39-
result: Dict[str, Any] = {
40-
"type": ParameterType.STRING,
41-
"description": self.description,
42-
}
55+
type: ClassVar[ParameterType] = ParameterType.STRING
56+
enum: Optional[List[str]] = None
57+
58+
def model_dump_tool(self) -> Dict[str, Any]:
59+
result = super().model_dump_tool()
4360
if self.enum:
4461
result["enum"] = self.enum
4562
return result
4663

4764

4865
class IntegerParameter(ToolParameter):
4966
"""Integer parameter for tools."""
50-
51-
def __init__(
52-
self,
53-
description: str,
54-
required: bool = False,
55-
minimum: Optional[int] = None,
56-
maximum: Optional[int] = None,
57-
):
58-
super().__init__(description, required)
59-
self.minimum = minimum
60-
self.maximum = maximum
61-
62-
def to_dict(self) -> Dict[str, Any]:
63-
result: Dict[str, Any] = {
64-
"type": ParameterType.INTEGER,
65-
"description": self.description,
66-
}
67+
type: ClassVar[ParameterType] = ParameterType.INTEGER
68+
minimum: Optional[int] = None
69+
maximum: Optional[int] = None
70+
71+
def model_dump_tool(self) -> Dict[str, Any]:
72+
result = super().model_dump_tool()
6773
if self.minimum is not None:
6874
result["minimum"] = self.minimum
6975
if self.maximum is not None:
@@ -73,23 +79,12 @@ def to_dict(self) -> Dict[str, Any]:
7379

7480
class NumberParameter(ToolParameter):
7581
"""Number parameter for tools."""
76-
77-
def __init__(
78-
self,
79-
description: str,
80-
required: bool = False,
81-
minimum: Optional[float] = None,
82-
maximum: Optional[float] = None,
83-
):
84-
super().__init__(description, required)
85-
self.minimum = minimum
86-
self.maximum = maximum
87-
88-
def to_dict(self) -> Dict[str, Any]:
89-
result: Dict[str, Any] = {
90-
"type": ParameterType.NUMBER,
91-
"description": self.description,
92-
}
82+
type: ClassVar[ParameterType] = ParameterType.NUMBER
83+
minimum: Optional[float] = None
84+
maximum: Optional[float] = None
85+
86+
def model_dump_tool(self) -> Dict[str, Any]:
87+
result = super().model_dump_tool()
9388
if self.minimum is not None:
9489
result["minimum"] = self.minimum
9590
if self.maximum is not None:
@@ -99,77 +94,71 @@ def to_dict(self) -> Dict[str, Any]:
9994

10095
class BooleanParameter(ToolParameter):
10196
"""Boolean parameter for tools."""
97+
type: ClassVar[ParameterType] = ParameterType.BOOLEAN
98+
10299

103-
def to_dict(self) -> Dict[str, Any]:
104-
return {"type": ParameterType.BOOLEAN, "description": self.description}
100+
class ArrayParameter(ToolParameter):
101+
"""Array parameter for tools."""
102+
type: ClassVar[ParameterType] = ParameterType.ARRAY
103+
items: "ToolParameter"
104+
min_items: Optional[int] = None
105+
max_items: Optional[int] = None
106+
107+
def model_dump_tool(self) -> Dict[str, Any]:
108+
result = super().model_dump_tool()
109+
result["items"] = self.items.model_dump_tool()
110+
if self.min_items is not None:
111+
result["minItems"] = self.min_items
112+
if self.max_items is not None:
113+
result["maxItems"] = self.max_items
114+
return result
115+
116+
@model_validator(mode="after")
117+
def validate_items(self) -> "ArrayParameter":
118+
if not isinstance(self.items, ToolParameter):
119+
if isinstance(self.items, dict):
120+
self.items = ToolParameter.from_dict(self.items)
121+
else:
122+
raise ValueError(f"Items must be a ToolParameter or dict, got {type(self.items)}")
123+
return self
105124

106125

107126
class ObjectParameter(ToolParameter):
108127
"""Object parameter for tools."""
109-
110-
def __init__(
111-
self,
112-
description: str,
113-
properties: Dict[str, ToolParameter],
114-
required: bool = False,
115-
required_properties: Optional[List[str]] = None,
116-
additional_properties: bool = True,
117-
):
118-
super().__init__(description, required)
119-
self.properties = properties
120-
self.required_properties = required_properties or []
121-
self.additional_properties = additional_properties
122-
123-
def to_dict(self) -> Dict[str, Any]:
128+
type: ClassVar[ParameterType] = ParameterType.OBJECT
129+
properties: Dict[str, ToolParameter]
130+
required_properties: List[str] = Field(default_factory=list)
131+
additional_properties: bool = True
132+
133+
def model_dump_tool(self) -> Dict[str, Any]:
124134
properties_dict: Dict[str, Any] = {}
125135
for name, param in self.properties.items():
126-
properties_dict[name] = param.to_dict()
127-
128-
result: Dict[str, Any] = {
129-
"type": ParameterType.OBJECT,
130-
"description": self.description,
131-
"properties": properties_dict,
132-
}
136+
properties_dict[name] = param.model_dump_tool()
133137

138+
result = super().model_dump_tool()
139+
result["properties"] = properties_dict
140+
134141
if self.required_properties:
135142
result["required"] = self.required_properties
136-
143+
137144
if not self.additional_properties:
138145
result["additionalProperties"] = False
139-
140-
return result
141-
142-
143-
class ArrayParameter(ToolParameter):
144-
"""Array parameter for tools."""
145-
146-
def __init__(
147-
self,
148-
description: str,
149-
items: ToolParameter,
150-
required: bool = False,
151-
min_items: Optional[int] = None,
152-
max_items: Optional[int] = None,
153-
):
154-
super().__init__(description, required)
155-
self.items = items
156-
self.min_items = min_items
157-
self.max_items = max_items
158-
159-
def to_dict(self) -> Dict[str, Any]:
160-
result: Dict[str, Any] = {
161-
"type": ParameterType.ARRAY,
162-
"description": self.description,
163-
"items": self.items.to_dict(),
164-
}
165-
166-
if self.min_items is not None:
167-
result["minItems"] = self.min_items
168-
169-
if self.max_items is not None:
170-
result["maxItems"] = self.max_items
171-
146+
172147
return result
148+
149+
@model_validator(mode="after")
150+
def validate_properties(self) -> "ObjectParameter":
151+
validated_properties = {}
152+
for name, param in self.properties.items():
153+
if not isinstance(param, ToolParameter):
154+
if isinstance(param, dict):
155+
validated_properties[name] = ToolParameter.from_dict(param)
156+
else:
157+
raise ValueError(f"Property {name} must be a ToolParameter or dict, got {type(param)}")
158+
else:
159+
validated_properties[name] = param
160+
self.properties = validated_properties
161+
return self
173162

174163

175164
class Tool(ABC):
@@ -179,12 +168,18 @@ def __init__(
179168
self,
180169
name: str,
181170
description: str,
182-
parameters: ObjectParameter,
171+
parameters: Union[ObjectParameter, Dict[str, Any]],
183172
execute_func: Callable[..., Any],
184173
):
185174
self._name = name
186175
self._description = description
187-
self._parameters = parameters
176+
177+
# Allow parameters to be provided as a dictionary
178+
if isinstance(parameters, dict):
179+
self._parameters = ObjectParameter.model_validate(parameters)
180+
else:
181+
self._parameters = parameters
182+
188183
self._execute_func = execute_func
189184

190185
def get_name(self) -> str:
@@ -209,7 +204,7 @@ def get_parameters(self) -> Dict[str, Any]:
209204
Returns:
210205
Dict[str, Any]: Dictionary containing parameter schema information.
211206
"""
212-
return self._parameters.to_dict()
207+
return self._parameters.model_dump_tool()
213208

214209
def execute(self, query: str, **kwargs: Any) -> Any:
215210
"""Execute the tool with the given query and additional parameters.

0 commit comments

Comments
 (0)