From fa465ef2ebd14d313468ab183589878e58700958 Mon Sep 17 00:00:00 2001 From: Oskar Hane Date: Fri, 4 Apr 2025 09:58:38 +0200 Subject: [PATCH 01/15] Add tool calling to the LLM base class, implement in OpenAI --- examples/tool_calls/openai_tool_calls.py | 95 ++++++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 examples/tool_calls/openai_tool_calls.py diff --git a/examples/tool_calls/openai_tool_calls.py b/examples/tool_calls/openai_tool_calls.py new file mode 100644 index 000000000..e61344031 --- /dev/null +++ b/examples/tool_calls/openai_tool_calls.py @@ -0,0 +1,95 @@ +""" +Example showing how to use OpenAI tool calls with parameter extraction. +Both synchronous and asynchronous examples are provided. + +To run this example: +1. Make sure you have the OpenAI API key in your .env file: + OPENAI_API_KEY=your-api-key +2. Run: python examples/tool_calls/openai_tool_calls.py +""" + +import asyncio +import json +import os +from typing import Dict, Any + +from dotenv import load_dotenv + +from neo4j_graphrag.llm import OpenAILLM +from neo4j_graphrag.llm.types import ToolCallResponse + +# Load environment variables from .env file +load_dotenv() + +# Define a tool for extracting information from text +TOOLS = [ + { + "type": "function", + "function": { + "name": "extract_person_info", + "description": "Extract information about a person from text", + "parameters": { + "type": "object", + "properties": { + "name": {"type": "string", "description": "The person's full name"}, + "age": {"type": "integer", "description": "The person's age"}, + "occupation": { + "type": "string", + "description": "The person's occupation", + }, + }, + "required": ["name"], + }, + }, + } +] + + +def process_tool_call(response: ToolCallResponse) -> Dict[str, Any]: + """Process the tool call response and return the extracted parameters.""" + if not response.tool_calls: + raise ValueError("No tool calls found in response") + + tool_call = response.tool_calls[0] + print(f"\nTool called: {tool_call.name}") + print(f"Arguments: {tool_call.arguments}") + print(f"Additional content: {response.content or 'None'}") + return tool_call.arguments + + +async def main() -> None: + # Initialize the OpenAI LLM + llm = OpenAILLM( + api_key=os.getenv("OPENAI_API_KEY"), + model_name="gpt-4o", + model_params={"temperature": 0}, + ) + + # Example text containing information about a person + text = "Stella Hane is a 35-year-old software engineer who loves coding." + + print("\n=== Synchronous Tool Call ===") + # Make a synchronous tool call + sync_response = llm.invoke_with_tools( + input=f"Extract information about the person from this text: {text}", + tools=TOOLS, + ) + sync_result = process_tool_call(sync_response) + print("\n=== Synchronous Tool Call Result ===") + print(json.dumps(sync_result, indent=2)) + + print("\n=== Asynchronous Tool Call ===") + # Make an asynchronous tool call with a different text + text2 = "Molly Hane, 32, works as a data scientist and enjoys machine learning." + async_response = await llm.ainvoke_with_tools( + input=f"Extract information about the person from this text: {text2}", + tools=TOOLS, + ) + async_result = process_tool_call(async_response) + print("\n=== Asynchronous Tool Call Result ===") + print(json.dumps(async_result, indent=2)) + + +if __name__ == "__main__": + # Run the async main function + asyncio.run(main()) From f8a7604fc94d7ffcf129928f2ba7e50623ad420b Mon Sep 17 00:00:00 2001 From: Oskar Hane Date: Mon, 14 Apr 2025 15:55:31 +0200 Subject: [PATCH 02/15] Add Tool class To not rely on json schema from openai --- examples/tool_calls/openai_tool_calls.py | 44 ++++++++++++------------ src/neo4j_graphrag/tool.py | 31 +++++++++++++++++ 2 files changed, 53 insertions(+), 22 deletions(-) diff --git a/examples/tool_calls/openai_tool_calls.py b/examples/tool_calls/openai_tool_calls.py index e61344031..420af2c0f 100644 --- a/examples/tool_calls/openai_tool_calls.py +++ b/examples/tool_calls/openai_tool_calls.py @@ -17,32 +17,32 @@ from neo4j_graphrag.llm import OpenAILLM from neo4j_graphrag.llm.types import ToolCallResponse +from neo4j_graphrag.tool import Tool, ObjectParameter, StringParameter, IntegerParameter # Load environment variables from .env file load_dotenv() -# Define a tool for extracting information from text -TOOLS = [ - { - "type": "function", - "function": { - "name": "extract_person_info", - "description": "Extract information about a person from text", - "parameters": { - "type": "object", - "properties": { - "name": {"type": "string", "description": "The person's full name"}, - "age": {"type": "integer", "description": "The person's age"}, - "occupation": { - "type": "string", - "description": "The person's occupation", - }, - }, - "required": ["name"], - }, - }, - } -] + +# Create a custom Tool implementation for person info extraction +parameters = ObjectParameter( + description="Parameters for extracting person information", + properties={ + "name": StringParameter(description="The person's full name"), + "age": IntegerParameter(description="The person's age"), + "occupation": StringParameter(description="The person's occupation"), + }, + required_properties=["name"], + additional_properties=False, +) +person_info_tool = Tool( + name="extract_person_info", + description="Extract information about a person from text", + parameters=parameters, + execute_func=lambda **kwargs: kwargs, +) + +# Create the tool instance +TOOLS = [person_info_tool] def process_tool_call(response: ToolCallResponse) -> Dict[str, Any]: diff --git a/src/neo4j_graphrag/tool.py b/src/neo4j_graphrag/tool.py index 63aac6684..f064233ef 100644 --- a/src/neo4j_graphrag/tool.py +++ b/src/neo4j_graphrag/tool.py @@ -204,6 +204,37 @@ def validate_properties(self) -> "ObjectParameter": self.properties = validated_properties return self +class ArrayParameter(ToolParameter): + """Array parameter for tools.""" + + def __init__( + self, + description: str, + items: ToolParameter, + required: bool = False, + min_items: Optional[int] = None, + max_items: Optional[int] = None, + ): + super().__init__(description, required) + self.items = items + self.min_items = min_items + self.max_items = max_items + + def to_dict(self) -> Dict[str, Any]: + result: Dict[str, Any] = { + "type": ParameterType.ARRAY, + "description": self.description, + "items": self.items.to_dict(), + } + + if self.min_items is not None: + result["minItems"] = self.min_items + + if self.max_items is not None: + result["maxItems"] = self.max_items + + return result + class Tool(ABC): """Abstract base class defining the interface for all tools in the neo4j-graphrag library.""" From 04f8a62436a8c74d15851be1712cb4961f29fd6c Mon Sep 17 00:00:00 2001 From: Oskar Hane Date: Tue, 15 Apr 2025 16:39:44 +0200 Subject: [PATCH 03/15] Print all tool calls in example file --- examples/tool_calls/openai_tool_calls.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/examples/tool_calls/openai_tool_calls.py b/examples/tool_calls/openai_tool_calls.py index 420af2c0f..9be8b804c 100644 --- a/examples/tool_calls/openai_tool_calls.py +++ b/examples/tool_calls/openai_tool_calls.py @@ -45,16 +45,22 @@ TOOLS = [person_info_tool] -def process_tool_call(response: ToolCallResponse) -> Dict[str, Any]: - """Process the tool call response and return the extracted parameters.""" +def process_tool_calls(response: ToolCallResponse) -> Dict[str, Any]: + """Process all tool calls in the response and return the extracted parameters.""" if not response.tool_calls: raise ValueError("No tool calls found in response") - tool_call = response.tool_calls[0] - print(f"\nTool called: {tool_call.name}") - print(f"Arguments: {tool_call.arguments}") + print(f"\nNumber of tool calls: {len(response.tool_calls)}") print(f"Additional content: {response.content or 'None'}") - return tool_call.arguments + + results = [] + for i, tool_call in enumerate(response.tool_calls): + print(f"\nTool call #{i+1}: {tool_call.name}") + print(f"Arguments: {tool_call.arguments}") + results.append(tool_call.arguments) + + # For backward compatibility, return the first tool call's arguments + return results[0] if results else {} async def main() -> None: @@ -74,7 +80,7 @@ async def main() -> None: input=f"Extract information about the person from this text: {text}", tools=TOOLS, ) - sync_result = process_tool_call(sync_response) + sync_result = process_tool_calls(sync_response) print("\n=== Synchronous Tool Call Result ===") print(json.dumps(sync_result, indent=2)) @@ -85,7 +91,7 @@ async def main() -> None: input=f"Extract information about the person from this text: {text2}", tools=TOOLS, ) - async_result = process_tool_call(async_response) + async_result = process_tool_calls(async_response) print("\n=== Asynchronous Tool Call Result ===") print(json.dumps(async_result, indent=2)) From 504eaf57e51c09346f422ebcce090af091d426c4 Mon Sep 17 00:00:00 2001 From: Oskar Hane Date: Tue, 15 Apr 2025 16:46:08 +0200 Subject: [PATCH 04/15] Move tool call exmaple file and add link to README --- examples/tool_calls/openai_tool_calls.py | 101 ----------------------- 1 file changed, 101 deletions(-) delete mode 100644 examples/tool_calls/openai_tool_calls.py diff --git a/examples/tool_calls/openai_tool_calls.py b/examples/tool_calls/openai_tool_calls.py deleted file mode 100644 index 9be8b804c..000000000 --- a/examples/tool_calls/openai_tool_calls.py +++ /dev/null @@ -1,101 +0,0 @@ -""" -Example showing how to use OpenAI tool calls with parameter extraction. -Both synchronous and asynchronous examples are provided. - -To run this example: -1. Make sure you have the OpenAI API key in your .env file: - OPENAI_API_KEY=your-api-key -2. Run: python examples/tool_calls/openai_tool_calls.py -""" - -import asyncio -import json -import os -from typing import Dict, Any - -from dotenv import load_dotenv - -from neo4j_graphrag.llm import OpenAILLM -from neo4j_graphrag.llm.types import ToolCallResponse -from neo4j_graphrag.tool import Tool, ObjectParameter, StringParameter, IntegerParameter - -# Load environment variables from .env file -load_dotenv() - - -# Create a custom Tool implementation for person info extraction -parameters = ObjectParameter( - description="Parameters for extracting person information", - properties={ - "name": StringParameter(description="The person's full name"), - "age": IntegerParameter(description="The person's age"), - "occupation": StringParameter(description="The person's occupation"), - }, - required_properties=["name"], - additional_properties=False, -) -person_info_tool = Tool( - name="extract_person_info", - description="Extract information about a person from text", - parameters=parameters, - execute_func=lambda **kwargs: kwargs, -) - -# Create the tool instance -TOOLS = [person_info_tool] - - -def process_tool_calls(response: ToolCallResponse) -> Dict[str, Any]: - """Process all tool calls in the response and return the extracted parameters.""" - if not response.tool_calls: - raise ValueError("No tool calls found in response") - - print(f"\nNumber of tool calls: {len(response.tool_calls)}") - print(f"Additional content: {response.content or 'None'}") - - results = [] - for i, tool_call in enumerate(response.tool_calls): - print(f"\nTool call #{i+1}: {tool_call.name}") - print(f"Arguments: {tool_call.arguments}") - results.append(tool_call.arguments) - - # For backward compatibility, return the first tool call's arguments - return results[0] if results else {} - - -async def main() -> None: - # Initialize the OpenAI LLM - llm = OpenAILLM( - api_key=os.getenv("OPENAI_API_KEY"), - model_name="gpt-4o", - model_params={"temperature": 0}, - ) - - # Example text containing information about a person - text = "Stella Hane is a 35-year-old software engineer who loves coding." - - print("\n=== Synchronous Tool Call ===") - # Make a synchronous tool call - sync_response = llm.invoke_with_tools( - input=f"Extract information about the person from this text: {text}", - tools=TOOLS, - ) - sync_result = process_tool_calls(sync_response) - print("\n=== Synchronous Tool Call Result ===") - print(json.dumps(sync_result, indent=2)) - - print("\n=== Asynchronous Tool Call ===") - # Make an asynchronous tool call with a different text - text2 = "Molly Hane, 32, works as a data scientist and enjoys machine learning." - async_response = await llm.ainvoke_with_tools( - input=f"Extract information about the person from this text: {text2}", - tools=TOOLS, - ) - async_result = process_tool_calls(async_response) - print("\n=== Asynchronous Tool Call Result ===") - print(json.dumps(async_result, indent=2)) - - -if __name__ == "__main__": - # Run the async main function - asyncio.run(main()) From c9d63f80529528e14f6a19286e226f5e43abfdd3 Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 15 Apr 2025 14:44:50 +0200 Subject: [PATCH 05/15] Implement tool calling for VertexAILLM --- src/neo4j_graphrag/llm/vertexai_llm.py | 120 ++++++++++++++++++++++++- src/neo4j_graphrag/tool.py | 4 +- 2 files changed, 120 insertions(+), 4 deletions(-) diff --git a/src/neo4j_graphrag/llm/vertexai_llm.py b/src/neo4j_graphrag/llm/vertexai_llm.py index f7c44b21e..991631461 100644 --- a/src/neo4j_graphrag/llm/vertexai_llm.py +++ b/src/neo4j_graphrag/llm/vertexai_llm.py @@ -13,22 +13,33 @@ # limitations under the License. from __future__ import annotations -from typing import Any, List, Optional, Union, cast +from typing import Any, List, Optional, Union, cast, Sequence from pydantic import ValidationError from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface -from neo4j_graphrag.llm.types import BaseMessage, LLMResponse, MessageList +from neo4j_graphrag.llm.types import ( + BaseMessage, + LLMResponse, + MessageList, + ToolCall, + ToolCallResponse, +) from neo4j_graphrag.message_history import MessageHistory +from neo4j_graphrag.tool import Tool from neo4j_graphrag.types import LLMMessage try: from vertexai.generative_models import ( Content, + FunctionCall, + FunctionDeclaration, + GenerationResponse, GenerativeModel, Part, ResponseValidationError, + Tool as VertexAITool, ) except ImportError: GenerativeModel = None @@ -176,3 +187,108 @@ async def ainvoke( return LLMResponse(content=response.text) except ResponseValidationError as e: raise LLMGenerationError(e) + + def _to_vertexai_tool(self, tool: Tool) -> VertexAITool: + return VertexAITool( + function_declarations=[ + FunctionDeclaration( + name=tool.get_name(), + description=tool.get_description(), + parameters=tool.get_parameters(), + ) + ] + ) + + def get_tools( + self, tools: Optional[Sequence[Tool]] + ) -> Optional[list[VertexAITool]]: + if not tools: + return None + return [self._to_vertexai_tool(tool) for tool in tools] + + def _get_model( + self, + system_instruction: Optional[str] = None, + tools: Optional[Sequence[Tool]] = None, + ) -> GenerativeModel: + system_message = [system_instruction] if system_instruction is not None else [] + vertex_ai_tools = self.get_tools(tools) + model = GenerativeModel( + model_name=self.model_name, + system_instruction=system_message, + tools=vertex_ai_tools, + **self.options, + ) + return model + + async def _acall_llm( + self, + input: str, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, + tools: Optional[Sequence[Tool]] = None, + ) -> GenerationResponse: + model = self._get_model(system_instruction, tools) + messages = self.get_messages(input, message_history) + response = await model.generate_content_async(messages, **self.model_params) + return response + + def _call_llm( + self, + input: str, + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, + tools: Optional[Sequence[Tool]] = None, + ) -> GenerationResponse: + model = self._get_model(system_instruction, tools) + messages = self.get_messages(input, message_history) + response = model.generate_content(messages, **self.model_params) + return response + + def _to_tool_call(self, function_call: FunctionCall) -> ToolCall: + return ToolCall( + name=function_call.name, + arguments=function_call.args, + ) + + def _parse_tool_response(self, response) -> ToolCallResponse: + function_calls = response.candidates[0].function_calls + return ToolCallResponse( + tool_calls=[self._to_tool_call(f) for f in function_calls], + content=None, + ) + + def _parse_content_response(self, response) -> LLMResponse: + return LLMResponse( + content=response.text, + ) + + async def ainvoke_with_tools( + self, + input: str, + tools: Sequence[Tool], + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, + ) -> ToolCallResponse: + response = await self._acall_llm( + input, + message_history=message_history, + system_instruction=system_instruction, + tools=tools, + ) + return self._parse_tool_response(response) + + def invoke_with_tools( + self, + input: str, + tools: Sequence[Tool], + message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, + system_instruction: Optional[str] = None, + ) -> ToolCallResponse: + response = self._call_llm( + input, + message_history=message_history, + system_instruction=system_instruction, + tools=tools, + ) + return self._parse_tool_response(response) diff --git a/src/neo4j_graphrag/tool.py b/src/neo4j_graphrag/tool.py index f064233ef..b68e55107 100644 --- a/src/neo4j_graphrag/tool.py +++ b/src/neo4j_graphrag/tool.py @@ -180,8 +180,8 @@ def model_dump_tool(self) -> Dict[str, Any]: if self.required_properties: result["required"] = self.required_properties - if not self.additional_properties: - result["additionalProperties"] = False + # if not self.additional_properties: + # result["additionalProperties"] = False return result From b96f8421576cdb1ba9fae079c8bedb7189ff2d07 Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 15 Apr 2025 15:02:13 +0200 Subject: [PATCH 06/15] mypy --- src/neo4j_graphrag/llm/vertexai_llm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/neo4j_graphrag/llm/vertexai_llm.py b/src/neo4j_graphrag/llm/vertexai_llm.py index 991631461..83487fe06 100644 --- a/src/neo4j_graphrag/llm/vertexai_llm.py +++ b/src/neo4j_graphrag/llm/vertexai_llm.py @@ -251,14 +251,14 @@ def _to_tool_call(self, function_call: FunctionCall) -> ToolCall: arguments=function_call.args, ) - def _parse_tool_response(self, response) -> ToolCallResponse: + def _parse_tool_response(self, response: GenerationResponse) -> ToolCallResponse: function_calls = response.candidates[0].function_calls return ToolCallResponse( tool_calls=[self._to_tool_call(f) for f in function_calls], content=None, ) - def _parse_content_response(self, response) -> LLMResponse: + def _parse_content_response(self, response: GenerationResponse) -> LLMResponse: return LLMResponse( content=response.text, ) From 97e86bf091f7d2b734c5430d19409a5bbf59833f Mon Sep 17 00:00:00 2001 From: estelle Date: Thu, 24 Apr 2025 13:40:48 +0200 Subject: [PATCH 07/15] Fix merge and tests --- .../retrievers/mcp_retriever.py | 40 +++++++++++++++++++ src/neo4j_graphrag/tool.py | 35 +--------------- 2 files changed, 42 insertions(+), 33 deletions(-) create mode 100644 src/neo4j_graphrag/retrievers/mcp_retriever.py diff --git a/src/neo4j_graphrag/retrievers/mcp_retriever.py b/src/neo4j_graphrag/retrievers/mcp_retriever.py new file mode 100644 index 000000000..04be5e794 --- /dev/null +++ b/src/neo4j_graphrag/retrievers/mcp_retriever.py @@ -0,0 +1,40 @@ +from typing import Any + +# from neo4j_graphrag.retrievers.base import Retriever +from neo4j_graphrag.types import RawSearchResult, RetrieverResult, RetrieverResultItem + + +class MCPServerInterface: + def __init__(self, *args, **kwargs): + pass + + def get_tools(self): + return [] + + def execute_tool(self, tool) -> Any: + return "" + + +class MCPRetriever: + def __init__(self, server: MCPServerInterface) -> None: + super().__init__() + self.server = server + self.tools = server.get_tools() + + def search(self, query_text: str) -> RetrieverResult: + """Reimplement the search method because we can't inherit from + the Retriever interface (no need for neo4j.driver here). + + 1. Call llm with a list of tools + 2. Call MCP server for specific tool and LLM-generated arguments + 3. Return all results as context in RetrieverResult + """ + raw_result = RawSearchResult(records=[]) + search_items = [RetrieverResultItem(content=str(record)) for record in raw_result.records] + metadata = raw_result.metadata or {} + metadata["__retriever"] = self.__class__.__name__ + metadata["__tool_results"] = {} + return RetrieverResult( + items=search_items, + metadata=metadata, + ) diff --git a/src/neo4j_graphrag/tool.py b/src/neo4j_graphrag/tool.py index b68e55107..63aac6684 100644 --- a/src/neo4j_graphrag/tool.py +++ b/src/neo4j_graphrag/tool.py @@ -180,8 +180,8 @@ def model_dump_tool(self) -> Dict[str, Any]: if self.required_properties: result["required"] = self.required_properties - # if not self.additional_properties: - # result["additionalProperties"] = False + if not self.additional_properties: + result["additionalProperties"] = False return result @@ -204,37 +204,6 @@ def validate_properties(self) -> "ObjectParameter": self.properties = validated_properties return self -class ArrayParameter(ToolParameter): - """Array parameter for tools.""" - - def __init__( - self, - description: str, - items: ToolParameter, - required: bool = False, - min_items: Optional[int] = None, - max_items: Optional[int] = None, - ): - super().__init__(description, required) - self.items = items - self.min_items = min_items - self.max_items = max_items - - def to_dict(self) -> Dict[str, Any]: - result: Dict[str, Any] = { - "type": ParameterType.ARRAY, - "description": self.description, - "items": self.items.to_dict(), - } - - if self.min_items is not None: - result["minItems"] = self.min_items - - if self.max_items is not None: - result["maxItems"] = self.max_items - - return result - class Tool(ABC): """Abstract base class defining the interface for all tools in the neo4j-graphrag library.""" From 43f5f2909fe035b3867213c728b2803cbb484e92 Mon Sep 17 00:00:00 2001 From: estelle Date: Thu, 24 Apr 2025 13:43:12 +0200 Subject: [PATCH 08/15] Remove unrelated test file --- .../retrievers/mcp_retriever.py | 40 ------------------- 1 file changed, 40 deletions(-) delete mode 100644 src/neo4j_graphrag/retrievers/mcp_retriever.py diff --git a/src/neo4j_graphrag/retrievers/mcp_retriever.py b/src/neo4j_graphrag/retrievers/mcp_retriever.py deleted file mode 100644 index 04be5e794..000000000 --- a/src/neo4j_graphrag/retrievers/mcp_retriever.py +++ /dev/null @@ -1,40 +0,0 @@ -from typing import Any - -# from neo4j_graphrag.retrievers.base import Retriever -from neo4j_graphrag.types import RawSearchResult, RetrieverResult, RetrieverResultItem - - -class MCPServerInterface: - def __init__(self, *args, **kwargs): - pass - - def get_tools(self): - return [] - - def execute_tool(self, tool) -> Any: - return "" - - -class MCPRetriever: - def __init__(self, server: MCPServerInterface) -> None: - super().__init__() - self.server = server - self.tools = server.get_tools() - - def search(self, query_text: str) -> RetrieverResult: - """Reimplement the search method because we can't inherit from - the Retriever interface (no need for neo4j.driver here). - - 1. Call llm with a list of tools - 2. Call MCP server for specific tool and LLM-generated arguments - 3. Return all results as context in RetrieverResult - """ - raw_result = RawSearchResult(records=[]) - search_items = [RetrieverResultItem(content=str(record)) for record in raw_result.records] - metadata = raw_result.metadata or {} - metadata["__retriever"] = self.__class__.__name__ - metadata["__tool_results"] = {} - return RetrieverResult( - items=search_items, - metadata=metadata, - ) From dc638c2e24ea139ca96c2e8a65ebd8b1c4adb896 Mon Sep 17 00:00:00 2001 From: estelle Date: Thu, 24 Apr 2025 14:58:58 +0200 Subject: [PATCH 09/15] Add tests --- src/neo4j_graphrag/llm/vertexai_llm.py | 10 +- src/neo4j_graphrag/tool.py | 13 ++- tests/unit/llm/conftest.py | 27 ++++++ tests/unit/llm/test_openai_llm.py | 43 +++------ tests/unit/llm/test_vertexai_llm.py | 123 ++++++++++++++++++++++++- tests/unit/tool/__init__.py | 0 6 files changed, 176 insertions(+), 40 deletions(-) create mode 100644 tests/unit/llm/conftest.py create mode 100644 tests/unit/tool/__init__.py diff --git a/src/neo4j_graphrag/llm/vertexai_llm.py b/src/neo4j_graphrag/llm/vertexai_llm.py index 83487fe06..100ff99ab 100644 --- a/src/neo4j_graphrag/llm/vertexai_llm.py +++ b/src/neo4j_graphrag/llm/vertexai_llm.py @@ -194,12 +194,12 @@ def _to_vertexai_tool(self, tool: Tool) -> VertexAITool: FunctionDeclaration( name=tool.get_name(), description=tool.get_description(), - parameters=tool.get_parameters(), + parameters=tool.get_parameters(exclude=["additional_properties"]), ) ] ) - def get_tools( + def _get_llm_tools( self, tools: Optional[Sequence[Tool]] ) -> Optional[list[VertexAITool]]: if not tools: @@ -212,7 +212,7 @@ def _get_model( tools: Optional[Sequence[Tool]] = None, ) -> GenerativeModel: system_message = [system_instruction] if system_instruction is not None else [] - vertex_ai_tools = self.get_tools(tools) + vertex_ai_tools = self._get_llm_tools(tools) model = GenerativeModel( model_name=self.model_name, system_instruction=system_message, @@ -228,7 +228,7 @@ async def _acall_llm( system_instruction: Optional[str] = None, tools: Optional[Sequence[Tool]] = None, ) -> GenerationResponse: - model = self._get_model(system_instruction, tools) + model = self._get_model(system_instruction=system_instruction, tools=tools) messages = self.get_messages(input, message_history) response = await model.generate_content_async(messages, **self.model_params) return response @@ -240,7 +240,7 @@ def _call_llm( system_instruction: Optional[str] = None, tools: Optional[Sequence[Tool]] = None, ) -> GenerationResponse: - model = self._get_model(system_instruction, tools) + model = self._get_model(system_instruction=system_instruction, tools=tools) messages = self.get_messages(input, message_history) response = model.generate_content(messages, **self.model_params) return response diff --git a/src/neo4j_graphrag/tool.py b/src/neo4j_graphrag/tool.py index 63aac6684..dab14fdaf 100644 --- a/src/neo4j_graphrag/tool.py +++ b/src/neo4j_graphrag/tool.py @@ -169,18 +169,21 @@ def _preprocess_properties(cls, values: dict[str, Any]) -> dict[str, Any]: values["properties"] = new_props return values - def model_dump_tool(self) -> Dict[str, Any]: + def model_dump_tool(self, exclude: Optional[list[str]] = None) -> Dict[str, Any]: + exclude = exclude or [] properties_dict: Dict[str, Any] = {} for name, param in self.properties.items(): + if name in exclude: + continue properties_dict[name] = param.model_dump_tool() result = super().model_dump_tool() result["properties"] = properties_dict - if self.required_properties: + if self.required_properties and "required" not in exclude: result["required"] = self.required_properties - if not self.additional_properties: + if not self.additional_properties and "additional_properties" not in exclude: result["additionalProperties"] = False return result @@ -242,13 +245,13 @@ def get_description(self) -> str: """ return self._description - def get_parameters(self) -> Dict[str, Any]: + def get_parameters(self, exclude: Optional[list[str]] = None) -> Dict[str, Any]: """Get the parameters the tool accepts in a dictionary format suitable for LLM providers. Returns: Dict[str, Any]: Dictionary containing parameter schema information. """ - return self._parameters.model_dump_tool() + return self._parameters.model_dump_tool(exclude) def execute(self, query: str, **kwargs: Any) -> Any: """Execute the tool with the given query and additional parameters. diff --git a/tests/unit/llm/conftest.py b/tests/unit/llm/conftest.py new file mode 100644 index 000000000..269efadec --- /dev/null +++ b/tests/unit/llm/conftest.py @@ -0,0 +1,27 @@ +import pytest + +from neo4j_graphrag.tool import Tool, ObjectParameter, StringParameter + + +class TestTool(Tool): + """Test tool for unit tests.""" + + def __init__(self, name: str = "test_tool", description: str = "A test tool"): + parameters = ObjectParameter( + description="Test parameters", + properties={"param1": StringParameter(description="Test parameter")}, + required_properties=["param1"], + additional_properties=False, + ) + + super().__init__( + name=name, + description=description, + parameters=parameters, + execute_func=lambda **kwargs: kwargs, + ) + + +@pytest.fixture +def test_tool() -> Tool: + return TestTool() diff --git a/tests/unit/llm/test_openai_llm.py b/tests/unit/llm/test_openai_llm.py index 4220f3b36..3c5ee1b9e 100644 --- a/tests/unit/llm/test_openai_llm.py +++ b/tests/unit/llm/test_openai_llm.py @@ -20,7 +20,7 @@ from neo4j_graphrag.llm import LLMResponse from neo4j_graphrag.llm.openai_llm import AzureOpenAILLM, OpenAILLM from neo4j_graphrag.llm.types import ToolCallResponse -from neo4j_graphrag.tool import Tool, ObjectParameter, StringParameter +from neo4j_graphrag.tool import Tool def get_mock_openai() -> MagicMock: @@ -29,25 +29,6 @@ def get_mock_openai() -> MagicMock: return mock -class TestTool(Tool): - """Test tool for unit tests.""" - - def __init__(self, name: str = "test_tool", description: str = "A test tool"): - parameters = ObjectParameter( - description="Test parameters", - properties={"param1": StringParameter(description="Test parameter")}, - required_properties=["param1"], - additional_properties=False, - ) - - super().__init__( - name=name, - description=description, - parameters=parameters, - execute_func=lambda **kwargs: kwargs, - ) - - @patch("builtins.__import__", side_effect=ImportError) def test_openai_llm_missing_dependency(mock_import: Mock) -> None: with pytest.raises(ImportError): @@ -156,7 +137,9 @@ def test_openai_llm_with_message_history_validation_error(mock_import: Mock) -> @patch("builtins.__import__") @patch("json.loads") def test_openai_llm_invoke_with_tools_happy_path( - mock_json_loads: Mock, mock_import: Mock + mock_json_loads: Mock, + mock_import: Mock, + test_tool: Tool, ) -> None: # Set up json.loads to return a dictionary mock_json_loads.return_value = {"param1": "value1"} @@ -183,7 +166,7 @@ def test_openai_llm_invoke_with_tools_happy_path( ) llm = OpenAILLM(api_key="my key", model_name="gpt") - tools = [TestTool()] + tools = [test_tool] res = llm.invoke_with_tools("my text", tools) assert isinstance(res, ToolCallResponse) @@ -196,7 +179,9 @@ def test_openai_llm_invoke_with_tools_happy_path( @patch("builtins.__import__") @patch("json.loads") def test_openai_llm_invoke_with_tools_with_message_history( - mock_json_loads: Mock, mock_import: Mock + mock_json_loads: Mock, + mock_import: Mock, + test_tool: Tool, ) -> None: # Set up json.loads to return a dictionary mock_json_loads.return_value = {"param1": "value1"} @@ -223,7 +208,7 @@ def test_openai_llm_invoke_with_tools_with_message_history( ) llm = OpenAILLM(api_key="my key", model_name="gpt") - tools = [TestTool()] + tools = [test_tool] message_history = [ {"role": "user", "content": "When does the sun come up in the summer?"}, @@ -259,7 +244,9 @@ def test_openai_llm_invoke_with_tools_with_message_history( @patch("builtins.__import__") @patch("json.loads") def test_openai_llm_invoke_with_tools_with_system_instruction( - mock_json_loads: Mock, mock_import: Mock + mock_json_loads: Mock, + mock_import: Mock, + test_tool: Mock, ) -> None: # Set up json.loads to return a dictionary mock_json_loads.return_value = {"param1": "value1"} @@ -286,7 +273,7 @@ def test_openai_llm_invoke_with_tools_with_system_instruction( ) llm = OpenAILLM(api_key="my key", model_name="gpt") - tools = [TestTool()] + tools = [test_tool] system_instruction = "You are a helpful assistant." @@ -314,7 +301,7 @@ def test_openai_llm_invoke_with_tools_with_system_instruction( @patch("builtins.__import__") -def test_openai_llm_invoke_with_tools_error(mock_import: Mock) -> None: +def test_openai_llm_invoke_with_tools_error(mock_import: Mock, test_tool: Tool) -> None: mock_openai = get_mock_openai() mock_import.return_value = mock_openai @@ -324,7 +311,7 @@ def test_openai_llm_invoke_with_tools_error(mock_import: Mock) -> None: ) llm = OpenAILLM(api_key="my key", model_name="gpt") - tools = [TestTool()] + tools = [test_tool] with pytest.raises(LLMGenerationError): llm.invoke_with_tools("my text", tools) diff --git a/tests/unit/llm/test_vertexai_llm.py b/tests/unit/llm/test_vertexai_llm.py index 48ebf3505..f640d585e 100644 --- a/tests/unit/llm/test_vertexai_llm.py +++ b/tests/unit/llm/test_vertexai_llm.py @@ -19,9 +19,15 @@ import pytest from neo4j_graphrag.exceptions import LLMGenerationError +from neo4j_graphrag.llm.types import ToolCallResponse from neo4j_graphrag.llm.vertexai_llm import VertexAILLM +from neo4j_graphrag.tool import Tool from neo4j_graphrag.types import LLMMessage -from vertexai.generative_models import Content, Part +from vertexai.generative_models import ( + Content, + GenerationResponse, + Part, +) @patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel", None) @@ -171,4 +177,117 @@ async def test_vertexai_ainvoke_happy_path(GenerativeModelMock: MagicMock) -> No input_text = "may thy knife chip and shatter" response = await llm.ainvoke(input_text) assert response.content == "Return text" - llm.model.generate_content_async.assert_called_once_with([mock.ANY], **model_params) + llm.model.generate_content_async.assert_awaited_once_with( + [mock.ANY], **model_params + ) + + +def test_vertexai_get_llm_tools(test_tool: Tool) -> None: + llm = VertexAILLM(model_name="gemini") + tools = llm._get_llm_tools(tools=[test_tool]) + assert tools is not None + assert len(tools) == 1 + tool = tools[0] + tool_dict = tool.to_dict() + assert len(tool_dict["function_declarations"]) == 1 + assert tool_dict["function_declarations"][0]["name"] == "test_tool" + assert tool_dict["function_declarations"][0]["description"] == "A test tool" + + +@patch("neo4j_graphrag.llm.vertexai_llm.VertexAILLM._parse_tool_response") +@patch("neo4j_graphrag.llm.vertexai_llm.VertexAILLM._call_llm") +def test_vertexai_invoke_with_tools( + mock_call_llm: Mock, mock_parse_tool: Mock, test_tool: Tool, +) -> None: + # Mock the model call response + tool_call_mock = MagicMock() + tool_call_mock.name = "function" + tool_call_mock.args = {} + mock_call_llm.return_value = MagicMock( + candidates=[MagicMock(function_calls=[tool_call_mock])] + ) + mock_parse_tool.return_value = ToolCallResponse(tool_calls=[]) + + llm = VertexAILLM(model_name="gemini") + tools = [test_tool] + + res = llm.invoke_with_tools("my text", tools) + mock_call_llm.assert_called_once_with( + "my text", + message_history=None, + system_instruction=None, + tools=tools, + ) + mock_parse_tool.assert_called_once() + assert isinstance(res, ToolCallResponse) + + +@patch("neo4j_graphrag.llm.vertexai_llm.VertexAILLM._get_model") +def test_vertexai_call_llm_with_tools(mock_model: Mock, test_tool: Tool) -> None: + # Mock the generation response + mock_generate_content = mock_model.return_value.generate_content + mock_generate_content.return_value = MagicMock( + spec=GenerationResponse, + ) + + llm = VertexAILLM(model_name="gemini") + tools = [test_tool] + + res = llm._call_llm("my text", tools=tools) + assert isinstance(res, GenerationResponse) + + mock_model.assert_called_once_with( + system_instruction=None, + tools=tools, + ) + + +@patch("neo4j_graphrag.llm.vertexai_llm.VertexAILLM._parse_tool_response") +@patch("neo4j_graphrag.llm.vertexai_llm.VertexAILLM._call_llm") +def test_vertexai_ainvoke_with_tools( + mock_call_llm: Mock, mock_parse_tool: Mock, test_tool: Tool, +) -> None: + # Mock the model call response + tool_call_mock = MagicMock() + tool_call_mock.name = "function" + tool_call_mock.args = {} + mock_call_llm.return_value = AsyncMock( + return_value=MagicMock(candidates=[MagicMock(function_calls=[tool_call_mock])]) + ) + mock_parse_tool.return_value = ToolCallResponse(tool_calls=[]) + + llm = VertexAILLM(model_name="gemini") + tools = [test_tool] + + res = llm.invoke_with_tools("my text", tools) + mock_call_llm.assert_called_once_with( + "my text", + message_history=None, + system_instruction=None, + tools=tools, + ) + mock_parse_tool.assert_called_once() + assert isinstance(res, ToolCallResponse) + + +@pytest.mark.asyncio +@patch("neo4j_graphrag.llm.vertexai_llm.VertexAILLM._get_model") +async def test_vertexai_acall_llm_with_tools(mock_model, test_tool: Tool) -> None: + # Mock the generation response + mock_model.return_value = AsyncMock( + generate_content_async=AsyncMock( + return_value=MagicMock( + spec=GenerationResponse, + ) + ) + ) + + llm = VertexAILLM(model_name="gemini") + tools = [test_tool] + + res = await llm._acall_llm("my text", tools=tools) + mock_model.assert_called_once_with( + system_instruction=None, + tools=tools, + ) + assert isinstance(res, GenerationResponse) diff --git a/tests/unit/tool/__init__.py b/tests/unit/tool/__init__.py new file mode 100644 index 000000000..e69de29bb From d409e2d7f49eecc85270bf146e9bdedf177e9aa6 Mon Sep 17 00:00:00 2001 From: estelle Date: Thu, 24 Apr 2025 15:12:52 +0200 Subject: [PATCH 10/15] Ruff --- tests/unit/llm/test_vertexai_llm.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/unit/llm/test_vertexai_llm.py b/tests/unit/llm/test_vertexai_llm.py index f640d585e..2ebfcc68b 100644 --- a/tests/unit/llm/test_vertexai_llm.py +++ b/tests/unit/llm/test_vertexai_llm.py @@ -197,7 +197,9 @@ def test_vertexai_get_llm_tools(test_tool: Tool) -> None: @patch("neo4j_graphrag.llm.vertexai_llm.VertexAILLM._parse_tool_response") @patch("neo4j_graphrag.llm.vertexai_llm.VertexAILLM._call_llm") def test_vertexai_invoke_with_tools( - mock_call_llm: Mock, mock_parse_tool: Mock, test_tool: Tool, + mock_call_llm: Mock, + mock_parse_tool: Mock, + test_tool: Tool, ) -> None: # Mock the model call response tool_call_mock = MagicMock() @@ -245,7 +247,9 @@ def test_vertexai_call_llm_with_tools(mock_model: Mock, test_tool: Tool) -> None @patch("neo4j_graphrag.llm.vertexai_llm.VertexAILLM._parse_tool_response") @patch("neo4j_graphrag.llm.vertexai_llm.VertexAILLM._call_llm") def test_vertexai_ainvoke_with_tools( - mock_call_llm: Mock, mock_parse_tool: Mock, test_tool: Tool, + mock_call_llm: Mock, + mock_parse_tool: Mock, + test_tool: Tool, ) -> None: # Mock the model call response tool_call_mock = MagicMock() From d5a142aeb0520035a1f14f58108913f262b9eff6 Mon Sep 17 00:00:00 2001 From: estelle Date: Thu, 24 Apr 2025 15:23:35 +0200 Subject: [PATCH 11/15] Mypy --- tests/unit/llm/test_vertexai_llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/llm/test_vertexai_llm.py b/tests/unit/llm/test_vertexai_llm.py index 2ebfcc68b..b475efcc5 100644 --- a/tests/unit/llm/test_vertexai_llm.py +++ b/tests/unit/llm/test_vertexai_llm.py @@ -276,7 +276,7 @@ def test_vertexai_ainvoke_with_tools( @pytest.mark.asyncio @patch("neo4j_graphrag.llm.vertexai_llm.VertexAILLM._get_model") -async def test_vertexai_acall_llm_with_tools(mock_model, test_tool: Tool) -> None: +async def test_vertexai_acall_llm_with_tools(mock_model: Mock, test_tool: Tool) -> None: # Mock the generation response mock_model.return_value = AsyncMock( generate_content_async=AsyncMock( From a982d5b8523513bac53ae421faa33d907cf6baf6 Mon Sep 17 00:00:00 2001 From: estelle Date: Thu, 24 Apr 2025 15:31:41 +0200 Subject: [PATCH 12/15] Update CHANGELOG --- CHANGELOG.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b149f2d44..b226ed893 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,7 @@ ### Added -- Added tool calling functionality to the LLM base class with OpenAI implementation, enabling structured parameter extraction and function calling. +- Added tool calling functionality to the LLM base class with OpenAI and VertexAI implementations, enabling structured parameter extraction and function calling. - Added support for multi-vector collection in Qdrant driver. - Added a `Pipeline.stream` method to stream pipeline progress. - Added a new semantic match resolver to the KG Builder for entity resolution based on spaCy embeddings and cosine similarities so that nodes with similar textual properties get merged. @@ -13,7 +13,7 @@ ### Changed - Improved log output readability in Retrievers and GraphRAG and added embedded vector to retriever result metadata for debugging. -- Switched from pygraphviz to neo4j-viz +- Switched from pygraphviz to neo4j-viz - Renders interactive graph now on HTML instead of PNG - Removed `get_pygraphviz_graph` method From 633c9354db3f53d081cfb09ce6c70bdc51b1f02d Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 28 Apr 2025 12:00:34 +0200 Subject: [PATCH 13/15] Add example --- examples/README.md | 1 + .../customize/llms/vertexai_tool_calls.py | 95 +++++++++++++++++++ 2 files changed, 96 insertions(+) create mode 100644 examples/customize/llms/vertexai_tool_calls.py diff --git a/examples/README.md b/examples/README.md index b1b06f938..7feb71f3a 100644 --- a/examples/README.md +++ b/examples/README.md @@ -79,6 +79,7 @@ are listed in [the last section of this file](#customize). - [System Instruction](./customize/llms/llm_with_system_instructions.py) - [Tool Calling with OpenAI](./customize/llms/openai_tool_calls.py) +- [Tool Calling with VertexAI](./customize/llms/vertexai_tool_calls.py) ### Prompts diff --git a/examples/customize/llms/vertexai_tool_calls.py b/examples/customize/llms/vertexai_tool_calls.py new file mode 100644 index 000000000..e926e0cec --- /dev/null +++ b/examples/customize/llms/vertexai_tool_calls.py @@ -0,0 +1,95 @@ +""" +Example showing how to use VertexAI tool calls with parameter extraction. +Both synchronous and asynchronous examples are provided. +""" + +import asyncio + +from dotenv import load_dotenv +from vertexai.generative_models import GenerationConfig + +from neo4j_graphrag.llm import VertexAILLM +from neo4j_graphrag.llm.types import ToolCallResponse +from neo4j_graphrag.tool import Tool, ObjectParameter, StringParameter, IntegerParameter + +# Load environment variables from .env file +load_dotenv() + + +# Create a custom Tool implementation for person info extraction +parameters = ObjectParameter( + description="Parameters for extracting person information", + properties={ + "name": StringParameter(description="The person's full name"), + "age": IntegerParameter(description="The person's age"), + "occupation": StringParameter(description="The person's occupation"), + }, + required_properties=["name"], + additional_properties=False, +) + + +def run_tool(name: str, age: int, occupation: str) -> str: + """A simple function that summarizes person information from input parameters.""" + return f"Found person {name} with age {age} and occupation {occupation}" + + +person_info_tool = Tool( + name="extract_person_info", + description="Extract information about a person from text", + parameters=parameters, + execute_func=run_tool, +) + +# Create the tool instance +TOOLS = [person_info_tool] + + +def process_tool_call(response: ToolCallResponse) -> str: + """Process the tool call response and return the extracted parameters.""" + if not response.tool_calls: + raise ValueError("No tool calls found in response") + + tool_call = response.tool_calls[0] + print(f"\nTool called: {tool_call.name}") + print(f"Arguments: {tool_call.arguments}") + print(f"Additional content: {response.content or 'None'}") + return person_info_tool.execute(**tool_call.arguments) + + +async def main() -> None: + # Initialize the VertexAI LLM + generation_config = GenerationConfig(temperature=0.0) + llm = VertexAILLM( + model_name="gemini-1.5-flash-001", + generation_config=generation_config, + ) + + # Example text containing information about a person + text = "Stella Hane is a 35-year-old software engineer who loves coding." + + print("\n=== Synchronous Tool Call ===") + # Make a synchronous tool call + sync_response = llm.invoke_with_tools( + input=f"Extract information about the person from this text: {text}", + tools=TOOLS, + ) + sync_result = process_tool_call(sync_response) + print("\n=== Synchronous Tool Call Result ===") + print(sync_result) + + print("\n=== Asynchronous Tool Call ===") + # Make an asynchronous tool call with a different text + text2 = "Molly Hane, 32, works as a data scientist and enjoys machine learning." + async_response = await llm.ainvoke_with_tools( + input=f"Extract information about the person from this text: {text2}", + tools=TOOLS, + ) + async_result = process_tool_call(async_response) + print("\n=== Asynchronous Tool Call Result ===") + print(async_result) + + +if __name__ == "__main__": + # Run the async main function + asyncio.run(main()) From 2a21a509bb5ada74a30af3cb073131c7c9bf4275 Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 28 Apr 2025 12:16:08 +0200 Subject: [PATCH 14/15] mypy --- examples/customize/llms/vertexai_tool_calls.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/customize/llms/vertexai_tool_calls.py b/examples/customize/llms/vertexai_tool_calls.py index e926e0cec..b8b00da5b 100644 --- a/examples/customize/llms/vertexai_tool_calls.py +++ b/examples/customize/llms/vertexai_tool_calls.py @@ -54,7 +54,7 @@ def process_tool_call(response: ToolCallResponse) -> str: print(f"\nTool called: {tool_call.name}") print(f"Arguments: {tool_call.arguments}") print(f"Additional content: {response.content or 'None'}") - return person_info_tool.execute(**tool_call.arguments) + return person_info_tool.execute(**tool_call.arguments) # type: ignore[no-any-return] async def main() -> None: From 5911179e34e051eed7839c3e37986f47a6e9669a Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 28 Apr 2025 13:37:03 +0200 Subject: [PATCH 15/15] Remove mandatory first "query" parameter in the Tool interface --- src/neo4j_graphrag/tool.py | 5 ++--- tests/unit/tool/test_tool.py | 6 +++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/neo4j_graphrag/tool.py b/src/neo4j_graphrag/tool.py index dab14fdaf..905fb663a 100644 --- a/src/neo4j_graphrag/tool.py +++ b/src/neo4j_graphrag/tool.py @@ -253,14 +253,13 @@ def get_parameters(self, exclude: Optional[list[str]] = None) -> Dict[str, Any]: """ return self._parameters.model_dump_tool(exclude) - def execute(self, query: str, **kwargs: Any) -> Any: + def execute(self, **kwargs: Any) -> Any: """Execute the tool with the given query and additional parameters. Args: - query (str): The query or input for the tool to process. **kwargs (Any): Additional parameters for the tool. Returns: Any: The result of the tool execution. """ - return self._execute_func(query, **kwargs) + return self._execute_func(**kwargs) diff --git a/tests/unit/tool/test_tool.py b/tests/unit/tool/test_tool.py index b3b1d5dd5..6c04a1782 100644 --- a/tests/unit/tool/test_tool.py +++ b/tests/unit/tool/test_tool.py @@ -174,7 +174,7 @@ def test_required_parameter() -> None: def test_tool_class() -> None: - def dummy_func(query: str, **kwargs: Any) -> dict[str, Any]: + def dummy_func(**kwargs: Any) -> dict[str, Any]: return kwargs params = ObjectParameter( @@ -190,7 +190,7 @@ def dummy_func(query: str, **kwargs: Any) -> dict[str, Any]: assert tool.get_name() == "mytool" assert tool.get_description() == "desc" assert tool.get_parameters()["type"] == ParameterType.OBJECT - assert tool.execute("query", a="b") == {"a": "b"} + assert tool.execute(query="query", a="b") == {"query": "query", "a": "b"} # Test parameters as dict params_dict = { @@ -205,4 +205,4 @@ def dummy_func(query: str, **kwargs: Any) -> dict[str, Any]: execute_func=dummy_func, ) assert tool2.get_parameters()["type"] == ParameterType.OBJECT - assert tool2.execute("query", a="b") == {"a": "b"} + assert tool2.execute(a="b") == {"a": "b"}