Skip to content

Commit 39c73c6

Browse files
oskarhanestellasia
authored andcommitted
Add Tool class
To not rely on json schema from openai
1 parent 26886cc commit 39c73c6

File tree

2 files changed

+53
-22
lines changed

2 files changed

+53
-22
lines changed

examples/tool_calls/openai_tool_calls.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,32 +17,32 @@
1717

1818
from neo4j_graphrag.llm import OpenAILLM
1919
from neo4j_graphrag.llm.types import ToolCallResponse
20+
from neo4j_graphrag.tool import Tool, ObjectParameter, StringParameter, IntegerParameter
2021

2122
# Load environment variables from .env file
2223
load_dotenv()
2324

24-
# Define a tool for extracting information from text
25-
TOOLS = [
26-
{
27-
"type": "function",
28-
"function": {
29-
"name": "extract_person_info",
30-
"description": "Extract information about a person from text",
31-
"parameters": {
32-
"type": "object",
33-
"properties": {
34-
"name": {"type": "string", "description": "The person's full name"},
35-
"age": {"type": "integer", "description": "The person's age"},
36-
"occupation": {
37-
"type": "string",
38-
"description": "The person's occupation",
39-
},
40-
},
41-
"required": ["name"],
42-
},
43-
},
44-
}
45-
]
25+
26+
# Create a custom Tool implementation for person info extraction
27+
parameters = ObjectParameter(
28+
description="Parameters for extracting person information",
29+
properties={
30+
"name": StringParameter(description="The person's full name"),
31+
"age": IntegerParameter(description="The person's age"),
32+
"occupation": StringParameter(description="The person's occupation"),
33+
},
34+
required_properties=["name"],
35+
additional_properties=False,
36+
)
37+
person_info_tool = Tool(
38+
name="extract_person_info",
39+
description="Extract information about a person from text",
40+
parameters=parameters,
41+
execute_func=lambda **kwargs: kwargs,
42+
)
43+
44+
# Create the tool instance
45+
TOOLS = [person_info_tool]
4646

4747

4848
def process_tool_call(response: ToolCallResponse) -> Dict[str, Any]:

src/neo4j_graphrag/tool.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,37 @@ def validate_properties(self) -> "ObjectParameter":
204204
self.properties = validated_properties
205205
return self
206206

207+
class ArrayParameter(ToolParameter):
208+
"""Array parameter for tools."""
209+
210+
def __init__(
211+
self,
212+
description: str,
213+
items: ToolParameter,
214+
required: bool = False,
215+
min_items: Optional[int] = None,
216+
max_items: Optional[int] = None,
217+
):
218+
super().__init__(description, required)
219+
self.items = items
220+
self.min_items = min_items
221+
self.max_items = max_items
222+
223+
def to_dict(self) -> Dict[str, Any]:
224+
result: Dict[str, Any] = {
225+
"type": ParameterType.ARRAY,
226+
"description": self.description,
227+
"items": self.items.to_dict(),
228+
}
229+
230+
if self.min_items is not None:
231+
result["minItems"] = self.min_items
232+
233+
if self.max_items is not None:
234+
result["maxItems"] = self.max_items
235+
236+
return result
237+
207238

208239
class Tool(ABC):
209240
"""Abstract base class defining the interface for all tools in the neo4j-graphrag library."""

0 commit comments

Comments
 (0)