Skip to content

Commit da6aa51

Browse files
committed
Fix formatting errors
1 parent 7820e37 commit da6aa51

File tree

2 files changed

+32
-21
lines changed

2 files changed

+32
-21
lines changed

src/neo4j_graphrag/tool.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from abc import ABC
22
from enum import Enum
3-
from typing import Any, Dict, List, Callable, Optional, Union, ClassVar, Type
3+
from typing import Any, Dict, List, Callable, Optional, Union, ClassVar
44
from pydantic import BaseModel, Field, model_validator
55

66

@@ -17,22 +17,23 @@ class ParameterType(str, Enum):
1717

1818
class ToolParameter(BaseModel):
1919
"""Base class for all tool parameters using Pydantic."""
20+
2021
description: str
2122
required: bool = False
2223
type: ClassVar[ParameterType]
23-
24+
2425
def model_dump_tool(self) -> Dict[str, Any]:
2526
"""Convert the parameter to a dictionary format for tool usage."""
2627
result = {"type": self.type, "description": self.description}
2728
return result
28-
29+
2930
@classmethod
3031
def from_dict(cls, data: Dict[str, Any]) -> "ToolParameter":
3132
"""Create a parameter from a dictionary."""
3233
param_type = data.get("type")
3334
if not param_type:
3435
raise ValueError("Parameter type is required")
35-
36+
3637
# Find the appropriate class based on the type
3738
param_classes = {
3839
ParameterType.STRING: StringParameter,
@@ -42,19 +43,20 @@ def from_dict(cls, data: Dict[str, Any]) -> "ToolParameter":
4243
ParameterType.OBJECT: ObjectParameter,
4344
ParameterType.ARRAY: ArrayParameter,
4445
}
45-
46+
4647
param_class = param_classes.get(param_type)
4748
if not param_class:
4849
raise ValueError(f"Unknown parameter type: {param_type}")
49-
50+
5051
return param_class.model_validate(data)
5152

5253

5354
class StringParameter(ToolParameter):
5455
"""String parameter for tools."""
56+
5557
type: ClassVar[ParameterType] = ParameterType.STRING
5658
enum: Optional[List[str]] = None
57-
59+
5860
def model_dump_tool(self) -> Dict[str, Any]:
5961
result = super().model_dump_tool()
6062
if self.enum:
@@ -64,10 +66,11 @@ def model_dump_tool(self) -> Dict[str, Any]:
6466

6567
class IntegerParameter(ToolParameter):
6668
"""Integer parameter for tools."""
69+
6770
type: ClassVar[ParameterType] = ParameterType.INTEGER
6871
minimum: Optional[int] = None
6972
maximum: Optional[int] = None
70-
73+
7174
def model_dump_tool(self) -> Dict[str, Any]:
7275
result = super().model_dump_tool()
7376
if self.minimum is not None:
@@ -79,10 +82,11 @@ def model_dump_tool(self) -> Dict[str, Any]:
7982

8083
class NumberParameter(ToolParameter):
8184
"""Number parameter for tools."""
85+
8286
type: ClassVar[ParameterType] = ParameterType.NUMBER
8387
minimum: Optional[float] = None
8488
maximum: Optional[float] = None
85-
89+
8690
def model_dump_tool(self) -> Dict[str, Any]:
8791
result = super().model_dump_tool()
8892
if self.minimum is not None:
@@ -94,16 +98,18 @@ def model_dump_tool(self) -> Dict[str, Any]:
9498

9599
class BooleanParameter(ToolParameter):
96100
"""Boolean parameter for tools."""
101+
97102
type: ClassVar[ParameterType] = ParameterType.BOOLEAN
98103

99104

100105
class ArrayParameter(ToolParameter):
101106
"""Array parameter for tools."""
107+
102108
type: ClassVar[ParameterType] = ParameterType.ARRAY
103109
items: "ToolParameter"
104110
min_items: Optional[int] = None
105111
max_items: Optional[int] = None
106-
112+
107113
def model_dump_tool(self) -> Dict[str, Any]:
108114
result = super().model_dump_tool()
109115
result["items"] = self.items.model_dump_tool()
@@ -112,40 +118,43 @@ def model_dump_tool(self) -> Dict[str, Any]:
112118
if self.max_items is not None:
113119
result["maxItems"] = self.max_items
114120
return result
115-
121+
116122
@model_validator(mode="after")
117123
def validate_items(self) -> "ArrayParameter":
118124
if not isinstance(self.items, ToolParameter):
119125
if isinstance(self.items, dict):
120126
self.items = ToolParameter.from_dict(self.items)
121127
else:
122-
raise ValueError(f"Items must be a ToolParameter or dict, got {type(self.items)}")
128+
raise ValueError(
129+
f"Items must be a ToolParameter or dict, got {type(self.items)}"
130+
)
123131
return self
124132

125133

126134
class ObjectParameter(ToolParameter):
127135
"""Object parameter for tools."""
136+
128137
type: ClassVar[ParameterType] = ParameterType.OBJECT
129138
properties: Dict[str, ToolParameter]
130139
required_properties: List[str] = Field(default_factory=list)
131140
additional_properties: bool = True
132-
141+
133142
def model_dump_tool(self) -> Dict[str, Any]:
134143
properties_dict: Dict[str, Any] = {}
135144
for name, param in self.properties.items():
136145
properties_dict[name] = param.model_dump_tool()
137146

138147
result = super().model_dump_tool()
139148
result["properties"] = properties_dict
140-
149+
141150
if self.required_properties:
142151
result["required"] = self.required_properties
143-
152+
144153
if not self.additional_properties:
145154
result["additionalProperties"] = False
146-
155+
147156
return result
148-
157+
149158
@model_validator(mode="after")
150159
def validate_properties(self) -> "ObjectParameter":
151160
validated_properties = {}
@@ -154,7 +163,9 @@ def validate_properties(self) -> "ObjectParameter":
154163
if isinstance(param, dict):
155164
validated_properties[name] = ToolParameter.from_dict(param)
156165
else:
157-
raise ValueError(f"Property {name} must be a ToolParameter or dict, got {type(param)}")
166+
raise ValueError(
167+
f"Property {name} must be a ToolParameter or dict, got {type(param)}"
168+
)
158169
else:
159170
validated_properties[name] = param
160171
self.properties = validated_properties
@@ -173,13 +184,13 @@ def __init__(
173184
):
174185
self._name = name
175186
self._description = description
176-
187+
177188
# Allow parameters to be provided as a dictionary
178189
if isinstance(parameters, dict):
179190
self._parameters = ObjectParameter.model_validate(parameters)
180191
else:
181192
self._parameters = parameters
182-
193+
183194
self._execute_func = execute_func
184195

185196
def get_name(self) -> str:

tests/unit/llm/test_openai_llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
from unittest.mock import MagicMock, Mock, patch
16-
from typing import Any, Dict, cast
16+
from typing import Any, cast
1717

1818
import openai
1919
import pytest

0 commit comments

Comments
 (0)