From 1306f4942dbe5f4dba9aa5db4d0688806db11aa9 Mon Sep 17 00:00:00 2001 From: Oskar Hane Date: Wed, 7 May 2025 17:11:21 +0200 Subject: [PATCH 1/5] Add ToolsRetriever class and convert_retriever_to_tool() fn --- examples/customize/llms/openai_tool_calls.py | 7 +- .../customize/llms/vertexai_tool_calls.py | 7 +- .../tools/retriever_to_tool_example.py | 121 ++++ .../retrieve/tools/tools_retriever_example.py | 350 ++++++++++++ src/neo4j_graphrag/llm/base.py | 2 +- src/neo4j_graphrag/llm/openai_llm.py | 2 +- src/neo4j_graphrag/llm/vertexai_llm.py | 2 +- src/neo4j_graphrag/retrievers/__init__.py | 2 + .../retrievers/tools_retriever.py | 161 ++++++ src/neo4j_graphrag/{ => tools}/tool.py | 19 +- src/neo4j_graphrag/tools/utils.py | 76 +++ tests/unit/llm/conftest.py | 2 +- tests/unit/llm/test_openai_llm.py | 2 +- tests/unit/llm/test_vertexai_llm.py | 2 +- tests/unit/retrievers/test_tools_retriever.py | 262 +++++++++ tests/unit/tool/test_tool.py | 2 +- tests/unit/tool/test_tools_utils.py | 529 ++++++++++++++++++ 17 files changed, 1533 insertions(+), 15 deletions(-) create mode 100644 examples/retrieve/tools/retriever_to_tool_example.py create mode 100644 examples/retrieve/tools/tools_retriever_example.py create mode 100644 src/neo4j_graphrag/retrievers/tools_retriever.py rename src/neo4j_graphrag/{ => tools}/tool.py (95%) create mode 100644 src/neo4j_graphrag/tools/utils.py create mode 100644 tests/unit/retrievers/test_tools_retriever.py create mode 100644 tests/unit/tool/test_tools_utils.py diff --git a/examples/customize/llms/openai_tool_calls.py b/examples/customize/llms/openai_tool_calls.py index 166fb7248..87a14f8df 100644 --- a/examples/customize/llms/openai_tool_calls.py +++ b/examples/customize/llms/openai_tool_calls.py @@ -17,7 +17,12 @@ from neo4j_graphrag.llm import OpenAILLM from neo4j_graphrag.llm.types import ToolCallResponse -from neo4j_graphrag.tool import Tool, ObjectParameter, StringParameter, IntegerParameter +from neo4j_graphrag.tools.tool import ( + Tool, + ObjectParameter, + StringParameter, + IntegerParameter, +) # Load environment variables from .env file (OPENAI_API_KEY required for this example) load_dotenv() diff --git a/examples/customize/llms/vertexai_tool_calls.py b/examples/customize/llms/vertexai_tool_calls.py index ebe9fec22..0d91e1eb3 100644 --- a/examples/customize/llms/vertexai_tool_calls.py +++ b/examples/customize/llms/vertexai_tool_calls.py @@ -11,7 +11,12 @@ from neo4j_graphrag.llm import VertexAILLM from neo4j_graphrag.llm.types import ToolCallResponse -from neo4j_graphrag.tool import Tool, ObjectParameter, StringParameter, IntegerParameter +from neo4j_graphrag.tools.tool import ( + Tool, + ObjectParameter, + StringParameter, + IntegerParameter, +) # Load environment variables from .env file load_dotenv() diff --git a/examples/retrieve/tools/retriever_to_tool_example.py b/examples/retrieve/tools/retriever_to_tool_example.py new file mode 100644 index 000000000..8b8a1cbf6 --- /dev/null +++ b/examples/retrieve/tools/retriever_to_tool_example.py @@ -0,0 +1,121 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Example demonstrating how to convert a retriever to a tool. + +This example shows: +1. How to convert a custom StaticRetriever to a Tool +2. How to define parameters for the tool +3. How to execute the tool +""" + +import neo4j +from typing import Optional, Any, cast +from unittest.mock import MagicMock + +from neo4j_graphrag.retrievers.base import Retriever +from neo4j_graphrag.types import RawSearchResult +from neo4j_graphrag.tools.tool import ( + StringParameter, + ObjectParameter, +) +from neo4j_graphrag.tools.utils import convert_retriever_to_tool + + +# Create a Retriever that returns static results about Neo4j +# This would illustrate the conversion process of any Retriever (Vector, Hybrid, etc.) +class StaticRetriever(Retriever): + """A retriever that returns static results about Neo4j.""" + + # Disable Neo4j version verification + VERIFY_NEO4J_VERSION = False + + def __init__(self, driver: neo4j.Driver): + # Call the parent class constructor with the driver + super().__init__(driver) + + def get_search_results( + self, query_text: Optional[str] = None, **kwargs: Any + ) -> RawSearchResult: + """Return static information about Neo4j regardless of the query.""" + # Create formatted Neo4j information + neo4j_info = ( + "# Neo4j Graph Database\n\n" + "Neo4j is a graph database management system developed by Neo4j, Inc. " + "It is an ACID-compliant transactional database with native graph storage and processing.\n\n" + "## Key Features:\n\n" + "- **Cypher Query Language**: Neo4j's intuitive query language designed specifically for working with graph data\n" + "- **Property Graphs**: Both nodes and relationships can have properties (key-value pairs)\n" + "- **ACID Compliance**: Ensures data integrity with full transaction support\n" + "- **Native Graph Storage**: Optimized storage for graph data structures\n" + "- **High Availability**: Clustering for enterprise deployments\n" + "- **Scalability**: Handles billions of nodes and relationships" + ) + + # Create a Neo4j record with the information + records = [neo4j.Record({"result": neo4j_info})] + + # Return a RawSearchResult with the records and metadata + return RawSearchResult(records=records, metadata={"query": query_text}) + + +def main() -> None: + # Convert a StaticRetriever to a tool with specific parameters + static_retriever = StaticRetriever(driver=cast(Any, MagicMock())) + + # Define parameters for the static retriever tool + static_parameters = ObjectParameter( + description="Parameters for the Neo4j information retriever", + properties={ + "query_text": StringParameter( + description="The query about Neo4j (any query will return general Neo4j information)", + required=True, + ), + }, + ) + + # Convert the retriever to a tool with specific parameters + static_tool = convert_retriever_to_tool( + retriever=static_retriever, + description="Get general information about Neo4j graph database", + parameters=static_parameters, + name="Neo4jInfoTool", + ) + + # Print tool information + print("Example: StaticRetriever with specific parameters") + print(f"Tool Name: {static_tool.get_name()}") + print(f"Tool Description: {static_tool.get_description()}") + print(f"Tool Parameters: {static_tool.get_parameters()}") + print() + + # Execute the tools (in a real application, this would be done by instructions from an LLM) + try: + # Execute the static retriever tool + print("\nExecuting the static retriever tool...") + static_result = static_tool.execute( + query="What is Neo4j?", + ) + print("Static Search Results:") + for i, item in enumerate(static_result): + print(f"{i + 1}. {str(item)[:100]}...") + + except Exception as e: + print(f"Error executing tool: {e}") + + +if __name__ == "__main__": + main() diff --git a/examples/retrieve/tools/tools_retriever_example.py b/examples/retrieve/tools/tools_retriever_example.py new file mode 100644 index 000000000..3309205cf --- /dev/null +++ b/examples/retrieve/tools/tools_retriever_example.py @@ -0,0 +1,350 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Example demonstrating how to use the ToolsRetriever. + +This example shows: +1. How to create tools from different retrievers +2. How to use the ToolsRetriever to select and execute tools based on a query +""" + +import os +from typing import Any, Optional, cast +from unittest.mock import MagicMock +from dotenv import load_dotenv +import requests +from datetime import datetime, date + +import neo4j + +from neo4j_graphrag.generation import GraphRAG +from neo4j_graphrag.retrievers.base import Retriever +from neo4j_graphrag.retrievers.tools_retriever import ToolsRetriever +from neo4j_graphrag.types import RawSearchResult +from neo4j_graphrag.tools.tool import ( + ObjectParameter, + StringParameter, + Tool, +) +from neo4j_graphrag.tools.utils import convert_retriever_to_tool +from neo4j_graphrag.llm.openai_llm import OpenAILLM + +# Load environment variables from .env file (OPENAI_API_KEY required for this example) +load_dotenv() + + +# Create a Retriever that returns static results about Neo4j +class Neo4jInfoRetriever(Retriever): + """A retriever that returns general information about Neo4j.""" + + # Disable Neo4j version verification + VERIFY_NEO4J_VERSION = False + + def __init__(self, driver: neo4j.Driver): + # Call the parent class constructor with the driver + super().__init__(driver) + + def get_search_results( + self, query_text: Optional[str] = None, **kwargs: Any + ) -> RawSearchResult: + """Return general information about Neo4j.""" + # Create formatted Neo4j information + neo4j_info = ( + "# Neo4j Graph Database\n\n" + "Neo4j is a graph database management system developed by Neo4j, Inc. " + "It is an ACID-compliant transactional database with native graph storage and processing.\n\n" + "## Key Features:\n\n" + "- **Cypher Query Language**: Neo4j's intuitive query language designed specifically for working with graph data\n" + "- **Property Graphs**: Both nodes and relationships can have properties (key-value pairs)\n" + "- **ACID Compliance**: Ensures data integrity with full transaction support\n" + "- **Native Graph Storage**: Optimized storage for graph data structures\n" + "- **High Availability**: Clustering for enterprise deployments\n" + "- **Scalability**: Handles billions of nodes and relationships" + ) + + # Create a Neo4j record with the information + records = [neo4j.Record({"result": neo4j_info})] + + # Return a RawSearchResult with the records and metadata + return RawSearchResult(records=records, metadata={"query": query_text}) + + +class CalendarTool(Tool): + """A simple tool to get calendar information.""" + + def __init__(self) -> None: + """Initialize the calendar tool.""" + # Define parameters for the tool + parameters = ObjectParameter( + description="Parameters for calendar information retrieval", + properties={ + "date": StringParameter( + description="The date to check events for in YYYY-MM-DD format (e.g., 2025-04-14)", + ), + }, + required_properties=["date"], + ) + + # Sample calendar data with fixed dates + self.calendar_data = { + "2025-04-15": [ + {"time": "10:00", "title": "Dentist Appointment"}, + {"time": "14:00", "title": "Conference Call"}, + ], + "2025-04-16": [], + } + + # Define a wrapper function that handles the query parameter correctly + def execute_func(query: str, **kwargs: Any) -> str: + # Ignore the query parameter and call our execute method + return self.execute_calendar(**kwargs) + + super().__init__( + name="calendar_tool", + description="Check calendar events for a specific date in YYYY-MM-DD format", + parameters=parameters, + execute_func=execute_func, + ) + + def execute_calendar(self, **kwargs: Any) -> str: + """Execute the calendar tool. + + Args: + **kwargs: Dictionary of parameters, including 'date'. + + Returns: + str: The events for the specified date. + """ + date = kwargs.get("date") + if not date: + return "Error: No date provided" + + # Check for events on the date + if date in self.calendar_data: + events_list = self.calendar_data[date] + if not events_list: + return f"No events scheduled for {date}" + + events_str = "\n".join( + f"- {event.get('time', 'All day')}: {event.get('title', 'Untitled event')}" + for event in events_list + ) + return f"Events for {date}:\n{events_str}" + else: + return f"No events found for {date}" + + +class WeatherTool(Tool): + """A tool to fetch weather in Malmö, Sweden for a specific date.""" + + def __init__(self) -> None: + """Initialize the weather tool.""" + parameters = ObjectParameter( + description="Parameters for fetching weather information about a date.", + properties={ + "date": StringParameter( + description='The date, in YYYY-MM-DD format. Example: "2025-04-25"' + ) + }, + required_properties=["date"], + ) + super().__init__( + name="weather_tool", + description="Check for weather for a specific date in YYYY-MM-DD format", + parameters=parameters, + execute_func=self.execute_weather_retrieval, + ) + + def execute_weather_retrieval( + self, query: Optional[str] = None, **kwargs: Any + ) -> str: + """Fetch historical weather data for a given date in Malmö, Sweden.""" + date_str = kwargs.get("date") + if not date_str: + return "Error: Date not provided for weather lookup." + + try: + input_date = datetime.strptime(date_str, "%Y-%m-%d").date() + except ValueError: + return f"Error: Invalid date format '{date_str}'. Please use YYYY-MM-DD." + + today_date = date.today() + + if input_date < today_date: + api_url = "https://archive-api.open-meteo.com/v1/archive" + else: + # For today or future dates, use the forecast API + # Note: Forecast API typically has a limit (e.g., 16 days into the future) + api_url = "https://api.open-meteo.com/v1/forecast" + + params = { + "latitude": 55.6059, # Malmö, Sweden + "longitude": 13.0007, # Malmö, Sweden + "start_date": date_str, + "end_date": date_str, + "daily": "temperature_2m_max,sunshine_duration", + "timezone": "Europe/Stockholm", + } + headers = {"Accept": "application/json"} + + try: + response = requests.get(api_url, headers=headers, params=params) + response.raise_for_status() + data = response.json() + + # Try to access keys directly, relying on the existing broader except block for errors + daily = data["daily"] + temp_max = daily["temperature_2m_max"][0] + sunshine_seconds = daily["sunshine_duration"][0] + + sunshine_hours = 0 + if ( + sunshine_seconds is not None + ): # API might return null for sunshine_duration + sunshine_hours = round(sunshine_seconds / 3600, 1) + + return ( + f"Weather for Malmö, Sweden on this day:\n" + f"- Max Temperature: {temp_max}°C\n" + f"- Sunshine Duration: {sunshine_hours} hours" + ) + except requests.exceptions.RequestException as e: + return f"API request failed for weather data: {e}" + except ( + ValueError, + KeyError, + ) as e: + return f"Error parsing weather data for Malmö on {date_str}: {e}" + + return ( + f"Sorry, I couldn't fetch the weather for Malmö on {date_str} at this time." + ) + + +def main() -> None: + """Run the example.""" + # Create a mock Neo4j driver + driver = cast(neo4j.Driver, MagicMock()) + + # Create retrievers + neo4j_retriever = Neo4jInfoRetriever(driver=driver) + + # Define parameters for the tools + neo4j_parameters = ObjectParameter( + description="Parameters for Neo4j information retrieval", + properties={ + "query": StringParameter( + description="The query about Neo4j", + ), + }, + required_properties=["query"], + ) + + # Convert retrievers to tools + neo4j_tool = convert_retriever_to_tool( + retriever=neo4j_retriever, + name="neo4j_info_tool", + description="Get information about Neo4j graph database", + parameters=neo4j_parameters, + ) + + # Create a calendar tool + calendar_tool = CalendarTool() + + # Create a weather tool + weather_tool = WeatherTool() + + # Create an OpenAI LLM + llm = OpenAILLM( + api_key=os.getenv("OPENAI_API_KEY"), + model_name="gpt-4o", + model_params={ + "temperature": 0.2, + }, + ) + + # Print tool information for debugging + print("\nTool Information:") + print(f"Neo4j Tool: {neo4j_tool.get_name()}, {neo4j_tool.get_description()}") + print( + f"Calendar Tool: {calendar_tool.get_name()}, {calendar_tool.get_description()}" + ) + parameters_description = ( + weather_tool._parameters.description + if weather_tool._parameters + else "No parameters description" + ) + print( + f"Weather Tool: {weather_tool.get_name()}, {weather_tool.get_description()}: {parameters_description}" + ) + + # Create a ToolsRetriever with the LLM and tools + tools_retriever = ToolsRetriever( + driver=driver, + llm=llm, + tools=[neo4j_tool, calendar_tool, weather_tool], + ) + + # Test queries + test_queries = [ + "What is Neo4j?", + "Do I have any meetings the 15th of April 2025?", + "Any information about 2025-04-16?", + ] + + # Run just the tools retriever directly to show metadata etc. + print(f"\n\n{'=' * 80}") + print("Retriever call examples, to show metadata etc.") + print(f"{'=' * 80}") + for query in test_queries: + print(f"Query: {query}") + + try: + # Get search results through the ToolsRetriever + result = tools_retriever.get_search_results(query_text=query) + + # Print metadata + if result.metadata is not None: + print(f"\nTools selected: {result.metadata.get('tools_selected', [])}") + if result.metadata.get("error", ""): + print(f"Error: {result.metadata.get('error', '')}") + + # Print results + print("\nRESULTS:") + for i, record in enumerate(result.records): + print(f"\n--- Result {i + 1} ---") + print(record) + except Exception as e: + print(f"Error: {str(e)}") + print(f"{'=' * 80}") + + # For demo purposes, run the queries through GraphRAG to get text input -> text output + print(f"\n\n{'=' * 80}") + print("Full GraphRAG examples") + print(f"{'=' * 80}") + for query in test_queries: + print(f"Query: {query}") + # Full GraphRAG example + graphrag = GraphRAG( + llm=llm, + retriever=tools_retriever, + ) + rag_result = graphrag.search(query_text=query, return_context=False) + print(f"Answer: {rag_result.answer}") + print(f"{'=' * 80}") + + +if __name__ == "__main__": + main() diff --git a/src/neo4j_graphrag/llm/base.py b/src/neo4j_graphrag/llm/base.py index 87d281794..d634ce085 100644 --- a/src/neo4j_graphrag/llm/base.py +++ b/src/neo4j_graphrag/llm/base.py @@ -22,7 +22,7 @@ from .types import LLMResponse, ToolCallResponse -from neo4j_graphrag.tool import Tool +from neo4j_graphrag.tools.tool import Tool class LLMInterface(ABC): diff --git a/src/neo4j_graphrag/llm/openai_llm.py b/src/neo4j_graphrag/llm/openai_llm.py index 1e0228e45..563944842 100644 --- a/src/neo4j_graphrag/llm/openai_llm.py +++ b/src/neo4j_graphrag/llm/openai_llm.py @@ -49,7 +49,7 @@ UserMessage, ) -from neo4j_graphrag.tool import Tool +from neo4j_graphrag.tools.tool import Tool if TYPE_CHECKING: import openai diff --git a/src/neo4j_graphrag/llm/vertexai_llm.py b/src/neo4j_graphrag/llm/vertexai_llm.py index 39d483915..513c6275d 100644 --- a/src/neo4j_graphrag/llm/vertexai_llm.py +++ b/src/neo4j_graphrag/llm/vertexai_llm.py @@ -27,7 +27,7 @@ ToolCallResponse, ) from neo4j_graphrag.message_history import MessageHistory -from neo4j_graphrag.tool import Tool +from neo4j_graphrag.tools.tool import Tool from neo4j_graphrag.types import LLMMessage try: diff --git a/src/neo4j_graphrag/retrievers/__init__.py b/src/neo4j_graphrag/retrievers/__init__.py index 595eac93b..061679d93 100644 --- a/src/neo4j_graphrag/retrievers/__init__.py +++ b/src/neo4j_graphrag/retrievers/__init__.py @@ -15,6 +15,7 @@ from .hybrid import HybridCypherRetriever, HybridRetriever from .text2cypher import Text2CypherRetriever +from .tools_retriever import ToolsRetriever from .vector import VectorCypherRetriever, VectorRetriever __all__ = [ @@ -23,6 +24,7 @@ "HybridRetriever", "HybridCypherRetriever", "Text2CypherRetriever", + "ToolsRetriever", ] diff --git a/src/neo4j_graphrag/retrievers/tools_retriever.py b/src/neo4j_graphrag/retrievers/tools_retriever.py new file mode 100644 index 000000000..7f125df91 --- /dev/null +++ b/src/neo4j_graphrag/retrievers/tools_retriever.py @@ -0,0 +1,161 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import Any, List, Optional, Sequence + +import neo4j + +from neo4j_graphrag.llm.base import LLMInterface +from neo4j_graphrag.retrievers.base import Retriever +from neo4j_graphrag.types import RawSearchResult +from neo4j_graphrag.tools.tool import Tool +from neo4j_graphrag.types import LLMMessage + + +class ToolsRetriever(Retriever): + """A retriever that uses an LLM to select appropriate tools for retrieval based on user input. + + This retriever takes an LLM instance and a list of Tool objects as input. When a search is performed, + it uses the LLM to analyze the query and determine which tools (if any) should be used to retrieve + the necessary data. It then executes the selected tools and returns the combined results. + + Args: + driver (neo4j.Driver): Neo4j driver instance. + llm (LLMInterface): LLM instance used to select tools. + tools (Sequence[Tool]): List of tools available for selection. + neo4j_database (Optional[str], optional): Neo4j database name. Defaults to None. + system_instruction (Optional[str], optional): Custom system instruction for the LLM. Defaults to None. + """ + + # Disable Neo4j version verification since this retriever doesn't directly interact with Neo4j + VERIFY_NEO4J_VERSION = False + + def __init__( + self, + driver: neo4j.Driver, + llm: LLMInterface, + tools: Sequence[Tool], + neo4j_database: Optional[str] = None, + system_instruction: Optional[str] = None, + ): + """Initialize the ToolsRetriever with an LLM and a list of tools.""" + super().__init__(driver, neo4j_database) + self.llm = llm + self._tools = list(tools) # Make a copy to allow modification + self.system_instruction = ( + system_instruction or self._get_default_system_instruction() + ) + + def _get_default_system_instruction(self) -> str: + """Get the default system instruction for the LLM.""" + return ( + "You are an assistant that helps select the most appropriate tools to retrieve information " + "based on the user's query. Analyze the query carefully and determine which tools, if any, " + "would be most helpful in retrieving the relevant information. You can select multiple tools " + "if necessary, or none if no tools are appropriate for the query." + ) + + def get_search_results( + self, + query_text: str, + message_history: Optional[List[LLMMessage]] = None, + **kwargs: Any, + ) -> RawSearchResult: + """Use the LLM to select and execute appropriate tools based on the query. + + Args: + query_text (str): The user's query text. + message_history (Optional[Union[List[LLMMessage], MessageHistory]], optional): + Previous conversation history. Defaults to None. + **kwargs (Any): Additional arguments passed to the tool execution. + + Returns: + RawSearchResult: The combined results from the executed tools. + """ + if not self._tools: + # No tools available, return empty result + return RawSearchResult( + records=[], + metadata={"query": query_text, "error": "No tools available"}, + ) + + try: + # Use the LLM to select appropriate tools + tool_call_response = self.llm.invoke_with_tools( + input=query_text, + tools=self._tools, + message_history=message_history, + system_instruction=self.system_instruction, + ) + # If no tool calls were made, return empty result + if not tool_call_response.tool_calls: + return RawSearchResult( + records=[], + metadata={ + "query": query_text, + "llm_response": tool_call_response.content, + "tools_selected": [], + }, + ) + + # Execute each selected tool and collect results + all_records = [] + tools_selected = [] + + for tool_call in tool_call_response.tool_calls: + tool_name = tool_call.name + tools_selected.append(tool_name) + + # Find the tool by name + selected_tool = next( + (tool for tool in self._tools if tool.get_name() == tool_name), None + ) + if selected_tool is not None: + # Extract arguments from the tool call + tool_args = tool_call.arguments or {} + + # Always include the query_text in the arguments for tools that might need it + tool_args.setdefault("query", query_text) + + # Execute the tool with the provided arguments + tool_result = selected_tool.execute(**tool_args) + # If the tool result is a RawSearchResult, extract its records + if hasattr(tool_result, "records"): + all_records.extend(tool_result.records) + else: + # Create a record from the tool result + record = neo4j.Record({"result": tool_result}) + all_records.append(record) + + # Combine metadata from all tool calls + combined_metadata = { + "query": query_text, + "llm_response": tool_call_response.content, + "tools_selected": tools_selected, + } + + return RawSearchResult(records=all_records, metadata=combined_metadata) + + except Exception as e: + # Handle any errors during tool selection or execution + return RawSearchResult( + records=[], + metadata={ + "query": query_text, + "error": str(e), + "error_type": type(e).__name__, + }, + ) diff --git a/src/neo4j_graphrag/tool.py b/src/neo4j_graphrag/tools/tool.py similarity index 95% rename from src/neo4j_graphrag/tool.py rename to src/neo4j_graphrag/tools/tool.py index c07ffe082..a83802bf5 100644 --- a/src/neo4j_graphrag/tool.py +++ b/src/neo4j_graphrag/tools/tool.py @@ -211,23 +211,28 @@ def validate_properties(self) -> "ObjectParameter": class Tool(ABC): """Abstract base class defining the interface for all tools in the neo4j-graphrag library.""" + _name: str + _description: str + _parameters: Optional[ObjectParameter] + _execute_func: Callable[..., Any] + def __init__( self, name: str, description: str, - parameters: Union[ObjectParameter, Dict[str, Any]], execute_func: Callable[..., Any], + parameters: Optional[Union[ObjectParameter, Dict[str, Any]]] = None, ): self._name = name self._description = description + self._execute_func = execute_func - # Allow parameters to be provided as a dictionary if isinstance(parameters, dict): self._parameters = ObjectParameter.model_validate(parameters) - else: + elif isinstance(parameters, ObjectParameter): self._parameters = parameters - - self._execute_func = execute_func + else: + self._parameters = None def get_name(self) -> str: """Get the name of the tool. @@ -251,7 +256,9 @@ def get_parameters(self, exclude: Optional[list[str]] = None) -> Dict[str, Any]: Returns: Dict[str, Any]: Dictionary containing parameter schema information. """ - return self._parameters.model_dump_tool(exclude) + if self._parameters: + return self._parameters.model_dump_tool(exclude) + return {} def execute(self, **kwargs: Any) -> Any: """Execute the tool with the given query and additional parameters. diff --git a/src/neo4j_graphrag/tools/utils.py b/src/neo4j_graphrag/tools/utils.py new file mode 100644 index 000000000..0df86b0cd --- /dev/null +++ b/src/neo4j_graphrag/tools/utils.py @@ -0,0 +1,76 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional, Union + +from neo4j_graphrag.tools.tool import Tool, ObjectParameter + + +def convert_retriever_to_tool( + retriever: Any, + description: Optional[str] = None, + parameters: Optional[Union[ObjectParameter, Dict[str, Any]]] = None, + name: Optional[str] = None, +) -> Tool: + """Convert a retriever instance to a Tool object. + + Args: + retriever (Any): The retriever instance to convert. + description (Optional[str]): Custom description for the tool. If not provided, + an attempt will be made to infer it from the retriever or a generic description will be used. + parameters (Optional[Union[ObjectParameter, Dict[str, ToolParameter]]]): Custom parameters for the tool. + If not provided, no parameters will be included in the tool. + name (Optional[str]): Custom name for the tool. If not provided, + an attempt will be made to infer it from the retriever or a default name will be used. + + Returns: + RetrieverTool: A Tool object configured to use the retriever's search functionality. + """ + # Use provided name or infer it from the retriever + if name is None: + name = getattr(retriever, "name", None) or getattr( + retriever.__class__, "__name__", "UnnamedRetrieverTool" + ) + + # Infer description if not provided + if description is None: + description = ( + getattr(retriever, "description", None) + or f"A tool for retrieving data using {name}." + ) + + # Parameters can be None + + # Define a function that matches the Callable[[str, ...], Any] signature + def execute_func(**kwargs: Any) -> Any: + # The retriever's get_search_results method is expected to handle + # arguments like query_text, top_k, etc., passed as keyword arguments. + # The Tool's 'parameters' definition (e.g., ObjectParameter) ensures + # that these arguments are provided in kwargs when Tool.execute is called. + return retriever.get_search_results(**kwargs) + + # Ensure name is a string + tool_name = str(name) if name is not None else "UnnamedRetrieverTool" + + # Create a Tool object from the retriever + + # Pass parameters directly to the Tool constructor + # If parameters is None, the Tool class will handle it appropriately + return Tool( + name=tool_name, + description=description, + execute_func=execute_func, + parameters=parameters, + ) diff --git a/tests/unit/llm/conftest.py b/tests/unit/llm/conftest.py index 269efadec..9fc776120 100644 --- a/tests/unit/llm/conftest.py +++ b/tests/unit/llm/conftest.py @@ -1,6 +1,6 @@ import pytest -from neo4j_graphrag.tool import Tool, ObjectParameter, StringParameter +from neo4j_graphrag.tools.tool import Tool, ObjectParameter, StringParameter class TestTool(Tool): diff --git a/tests/unit/llm/test_openai_llm.py b/tests/unit/llm/test_openai_llm.py index 3c5ee1b9e..55a4f7824 100644 --- a/tests/unit/llm/test_openai_llm.py +++ b/tests/unit/llm/test_openai_llm.py @@ -20,7 +20,7 @@ from neo4j_graphrag.llm import LLMResponse from neo4j_graphrag.llm.openai_llm import AzureOpenAILLM, OpenAILLM from neo4j_graphrag.llm.types import ToolCallResponse -from neo4j_graphrag.tool import Tool +from neo4j_graphrag.tools.tool import Tool def get_mock_openai() -> MagicMock: diff --git a/tests/unit/llm/test_vertexai_llm.py b/tests/unit/llm/test_vertexai_llm.py index c937d2cb8..d914e4cb4 100644 --- a/tests/unit/llm/test_vertexai_llm.py +++ b/tests/unit/llm/test_vertexai_llm.py @@ -20,7 +20,7 @@ from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.types import ToolCallResponse from neo4j_graphrag.llm.vertexai_llm import VertexAILLM -from neo4j_graphrag.tool import Tool +from neo4j_graphrag.tools.tool import Tool from neo4j_graphrag.types import LLMMessage from vertexai.generative_models import ( Content, diff --git a/tests/unit/retrievers/test_tools_retriever.py b/tests/unit/retrievers/test_tools_retriever.py new file mode 100644 index 000000000..a5333aedc --- /dev/null +++ b/tests/unit/retrievers/test_tools_retriever.py @@ -0,0 +1,262 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Standard library imports +from typing import Any, List, cast +from unittest.mock import MagicMock + +import neo4j + +# Local imports +from neo4j_graphrag.llm.base import LLMInterface +from neo4j_graphrag.llm.types import ToolCall, ToolCallResponse +from neo4j_graphrag.retrievers.tools_retriever import ToolsRetriever +from neo4j_graphrag.tools.tool import Tool + + +# Mock dependencies +def create_mock_driver() -> neo4j.Driver: + driver = MagicMock(spec=neo4j.Driver) + # Create a mock result object with a records attribute + mock_result = MagicMock() + mock_result.records = [MagicMock()] + driver.execute_query.return_value = mock_result + return cast(neo4j.Driver, driver) + + +def create_mock_llm() -> Any: + llm = MagicMock(spec=LLMInterface) + return llm + + +def create_mock_tool(name: str = "MockTool", description: str = "A mock tool") -> Any: + tool = MagicMock(spec=Tool) + cast(Any, tool.get_name).return_value = name + cast(Any, tool.get_description).return_value = description + cast(Any, tool.get_parameters).return_value = { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The query to search for", + } + }, + } + # Mock the execute method to return a dictionary with records and metadata + cast(Any, tool.execute).return_value = { + "records": [neo4j.Record({"result": f"Result from {name}"})], + "metadata": {"source": name}, + } + return tool + + +class TestToolsRetriever: + """Test the ToolsRetriever class.""" + + def test_initialization(self) -> None: + """Test that the ToolsRetriever initializes correctly.""" + driver = create_mock_driver() + llm = create_mock_llm() + tools = [create_mock_tool("Tool1"), create_mock_tool("Tool2")] + + retriever = ToolsRetriever(driver=driver, llm=llm, tools=tools) + + assert retriever.llm == llm + assert len(retriever._tools) == 2 + assert retriever._tools[0].get_name() == "Tool1" + assert retriever._tools[1].get_name() == "Tool2" + + def test_get_search_results_no_tools(self) -> None: + """Test that get_search_results returns an empty result when no tools are available.""" + driver = create_mock_driver() + llm = create_mock_llm() + tools: List[Tool] = [] + + retriever = ToolsRetriever(driver=driver, llm=llm, tools=tools) + result = retriever.get_search_results(query_text="Test query") + + assert result.records == [] + assert result.metadata is not None + assert result.metadata.get("query") == "Test query" + assert "error" in result.metadata + assert result.metadata.get("error") == "No tools available" + + def test_get_search_results_no_tool_calls(self) -> None: + """Test that get_search_results returns an empty result when the LLM doesn't select any tools.""" + driver = create_mock_driver() + llm = create_mock_llm() + tools = [create_mock_tool("Tool1"), create_mock_tool("Tool2")] + + # Mock the LLM to return a response with no tool calls + cast(Any, llm.invoke_with_tools).return_value = ToolCallResponse( + content="I don't need any tools for this query.", + tool_calls=[], + ) + + retriever = ToolsRetriever(driver=driver, llm=llm, tools=tools) + result = retriever.get_search_results(query_text="Test query") + + assert result.records == [] + assert result.metadata is not None + assert result.metadata.get("query") == "Test query" + assert ( + result.metadata.get("llm_response") + == "I don't need any tools for this query." + ) + assert result.metadata.get("tools_selected") == [] + + def test_get_search_results_with_tool_calls(self) -> None: + """Test that get_search_results correctly executes selected tools and returns their results.""" + driver = create_mock_driver() + llm = create_mock_llm() + tool1 = create_mock_tool("Tool1") + tool2 = create_mock_tool("Tool2") + tools = [tool1, tool2] + + # Mock the LLM to return a response with tool calls + cast(Any, llm.invoke_with_tools).return_value = ToolCallResponse( + content="I'll use Tool1 for this query.", + tool_calls=[ + ToolCall( + name="Tool1", + arguments={"query": "Test query"}, + ) + ], + ) + + # Mock the tool execution to return a simple string value + # This is processed by the ToolsRetriever and converted to a neo4j.Record + cast(Any, tool1).execute.return_value = "Result from Tool1" + + retriever = ToolsRetriever(driver=driver, llm=llm, tools=tools) + result = retriever.get_search_results(query_text="Test query") + + # Check that the LLM was called with the right arguments + cast(Any, llm.invoke_with_tools).assert_called_once_with( + input="Test query", + tools=tools, + message_history=None, + system_instruction=retriever.system_instruction, + ) + + # Check that the tool was executed with the right arguments + tool1.execute.assert_called_once_with(query="Test query") + + # Check that the result contains the expected records and metadata + assert len(result.records) == 1 + # The record is a neo4j.Record object + assert isinstance(result.records[0], neo4j.Record) + # Access the result directly using index 0 + assert result.records[0][0] == "Result from Tool1" + assert result.metadata is not None + assert result.metadata.get("query") == "Test query" + assert result.metadata.get("llm_response") == "I'll use Tool1 for this query." + assert result.metadata.get("tools_selected") == ["Tool1"] + + def test_get_search_results_with_multiple_tool_calls(self) -> None: + """Test that get_search_results correctly executes multiple selected tools and combines their results.""" + driver = create_mock_driver() + llm = create_mock_llm() + tool1 = create_mock_tool("Tool1") + tool2 = create_mock_tool("Tool2") + tools = [tool1, tool2] + + # Mock the LLM to return a response with multiple tool calls + cast(Any, llm.invoke_with_tools).return_value = ToolCallResponse( + content="I'll use both Tool1 and Tool2 for this query.", + tool_calls=[ + ToolCall( + name="Tool1", + arguments={"query": "Test query part 1"}, + ), + ToolCall( + name="Tool2", + arguments={"query": "Test query part 2"}, + ), + ], + ) + + # Mock the tool executions to return specific records + tool1_record = neo4j.Record({"result": "Result from Tool1"}) + cast(Any, tool1.execute).return_value = { + "records": [tool1_record], + "metadata": {"source": "Tool1"}, + } + + tool2_record = neo4j.Record({"result": "Result from Tool2"}) + cast(Any, tool2.execute).return_value = { + "records": [tool2_record], + "metadata": {"source": "Tool2"}, + } + + retriever = ToolsRetriever(driver=driver, llm=llm, tools=tools) + result = retriever.get_search_results(query_text="Test query") + + # Check that both tools were executed with the right arguments + cast(Any, tool1.execute).assert_called_once_with(query="Test query part 1") + cast(Any, tool2.execute).assert_called_once_with(query="Test query part 2") + + # Check that the result contains the expected records and metadata + assert len(result.records) == 2 + assert result.metadata is not None + assert result.metadata.get("query") == "Test query" + assert ( + result.metadata.get("llm_response") + == "I'll use both Tool1 and Tool2 for this query." + ) + assert result.metadata.get("tools_selected") == ["Tool1", "Tool2"] + + def test_get_search_results_with_error(self) -> None: + """Test that get_search_results handles errors during tool execution.""" + driver = create_mock_driver() + llm = create_mock_llm() + tool = create_mock_tool("Tool1") + tools = [tool] + + # Mock the LLM to raise an exception + cast(Any, llm.invoke_with_tools).side_effect = Exception("LLM error") + + retriever = ToolsRetriever(driver=driver, llm=llm, tools=tools) + result = retriever.get_search_results(query_text="Test query") + + # Check that the result contains the error information + assert result.records == [] + assert result.metadata is not None + assert result.metadata.get("query") == "Test query" + assert result.metadata.get("error") == "LLM error" + assert result.metadata.get("error_type") == "Exception" + + def test_custom_system_instruction(self) -> None: + """Test that a custom system instruction is used when provided.""" + driver = create_mock_driver() + llm = create_mock_llm() + tools = [create_mock_tool("Tool1")] + custom_instruction = "This is a custom system instruction." + + retriever = ToolsRetriever( + driver=driver, llm=llm, tools=tools, system_instruction=custom_instruction + ) + + assert retriever.system_instruction == custom_instruction + + # Test that the custom instruction is passed to the LLM + retriever.get_search_results(query_text="Test query") + + llm.invoke_with_tools.assert_called_once_with( + input="Test query", + tools=tools, + message_history=None, + system_instruction=custom_instruction, + ) diff --git a/tests/unit/tool/test_tool.py b/tests/unit/tool/test_tool.py index 6c04a1782..e1af50210 100644 --- a/tests/unit/tool/test_tool.py +++ b/tests/unit/tool/test_tool.py @@ -1,6 +1,6 @@ import pytest from typing import Any -from neo4j_graphrag.tool import ( +from neo4j_graphrag.tools.tool import ( StringParameter, IntegerParameter, NumberParameter, diff --git a/tests/unit/tool/test_tools_utils.py b/tests/unit/tool/test_tools_utils.py new file mode 100644 index 000000000..6926c5105 --- /dev/null +++ b/tests/unit/tool/test_tools_utils.py @@ -0,0 +1,529 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from unittest.mock import MagicMock, patch +import neo4j +from neo4j_graphrag.embeddings.base import Embedder +from neo4j_graphrag.llm.base import LLMInterface +from neo4j_graphrag.retrievers import ( + HybridCypherRetriever, + HybridRetriever, + Text2CypherRetriever, + VectorCypherRetriever, + VectorRetriever, +) +from neo4j_graphrag.tools.tool import ( + Tool, + ObjectParameter, + StringParameter, + IntegerParameter, +) +from neo4j_graphrag.tools.utils import convert_retriever_to_tool + + +# Mock dependencies for retriever instances +def create_mock_driver() -> neo4j.Driver: + driver = MagicMock(spec=neo4j.Driver) + # Create a mock result object with a records attribute + mock_result = MagicMock() + mock_result.records = [MagicMock()] + driver.execute_query.return_value = mock_result + return driver + + +def create_mock_embedder() -> Embedder: + embedder = MagicMock(spec=Embedder) + embedder.embed_query.return_value = [0.1, 0.2, 0.3] + return embedder + + +def create_mock_llm() -> LLMInterface: + llm = MagicMock() + llm.invoke.return_value = "MATCH (n) RETURN n" + return llm + + +# Test conversion with VectorRetriever +@patch("neo4j_graphrag.retrievers.base.get_version") +def test_convert_vector_retriever_to_tool(mock_get_version: MagicMock) -> None: + """Test conversion of VectorRetriever to a Tool instance with correct attributes.""" + mock_get_version.return_value = ((5, 20, 0), False, False) + driver = create_mock_driver() + embedder = create_mock_embedder() + retriever = VectorRetriever( + driver=driver, + index_name="test_index", + embedder=embedder, + return_properties=["name", "description"], + ) + parameters = ObjectParameter( + description="Parameters for vector search", + properties={ + "query_text": StringParameter( + description="The query text for vector search.", + required=True, + ), + "top_k": IntegerParameter( + description="Number of results to return.", + required=False, + ), + }, + ) + tool = convert_retriever_to_tool( + retriever, + description="A tool for vector-based retrieval from Neo4j.", + parameters=parameters, + ) + assert isinstance(tool, Tool) + assert tool.get_name() in ["VectorRetriever", "UnnamedRetrieverTool"] + assert tool.get_description() == "A tool for vector-based retrieval from Neo4j." + # Check that the parameters object has the expected properties + params = tool.get_parameters() + assert "properties" in params + assert len(params["properties"]) == 2 + + +# Test conversion with VectorCypherRetriever +@patch("neo4j_graphrag.retrievers.base.get_version") +def test_convert_vector_cypher_retriever_to_tool(mock_get_version: MagicMock) -> None: + """Test conversion of VectorCypherRetriever to a Tool instance with correct attributes.""" + mock_get_version.return_value = ((5, 20, 0), False, False) + driver = create_mock_driver() + embedder = create_mock_embedder() + retriever = VectorCypherRetriever( + driver=driver, + index_name="test_index", + embedder=embedder, + retrieval_query="RETURN n", + ) + parameters = ObjectParameter( + description="Parameters for vector-cypher search", + properties={ + "query_text": StringParameter( + description="The query text for vector-cypher search.", + required=True, + ), + "top_k": IntegerParameter( + description="Number of results to return.", + required=False, + ), + }, + ) + tool = convert_retriever_to_tool( + retriever, + description="A tool for vector-cypher retrieval from Neo4j.", + parameters=parameters, + ) + assert isinstance(tool, Tool) + assert tool.get_name() in ["VectorCypherRetriever", "UnnamedRetrieverTool"] + assert tool.get_description() == "A tool for vector-cypher retrieval from Neo4j." + # Check that the parameters object has the expected properties + params = tool.get_parameters() + assert "properties" in params + assert len(params["properties"]) == 2 + + +# Test conversion with HybridRetriever +@patch("neo4j_graphrag.retrievers.base.get_version") +def test_convert_hybrid_retriever_to_tool(mock_get_version: MagicMock) -> None: + """Test conversion of HybridRetriever to a Tool instance with correct attributes.""" + mock_get_version.return_value = ((5, 20, 0), False, False) + driver = create_mock_driver() + embedder = create_mock_embedder() + retriever = HybridRetriever( + driver=driver, + vector_index_name="test_vector_index", + fulltext_index_name="test_fulltext_index", + embedder=embedder, + return_properties=["name", "description"], + ) + parameters = ObjectParameter( + description="Parameters for hybrid search", + properties={ + "query_text": StringParameter( + description="The query text for hybrid search.", + required=True, + ), + "top_k": IntegerParameter( + description="Number of results to return.", + required=False, + ), + }, + ) + tool = convert_retriever_to_tool( + retriever, + description="A tool for hybrid retrieval from Neo4j.", + parameters=parameters, + ) + assert isinstance(tool, Tool) + assert tool.get_name() in ["HybridRetriever", "UnnamedRetrieverTool"] + assert tool.get_description() == "A tool for hybrid retrieval from Neo4j." + # Check that the parameters object has the expected properties + params = tool.get_parameters() + assert "properties" in params + assert len(params["properties"]) == 2 + + +# Test conversion with HybridCypherRetriever +@patch("neo4j_graphrag.retrievers.base.get_version") +def test_convert_hybrid_cypher_retriever_to_tool(mock_get_version: MagicMock) -> None: + """Test conversion of HybridCypherRetriever to a Tool instance with correct attributes.""" + mock_get_version.return_value = ((5, 20, 0), False, False) + driver = create_mock_driver() + embedder = create_mock_embedder() + retriever = HybridCypherRetriever( + driver=driver, + vector_index_name="test_vector_index", + fulltext_index_name="test_fulltext_index", + embedder=embedder, + retrieval_query="RETURN n", + ) + parameters = ObjectParameter( + description="Parameters for hybrid-cypher search", + properties={ + "query_text": StringParameter( + description="The query text for hybrid-cypher search.", + required=True, + ), + "top_k": IntegerParameter( + description="Number of results to return.", + required=False, + ), + }, + ) + tool = convert_retriever_to_tool( + retriever, + description="A tool for hybrid-cypher retrieval from Neo4j.", + parameters=parameters, + ) + assert isinstance(tool, Tool) + assert tool.get_name() in ["HybridCypherRetriever", "UnnamedRetrieverTool"] + assert tool.get_description() == "A tool for hybrid-cypher retrieval from Neo4j." + # Check that the parameters object has the expected properties + params = tool.get_parameters() + assert "properties" in params + assert len(params["properties"]) == 2 + + +# Test conversion with Text2CypherRetriever +@patch("neo4j_graphrag.retrievers.base.get_version") +def test_convert_text2cypher_retriever_to_tool(mock_get_version: MagicMock) -> None: + """Test conversion of Text2CypherRetriever to a Tool instance with correct attributes.""" + mock_get_version.return_value = ((5, 20, 0), False, False) + driver = create_mock_driver() + llm = create_mock_llm() + retriever = Text2CypherRetriever(driver=driver, llm=llm) + parameters = ObjectParameter( + description="Parameters for text to Cypher conversion", + properties={ + "query_text": StringParameter( + description="The query text for text to Cypher conversion.", + required=True, + ), + }, + ) + tool = convert_retriever_to_tool( + retriever, + description="A tool for text to Cypher retrieval from Neo4j.", + parameters=parameters, + ) + assert isinstance(tool, Tool) + assert tool.get_name() in ["Text2CypherRetriever", "UnnamedRetrieverTool"] + assert tool.get_description() == "A tool for text to Cypher retrieval from Neo4j." + # Check that the parameters object has the expected properties + params = tool.get_parameters() + assert "properties" in params + assert len(params["properties"]) == 1 + + +# Test conversion with custom name provided +@patch("neo4j_graphrag.retrievers.base.get_version") +def test_convert_retriever_with_custom_name( + mock_get_version: MagicMock, +) -> None: + """Test conversion of a retriever to a Tool instance with a custom name.""" + mock_get_version.return_value = ((5, 20, 0), False, False) + driver = create_mock_driver() + embedder = create_mock_embedder() + retriever = VectorRetriever( + driver=driver, + index_name="test_index", + embedder=embedder, + return_properties=["name", "description"], + ) + + custom_name = "CustomNamedTool" + parameters = ObjectParameter( + description="Parameters for vector search", + properties={ + "query_text": StringParameter( + description="The query text for vector search.", + required=True, + ), + }, + ) + + tool = convert_retriever_to_tool( + retriever, + description="A tool with a custom name", + parameters=parameters, + name=custom_name, + ) + + # Verify that the custom name is used instead of the retriever class name + assert tool.get_name() == custom_name + assert tool.get_name() != "VectorRetriever" + assert tool.get_name() != "UnnamedRetrieverTool" + + +# Test conversion with no parameters provided +@patch("neo4j_graphrag.retrievers.base.get_version") +def test_convert_vector_retriever_to_tool_no_parameters( + mock_get_version: MagicMock, +) -> None: + """Test conversion of VectorRetriever to a Tool instance when no parameters are provided.""" + mock_get_version.return_value = ((5, 20, 0), False, False) + driver = create_mock_driver() + embedder = create_mock_embedder() + retriever = VectorRetriever( + driver=driver, + index_name="test_index", + embedder=embedder, + return_properties=["name", "description"], + ) + tool = convert_retriever_to_tool( + retriever, description="A tool for vector-based retrieval from Neo4j." + ) + assert isinstance(tool, Tool) + assert tool.get_name() in ["VectorRetriever", "UnnamedRetrieverTool"] + assert tool.get_description() == "A tool for vector-based retrieval from Neo4j." + # Since we don't provide parameters, it should be None + assert tool._parameters is None + + +# Test tool execution for VectorRetriever +@patch("neo4j_graphrag.retrievers.base.get_version") +def test_vector_retriever_tool_execution(mock_get_version: MagicMock) -> None: + """Test execution of VectorRetriever tool calls the search method with correct arguments.""" + mock_get_version.return_value = ((5, 20, 0), False, False) + driver = create_mock_driver() + embedder = create_mock_embedder() + retriever = VectorRetriever( + driver=driver, + index_name="test_index", + embedder=embedder, + return_properties=["name", "description"], + ) + parameters = ObjectParameter( + description="Parameters for vector search", + properties={ + "query_text": StringParameter( + description="The query text for vector search.", + required=True, + ), + "top_k": IntegerParameter( + description="Number of results to return.", + required=False, + ), + }, + ) + # Mock the get_search_results method to track calls + get_search_results_mock = MagicMock(return_value=([], None)) + # Use patch to mock the method + with patch.object(retriever, "get_search_results", get_search_results_mock): + tool = convert_retriever_to_tool( + retriever, + description="A tool for vector-based retrieval from Neo4j.", + parameters=parameters, + ) + tools = {tool.get_name(): tool} + # Simulate indirect invocation as would happen in real usage + tool_call_arguments = {"query_text": "test query", "top_k": 5} + # Pass the arguments as kwargs + result = tools[tool.get_name()].execute(**tool_call_arguments) + + # Since we're using a context manager for patching, we need to verify the call inside the context + # We can only check the result, not the method call itself + assert result == ([], None) + + +# Test tool execution for HybridRetriever +@patch("neo4j_graphrag.retrievers.base.get_version") +def test_hybrid_retriever_tool_execution(mock_get_version: MagicMock) -> None: + """Test execution of HybridRetriever tool calls the search method with correct arguments.""" + mock_get_version.return_value = ((5, 20, 0), False, False) + driver = create_mock_driver() + embedder = create_mock_embedder() + retriever = HybridRetriever( + driver=driver, + vector_index_name="test_vector_index", + fulltext_index_name="test_fulltext_index", + embedder=embedder, + return_properties=["name", "description"], + ) + parameters = ObjectParameter( + description="Parameters for hybrid search", + properties={ + "query_text": StringParameter( + description="The query text for hybrid search.", + required=True, + ), + "top_k": IntegerParameter( + description="Number of results to return.", + required=False, + ), + }, + ) + # Mock the get_search_results method to track calls + get_search_results_mock = MagicMock(return_value=([], None)) + # Use patch to mock the method + with patch.object(retriever, "get_search_results", get_search_results_mock): + tool = convert_retriever_to_tool( + retriever, + description="A tool for hybrid retrieval from Neo4j.", + parameters=parameters, + ) + tools = {tool.get_name(): tool} + # Simulate indirect invocation as would happen in real usage + tool_call_arguments = {"query_text": "test query", "top_k": 5} + # Pass the arguments as kwargs + result = tools[tool.get_name()].execute(**tool_call_arguments) + + # Since we're using a context manager for patching, we need to verify the call inside the context + # We can only check the result, not the method call itself + assert result == ([], None) + + +# Test tool execution for Text2CypherRetriever +@patch("neo4j_graphrag.retrievers.base.get_version") +def test_text2cypher_retriever_tool_execution(mock_get_version: MagicMock) -> None: + """Test execution of Text2CypherRetriever tool calls the search method with correct arguments.""" + mock_get_version.return_value = ((5, 20, 0), False, False) + driver = create_mock_driver() + llm = create_mock_llm() + retriever = Text2CypherRetriever(driver=driver, llm=llm) + parameters = ObjectParameter( + description="Parameters for text to Cypher conversion", + properties={ + "query_text": StringParameter( + description="The query text for text to Cypher conversion.", + required=True, + ), + }, + ) + # Mock the get_search_results method to track calls + get_search_results_mock = MagicMock(return_value=([], None)) + # Use patch to mock the method + with patch.object(retriever, "get_search_results", get_search_results_mock): + tool = convert_retriever_to_tool( + retriever, + description="A tool for text to Cypher retrieval from Neo4j.", + parameters=parameters, + ) + tools = {tool.get_name(): tool} + # Simulate indirect invocation as would happen in real usage + tool_call_arguments = {"query_text": "test query"} + # Pass the arguments as kwargs + result = tools[tool.get_name()].execute(**tool_call_arguments) + + # Since we're using a context manager for patching, we need to verify the call inside the context + # We can only check the result, not the method call itself + assert result == ([], None) + + +# Test tool serialization to JSON format +@patch("neo4j_graphrag.retrievers.base.get_version") +def test_tool_serialization(mock_get_version: MagicMock) -> None: + """Test that a Tool instance can be serialized to the required JSON format.""" + mock_get_version.return_value = ((5, 20, 0), False, False) + driver = create_mock_driver() + embedder = create_mock_embedder() + retriever = VectorRetriever( + driver=driver, + index_name="test_index", + embedder=embedder, + return_properties=["name", "description"], + ) + # Define parameters for the tool + parameters = ObjectParameter( + description="Parameters for vector search", + properties={ + "query_text": StringParameter( + description="The query text for vector search.", + required=True, + ), + "top_k": IntegerParameter( + description="Number of results to return.", + required=False, + ), + }, + ) + tool = convert_retriever_to_tool( + retriever, + description="A tool for vector-based retrieval from Neo4j.", + parameters=parameters, + ) + # Create a dictionary representation of the tool + tool_dict = { + "type": "function", + "name": tool.get_name(), + "description": tool.get_description(), + "parameters": tool.get_parameters(), + } + + assert tool_dict["type"] == "function" + assert tool_dict["name"] == tool.get_name() + assert tool_dict["description"] == tool.get_description() + assert "parameters" in tool_dict + + # Get parameters and convert to dictionary + parameters_any = tool_dict["parameters"] + # Use type casting to handle various parameter types + if isinstance(parameters_any, ObjectParameter): + parameters_dict = parameters_any.model_dump_tool() + elif isinstance(parameters_any, dict): + parameters_dict = parameters_any + else: + # Handle the case where parameters is a Collection[str] or other type + parameters_dict = { + str(k): v for k, v in enumerate(parameters_any) if v is not None + } + + # Check the parameters structure + assert parameters_dict.get("type") == "object" + assert "properties" in parameters_dict + + # Check that at least one parameter is marked as required + required_found = False + properties = parameters_dict.get("properties", {}) + if isinstance(properties, dict): + for param_name, param_data in properties.items(): + if isinstance(param_data, dict) and param_data.get("required", False): + required_found = True + break + + if not required_found and "required" in parameters_dict: + # Check if there's a required array at the parameters level + required_params = parameters_dict.get("required", []) + required_found = len(list(required_params)) > 0 + + assert required_found, "No required parameters found" + + # Check additionalProperties if it exists + if "additionalProperties" in parameters_dict and not parameters_dict.get( + "additionalProperties" + ): + pass # This line is just to satisfy the test, actual check is visual From 58a6309b90394f7420c4a57eb96044c8289f5dc6 Mon Sep 17 00:00:00 2001 From: Oskar Hane Date: Wed, 18 Jun 2025 16:24:32 +0200 Subject: [PATCH 2/5] Address PR comments: refactor retriever-to-tool conversion - Add abstract get_parameters() method to Retriever base class - Add convert_to_tool() instance method to Retriever class - Implement get_parameters() for all concrete retriever classes - Remove automatic query_text injection in ToolsRetriever - Update example to use new convert_to_tool() method - Remove unnecessary description from ObjectParameter in example --- .../retrieve/tools/multiple_tools_example.py | 151 ++++++ .../tools/retriever_to_tool_example.py | 45 +- src/neo4j_graphrag/retrievers/base.py | 267 +++++++++- .../retrievers/tools_retriever.py | 3 - .../test_retriever_parameter_inference.py | 470 ++++++++++++++++++ 5 files changed, 906 insertions(+), 30 deletions(-) create mode 100644 examples/retrieve/tools/multiple_tools_example.py create mode 100644 tests/unit/retrievers/test_retriever_parameter_inference.py diff --git a/examples/retrieve/tools/multiple_tools_example.py b/examples/retrieve/tools/multiple_tools_example.py new file mode 100644 index 000000000..29cc29305 --- /dev/null +++ b/examples/retrieve/tools/multiple_tools_example.py @@ -0,0 +1,151 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Example demonstrating how to create multiple domain-specific tools from retrievers. + +This example shows: +1. How to create multiple tools from the same retriever type for different use cases +2. How to provide custom parameter descriptions for each tool +3. How type inference works automatically while descriptions are explicit +""" + +import neo4j +from typing import cast, Any, Optional +from unittest.mock import MagicMock + +from neo4j_graphrag.retrievers.base import Retriever +from neo4j_graphrag.types import RawSearchResult + + +class MockVectorRetriever(Retriever): + """A mock vector retriever for demonstration purposes.""" + + VERIFY_NEO4J_VERSION = False + + def __init__(self, driver: neo4j.Driver, index_name: str): + super().__init__(driver) + self.index_name = index_name + + def get_search_results( + self, + query_vector: Optional[list[float]] = None, + query_text: Optional[str] = None, + top_k: int = 5, + effective_search_ratio: int = 1, + filters: Optional[dict[str, Any]] = None, + ) -> RawSearchResult: + """Get vector search results (mocked for demonstration).""" + # Return empty results for demo + return RawSearchResult(records=[], metadata={"index": self.index_name}) + + +def main() -> None: + """Demonstrate creating multiple domain-specific tools from retrievers.""" + + # Create mock driver (in real usage, this would be actual Neo4j driver) + driver = cast(Any, MagicMock()) + + # Create retrievers for different domains using the same retriever type + # In practice, these would point to different vector indexes + + # Movie recommendations retriever + movie_retriever = MockVectorRetriever(driver=driver, index_name="movie_embeddings") + + # Product search retriever + product_retriever = MockVectorRetriever( + driver=driver, index_name="product_embeddings" + ) + + # Document search retriever + document_retriever = MockVectorRetriever( + driver=driver, index_name="document_embeddings" + ) + + # Convert each retriever to a domain-specific tool with custom descriptions + + # 1. Movie recommendation tool + movie_tool = movie_retriever.convert_to_tool( + name="movie_search", + description="Find movie recommendations based on plot, genre, or actor preferences", + parameter_descriptions={ + "query_text": "Movie title, plot description, genre, or actor name", + "query_vector": "Pre-computed embedding vector for movie search", + "top_k": "Number of movie recommendations to return (1-20)", + "filters": "Optional filters for genre, year, rating, etc.", + "effective_search_ratio": "Search pool multiplier for better accuracy", + }, + ) + + # 2. Product search tool + product_tool = product_retriever.convert_to_tool( + name="product_search", + description="Search for products matching customer needs and preferences", + parameter_descriptions={ + "query_text": "Product name, description, or customer need", + "query_vector": "Pre-computed embedding for product matching", + "top_k": "Maximum number of product results (1-50)", + "filters": "Filters for price range, brand, category, availability", + "effective_search_ratio": "Breadth vs precision trade-off for search", + }, + ) + + # 3. Document search tool + document_tool = document_retriever.convert_to_tool( + name="document_search", + description="Find relevant documents and knowledge articles", + parameter_descriptions={ + "query_text": "Question, keywords, or topic to search for", + "query_vector": "Semantic embedding for document retrieval", + "top_k": "Number of relevant documents to retrieve (1-10)", + "filters": "Document type, date range, or department filters", + }, + ) + + # Demonstrate that each tool has distinct, meaningful descriptions + tools = [movie_tool, product_tool, document_tool] + + for tool in tools: + print(f"\n=== {tool.get_name().upper()} ===") + print(f"Description: {tool.get_description()}") + print("Parameters:") + + params = tool.get_parameters() + for param_name, param_def in params["properties"].items(): + required = ( + "required" if param_name in params.get("required", []) else "optional" + ) + print( + f" - {param_name} ({param_def['type']}, {required}): {param_def['description']}" + ) + + # Show how the same parameter type gets different contextual descriptions + print("\n=== PARAMETER COMPARISON ===") + print("Same parameter 'query_text' with different contextual descriptions:") + + for tool in tools: + params = tool.get_parameters() + query_text_desc = params["properties"]["query_text"]["description"] + print(f" {tool.get_name()}: {query_text_desc}") + + print("\nSame parameter 'top_k' with different contextual descriptions:") + for tool in tools: + params = tool.get_parameters() + top_k_desc = params["properties"]["top_k"]["description"] + print(f" {tool.get_name()}: {top_k_desc}") + + +if __name__ == "__main__": + main() diff --git a/examples/retrieve/tools/retriever_to_tool_example.py b/examples/retrieve/tools/retriever_to_tool_example.py index 8b8a1cbf6..fca986dcd 100644 --- a/examples/retrieve/tools/retriever_to_tool_example.py +++ b/examples/retrieve/tools/retriever_to_tool_example.py @@ -17,8 +17,8 @@ Example demonstrating how to convert a retriever to a tool. This example shows: -1. How to convert a custom StaticRetriever to a Tool -2. How to define parameters for the tool +1. How to convert a custom StaticRetriever to a Tool using the convert_to_tool method +2. How to define parameters for the tool in the retriever class 3. How to execute the tool """ @@ -28,11 +28,6 @@ from neo4j_graphrag.retrievers.base import Retriever from neo4j_graphrag.types import RawSearchResult -from neo4j_graphrag.tools.tool import ( - StringParameter, - ObjectParameter, -) -from neo4j_graphrag.tools.utils import convert_retriever_to_tool # Create a Retriever that returns static results about Neo4j @@ -50,7 +45,15 @@ def __init__(self, driver: neo4j.Driver): def get_search_results( self, query_text: Optional[str] = None, **kwargs: Any ) -> RawSearchResult: - """Return static information about Neo4j regardless of the query.""" + """Return static information about Neo4j regardless of the query. + + Args: + query_text (Optional[str]): The query about Neo4j (any query will return general Neo4j information) + **kwargs (Any): Additional keyword arguments (not used) + + Returns: + RawSearchResult: Static Neo4j information with metadata + """ # Create formatted Neo4j information neo4j_info = ( "# Neo4j Graph Database\n\n" @@ -73,26 +76,16 @@ def get_search_results( def main() -> None: - # Convert a StaticRetriever to a tool with specific parameters + # Convert a StaticRetriever to a tool using the new convert_to_tool method static_retriever = StaticRetriever(driver=cast(Any, MagicMock())) - # Define parameters for the static retriever tool - static_parameters = ObjectParameter( - description="Parameters for the Neo4j information retriever", - properties={ - "query_text": StringParameter( - description="The query about Neo4j (any query will return general Neo4j information)", - required=True, - ), - }, - ) - - # Convert the retriever to a tool with specific parameters - static_tool = convert_retriever_to_tool( - retriever=static_retriever, - description="Get general information about Neo4j graph database", - parameters=static_parameters, + # Convert the retriever to a tool with custom parameter descriptions + static_tool = static_retriever.convert_to_tool( name="Neo4jInfoTool", + description="Get general information about Neo4j graph database", + parameter_descriptions={ + "query_text": "Any query about Neo4j (the tool returns general information regardless)" + }, ) # Print tool information @@ -107,7 +100,7 @@ def main() -> None: # Execute the static retriever tool print("\nExecuting the static retriever tool...") static_result = static_tool.execute( - query="What is Neo4j?", + query_text="What is Neo4j?", ) print("Static Search Results:") for i, item in enumerate(static_result): diff --git a/src/neo4j_graphrag/retrievers/base.py b/src/neo4j_graphrag/retrievers/base.py index c3b295d15..c3f694973 100644 --- a/src/neo4j_graphrag/retrievers/base.py +++ b/src/neo4j_graphrag/retrievers/base.py @@ -17,7 +17,18 @@ import inspect import types from abc import ABC, ABCMeta, abstractmethod -from typing import Any, Callable, Optional, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Optional, + TypeVar, + get_args, + get_origin, + Union, + Dict, + get_type_hints, +) import neo4j from typing_extensions import ParamSpec @@ -32,6 +43,13 @@ ) from neo4j_graphrag.utils import driver_config +if TYPE_CHECKING: + from neo4j_graphrag.tools.tool import ( + ObjectParameter, + Tool, + ToolParameter, + ) + T = ParamSpec("T") P = TypeVar("P") @@ -175,6 +193,253 @@ def default_record_formatter(self, record: neo4j.Record) -> RetrieverResultItem: """ return RetrieverResultItem(content=str(record), metadata=record.get("metadata")) + def get_parameters( + self, parameter_descriptions: Optional[Dict[str, str]] = None + ) -> "ObjectParameter": + """Return the parameters that this retriever expects for tool conversion. + + This method automatically infers parameters from the get_search_results method signature. + + Args: + parameter_descriptions (Optional[Dict[str, str]]): Custom descriptions for parameters. + Keys should match parameter names from get_search_results method. + + Returns: + ObjectParameter: The parameter definition for this retriever + """ + return self._infer_parameters_from_signature(parameter_descriptions or {}) + + def _infer_parameters_from_signature( + self, parameter_descriptions: Dict[str, str] + ) -> "ObjectParameter": + """Infer parameters from the get_search_results method signature.""" + # Import here to avoid circular imports + from neo4j_graphrag.tools.tool import ( + ObjectParameter, + ) + + # Get the method signature and resolved type hints + sig = inspect.signature(self.get_search_results) + try: + type_hints = get_type_hints(self.get_search_results) + except (NameError, AttributeError): + # If type hints can't be resolved, fall back to annotation strings + type_hints = {} + + properties: Dict[str, "ToolParameter"] = {} + required_properties = [] + + for param_name, param in sig.parameters.items(): + # Skip 'self' parameter + if param_name == "self": + continue + + # Skip **kwargs + if param.kind == inspect.Parameter.VAR_KEYWORD: + continue + + # Determine if parameter is required (no default value) + is_required = param.default is inspect.Parameter.empty + + # Use resolved type hint if available, otherwise fall back to annotation + type_annotation = type_hints.get(param_name, param.annotation) + + # Get the parameter type and create appropriate tool parameter + tool_param = self._create_tool_parameter_from_type( + param_name, type_annotation, is_required, parameter_descriptions + ) + + if tool_param: + properties[param_name] = tool_param + if is_required: + required_properties.append(param_name) + + return ObjectParameter( + description=f"Parameters for {self.__class__.__name__}", + properties=properties, + required_properties=required_properties, + additional_properties=False, + ) + + def _create_tool_parameter_from_type( + self, + param_name: str, + type_annotation: Any, + is_required: bool, + parameter_descriptions: Dict[str, str], + ) -> Optional["ToolParameter"]: + """Create a tool parameter from a type annotation.""" + # Import here to avoid circular imports + from neo4j_graphrag.tools.tool import ( + StringParameter, + IntegerParameter, + NumberParameter, + ArrayParameter, + ObjectParameter, + ) + + # Handle None/missing annotation + if type_annotation is inspect.Parameter.empty or type_annotation is None: + return StringParameter( + description=parameter_descriptions.get( + param_name, f"Parameter {param_name}" + ), + required=is_required, + ) + + # Get the origin and args for generic types + origin = get_origin(type_annotation) + args = get_args(type_annotation) + + # Handle Optional[T] and Union[T, None] + if origin is Union: + # Remove None from union args to get the actual type + non_none_args = [arg for arg in args if arg is not type(None)] + if len(non_none_args) == 1: + # This is Optional[T], use T + type_annotation = non_none_args[0] + # Re-calculate origin and args for the unwrapped type + origin = get_origin(type_annotation) + args = get_args(type_annotation) + elif len(non_none_args) > 1: + # This is Union[T, U, ...], treat as string for now + return StringParameter( + description=parameter_descriptions.get( + param_name, f"Parameter {param_name}" + ), + required=is_required, + ) + + # Handle specific types + if type_annotation is str: + return StringParameter( + description=parameter_descriptions.get( + param_name, f"Parameter {param_name}" + ), + required=is_required, + ) + elif type_annotation is int: + return IntegerParameter( + description=parameter_descriptions.get( + param_name, f"Parameter {param_name}" + ), + minimum=1 + if param_name in ["top_k", "effective_search_ratio"] + else None, + required=is_required, + ) + elif type_annotation is float: + constraints: Dict[str, Any] = {} + if param_name == "alpha": + constraints.update(minimum=0.0, maximum=1.0) + return NumberParameter( + description=parameter_descriptions.get( + param_name, f"Parameter {param_name}" + ), + required=is_required, + **constraints, + ) + elif ( + origin is list + or type_annotation is list + or ( + hasattr(type_annotation, "__origin__") + and type_annotation.__origin__ is list + ) + or str(type_annotation).startswith("list[") + ): + # Handle list[float] for vectors + if args and args[0] is float: + return ArrayParameter( + items=NumberParameter( + description="A single vector component", + required=False, + ), + description=parameter_descriptions.get( + param_name, f"Parameter {param_name}" + ), + required=is_required, + ) + else: + # For complex list types like List[LLMMessage], treat as object + return ObjectParameter( + description=parameter_descriptions.get( + param_name, f"Parameter {param_name}" + ), + properties={}, + additional_properties=True, + required=is_required, + ) + elif origin is dict or ( + hasattr(type_annotation, "__origin__") + and type_annotation.__origin__ is dict + ): + return ObjectParameter( + description=parameter_descriptions.get( + param_name, f"Parameter {param_name}" + ), + properties={}, + additional_properties=True, + required=is_required, + ) + else: + # Check if it's a complex type that should be an object + type_name = str(type_annotation) + if any( + keyword in type_name.lower() + for keyword in ["dict", "list", "optional[dict", "optional[list"] + ): + return ObjectParameter( + description=parameter_descriptions.get( + param_name, f"Parameter {param_name}" + ), + properties={}, + additional_properties=True, + required=is_required, + ) + # For other complex types or enums, default to string + return StringParameter( + description=parameter_descriptions.get( + param_name, f"Parameter {param_name}" + ), + required=is_required, + ) + + def convert_to_tool( + self, + name: str, + description: str, + parameter_descriptions: Optional[Dict[str, str]] = None, + ) -> "Tool": + """Convert this retriever to a Tool object. + + Args: + name (str): Name for the tool. + description (str): Description of what the tool does. + parameter_descriptions (Optional[Dict[str, str]]): Optional descriptions for each parameter. + Keys should match parameter names from get_search_results method. + + Returns: + Tool: A Tool object configured to use this retriever's search functionality. + """ + # Import here to avoid circular imports + from neo4j_graphrag.tools.tool import Tool + + # Get parameters from the retriever with custom descriptions + parameters = self.get_parameters(parameter_descriptions or {}) + + # Define a function that matches the Callable[[str, ...], Any] signature + def execute_func(**kwargs: Any) -> Any: + return self.get_search_results(**kwargs) + + # Create a Tool object from the retriever + return Tool( + name=name, + description=description, + execute_func=execute_func, + parameters=parameters, + ) + class ExternalRetriever(Retriever, ABC): """ diff --git a/src/neo4j_graphrag/retrievers/tools_retriever.py b/src/neo4j_graphrag/retrievers/tools_retriever.py index 7f125df91..633334b42 100644 --- a/src/neo4j_graphrag/retrievers/tools_retriever.py +++ b/src/neo4j_graphrag/retrievers/tools_retriever.py @@ -127,9 +127,6 @@ def get_search_results( # Extract arguments from the tool call tool_args = tool_call.arguments or {} - # Always include the query_text in the arguments for tools that might need it - tool_args.setdefault("query", query_text) - # Execute the tool with the provided arguments tool_result = selected_tool.execute(**tool_args) # If the tool result is a RawSearchResult, extract its records diff --git a/tests/unit/retrievers/test_retriever_parameter_inference.py b/tests/unit/retrievers/test_retriever_parameter_inference.py new file mode 100644 index 000000000..81cd0ee10 --- /dev/null +++ b/tests/unit/retrievers/test_retriever_parameter_inference.py @@ -0,0 +1,470 @@ +# type: ignore +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for retriever parameter inference and convert_to_tool functionality. +""" + +from unittest.mock import MagicMock, patch +from typing import Optional, Any, Dict + +import neo4j + +from neo4j_graphrag.retrievers.base import Retriever +from neo4j_graphrag.retrievers import ( + VectorRetriever, + VectorCypherRetriever, + HybridRetriever, + Text2CypherRetriever, +) +from neo4j_graphrag.retrievers.tools_retriever import ToolsRetriever +from neo4j_graphrag.tools.tool import Tool, ParameterType +from neo4j_graphrag.types import RawSearchResult +from neo4j_graphrag.embeddings.base import Embedder +from neo4j_graphrag.llm.base import LLMInterface + + +# Helper functions for creating mock objects +def create_mock_driver() -> neo4j.Driver: + driver = MagicMock(spec=neo4j.Driver) + mock_result = MagicMock() + mock_result.records = [] + driver.execute_query.return_value = mock_result + return driver + + +def create_mock_embedder() -> Embedder: + embedder = MagicMock(spec=Embedder) + embedder.embed_query.return_value = [0.1, 0.2, 0.3] + return embedder + + +def create_mock_llm() -> LLMInterface: + llm = MagicMock(spec=LLMInterface) + llm.invoke.return_value = MagicMock(content="MATCH (n) RETURN n") + return llm + + +class MockRetriever(Retriever): + """Test retriever with well-documented parameters.""" + + VERIFY_NEO4J_VERSION = False + + def get_search_results( + self, + query_text: str, + top_k: int = 5, + filters: Optional[Dict[str, Any]] = None, + score_threshold: Optional[float] = None, + ) -> RawSearchResult: + """Test search method with documented parameters. + + Args: + query_text (str): The text query to search for in the database + top_k (int): The maximum number of results to return + filters (Optional[Dict[str, Any]]): Optional metadata filters to apply + score_threshold (Optional[float]): Minimum similarity score threshold + + Returns: + RawSearchResult: The search results + """ + return RawSearchResult(records=[], metadata={}) + + +class MockRetrieverNoDocstring(Retriever): + """Test retriever without parameter documentation.""" + + VERIFY_NEO4J_VERSION = False + + def get_search_results( + self, param_one: str, param_two: Optional[int] = None + ) -> RawSearchResult: + """No parameter documentation here.""" + return RawSearchResult(records=[], metadata={}) + + +class TestParameterInference: + """Test parameter inference from method signatures and docstrings.""" + + def test_parameter_inference_with_docstring(self): + """Test that parameters are correctly inferred from method signature and docstring.""" + driver = create_mock_driver() + retriever = MockRetriever(driver) + + # Get inferred parameters + params = retriever.get_parameters() + + # Check basic structure + assert params.type == ParameterType.OBJECT + assert params.description == "Parameters for MockRetriever" + assert not params.additional_properties + + # Check properties + properties = params.properties + assert len(properties) == 4 + + # Check query_text parameter + query_text_param = properties["query_text"] + assert query_text_param.type == ParameterType.STRING + assert query_text_param.description == "Parameter query_text" + assert query_text_param.required is True + + # Check top_k parameter + top_k_param = properties["top_k"] + assert top_k_param.type == ParameterType.INTEGER + assert top_k_param.description == "Parameter top_k" + assert top_k_param.required is False + assert top_k_param.minimum == 1 # Should be set for top_k parameters + + # Check filters parameter + filters_param = properties["filters"] + assert filters_param.type == ParameterType.OBJECT + assert filters_param.description == "Parameter filters" + assert filters_param.required is False + assert filters_param.additional_properties is True + + # Check score_threshold parameter + score_param = properties["score_threshold"] + assert score_param.type == ParameterType.NUMBER + assert score_param.description == "Parameter score_threshold" + assert score_param.required is False + + def test_parameter_inference_without_docstring(self): + """Test that parameters work with fallback descriptions when no docstring documentation.""" + driver = create_mock_driver() + retriever = MockRetrieverNoDocstring(driver) + + # Get inferred parameters + params = retriever.get_parameters() + + # Check properties + properties = params.properties + assert len(properties) == 2 + + # Check param_one with fallback description + param_one = properties["param_one"] + assert param_one.type == ParameterType.STRING + assert param_one.description == "Parameter param_one" # Simple fallback format + assert param_one.required is True + + # Check param_two with fallback description + param_two = properties["param_two"] + assert param_two.type == ParameterType.INTEGER + assert param_two.description == "Parameter param_two" # Simple fallback format + assert param_two.required is False + + def test_convert_to_tool_basic(self): + """Test basic convert_to_tool functionality.""" + driver = create_mock_driver() + retriever = MockRetriever(driver) + + # Convert to tool + tool = retriever.convert_to_tool( + name="TestTool", description="A test tool for searching" + ) + + # Check tool properties + assert isinstance(tool, Tool) + assert tool.get_name() == "TestTool" + assert tool.get_description() == "A test tool for searching" + + # Check that parameters were inferred + params = tool.get_parameters() + assert len(params["properties"]) == 4 + assert "query_text" in params["properties"] + assert "top_k" in params["properties"] + + def test_convert_to_tool_with_custom_descriptions(self): + """Test convert_to_tool with custom parameter descriptions.""" + driver = create_mock_driver() + retriever = MockRetriever(driver) + + # Convert to tool with custom parameter descriptions + tool = retriever.convert_to_tool( + name="CustomTool", + description="A custom search tool", + parameter_descriptions={ + "query_text": "The search query to execute", + "top_k": "Maximum number of results to return", + "filters": "Optional filters to apply to the search", + }, + ) + + # Check tool properties + assert tool.get_name() == "CustomTool" + assert tool.get_description() == "A custom search tool" + + # Check custom parameter descriptions + params = tool.get_parameters() + properties = params["properties"] + + assert properties["query_text"]["description"] == "The search query to execute" + assert ( + properties["top_k"]["description"] == "Maximum number of results to return" + ) + assert ( + properties["filters"]["description"] + == "Optional filters to apply to the search" + ) + # Parameter without custom description should use fallback + assert ( + properties["score_threshold"]["description"] == "Parameter score_threshold" + ) + + +class TestRealRetrieverParameterInference: + """Test parameter inference on real retriever classes.""" + + @patch("neo4j_graphrag.retrievers.base.get_version") + def test_vector_retriever_parameters(self, mock_get_version): + """Test VectorRetriever parameter inference.""" + mock_get_version.return_value = ((5, 20, 0), False, False) + + driver = create_mock_driver() + embedder = create_mock_embedder() + + # Patch _fetch_index_infos to avoid database calls + with patch.object(VectorRetriever, "_fetch_index_infos"): + retriever = VectorRetriever( + driver=driver, index_name="test_index", embedder=embedder + ) + + params = retriever.get_parameters() + properties = params.properties + + # Check expected parameters from VectorRetriever.get_search_results + expected_params = { + "query_vector", + "query_text", + "top_k", + "effective_search_ratio", + "filters", + } + assert set(properties.keys()) == expected_params + + # Check specific parameter types + assert properties["query_vector"].type == ParameterType.ARRAY + assert properties["query_text"].type == ParameterType.STRING + assert properties["top_k"].type == ParameterType.INTEGER + assert properties["effective_search_ratio"].type == ParameterType.INTEGER + assert properties["filters"].type == ParameterType.OBJECT + + # Check that default descriptions are used when no custom descriptions provided + assert properties["query_vector"].description == "Parameter query_vector" + assert properties["query_text"].description == "Parameter query_text" + + @patch("neo4j_graphrag.retrievers.base.get_version") + def test_vector_cypher_retriever_parameters(self, mock_get_version): + """Test VectorCypherRetriever parameter inference.""" + mock_get_version.return_value = ((5, 20, 0), False, False) + + driver = create_mock_driver() + embedder = create_mock_embedder() + + # Patch _fetch_index_infos to avoid database calls + with patch.object(VectorCypherRetriever, "_fetch_index_infos"): + retriever = VectorCypherRetriever( + driver=driver, + index_name="test_index", + retrieval_query="RETURN node.name", + embedder=embedder, + ) + + params = retriever.get_parameters() + properties = params.properties + + # Should have all VectorRetriever params plus query_params + expected_params = { + "query_vector", + "query_text", + "top_k", + "effective_search_ratio", + "query_params", + "filters", + } + assert set(properties.keys()) == expected_params + + # Check query_params is properly typed + assert properties["query_params"].type == ParameterType.OBJECT + assert properties["query_params"].additional_properties is True + + @patch("neo4j_graphrag.retrievers.base.get_version") + def test_hybrid_retriever_parameters(self, mock_get_version): + """Test HybridRetriever parameter inference.""" + mock_get_version.return_value = ((5, 20, 0), False, False) + + driver = create_mock_driver() + embedder = create_mock_embedder() + + # Patch _fetch_index_infos to avoid database calls + with patch.object(HybridRetriever, "_fetch_index_infos"): + retriever = HybridRetriever( + driver=driver, + vector_index_name="vector_index", + fulltext_index_name="fulltext_index", + embedder=embedder, + ) + + params = retriever.get_parameters() + properties = params.properties + + # Check expected parameters from HybridRetriever.get_search_results + expected_params = { + "query_text", + "query_vector", + "top_k", + "effective_search_ratio", + "ranker", + "alpha", + } + assert set(properties.keys()) == expected_params + + # Check that query_text is required for hybrid retriever + assert properties["query_text"].required is True + assert properties["alpha"].type == ParameterType.NUMBER + assert properties["alpha"].minimum == 0.0 + assert properties["alpha"].maximum == 1.0 + + @patch("neo4j_graphrag.retrievers.base.get_version") + def test_text2cypher_retriever_parameters(self, mock_get_version): + """Test Text2CypherRetriever parameter inference.""" + mock_get_version.return_value = ((5, 20, 0), False, False) + + driver = create_mock_driver() + llm = create_mock_llm() + retriever = Text2CypherRetriever( + driver=driver, llm=llm, neo4j_schema="(Person)-[:KNOWS]->(Person)" + ) + + params = retriever.get_parameters() + properties = params.properties + + # Check expected parameters + expected_params = {"query_text", "prompt_params"} + assert set(properties.keys()) == expected_params + + # Check parameter types + assert properties["query_text"].type == ParameterType.STRING + assert properties["query_text"].required is True + assert ( + properties["prompt_params"].type == ParameterType.OBJECT + ) # Dict maps to object + assert properties["prompt_params"].required is False + + def test_tools_retriever_parameters(self): + """Test ToolsRetriever parameter inference.""" + driver = create_mock_driver() + llm = create_mock_llm() + retriever = ToolsRetriever(driver=driver, llm=llm, tools=[]) + + params = retriever.get_parameters() + properties = params.properties + + # Check expected parameters from ToolsRetriever.get_search_results + expected_params = {"query_text", "message_history"} + assert set(properties.keys()) == expected_params + + # Check parameter types + assert properties["query_text"].type == ParameterType.STRING + assert properties["query_text"].required is True + assert ( + properties["message_history"].type == ParameterType.OBJECT + ) # List[LLMMessage] maps to Object + assert properties["message_history"].required is False + + +class TestToolExecution: + """Test that tools created from retrievers actually work.""" + + def test_tool_execution(self): + """Test that a tool created from a retriever can be executed.""" + driver = create_mock_driver() + retriever = MockRetriever(driver) + + # Convert to tool + tool = retriever.convert_to_tool(name="TestTool", description="A test tool") + + # Execute the tool + result = tool.execute(query_text="test query", top_k=3) + + # Check that we get a result (even if empty due to mocking) + assert result is not None + assert hasattr(result, "records") + assert hasattr(result, "metadata") + + def test_tool_execution_with_validation(self): + """Test that tool parameter validation works.""" + driver = create_mock_driver() + retriever = MockRetriever(driver) + + # Convert to tool + tool = retriever.convert_to_tool(name="TestTool", description="A test tool") + + # Test with missing required parameter should work due to our setup + # (the actual validation happens in the Tool class) + result = tool.execute(query_text="test query") + assert result is not None + + +class TestParameterDescriptions: + """Test parameter description functionality.""" + + def test_custom_parameter_descriptions(self): + """Test that custom parameter descriptions are used correctly.""" + + class TestRetriever(Retriever): + VERIFY_NEO4J_VERSION = False + + def get_search_results( + self, param_a: str, param_b: int = 5, param_c: Optional[float] = None + ) -> RawSearchResult: + return RawSearchResult(records=[], metadata={}) + + driver = create_mock_driver() + retriever = TestRetriever(driver) + + # Test with custom descriptions + custom_descriptions = { + "param_a": "Custom description for param A", + "param_b": "Custom description for param B", + # param_c intentionally omitted to test fallback + } + + params = retriever.get_parameters(custom_descriptions) + properties = params.properties + + # Check that custom descriptions are used + assert properties["param_a"].description == "Custom description for param A" + assert properties["param_b"].description == "Custom description for param B" + # Check fallback for param without custom description + assert properties["param_c"].description == "Parameter param_c" + + def test_no_custom_descriptions(self): + """Test behavior when no custom descriptions are provided.""" + + class SimpleRetriever(Retriever): + VERIFY_NEO4J_VERSION = False + + def get_search_results(self, test_param: str) -> RawSearchResult: + return RawSearchResult(records=[], metadata={}) + + driver = create_mock_driver() + retriever = SimpleRetriever(driver) + params = retriever.get_parameters() + properties = params.properties + + # Should use fallback description + assert properties["test_param"].description == "Parameter test_param" From 7b29fc98a072435a09f7094aa42dd5558f661a2b Mon Sep 17 00:00:00 2001 From: Oskar Hane Date: Mon, 23 Jun 2025 10:57:04 +0200 Subject: [PATCH 3/5] Address PR comments: improve validation and remove redundant code - Remove redundant convert_retriever_to_tool function from utils.py - Add validation for tool name uniqueness in ToolsRetriever - Add parameter type validation in Tool constructor - Update convert_to_tool() to use search() method instead of get_search_results() - This ensures retriever result_formatter is applied for consistent formatting - Update ToolsRetriever to handle RetrieverResult objects from formatted tools - Create consistent record structure with tool attribution and metadata Fixes the result formatting inconsistency issue identified in PR review. Each tool now returns consistently formatted results while preserving the original retriever's formatting logic. --- .../retrieve/tools/tools_retriever_example.py | 24 +- src/neo4j_graphrag/retrievers/base.py | 2 +- .../retrievers/tools_retriever.py | 60 ++- src/neo4j_graphrag/tools/tool.py | 7 +- src/neo4j_graphrag/tools/utils.py | 76 ---- .../test_retriever_parameter_inference.py | 2 +- tests/unit/tool/test_tools_utils.py | 364 ++++++++---------- 7 files changed, 230 insertions(+), 305 deletions(-) delete mode 100644 src/neo4j_graphrag/tools/utils.py diff --git a/examples/retrieve/tools/tools_retriever_example.py b/examples/retrieve/tools/tools_retriever_example.py index 3309205cf..1f2c18c84 100644 --- a/examples/retrieve/tools/tools_retriever_example.py +++ b/examples/retrieve/tools/tools_retriever_example.py @@ -38,7 +38,6 @@ StringParameter, Tool, ) -from neo4j_graphrag.tools.utils import convert_retriever_to_tool from neo4j_graphrag.llm.openai_llm import OpenAILLM # Load environment variables from .env file (OPENAI_API_KEY required for this example) @@ -241,23 +240,13 @@ def main() -> None: # Create retrievers neo4j_retriever = Neo4jInfoRetriever(driver=driver) - # Define parameters for the tools - neo4j_parameters = ObjectParameter( - description="Parameters for Neo4j information retrieval", - properties={ - "query": StringParameter( - description="The query about Neo4j", - ), - }, - required_properties=["query"], - ) - # Convert retrievers to tools - neo4j_tool = convert_retriever_to_tool( - retriever=neo4j_retriever, + neo4j_tool = neo4j_retriever.convert_to_tool( name="neo4j_info_tool", description="Get information about Neo4j graph database", - parameters=neo4j_parameters, + parameter_descriptions={ + "query_text": "The query about Neo4j", + }, ) # Create a calendar tool @@ -325,7 +314,10 @@ def main() -> None: print("\nRESULTS:") for i, record in enumerate(result.records): print(f"\n--- Result {i + 1} ---") - print(record) + print(f"Content: {record.get('content', 'N/A')}") + print(f"Tool: {record.get('tool_name', 'Unknown')}") + if record.get("metadata"): + print(f"Metadata: {record.get('metadata')}") except Exception as e: print(f"Error: {str(e)}") print(f"{'=' * 80}") diff --git a/src/neo4j_graphrag/retrievers/base.py b/src/neo4j_graphrag/retrievers/base.py index c3f694973..e1481d58f 100644 --- a/src/neo4j_graphrag/retrievers/base.py +++ b/src/neo4j_graphrag/retrievers/base.py @@ -430,7 +430,7 @@ def convert_to_tool( # Define a function that matches the Callable[[str, ...], Any] signature def execute_func(**kwargs: Any) -> Any: - return self.get_search_results(**kwargs) + return self.search(**kwargs) # Create a Tool object from the retriever return Tool( diff --git a/src/neo4j_graphrag/retrievers/tools_retriever.py b/src/neo4j_graphrag/retrievers/tools_retriever.py index 633334b42..6b8b33cfd 100644 --- a/src/neo4j_graphrag/retrievers/tools_retriever.py +++ b/src/neo4j_graphrag/retrievers/tools_retriever.py @@ -55,10 +55,24 @@ def __init__( super().__init__(driver, neo4j_database) self.llm = llm self._tools = list(tools) # Make a copy to allow modification + self._validate_tool_names() self.system_instruction = ( system_instruction or self._get_default_system_instruction() ) + def _validate_tool_names(self) -> None: + """Validate that all tool names are unique.""" + tool_names = [tool.get_name() for tool in self._tools] + duplicate_names = [ + name for name in set(tool_names) if tool_names.count(name) > 1 + ] + + if duplicate_names: + raise ValueError( + f"Duplicate tool names found: {duplicate_names}. " + "All tools must have unique names for proper LLM tool selection." + ) + def _get_default_system_instruction(self) -> str: """Get the default system instruction for the LLM.""" return ( @@ -129,12 +143,48 @@ def get_search_results( # Execute the tool with the provided arguments tool_result = selected_tool.execute(**tool_args) - # If the tool result is a RawSearchResult, extract its records - if hasattr(tool_result, "records"): - all_records.extend(tool_result.records) + + # Handle different tool result types + if hasattr(tool_result, "items") and not callable( + getattr(tool_result, "items") + ): + # RetrieverResult from formatted retriever tools + for item in tool_result.items: + record = neo4j.Record( + { + "content": item.content, + "tool_name": tool_name, + "metadata": { + **(item.metadata or {}), + "tool": tool_name, + }, + } + ) + all_records.append(record) + elif hasattr(tool_result, "records"): + # RawSearchResult from raw retriever tools (legacy) + for record in tool_result.records: + # Wrap raw records with tool attribution + attributed_record = neo4j.Record( + { + "content": str(record), + "tool_name": tool_name, + "metadata": { + "original_record": dict(record), + "tool": tool_name, + }, + } + ) + all_records.append(attributed_record) else: - # Create a record from the tool result - record = neo4j.Record({"result": tool_result}) + # Handle non-retriever tools or simple return values + record = neo4j.Record( + { + "content": str(tool_result), + "tool_name": tool_name, + "metadata": {"tool": tool_name}, + } + ) all_records.append(record) # Combine metadata from all tool calls diff --git a/src/neo4j_graphrag/tools/tool.py b/src/neo4j_graphrag/tools/tool.py index a83802bf5..67451864b 100644 --- a/src/neo4j_graphrag/tools/tool.py +++ b/src/neo4j_graphrag/tools/tool.py @@ -231,8 +231,13 @@ def __init__( self._parameters = ObjectParameter.model_validate(parameters) elif isinstance(parameters, ObjectParameter): self._parameters = parameters - else: + elif parameters is None: self._parameters = None + else: + raise TypeError( + f"Parameters must be None, dict, or ObjectParameter, " + f"got {type(parameters).__name__}: {parameters}" + ) def get_name(self) -> str: """Get the name of the tool. diff --git a/src/neo4j_graphrag/tools/utils.py b/src/neo4j_graphrag/tools/utils.py deleted file mode 100644 index 0df86b0cd..000000000 --- a/src/neo4j_graphrag/tools/utils.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright (c) "Neo4j" -# Neo4j Sweden AB [https://neo4j.com] -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# https://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Dict, Optional, Union - -from neo4j_graphrag.tools.tool import Tool, ObjectParameter - - -def convert_retriever_to_tool( - retriever: Any, - description: Optional[str] = None, - parameters: Optional[Union[ObjectParameter, Dict[str, Any]]] = None, - name: Optional[str] = None, -) -> Tool: - """Convert a retriever instance to a Tool object. - - Args: - retriever (Any): The retriever instance to convert. - description (Optional[str]): Custom description for the tool. If not provided, - an attempt will be made to infer it from the retriever or a generic description will be used. - parameters (Optional[Union[ObjectParameter, Dict[str, ToolParameter]]]): Custom parameters for the tool. - If not provided, no parameters will be included in the tool. - name (Optional[str]): Custom name for the tool. If not provided, - an attempt will be made to infer it from the retriever or a default name will be used. - - Returns: - RetrieverTool: A Tool object configured to use the retriever's search functionality. - """ - # Use provided name or infer it from the retriever - if name is None: - name = getattr(retriever, "name", None) or getattr( - retriever.__class__, "__name__", "UnnamedRetrieverTool" - ) - - # Infer description if not provided - if description is None: - description = ( - getattr(retriever, "description", None) - or f"A tool for retrieving data using {name}." - ) - - # Parameters can be None - - # Define a function that matches the Callable[[str, ...], Any] signature - def execute_func(**kwargs: Any) -> Any: - # The retriever's get_search_results method is expected to handle - # arguments like query_text, top_k, etc., passed as keyword arguments. - # The Tool's 'parameters' definition (e.g., ObjectParameter) ensures - # that these arguments are provided in kwargs when Tool.execute is called. - return retriever.get_search_results(**kwargs) - - # Ensure name is a string - tool_name = str(name) if name is not None else "UnnamedRetrieverTool" - - # Create a Tool object from the retriever - - # Pass parameters directly to the Tool constructor - # If parameters is None, the Tool class will handle it appropriately - return Tool( - name=tool_name, - description=description, - execute_func=execute_func, - parameters=parameters, - ) diff --git a/tests/unit/retrievers/test_retriever_parameter_inference.py b/tests/unit/retrievers/test_retriever_parameter_inference.py index 81cd0ee10..0c03bffd6 100644 --- a/tests/unit/retrievers/test_retriever_parameter_inference.py +++ b/tests/unit/retrievers/test_retriever_parameter_inference.py @@ -402,7 +402,7 @@ def test_tool_execution(self): # Check that we get a result (even if empty due to mocking) assert result is not None - assert hasattr(result, "records") + assert hasattr(result, "items") # Should return RetrieverResult now assert hasattr(result, "metadata") def test_tool_execution_with_validation(self): diff --git a/tests/unit/tool/test_tools_utils.py b/tests/unit/tool/test_tools_utils.py index 6926c5105..85fe7fde6 100644 --- a/tests/unit/tool/test_tools_utils.py +++ b/tests/unit/tool/test_tools_utils.py @@ -25,13 +25,7 @@ VectorCypherRetriever, VectorRetriever, ) -from neo4j_graphrag.tools.tool import ( - Tool, - ObjectParameter, - StringParameter, - IntegerParameter, -) -from neo4j_graphrag.tools.utils import convert_retriever_to_tool +from neo4j_graphrag.tools.tool import Tool # Mock dependencies for retriever instances @@ -69,31 +63,26 @@ def test_convert_vector_retriever_to_tool(mock_get_version: MagicMock) -> None: embedder=embedder, return_properties=["name", "description"], ) - parameters = ObjectParameter( - description="Parameters for vector search", - properties={ - "query_text": StringParameter( - description="The query text for vector search.", - required=True, - ), - "top_k": IntegerParameter( - description="Number of results to return.", - required=False, - ), - }, - ) - tool = convert_retriever_to_tool( - retriever, + tool = retriever.convert_to_tool( + name="VectorRetriever", description="A tool for vector-based retrieval from Neo4j.", - parameters=parameters, + parameter_descriptions={ + "query_text": "The query text for vector search.", + "top_k": "Number of results to return.", + }, ) assert isinstance(tool, Tool) - assert tool.get_name() in ["VectorRetriever", "UnnamedRetrieverTool"] + assert tool.get_name() == "VectorRetriever" assert tool.get_description() == "A tool for vector-based retrieval from Neo4j." # Check that the parameters object has the expected properties params = tool.get_parameters() assert "properties" in params - assert len(params["properties"]) == 2 + assert len(params["properties"]) == 5 # VectorRetriever has 5 parameters + assert "query_text" in params["properties"] + assert "top_k" in params["properties"] + assert "query_vector" in params["properties"] + assert "effective_search_ratio" in params["properties"] + assert "filters" in params["properties"] # Test conversion with VectorCypherRetriever @@ -109,31 +98,27 @@ def test_convert_vector_cypher_retriever_to_tool(mock_get_version: MagicMock) -> embedder=embedder, retrieval_query="RETURN n", ) - parameters = ObjectParameter( - description="Parameters for vector-cypher search", - properties={ - "query_text": StringParameter( - description="The query text for vector-cypher search.", - required=True, - ), - "top_k": IntegerParameter( - description="Number of results to return.", - required=False, - ), - }, - ) - tool = convert_retriever_to_tool( - retriever, + tool = retriever.convert_to_tool( + name="VectorCypherRetriever", description="A tool for vector-cypher retrieval from Neo4j.", - parameters=parameters, + parameter_descriptions={ + "query_text": "The query text for vector-cypher search.", + "top_k": "Number of results to return.", + }, ) assert isinstance(tool, Tool) - assert tool.get_name() in ["VectorCypherRetriever", "UnnamedRetrieverTool"] + assert tool.get_name() == "VectorCypherRetriever" assert tool.get_description() == "A tool for vector-cypher retrieval from Neo4j." # Check that the parameters object has the expected properties params = tool.get_parameters() assert "properties" in params - assert len(params["properties"]) == 2 + assert len(params["properties"]) == 6 # VectorCypherRetriever has 6 parameters + assert "query_text" in params["properties"] + assert "top_k" in params["properties"] + assert "query_vector" in params["properties"] + assert "effective_search_ratio" in params["properties"] + assert "query_params" in params["properties"] + assert "filters" in params["properties"] # Test conversion with HybridRetriever @@ -150,31 +135,27 @@ def test_convert_hybrid_retriever_to_tool(mock_get_version: MagicMock) -> None: embedder=embedder, return_properties=["name", "description"], ) - parameters = ObjectParameter( - description="Parameters for hybrid search", - properties={ - "query_text": StringParameter( - description="The query text for hybrid search.", - required=True, - ), - "top_k": IntegerParameter( - description="Number of results to return.", - required=False, - ), - }, - ) - tool = convert_retriever_to_tool( - retriever, + tool = retriever.convert_to_tool( + name="HybridRetriever", description="A tool for hybrid retrieval from Neo4j.", - parameters=parameters, + parameter_descriptions={ + "query_text": "The query text for hybrid search.", + "top_k": "Number of results to return.", + }, ) assert isinstance(tool, Tool) - assert tool.get_name() in ["HybridRetriever", "UnnamedRetrieverTool"] + assert tool.get_name() == "HybridRetriever" assert tool.get_description() == "A tool for hybrid retrieval from Neo4j." # Check that the parameters object has the expected properties params = tool.get_parameters() assert "properties" in params - assert len(params["properties"]) == 2 + assert len(params["properties"]) == 6 # HybridRetriever has 6 parameters + assert "query_text" in params["properties"] + assert "top_k" in params["properties"] + assert "query_vector" in params["properties"] + assert "effective_search_ratio" in params["properties"] + assert "ranker" in params["properties"] + assert "alpha" in params["properties"] # Test conversion with HybridCypherRetriever @@ -191,31 +172,28 @@ def test_convert_hybrid_cypher_retriever_to_tool(mock_get_version: MagicMock) -> embedder=embedder, retrieval_query="RETURN n", ) - parameters = ObjectParameter( - description="Parameters for hybrid-cypher search", - properties={ - "query_text": StringParameter( - description="The query text for hybrid-cypher search.", - required=True, - ), - "top_k": IntegerParameter( - description="Number of results to return.", - required=False, - ), - }, - ) - tool = convert_retriever_to_tool( - retriever, + tool = retriever.convert_to_tool( + name="HybridCypherRetriever", description="A tool for hybrid-cypher retrieval from Neo4j.", - parameters=parameters, + parameter_descriptions={ + "query_text": "The query text for hybrid-cypher search.", + "top_k": "Number of results to return.", + }, ) assert isinstance(tool, Tool) - assert tool.get_name() in ["HybridCypherRetriever", "UnnamedRetrieverTool"] + assert tool.get_name() == "HybridCypherRetriever" assert tool.get_description() == "A tool for hybrid-cypher retrieval from Neo4j." # Check that the parameters object has the expected properties params = tool.get_parameters() assert "properties" in params - assert len(params["properties"]) == 2 + assert len(params["properties"]) == 7 # HybridCypherRetriever has 7 parameters + assert "query_text" in params["properties"] + assert "query_vector" in params["properties"] + assert "top_k" in params["properties"] + assert "effective_search_ratio" in params["properties"] + assert "query_params" in params["properties"] + assert "ranker" in params["properties"] + assert "alpha" in params["properties"] # Test conversion with Text2CypherRetriever @@ -226,27 +204,22 @@ def test_convert_text2cypher_retriever_to_tool(mock_get_version: MagicMock) -> N driver = create_mock_driver() llm = create_mock_llm() retriever = Text2CypherRetriever(driver=driver, llm=llm) - parameters = ObjectParameter( - description="Parameters for text to Cypher conversion", - properties={ - "query_text": StringParameter( - description="The query text for text to Cypher conversion.", - required=True, - ), - }, - ) - tool = convert_retriever_to_tool( - retriever, + tool = retriever.convert_to_tool( + name="Text2CypherRetriever", description="A tool for text to Cypher retrieval from Neo4j.", - parameters=parameters, + parameter_descriptions={ + "query_text": "The query text for text to Cypher conversion.", + }, ) assert isinstance(tool, Tool) - assert tool.get_name() in ["Text2CypherRetriever", "UnnamedRetrieverTool"] + assert tool.get_name() == "Text2CypherRetriever" assert tool.get_description() == "A tool for text to Cypher retrieval from Neo4j." # Check that the parameters object has the expected properties params = tool.get_parameters() assert "properties" in params - assert len(params["properties"]) == 1 + assert len(params["properties"]) == 2 # Text2CypherRetriever has 2 parameters + assert "query_text" in params["properties"] + assert "prompt_params" in params["properties"] # Test conversion with custom name provided @@ -266,27 +239,17 @@ def test_convert_retriever_with_custom_name( ) custom_name = "CustomNamedTool" - parameters = ObjectParameter( - description="Parameters for vector search", - properties={ - "query_text": StringParameter( - description="The query text for vector search.", - required=True, - ), - }, - ) - tool = convert_retriever_to_tool( - retriever, - description="A tool with a custom name", - parameters=parameters, + tool = retriever.convert_to_tool( name=custom_name, + description="A tool with a custom name", + parameter_descriptions={ + "query_text": "The query text for vector search.", + }, ) # Verify that the custom name is used instead of the retriever class name assert tool.get_name() == custom_name - assert tool.get_name() != "VectorRetriever" - assert tool.get_name() != "UnnamedRetrieverTool" # Test conversion with no parameters provided @@ -304,14 +267,18 @@ def test_convert_vector_retriever_to_tool_no_parameters( embedder=embedder, return_properties=["name", "description"], ) - tool = convert_retriever_to_tool( - retriever, description="A tool for vector-based retrieval from Neo4j." + tool = retriever.convert_to_tool( + name="VectorRetriever", + description="A tool for vector-based retrieval from Neo4j.", ) assert isinstance(tool, Tool) - assert tool.get_name() in ["VectorRetriever", "UnnamedRetrieverTool"] + assert tool.get_name() == "VectorRetriever" assert tool.get_description() == "A tool for vector-based retrieval from Neo4j." - # Since we don't provide parameters, it should be None - assert tool._parameters is None + # With the new API, parameters are always auto-inferred from method signature + params = tool.get_parameters() + assert params is not None + assert "properties" in params + assert len(params["properties"]) == 5 # VectorRetriever has 5 parameters # Test tool execution for VectorRetriever @@ -327,28 +294,25 @@ def test_vector_retriever_tool_execution(mock_get_version: MagicMock) -> None: embedder=embedder, return_properties=["name", "description"], ) - parameters = ObjectParameter( - description="Parameters for vector search", - properties={ - "query_text": StringParameter( - description="The query text for vector search.", - required=True, - ), - "top_k": IntegerParameter( - description="Number of results to return.", - required=False, - ), - }, + # Create the tool first, before mocking + with patch.object(VectorRetriever, "_fetch_index_infos"): + tool = retriever.convert_to_tool( + name="VectorRetriever", + description="A tool for vector-based retrieval from Neo4j.", + parameter_descriptions={ + "query_text": "The query text for vector search.", + "top_k": "Number of results to return.", + }, + ) + + # Now mock the get_search_results method to track calls + from neo4j_graphrag.types import RawSearchResult + + get_search_results_mock = MagicMock( + return_value=RawSearchResult(records=[], metadata={}) ) - # Mock the get_search_results method to track calls - get_search_results_mock = MagicMock(return_value=([], None)) # Use patch to mock the method with patch.object(retriever, "get_search_results", get_search_results_mock): - tool = convert_retriever_to_tool( - retriever, - description="A tool for vector-based retrieval from Neo4j.", - parameters=parameters, - ) tools = {tool.get_name(): tool} # Simulate indirect invocation as would happen in real usage tool_call_arguments = {"query_text": "test query", "top_k": 5} @@ -357,7 +321,10 @@ def test_vector_retriever_tool_execution(mock_get_version: MagicMock) -> None: # Since we're using a context manager for patching, we need to verify the call inside the context # We can only check the result, not the method call itself - assert result == ([], None) + assert result is not None + assert hasattr(result, "items") # Should return RetrieverResult now + assert isinstance(result.items, list) + assert hasattr(result, "metadata") # Test tool execution for HybridRetriever @@ -374,28 +341,25 @@ def test_hybrid_retriever_tool_execution(mock_get_version: MagicMock) -> None: embedder=embedder, return_properties=["name", "description"], ) - parameters = ObjectParameter( - description="Parameters for hybrid search", - properties={ - "query_text": StringParameter( - description="The query text for hybrid search.", - required=True, - ), - "top_k": IntegerParameter( - description="Number of results to return.", - required=False, - ), - }, + # Create the tool first, before mocking + with patch.object(HybridRetriever, "_fetch_index_infos"): + tool = retriever.convert_to_tool( + name="HybridRetriever", + description="A tool for hybrid retrieval from Neo4j.", + parameter_descriptions={ + "query_text": "The query text for hybrid search.", + "top_k": "Number of results to return.", + }, + ) + + # Now mock the get_search_results method to track calls + from neo4j_graphrag.types import RawSearchResult + + get_search_results_mock = MagicMock( + return_value=RawSearchResult(records=[], metadata={}) ) - # Mock the get_search_results method to track calls - get_search_results_mock = MagicMock(return_value=([], None)) # Use patch to mock the method with patch.object(retriever, "get_search_results", get_search_results_mock): - tool = convert_retriever_to_tool( - retriever, - description="A tool for hybrid retrieval from Neo4j.", - parameters=parameters, - ) tools = {tool.get_name(): tool} # Simulate indirect invocation as would happen in real usage tool_call_arguments = {"query_text": "test query", "top_k": 5} @@ -404,7 +368,10 @@ def test_hybrid_retriever_tool_execution(mock_get_version: MagicMock) -> None: # Since we're using a context manager for patching, we need to verify the call inside the context # We can only check the result, not the method call itself - assert result == ([], None) + assert result is not None + assert hasattr(result, "items") # Should return RetrieverResult now + assert isinstance(result.items, list) + assert hasattr(result, "metadata") # Test tool execution for Text2CypherRetriever @@ -415,24 +382,23 @@ def test_text2cypher_retriever_tool_execution(mock_get_version: MagicMock) -> No driver = create_mock_driver() llm = create_mock_llm() retriever = Text2CypherRetriever(driver=driver, llm=llm) - parameters = ObjectParameter( - description="Parameters for text to Cypher conversion", - properties={ - "query_text": StringParameter( - description="The query text for text to Cypher conversion.", - required=True, - ), + # Create the tool first, before mocking + tool = retriever.convert_to_tool( + name="Text2CypherRetriever", + description="A tool for text to Cypher retrieval from Neo4j.", + parameter_descriptions={ + "query_text": "The query text for text to Cypher conversion.", }, ) - # Mock the get_search_results method to track calls - get_search_results_mock = MagicMock(return_value=([], None)) + + # Now mock the get_search_results method to track calls + from neo4j_graphrag.types import RawSearchResult + + get_search_results_mock = MagicMock( + return_value=RawSearchResult(records=[], metadata={}) + ) # Use patch to mock the method with patch.object(retriever, "get_search_results", get_search_results_mock): - tool = convert_retriever_to_tool( - retriever, - description="A tool for text to Cypher retrieval from Neo4j.", - parameters=parameters, - ) tools = {tool.get_name(): tool} # Simulate indirect invocation as would happen in real usage tool_call_arguments = {"query_text": "test query"} @@ -441,7 +407,10 @@ def test_text2cypher_retriever_tool_execution(mock_get_version: MagicMock) -> No # Since we're using a context manager for patching, we need to verify the call inside the context # We can only check the result, not the method call itself - assert result == ([], None) + assert result is not None + assert hasattr(result, "items") # Should return RetrieverResult now + assert isinstance(result.items, list) + assert hasattr(result, "metadata") # Test tool serialization to JSON format @@ -457,24 +426,13 @@ def test_tool_serialization(mock_get_version: MagicMock) -> None: embedder=embedder, return_properties=["name", "description"], ) - # Define parameters for the tool - parameters = ObjectParameter( - description="Parameters for vector search", - properties={ - "query_text": StringParameter( - description="The query text for vector search.", - required=True, - ), - "top_k": IntegerParameter( - description="Number of results to return.", - required=False, - ), - }, - ) - tool = convert_retriever_to_tool( - retriever, + tool = retriever.convert_to_tool( + name="VectorRetriever", description="A tool for vector-based retrieval from Neo4j.", - parameters=parameters, + parameter_descriptions={ + "query_text": "The query text for vector search.", + "top_k": "Number of results to return.", + }, ) # Create a dictionary representation of the tool tool_dict = { @@ -491,13 +449,11 @@ def test_tool_serialization(mock_get_version: MagicMock) -> None: # Get parameters and convert to dictionary parameters_any = tool_dict["parameters"] - # Use type casting to handle various parameter types - if isinstance(parameters_any, ObjectParameter): - parameters_dict = parameters_any.model_dump_tool() - elif isinstance(parameters_any, dict): + # With the new API, parameters should be a dictionary + if isinstance(parameters_any, dict): parameters_dict = parameters_any else: - # Handle the case where parameters is a Collection[str] or other type + # Handle unexpected parameter format parameters_dict = { str(k): v for k, v in enumerate(parameters_any) if v is not None } @@ -506,21 +462,19 @@ def test_tool_serialization(mock_get_version: MagicMock) -> None: assert parameters_dict.get("type") == "object" assert "properties" in parameters_dict - # Check that at least one parameter is marked as required - required_found = False - properties = parameters_dict.get("properties", {}) - if isinstance(properties, dict): - for param_name, param_data in properties.items(): - if isinstance(param_data, dict) and param_data.get("required", False): - required_found = True - break - - if not required_found and "required" in parameters_dict: - # Check if there's a required array at the parameters level - required_params = parameters_dict.get("required", []) - required_found = len(list(required_params)) > 0 - - assert required_found, "No required parameters found" + # Check that we have the expected parameter properties + # VectorRetriever has all optional parameters (query_vector and query_text are both optional) + expected_properties = { + "query_vector", + "query_text", + "top_k", + "effective_search_ratio", + "filters", + } + actual_properties = set(parameters_dict.get("properties", {}).keys()) + assert ( + expected_properties == actual_properties + ), f"Expected {expected_properties}, got {actual_properties}" # Check additionalProperties if it exists if "additionalProperties" in parameters_dict and not parameters_dict.get( From d0d2e3de0fff79c0887ed14e0cde42f61ad7b1f4 Mon Sep 17 00:00:00 2001 From: Oskar Hane Date: Mon, 23 Jun 2025 17:31:43 +0200 Subject: [PATCH 4/5] Fix broken example file --- examples/retrieve/tools/tools_retriever_example.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/examples/retrieve/tools/tools_retriever_example.py b/examples/retrieve/tools/tools_retriever_example.py index 1f2c18c84..d5f50af77 100644 --- a/examples/retrieve/tools/tools_retriever_example.py +++ b/examples/retrieve/tools/tools_retriever_example.py @@ -105,9 +105,8 @@ def __init__(self) -> None: "2025-04-16": [], } - # Define a wrapper function that handles the query parameter correctly - def execute_func(query: str, **kwargs: Any) -> str: - # Ignore the query parameter and call our execute method + # Define a wrapper function that handles parameters correctly + def execute_func(**kwargs: Any) -> str: return self.execute_calendar(**kwargs) super().__init__( @@ -166,9 +165,7 @@ def __init__(self) -> None: execute_func=self.execute_weather_retrieval, ) - def execute_weather_retrieval( - self, query: Optional[str] = None, **kwargs: Any - ) -> str: + def execute_weather_retrieval(self, **kwargs: Any) -> str: """Fetch historical weather data for a given date in Malmö, Sweden.""" date_str = kwargs.get("date") if not date_str: @@ -227,10 +224,6 @@ def execute_weather_retrieval( ) as e: return f"Error parsing weather data for Malmö on {date_str}: {e}" - return ( - f"Sorry, I couldn't fetch the weather for Malmö on {date_str} at this time." - ) - def main() -> None: """Run the example.""" From 526ace97d2525be17082743487eb89075e9673db Mon Sep 17 00:00:00 2001 From: estelle Date: Fri, 4 Jul 2025 09:52:06 +0200 Subject: [PATCH 5/5] CHANGELOG --- CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5509b995f..5bcd8b60e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,12 @@ ## Next +### Added + +- Added a `ToolsRetriever` retriever that uses an LLM to decide on what tools to use to find the relevant data. +- Added `convert_to_tool` method to the `Retriever` interface to convert a Retriever to a Tool so it can be used within the ToolsRetriever. This is useful when you might want to have both a VectorRetriever and a Text2CypherRetreiver as a fallback. + + ## 1.8.0 ### Added