From 6864aadee75285bd339c2503bf30a346ff485e4a Mon Sep 17 00:00:00 2001 From: Oskar Hane Date: Fri, 14 Mar 2025 15:42:46 +0100 Subject: [PATCH 1/4] Add CypherRetriever for parameterized Cypher queries MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds a new CypherRetriever that enables direct database access through parameterized Cypher queries. Key features include: - Type-safe parameter validation - Support for optional parameters - Custom result formatting - Documentation and examples 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- docs/source/api.rst | 7 + docs/source/user_guide_rag.rst | 172 ++++++++++++++ examples/retrieve/cypher_retriever.py | 205 ++++++++++++++++ src/neo4j_graphrag/retrievers/__init__.py | 2 + src/neo4j_graphrag/retrievers/cypher.py | 268 +++++++++++++++++++++ src/neo4j_graphrag/types.py | 32 ++- tests/e2e/retrievers/test_cypher_e2e.py | 220 +++++++++++++++++ tests/unit/retrievers/test_cypher.py | 272 ++++++++++++++++++++++ 8 files changed, 1177 insertions(+), 1 deletion(-) create mode 100644 examples/retrieve/cypher_retriever.py create mode 100644 src/neo4j_graphrag/retrievers/cypher.py create mode 100644 tests/e2e/retrievers/test_cypher_e2e.py create mode 100644 tests/unit/retrievers/test_cypher.py diff --git a/docs/source/api.rst b/docs/source/api.rst index f27bf3af7..21b5286a1 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -187,6 +187,13 @@ Text2CypherRetriever :members: search +CypherRetriever +=============== + +.. autoclass:: neo4j_graphrag.retrievers.CypherRetriever + :members: search + + ******************* External Retrievers ******************* diff --git a/docs/source/user_guide_rag.rst b/docs/source/user_guide_rag.rst index b51c019a1..73a6f1028 100644 --- a/docs/source/user_guide_rag.rst +++ b/docs/source/user_guide_rag.rst @@ -349,6 +349,8 @@ We provide implementations for the following retrievers: - Same as HybridRetriever with a retrieval query similar to VectorCypherRetriever. * - :ref:`Text2Cypher ` - Translates the user question into a Cypher query to be run against a Neo4j database (or Knowledge Graph). The results of the query are then passed to the LLM to generate the final answer. + * - :ref:`CypherRetriever ` + - Uses a predefined Cypher query template with parameterized inputs to retrieve data from the database. * - :ref:`WeaviateNeo4jRetriever ` - Use this retriever when vectors are saved in a Weaviate vector database * - :ref:`PineconeNeo4jRetriever ` @@ -849,6 +851,176 @@ LLMs can be different. See :ref:`text2cypherretriever`. +.. _cypher-retriever-user-guide: + +Cypher Retriever +=============================== + +The `CypherRetriever` allows you to define a templated Cypher query with parameterized inputs. This retriever is useful when you need direct database access with dynamic parameters, but without the complexity of LLM-generated queries or vector similarity search. + +Basic Usage +---------- + +The simplest usage involves defining a query with parameters: + +.. code:: python + + from neo4j_graphrag.retrievers import CypherRetriever + + # Create a retriever for finding movies by title + retriever = CypherRetriever( + driver=driver, + query="MATCH (m:Movie {title: $movie_title}) RETURN m", + parameters={ + "movie_title": { + "type": "string", + "description": "Title of a movie" + } + } + ) + + # Use the retriever with specific parameter values + results = retriever.search(parameters={"movie_title": "The Matrix"}) + +Parameter Types +--------------- + +The CypherRetriever supports these parameter types: + +- `string`: For text values +- `number`: For floating point values +- `integer`: For whole number values +- `boolean`: For true/false values +- `array`: For lists of values + +Optional Parameters +------------------ + +You can make parameters optional by setting `required: false` in the parameter definition: + +.. code:: python + + retriever = CypherRetriever( + driver=driver, + query=""" + MATCH (m:Movie) + WHERE ($title IS NULL OR m.title CONTAINS $title) + AND ($year IS NULL OR m.released = $year) + RETURN m + """, + parameters={ + "title": { + "type": "string", + "description": "Movie title to search for", + "required": False + }, + "year": { + "type": "integer", + "description": "Release year", + "required": False + } + } + ) + + # Search with only one parameter + results = retriever.search(parameters={"title": "Matrix"}) + +Complex Queries +-------------- + +You can build more complex queries with multiple parameters and conditions: + +.. code:: python + + retriever = CypherRetriever( + driver=driver, + query=""" + MATCH (m:Movie) + WHERE ($title IS NULL OR m.title CONTAINS $title) + AND ($min_year IS NULL OR m.released >= $min_year) + AND ($max_year IS NULL OR m.released <= $max_year) + AND ($min_rating IS NULL OR m.rating >= $min_rating) + RETURN m + ORDER BY m.rating DESC + LIMIT $limit + """, + parameters={ + "title": { + "type": "string", + "description": "Partial movie title to search for", + "required": False + }, + "min_year": { + "type": "integer", + "description": "Minimum release year", + "required": False + }, + "max_year": { + "type": "integer", + "description": "Maximum release year", + "required": False + }, + "min_rating": { + "type": "number", + "description": "Minimum movie rating", + "required": False + }, + "limit": { + "type": "integer", + "description": "Maximum number of results to return", + "required": True + } + } + ) + +Custom Result Formatting +----------------------- + +You can customize how the results are formatted using a result formatter: + +.. code:: python + + def movie_formatter(record): + movie = record["m"] + return RetrieverResultItem( + content=f"{movie['title']} ({movie['released']})", + metadata={ + "rating": movie.get("rating"), + "tagline": movie.get("tagline"), + } + ) + + retriever = CypherRetriever( + driver=driver, + query="MATCH (m:Movie) WHERE m.title CONTAINS $title RETURN m", + parameters={"title": {"type": "string", "description": "Movie title"}}, + result_formatter=movie_formatter + ) + +Graph Traversals +--------------- + +The CypherRetriever is particularly useful for complex graph traversals: + +.. code:: python + + retriever = CypherRetriever( + driver=driver, + query=""" + MATCH (m:Movie {title: $movie_title})<-[r:ACTED_IN]-(a:Person) + RETURN a.name as actor, r.roles as roles + ORDER BY a.name + """, + parameters={ + "movie_title": { + "type": "string", + "description": "Title of a movie" + } + } + ) + +See :ref:`cypherretriever`. + .. _custom-retriever: Custom Retriever diff --git a/examples/retrieve/cypher_retriever.py b/examples/retrieve/cypher_retriever.py new file mode 100644 index 000000000..38af3d471 --- /dev/null +++ b/examples/retrieve/cypher_retriever.py @@ -0,0 +1,205 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example of using CypherRetriever for parametrized Cypher queries. + +This example demonstrates how to use CypherRetriever to define a retriever with +a templated Cypher query that accepts parameters at runtime. +""" + +import neo4j +from neo4j_graphrag.retrievers import CypherRetriever +from neo4j_graphrag.types import RetrieverResultItem + +# Connect to Neo4j +# Replace with your own connection details +NEO4J_URI = "bolt://localhost:7687" +NEO4J_USER = "neo4j" +NEO4J_PASSWORD = "password" # Change this in production + +driver = neo4j.GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD)) + +# Simple example: Find a movie by title +def find_movie_by_title(): + retriever = CypherRetriever( + driver=driver, + query="MATCH (m:Movie {title: $movie_title}) RETURN m", + parameters={ + "movie_title": { + "type": "string", + "description": "Title of a movie" + } + } + ) + + # Use the retriever to search for a movie + result = retriever.search(parameters={"movie_title": "The Matrix"}) + + print("=== Find Movie by Title ===") + for item in result.items: + print(f"Movie: {item.content}") + print() + + +# Advanced example: Find movies with multiple criteria +def find_movies_by_criteria(): + # Custom formatter to extract specific information + def movie_formatter(record): + movie = record["m"] + return RetrieverResultItem( + content=f"{movie['title']} ({movie['released']})", + metadata={ + "rating": movie.get("rating"), + "tagline": movie.get("tagline"), + } + ) + + # Create a more complex retriever with multiple parameters + retriever = CypherRetriever( + driver=driver, + query=""" + MATCH (m:Movie) + WHERE ($title IS NULL OR m.title CONTAINS $title) + AND ($min_year IS NULL OR m.released >= $min_year) + AND ($max_year IS NULL OR m.released <= $max_year) + AND ($min_rating IS NULL OR m.rating >= $min_rating) + RETURN m + ORDER BY m.rating DESC + LIMIT $limit + """, + parameters={ + "title": { + "type": "string", + "description": "Partial movie title to search for", + "required": False + }, + "min_year": { + "type": "integer", + "description": "Minimum release year", + "required": False + }, + "max_year": { + "type": "integer", + "description": "Maximum release year", + "required": False + }, + "min_rating": { + "type": "number", + "description": "Minimum movie rating", + "required": False + }, + "limit": { + "type": "integer", + "description": "Maximum number of results to return", + "required": True + } + }, + result_formatter=movie_formatter + ) + + # Search with optional parameters + result = retriever.search( + parameters={ + "title": "Matrix", + "min_year": 1990, + "min_rating": 7.5, + "limit": 5 + } + ) + + print("=== Find Movies by Criteria ===") + for item in result.items: + print(f"Movie: {item.content}") + if item.metadata: + if "rating" in item.metadata: + print(f" Rating: {item.metadata['rating']}") + if "tagline" in item.metadata: + print(f" Tagline: {item.metadata['tagline']}") + print() + + +# Example with relationship traversal +def find_actors_in_movie(): + retriever = CypherRetriever( + driver=driver, + query=""" + MATCH (m:Movie {title: $movie_title})<-[r:ACTED_IN]-(a:Person) + RETURN a.name as actor, r.roles as roles + ORDER BY a.name + """, + parameters={ + "movie_title": { + "type": "string", + "description": "Title of a movie" + } + } + ) + + result = retriever.search(parameters={"movie_title": "The Matrix"}) + + print("=== Find Actors in Movie ===") + for item in result.items: + record = eval(item.content) # Simple way to parse the string representation + actor = record.get("actor", "Unknown") + roles = record.get("roles", []) + roles_str = ", ".join(roles) if roles else "Unknown role" + print(f"Actor: {actor} as {roles_str}") + print() + + +if __name__ == "__main__": + try: + # Setup: Make sure we have some movie data + with driver.session() as session: + # Check if data exists + result = session.run("MATCH (m:Movie) RETURN count(m) as count") + count = result.single()["count"] + + if count == 0: + print("No movie data found. Creating sample data...") + # Create sample data if none exists + session.run(""" + CREATE (TheMatrix:Movie {title:'The Matrix', released:1999, tagline:'Welcome to the Real World', rating: 8.7}) + CREATE (Keanu:Person {name:'Keanu Reeves', born:1964}) + CREATE (Carrie:Person {name:'Carrie-Anne Moss', born:1967}) + CREATE (Laurence:Person {name:'Laurence Fishburne', born:1961}) + CREATE (Hugo:Person {name:'Hugo Weaving', born:1960}) + CREATE (Keanu)-[:ACTED_IN {roles:['Neo']}]->(TheMatrix) + CREATE (Carrie)-[:ACTED_IN {roles:['Trinity']}]->(TheMatrix) + CREATE (Laurence)-[:ACTED_IN {roles:['Morpheus']}]->(TheMatrix) + CREATE (Hugo)-[:ACTED_IN {roles:['Agent Smith']}]->(TheMatrix) + CREATE (TheMatrixReloaded:Movie {title:'The Matrix Reloaded', released:2003, tagline:'Free your mind', rating: 7.2}) + CREATE (Keanu)-[:ACTED_IN {roles:['Neo']}]->(TheMatrixReloaded) + CREATE (Carrie)-[:ACTED_IN {roles:['Trinity']}]->(TheMatrixReloaded) + CREATE (Laurence)-[:ACTED_IN {roles:['Morpheus']}]->(TheMatrixReloaded) + CREATE (Hugo)-[:ACTED_IN {roles:['Agent Smith']}]->(TheMatrixReloaded) + CREATE (TheMatrixRevolutions:Movie {title:'The Matrix Revolutions', released:2003, tagline:'Everything that has a beginning has an end', rating: 6.8}) + CREATE (Keanu)-[:ACTED_IN {roles:['Neo']}]->(TheMatrixRevolutions) + CREATE (Carrie)-[:ACTED_IN {roles:['Trinity']}]->(TheMatrixRevolutions) + CREATE (Laurence)-[:ACTED_IN {roles:['Morpheus']}]->(TheMatrixRevolutions) + CREATE (Hugo)-[:ACTED_IN {roles:['Agent Smith']}]->(TheMatrixRevolutions) + """) + print("Sample data created.") + else: + print(f"Found {count} movies in the database.") + + # Run the examples + find_movie_by_title() + find_movies_by_criteria() + find_actors_in_movie() + + finally: + # Close the driver + driver.close() \ No newline at end of file diff --git a/src/neo4j_graphrag/retrievers/__init__.py b/src/neo4j_graphrag/retrievers/__init__.py index 595eac93b..2e957552f 100644 --- a/src/neo4j_graphrag/retrievers/__init__.py +++ b/src/neo4j_graphrag/retrievers/__init__.py @@ -16,6 +16,7 @@ from .hybrid import HybridCypherRetriever, HybridRetriever from .text2cypher import Text2CypherRetriever from .vector import VectorCypherRetriever, VectorRetriever +from .cypher import CypherRetriever __all__ = [ "VectorRetriever", @@ -23,6 +24,7 @@ "HybridRetriever", "HybridCypherRetriever", "Text2CypherRetriever", + "CypherRetriever", ] diff --git a/src/neo4j_graphrag/retrievers/cypher.py b/src/neo4j_graphrag/retrievers/cypher.py new file mode 100644 index 000000000..97d4658b0 --- /dev/null +++ b/src/neo4j_graphrag/retrievers/cypher.py @@ -0,0 +1,268 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import logging +import re +from typing import Any, Callable, Dict, List, Optional, Union + +import neo4j +from neo4j.exceptions import CypherSyntaxError +from pydantic import ValidationError + +from neo4j_graphrag.exceptions import RetrieverInitializationError, SearchValidationError +from neo4j_graphrag.retrievers.base import Retriever +from neo4j_graphrag.types import ( + CypherParameterDefinition, + CypherParameterType, + CypherRetrieverModel, + CypherSearchModel, + Neo4jDriverModel, + RawSearchResult, + RetrieverResultItem, +) + +logger = logging.getLogger(__name__) + + +class CypherRetriever(Retriever): + """ + Allows for the retrieval of records from a Neo4j database using a parameterized Cypher query. + + This retriever enables direct execution of predefined Cypher queries with dynamic parameters. + It ensures type safety through parameter validation and provides the standard retriever result format. + + Example: + + .. code-block:: python + + import neo4j + from neo4j_graphrag.retrievers import CypherRetriever + + driver = neo4j.GraphDatabase.driver(URI, auth=AUTH) + + # Create a retriever for finding movies by title + retriever = CypherRetriever( + driver=driver, + query="MATCH (m:Movie {title: $movie_title}) RETURN m", + parameters={ + "movie_title": { + "type": "string", + "description": "Title of a movie" + } + } + ) + + # Use the retriever with specific parameter values + results = retriever.search(parameters={"movie_title": "The Matrix"}) + + Args: + driver (neo4j.Driver): The Neo4j Python driver. + query (str): Cypher query with parameter placeholders. + parameters (Dict[str, Dict]): Parameter definitions with types and descriptions. + Each parameter should have a 'type' and 'description' field. + Supported types: 'string', 'number', 'integer', 'boolean', 'array'. + result_formatter (Optional[Callable[[neo4j.Record], RetrieverResultItem]]): + Custom function to transform a neo4j.Record to a RetrieverResultItem. + neo4j_database (Optional[str]): The name of the Neo4j database to use. + + Raises: + RetrieverInitializationError: If validation of the input arguments fail. + """ + + def __init__( + self, + driver: neo4j.Driver, + query: str, + parameters: Dict[str, Dict], + result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None, + neo4j_database: Optional[str] = None, + ) -> None: + # Convert parameter dictionaries to CypherParameterDefinition objects + param_definitions = {} + for param_name, param_def in parameters.items(): + param_type = param_def.get("type", "string") + description = param_def.get("description", "") + required = param_def.get("required", True) + + try: + param_definitions[param_name] = CypherParameterDefinition( + type=param_type, + description=description, + required=required + ) + except ValidationError as e: + raise RetrieverInitializationError( + f"Invalid parameter definition for {param_name}: {e.errors()}" + ) from e + + try: + driver_model = Neo4jDriverModel(driver=driver) + validated_data = CypherRetrieverModel( + driver_model=driver_model, + query=query, + parameters=param_definitions, + result_formatter=result_formatter, + neo4j_database=neo4j_database, + ) + except ValidationError as e: + raise RetrieverInitializationError(e.errors()) from e + + # Validate that the query is syntactically valid Cypher + self._validate_cypher_query(query) + + # Validate that all parameters in the query are defined + self._validate_query_parameters(query, param_definitions) + + super().__init__(validated_data.driver_model.driver, validated_data.neo4j_database) + self.query = validated_data.query + self.parameters = validated_data.parameters + self.result_formatter = validated_data.result_formatter + + def _validate_cypher_query(self, query: str) -> None: + """ + Validates that the query is syntactically valid Cypher. + + Args: + query (str): The Cypher query to validate. + + Raises: + RetrieverInitializationError: If the query is not valid Cypher. + """ + # We can't fully validate the query without executing it, but we can check for basic syntax + if not query.strip(): + raise RetrieverInitializationError("Query cannot be empty") + + # Check for presence of common Cypher keywords + if not any(keyword in query.upper() for keyword in ["MATCH", "RETURN", "CREATE", "MERGE", "WITH"]): + raise RetrieverInitializationError( + "Query does not appear to be valid Cypher. " + "It should contain at least one of: MATCH, RETURN, CREATE, MERGE, WITH" + ) + + def _validate_query_parameters(self, query: str, parameters: Dict[str, CypherParameterDefinition]) -> None: + """ + Validates that all parameters in the query are defined in the parameters dictionary. + + Args: + query (str): The Cypher query to validate. + parameters (Dict[str, CypherParameterDefinition]): The parameter definitions. + + Raises: + RetrieverInitializationError: If any parameters in the query are not defined. + """ + # Find all parameters in the query (starting with $) + param_pattern = r'\$([a-zA-Z0-9_]+)' + query_params = set(re.findall(param_pattern, query)) + + # Check that all parameters in the query are defined + undefined_params = query_params - set(parameters.keys()) + if undefined_params: + raise RetrieverInitializationError( + f"The following parameters are used in the query but not defined: {', '.join(undefined_params)}" + ) + + def _validate_parameter_values(self, parameters: Dict[str, Any]) -> None: + """ + Validates that parameter values match their defined types. + + Args: + parameters (Dict[str, Any]): The parameter values to validate. + + Raises: + SearchValidationError: If any parameter values do not match their defined types. + """ + # Check that all required parameters are provided + for param_name, param_def in self.parameters.items(): + if param_def.required and param_name not in parameters: + raise SearchValidationError(f"Required parameter '{param_name}' is missing") + + # Validate the type of each parameter + for param_name, param_value in parameters.items(): + if param_name not in self.parameters: + raise SearchValidationError(f"Unexpected parameter: {param_name}") + + param_def = self.parameters[param_name] + + # Type validation + if param_def.type == CypherParameterType.STRING: + if not isinstance(param_value, str): + raise SearchValidationError( + f"Parameter '{param_name}' should be of type string, got {type(param_value).__name__}" + ) + elif param_def.type == CypherParameterType.NUMBER: + if not isinstance(param_value, (int, float)): + raise SearchValidationError( + f"Parameter '{param_name}' should be of type number, got {type(param_value).__name__}" + ) + elif param_def.type == CypherParameterType.INTEGER: + if not isinstance(param_value, int) or isinstance(param_value, bool): + raise SearchValidationError( + f"Parameter '{param_name}' should be of type integer, got {type(param_value).__name__}" + ) + elif param_def.type == CypherParameterType.BOOLEAN: + if not isinstance(param_value, bool): + raise SearchValidationError( + f"Parameter '{param_name}' should be of type boolean, got {type(param_value).__name__}" + ) + elif param_def.type == CypherParameterType.ARRAY: + if not isinstance(param_value, (list, tuple)): + raise SearchValidationError( + f"Parameter '{param_name}' should be of type array, got {type(param_value).__name__}" + ) + + def get_search_results(self, parameters: Dict[str, Any]) -> RawSearchResult: + """ + Executes the Cypher query with the provided parameters and returns the results. + + Args: + parameters (Dict[str, Any]): Parameter values to use in the query. + Each parameter should match the type specified in the parameter definitions. + + Raises: + SearchValidationError: If validation of the parameters fails. + + Returns: + RawSearchResult: The results of the query as a list of neo4j.Record and an optional metadata dict. + """ + try: + validated_data = CypherSearchModel(parameters=parameters) + except ValidationError as e: + raise SearchValidationError(e.errors()) from e + + # Validate parameter values against their definitions + self._validate_parameter_values(validated_data.parameters) + + logger.debug("CypherRetriever query: %s", self.query) + logger.debug("CypherRetriever parameters: %s", validated_data.parameters) + + try: + records, _, _ = self.driver.execute_query( + query_=self.query, + parameters_=validated_data.parameters, + database_=self.neo4j_database, + routing_=neo4j.RoutingControl.READ, + ) + except CypherSyntaxError as e: + raise SearchValidationError(f"Cypher syntax error: {e.message}") from e + except Exception as e: + raise SearchValidationError(f"Failed to execute query: {str(e)}") from e + + return RawSearchResult( + records=records, + metadata={ + "cypher": self.query, + }, + ) \ No newline at end of file diff --git a/src/neo4j_graphrag/types.py b/src/neo4j_graphrag/types.py index 1c0b74542..fc3f708db 100644 --- a/src/neo4j_graphrag/types.py +++ b/src/neo4j_graphrag/types.py @@ -16,7 +16,7 @@ import warnings from enum import Enum -from typing import Any, Callable, Literal, Optional, TypedDict, Union +from typing import Any, Callable, Literal, Optional, TypedDict, Union, Dict import neo4j from pydantic import ( @@ -312,3 +312,33 @@ def validate_session_id(cls, v: Union[str, int]) -> Union[str, int]: class LLMMessage(TypedDict): role: Literal["system", "user", "assistant"] content: str + + +class CypherParameterType(str, Enum): + """Enumeration of parameter types.""" + STRING = "string" + NUMBER = "number" + INTEGER = "integer" + BOOLEAN = "boolean" + ARRAY = "array" + + +class CypherParameterDefinition(BaseModel): + """Definition of a Cypher query parameter.""" + type: CypherParameterType + description: str + required: bool = True + + +class CypherRetrieverModel(BaseModel): + """Model for validating CypherRetriever arguments.""" + driver_model: Neo4jDriverModel + query: str + parameters: Dict[str, CypherParameterDefinition] + result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None + neo4j_database: Optional[str] = None + + +class CypherSearchModel(BaseModel): + """Model for validating search parameters.""" + parameters: Dict[str, Any] diff --git a/tests/e2e/retrievers/test_cypher_e2e.py b/tests/e2e/retrievers/test_cypher_e2e.py new file mode 100644 index 000000000..b97a43dfe --- /dev/null +++ b/tests/e2e/retrievers/test_cypher_e2e.py @@ -0,0 +1,220 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import random +import string +from typing import Generator + +import neo4j +import pytest +from neo4j.exceptions import Neo4jError + +from neo4j_graphrag.retrievers import CypherRetriever +from neo4j_graphrag.types import RetrieverResultItem + + +# Fixture to create test data +@pytest.fixture +def sample_data(driver: neo4j.Driver) -> Generator[str, None, None]: + # Generate a random prefix for category names to avoid conflicts between test runs + prefix = ''.join(random.choices(string.ascii_lowercase, k=8)) + category_name = f"Category_{prefix}" + + # Create test data + try: + with driver.session() as session: + session.run( + """ + CREATE (c:Category {name: $category_name}) + CREATE (p1:Product {name: "Product1", price: 10.99, stock: 100, featured: true}) + CREATE (p2:Product {name: "Product2", price: 25.50, stock: 50, featured: false}) + CREATE (p3:Product {name: "Product3", price: 5.99, stock: 200, featured: true}) + CREATE (p1)-[:BELONGS_TO]->(c) + CREATE (p2)-[:BELONGS_TO]->(c) + CREATE (p3)-[:BELONGS_TO]->(c) + """, + category_name=category_name + ) + except Neo4jError as e: + pytest.fail(f"Failed to create test data: {e}") + + yield category_name + + # Clean up test data + try: + with driver.session() as session: + session.run( + """ + MATCH (p:Product)-[:BELONGS_TO]->(c:Category {name: $category_name}) + DETACH DELETE p, c + """, + category_name=category_name + ) + except Neo4jError as e: + pytest.fail(f"Failed to clean up test data: {e}") + + +def test_cypher_retriever_basic_query(driver: neo4j.Driver, sample_data: str) -> None: + """Test basic query with CypherRetriever.""" + retriever = CypherRetriever( + driver=driver, + query="MATCH (p:Product) WHERE p.price > $min_price RETURN p ORDER BY p.price", + parameters={ + "min_price": { + "type": "number", + "description": "Minimum product price" + } + } + ) + + # Execute the query + result = retriever.search(parameters={"min_price": 10.0}) + + # Verify the results + assert len(result.items) == 2 + assert "Product1" in result.items[0].content or "Product2" in result.items[0].content + assert "cypher" in result.metadata + + +def test_cypher_retriever_multiple_parameters(driver: neo4j.Driver, sample_data: str) -> None: + """Test query with multiple parameters.""" + retriever = CypherRetriever( + driver=driver, + query=""" + MATCH (p:Product) + WHERE p.price >= $min_price AND p.price <= $max_price + AND p.stock > $min_stock + RETURN p + """, + parameters={ + "min_price": { + "type": "number", + "description": "Minimum product price" + }, + "max_price": { + "type": "number", + "description": "Maximum product price" + }, + "min_stock": { + "type": "integer", + "description": "Minimum stock quantity" + } + } + ) + + # Execute the query with parameters + result = retriever.search(parameters={ + "min_price": 5.0, + "max_price": 15.0, + "min_stock": 50 + }) + + # Verify the results + assert len(result.items) == 2 + assert any("Product1" in item.content for item in result.items) + assert any("Product3" in item.content for item in result.items) + + +def test_cypher_retriever_optional_parameters(driver: neo4j.Driver, sample_data: str) -> None: + """Test query with optional parameters.""" + retriever = CypherRetriever( + driver=driver, + query=""" + MATCH (p:Product) + WHERE ($featured IS NULL OR p.featured = $featured) + RETURN p + """, + parameters={ + "featured": { + "type": "boolean", + "description": "Filter for featured products", + "required": False + } + } + ) + + # Execute the query with the optional parameter + result_with_param = retriever.search(parameters={"featured": True}) + + # Verify the results with parameter + assert len(result_with_param.items) == 2 + assert all("featured: true" in item.content for item in result_with_param.items) + + # Execute the query without the optional parameter + result_without_param = retriever.search(parameters={}) + + # Verify the results without parameter (should return all products) + assert len(result_without_param.items) == 3 + + +def test_cypher_retriever_relationship_traversal(driver: neo4j.Driver, sample_data: str) -> None: + """Test query with relationship traversal.""" + retriever = CypherRetriever( + driver=driver, + query=""" + MATCH (p:Product)-[:BELONGS_TO]->(c:Category {name: $category_name}) + RETURN p.name as product, p.price as price, c.name as category + """, + parameters={ + "category_name": { + "type": "string", + "description": "Category name" + } + } + ) + + # Execute the query + result = retriever.search(parameters={"category_name": sample_data}) + + # Verify the results + assert len(result.items) == 3 + assert all(sample_data in item.content for item in result.items) + + +def test_cypher_retriever_custom_formatter(driver: neo4j.Driver, sample_data: str) -> None: + """Test query with custom result formatter.""" + # Custom formatter that extracts product info in a structured format + def product_formatter(record): + product = record["p"] + return RetrieverResultItem( + content=f"{product['name']} - ${product['price']}", + metadata={ + "price": product["price"], + "stock": product["stock"], + "featured": product["featured"] + } + ) + + retriever = CypherRetriever( + driver=driver, + query="MATCH (p:Product) RETURN p", + parameters={}, + result_formatter=product_formatter + ) + + # Execute the query + result = retriever.search(parameters={}) + + # Verify the results + assert len(result.items) == 3 + + # Check custom formatting + for item in result.items: + assert " - $" in item.content + assert "price" in item.metadata + assert "stock" in item.metadata + assert "featured" in item.metadata \ No newline at end of file diff --git a/tests/unit/retrievers/test_cypher.py b/tests/unit/retrievers/test_cypher.py new file mode 100644 index 000000000..4dd4a2eb7 --- /dev/null +++ b/tests/unit/retrievers/test_cypher.py @@ -0,0 +1,272 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import MagicMock, patch + +import pytest +import neo4j +from neo4j import Record + +from neo4j_graphrag.exceptions import RetrieverInitializationError, SearchValidationError +from neo4j_graphrag.retrievers.cypher import CypherRetriever +from neo4j_graphrag.types import RetrieverResultItem + + +class TestCypherRetriever(unittest.TestCase): + @classmethod + def setUpClass(cls): + # Patch the Neo4jDriverModel.check_driver method to pass validation with MagicMock + cls.patcher1 = patch('neo4j_graphrag.types.Neo4jDriverModel.check_driver') + cls.mock_check_driver = cls.patcher1.start() + cls.mock_check_driver.side_effect = lambda v: v + + # Patch the version check in the Retriever base class to avoid Neo4j version validation + cls.patcher2 = patch('neo4j_graphrag.retrievers.base.Retriever.VERIFY_NEO4J_VERSION', False) + cls.patcher2.start() + + @classmethod + def tearDownClass(cls): + cls.patcher1.stop() + cls.patcher2.stop() + def setUp(self): + # Create a mock driver + self.driver = MagicMock(spec=neo4j.Driver) + self.driver.execute_query.return_value = ( + [Record({"m": {"title": "Test Movie"}, "score": 0.9})], + None, + None, + ) + + # Sample query and parameters + self.valid_query = "MATCH (m:Movie {title: $movie_title}) RETURN m" + self.valid_parameters = { + "movie_title": {"type": "string", "description": "Title of a movie"} + } + + def test_init_success(self): + # Test successful initialization + retriever = CypherRetriever( + driver=self.driver, + query=self.valid_query, + parameters=self.valid_parameters, + ) + assert retriever.query == self.valid_query + assert "movie_title" in retriever.parameters + + def test_init_empty_query(self): + # Test initialization with empty query + with pytest.raises(RetrieverInitializationError): + CypherRetriever( + driver=self.driver, + query="", + parameters=self.valid_parameters, + ) + + def test_init_invalid_query(self): + # Test initialization with invalid query + with pytest.raises(RetrieverInitializationError): + CypherRetriever( + driver=self.driver, + query="SELECT * FROM movies", # SQL, not Cypher + parameters=self.valid_parameters, + ) + + def test_init_undefined_parameters(self): + # Test initialization with undefined parameters in query + with pytest.raises(RetrieverInitializationError): + CypherRetriever( + driver=self.driver, + query="MATCH (m:Movie {title: $movie_title, year: $year}) RETURN m", + parameters=self.valid_parameters, # Missing 'year' parameter + ) + + def test_init_invalid_parameter_type(self): + # Test initialization with invalid parameter type + with pytest.raises(RetrieverInitializationError): + CypherRetriever( + driver=self.driver, + query=self.valid_query, + parameters={"movie_title": {"type": "invalid_type", "description": "Title of a movie"}}, + ) + + def test_search_success(self): + # Test successful search + retriever = CypherRetriever( + driver=self.driver, + query=self.valid_query, + parameters=self.valid_parameters, + ) + result = retriever.search(parameters={"movie_title": "The Matrix"}) + + # Assert driver.execute_query was called with the right parameters + self.driver.execute_query.assert_called_once() + assert result.items + assert result.metadata and "cypher" in result.metadata + assert result.metadata["cypher"] == self.valid_query + + def test_search_missing_required_parameter(self): + # Test search with missing required parameter + retriever = CypherRetriever( + driver=self.driver, + query=self.valid_query, + parameters=self.valid_parameters, + ) + with pytest.raises(SearchValidationError): + retriever.search(parameters={}) # Missing 'movie_title' + + def test_search_unexpected_parameter(self): + # Test search with unexpected parameter + retriever = CypherRetriever( + driver=self.driver, + query=self.valid_query, + parameters=self.valid_parameters, + ) + with pytest.raises(SearchValidationError): + retriever.search(parameters={"movie_title": "The Matrix", "year": 1999}) # 'year' not defined + + def test_search_type_mismatch(self): + # Test search with parameter type mismatch + retriever = CypherRetriever( + driver=self.driver, + query=self.valid_query, + parameters=self.valid_parameters, + ) + with pytest.raises(SearchValidationError): + retriever.search(parameters={"movie_title": 123}) # Integer, expected string + + def test_different_parameter_types(self): + # Test with different parameter types + query = "MATCH (m:Movie) WHERE m.title = $title AND m.year = $year AND m.rating > $rating " \ + "AND m.is_available = $available AND m.genres IN $genres RETURN m" + parameters = { + "title": {"type": "string", "description": "Movie title"}, + "year": {"type": "integer", "description": "Release year"}, + "rating": {"type": "number", "description": "Minimum rating"}, + "available": {"type": "boolean", "description": "Is the movie available"}, + "genres": {"type": "array", "description": "List of genres"}, + } + + retriever = CypherRetriever( + driver=self.driver, + query=query, + parameters=parameters, + ) + + # Valid parameters of different types + result = retriever.search( + parameters={ + "title": "The Matrix", + "year": 1999, + "rating": 8.5, + "available": True, + "genres": ["Action", "Sci-Fi"] + } + ) + + assert result.items + + # Test integer type validation + with pytest.raises(SearchValidationError): + retriever.search( + parameters={ + "title": "The Matrix", + "year": "1999", # String, expected integer + "rating": 8.5, + "available": True, + "genres": ["Action", "Sci-Fi"] + } + ) + + # Test number type validation + with pytest.raises(SearchValidationError): + retriever.search( + parameters={ + "title": "The Matrix", + "year": 1999, + "rating": "8.5", # String, expected number + "available": True, + "genres": ["Action", "Sci-Fi"] + } + ) + + # Test boolean type validation + with pytest.raises(SearchValidationError): + retriever.search( + parameters={ + "title": "The Matrix", + "year": 1999, + "rating": 8.5, + "available": "yes", # String, expected boolean + "genres": ["Action", "Sci-Fi"] + } + ) + + # Test array type validation + with pytest.raises(SearchValidationError): + retriever.search( + parameters={ + "title": "The Matrix", + "year": 1999, + "rating": 8.5, + "available": True, + "genres": "Action, Sci-Fi" # String, expected array + } + ) + + def test_custom_result_formatter(self): + # Test with custom result formatter + def custom_formatter(record): + return RetrieverResultItem( + content=f"Movie: {record['m']['title']}", + metadata={"score": record["score"]} + ) + + retriever = CypherRetriever( + driver=self.driver, + query=self.valid_query, + parameters=self.valid_parameters, + result_formatter=custom_formatter, + ) + + result = retriever.search(parameters={"movie_title": "The Matrix"}) + assert result.items[0].content == "Movie: Test Movie" + assert result.items[0].metadata["score"] == 0.9 + + def test_optional_parameters(self): + # Test with optional parameters + query = "MATCH (m:Movie {title: $title}) WHERE m.year = $year RETURN m" + parameters = { + "title": {"type": "string", "description": "Movie title", "required": True}, + "year": {"type": "integer", "description": "Release year", "required": False}, + } + + retriever = CypherRetriever( + driver=self.driver, + query=query, + parameters=parameters, + ) + + # Should succeed with only required parameters + result = retriever.search(parameters={"title": "The Matrix"}) + assert result.items + + # Should also succeed with optional parameters + result = retriever.search(parameters={"title": "The Matrix", "year": 1999}) + assert result.items + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From 482ef083ac1fc714f1ad0aa10f9f7b25e790c3ef Mon Sep 17 00:00:00 2001 From: Oskar Hane Date: Fri, 14 Mar 2025 15:54:11 +0100 Subject: [PATCH 2/4] Apply ruff formatting to code --- examples/retrieve/cypher_retriever.py | 60 ++++++------- src/neo4j_graphrag/retrievers/cypher.py | 84 +++++++++-------- src/neo4j_graphrag/types.py | 4 + tests/e2e/retrievers/test_cypher_e2e.py | 114 +++++++++++------------- tests/unit/retrievers/test_cypher.py | 83 ++++++++++------- 5 files changed, 182 insertions(+), 163 deletions(-) diff --git a/examples/retrieve/cypher_retriever.py b/examples/retrieve/cypher_retriever.py index 38af3d471..050c2ef44 100644 --- a/examples/retrieve/cypher_retriever.py +++ b/examples/retrieve/cypher_retriever.py @@ -31,22 +31,20 @@ driver = neo4j.GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD)) + # Simple example: Find a movie by title def find_movie_by_title(): retriever = CypherRetriever( driver=driver, query="MATCH (m:Movie {title: $movie_title}) RETURN m", parameters={ - "movie_title": { - "type": "string", - "description": "Title of a movie" - } - } + "movie_title": {"type": "string", "description": "Title of a movie"} + }, ) - + # Use the retriever to search for a movie result = retriever.search(parameters={"movie_title": "The Matrix"}) - + print("=== Find Movie by Title ===") for item in result.items: print(f"Movie: {item.content}") @@ -63,9 +61,9 @@ def movie_formatter(record): metadata={ "rating": movie.get("rating"), "tagline": movie.get("tagline"), - } + }, ) - + # Create a more complex retriever with multiple parameters retriever = CypherRetriever( driver=driver, @@ -83,42 +81,37 @@ def movie_formatter(record): "title": { "type": "string", "description": "Partial movie title to search for", - "required": False + "required": False, }, "min_year": { "type": "integer", "description": "Minimum release year", - "required": False + "required": False, }, "max_year": { "type": "integer", "description": "Maximum release year", - "required": False + "required": False, }, "min_rating": { "type": "number", "description": "Minimum movie rating", - "required": False + "required": False, }, "limit": { "type": "integer", "description": "Maximum number of results to return", - "required": True - } + "required": True, + }, }, - result_formatter=movie_formatter + result_formatter=movie_formatter, ) - + # Search with optional parameters result = retriever.search( - parameters={ - "title": "Matrix", - "min_year": 1990, - "min_rating": 7.5, - "limit": 5 - } + parameters={"title": "Matrix", "min_year": 1990, "min_rating": 7.5, "limit": 5} ) - + print("=== Find Movies by Criteria ===") for item in result.items: print(f"Movie: {item.content}") @@ -140,15 +133,12 @@ def find_actors_in_movie(): ORDER BY a.name """, parameters={ - "movie_title": { - "type": "string", - "description": "Title of a movie" - } - } + "movie_title": {"type": "string", "description": "Title of a movie"} + }, ) - + result = retriever.search(parameters={"movie_title": "The Matrix"}) - + print("=== Find Actors in Movie ===") for item in result.items: record = eval(item.content) # Simple way to parse the string representation @@ -166,7 +156,7 @@ def find_actors_in_movie(): # Check if data exists result = session.run("MATCH (m:Movie) RETURN count(m) as count") count = result.single()["count"] - + if count == 0: print("No movie data found. Creating sample data...") # Create sample data if none exists @@ -194,12 +184,12 @@ def find_actors_in_movie(): print("Sample data created.") else: print(f"Found {count} movies in the database.") - + # Run the examples find_movie_by_title() find_movies_by_criteria() find_actors_in_movie() - + finally: # Close the driver - driver.close() \ No newline at end of file + driver.close() diff --git a/src/neo4j_graphrag/retrievers/cypher.py b/src/neo4j_graphrag/retrievers/cypher.py index 97d4658b0..9b2fff005 100644 --- a/src/neo4j_graphrag/retrievers/cypher.py +++ b/src/neo4j_graphrag/retrievers/cypher.py @@ -16,13 +16,16 @@ import logging import re -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, Optional import neo4j from neo4j.exceptions import CypherSyntaxError from pydantic import ValidationError -from neo4j_graphrag.exceptions import RetrieverInitializationError, SearchValidationError +from neo4j_graphrag.exceptions import ( + RetrieverInitializationError, + SearchValidationError, +) from neo4j_graphrag.retrievers.base import Retriever from neo4j_graphrag.types import ( CypherParameterDefinition, @@ -40,12 +43,12 @@ class CypherRetriever(Retriever): """ Allows for the retrieval of records from a Neo4j database using a parameterized Cypher query. - + This retriever enables direct execution of predefined Cypher queries with dynamic parameters. It ensures type safety through parameter validation and provides the standard retriever result format. Example: - + .. code-block:: python import neo4j @@ -59,7 +62,7 @@ class CypherRetriever(Retriever): query="MATCH (m:Movie {title: $movie_title}) RETURN m", parameters={ "movie_title": { - "type": "string", + "type": "string", "description": "Title of a movie" } } @@ -74,7 +77,7 @@ class CypherRetriever(Retriever): parameters (Dict[str, Dict]): Parameter definitions with types and descriptions. Each parameter should have a 'type' and 'description' field. Supported types: 'string', 'number', 'integer', 'boolean', 'array'. - result_formatter (Optional[Callable[[neo4j.Record], RetrieverResultItem]]): + result_formatter (Optional[Callable[[neo4j.Record], RetrieverResultItem]]): Custom function to transform a neo4j.Record to a RetrieverResultItem. neo4j_database (Optional[str]): The name of the Neo4j database to use. @@ -87,7 +90,9 @@ def __init__( driver: neo4j.Driver, query: str, parameters: Dict[str, Dict], - result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None, + result_formatter: Optional[ + Callable[[neo4j.Record], RetrieverResultItem] + ] = None, neo4j_database: Optional[str] = None, ) -> None: # Convert parameter dictionaries to CypherParameterDefinition objects @@ -96,12 +101,10 @@ def __init__( param_type = param_def.get("type", "string") description = param_def.get("description", "") required = param_def.get("required", True) - + try: param_definitions[param_name] = CypherParameterDefinition( - type=param_type, - description=description, - required=required + type=param_type, description=description, required=required ) except ValidationError as e: raise RetrieverInitializationError( @@ -122,51 +125,58 @@ def __init__( # Validate that the query is syntactically valid Cypher self._validate_cypher_query(query) - + # Validate that all parameters in the query are defined self._validate_query_parameters(query, param_definitions) - super().__init__(validated_data.driver_model.driver, validated_data.neo4j_database) + super().__init__( + validated_data.driver_model.driver, validated_data.neo4j_database + ) self.query = validated_data.query self.parameters = validated_data.parameters self.result_formatter = validated_data.result_formatter - + def _validate_cypher_query(self, query: str) -> None: """ Validates that the query is syntactically valid Cypher. - + Args: query (str): The Cypher query to validate. - + Raises: RetrieverInitializationError: If the query is not valid Cypher. """ # We can't fully validate the query without executing it, but we can check for basic syntax if not query.strip(): raise RetrieverInitializationError("Query cannot be empty") - + # Check for presence of common Cypher keywords - if not any(keyword in query.upper() for keyword in ["MATCH", "RETURN", "CREATE", "MERGE", "WITH"]): + if not any( + keyword in query.upper() + for keyword in ["MATCH", "RETURN", "CREATE", "MERGE", "WITH"] + ): raise RetrieverInitializationError( "Query does not appear to be valid Cypher. " "It should contain at least one of: MATCH, RETURN, CREATE, MERGE, WITH" ) - def _validate_query_parameters(self, query: str, parameters: Dict[str, CypherParameterDefinition]) -> None: + def _validate_query_parameters( + self, query: str, parameters: Dict[str, CypherParameterDefinition] + ) -> None: """ Validates that all parameters in the query are defined in the parameters dictionary. - + Args: query (str): The Cypher query to validate. parameters (Dict[str, CypherParameterDefinition]): The parameter definitions. - + Raises: RetrieverInitializationError: If any parameters in the query are not defined. """ # Find all parameters in the query (starting with $) - param_pattern = r'\$([a-zA-Z0-9_]+)' + param_pattern = r"\$([a-zA-Z0-9_]+)" query_params = set(re.findall(param_pattern, query)) - + # Check that all parameters in the query are defined undefined_params = query_params - set(parameters.keys()) if undefined_params: @@ -177,25 +187,27 @@ def _validate_query_parameters(self, query: str, parameters: Dict[str, CypherPar def _validate_parameter_values(self, parameters: Dict[str, Any]) -> None: """ Validates that parameter values match their defined types. - + Args: parameters (Dict[str, Any]): The parameter values to validate. - + Raises: SearchValidationError: If any parameter values do not match their defined types. """ # Check that all required parameters are provided for param_name, param_def in self.parameters.items(): if param_def.required and param_name not in parameters: - raise SearchValidationError(f"Required parameter '{param_name}' is missing") + raise SearchValidationError( + f"Required parameter '{param_name}' is missing" + ) # Validate the type of each parameter for param_name, param_value in parameters.items(): if param_name not in self.parameters: raise SearchValidationError(f"Unexpected parameter: {param_name}") - + param_def = self.parameters[param_name] - + # Type validation if param_def.type == CypherParameterType.STRING: if not isinstance(param_value, str): @@ -226,14 +238,14 @@ def _validate_parameter_values(self, parameters: Dict[str, Any]) -> None: def get_search_results(self, parameters: Dict[str, Any]) -> RawSearchResult: """ Executes the Cypher query with the provided parameters and returns the results. - + Args: parameters (Dict[str, Any]): Parameter values to use in the query. Each parameter should match the type specified in the parameter definitions. - + Raises: SearchValidationError: If validation of the parameters fails. - + Returns: RawSearchResult: The results of the query as a list of neo4j.Record and an optional metadata dict. """ @@ -241,13 +253,13 @@ def get_search_results(self, parameters: Dict[str, Any]) -> RawSearchResult: validated_data = CypherSearchModel(parameters=parameters) except ValidationError as e: raise SearchValidationError(e.errors()) from e - + # Validate parameter values against their definitions self._validate_parameter_values(validated_data.parameters) - + logger.debug("CypherRetriever query: %s", self.query) logger.debug("CypherRetriever parameters: %s", validated_data.parameters) - + try: records, _, _ = self.driver.execute_query( query_=self.query, @@ -259,10 +271,10 @@ def get_search_results(self, parameters: Dict[str, Any]) -> RawSearchResult: raise SearchValidationError(f"Cypher syntax error: {e.message}") from e except Exception as e: raise SearchValidationError(f"Failed to execute query: {str(e)}") from e - + return RawSearchResult( records=records, metadata={ "cypher": self.query, }, - ) \ No newline at end of file + ) diff --git a/src/neo4j_graphrag/types.py b/src/neo4j_graphrag/types.py index fc3f708db..df3c82dd2 100644 --- a/src/neo4j_graphrag/types.py +++ b/src/neo4j_graphrag/types.py @@ -316,6 +316,7 @@ class LLMMessage(TypedDict): class CypherParameterType(str, Enum): """Enumeration of parameter types.""" + STRING = "string" NUMBER = "number" INTEGER = "integer" @@ -325,6 +326,7 @@ class CypherParameterType(str, Enum): class CypherParameterDefinition(BaseModel): """Definition of a Cypher query parameter.""" + type: CypherParameterType description: str required: bool = True @@ -332,6 +334,7 @@ class CypherParameterDefinition(BaseModel): class CypherRetrieverModel(BaseModel): """Model for validating CypherRetriever arguments.""" + driver_model: Neo4jDriverModel query: str parameters: Dict[str, CypherParameterDefinition] @@ -341,4 +344,5 @@ class CypherRetrieverModel(BaseModel): class CypherSearchModel(BaseModel): """Model for validating search parameters.""" + parameters: Dict[str, Any] diff --git a/tests/e2e/retrievers/test_cypher_e2e.py b/tests/e2e/retrievers/test_cypher_e2e.py index b97a43dfe..51eed554d 100644 --- a/tests/e2e/retrievers/test_cypher_e2e.py +++ b/tests/e2e/retrievers/test_cypher_e2e.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json -import os import random import string from typing import Generator @@ -31,9 +29,9 @@ @pytest.fixture def sample_data(driver: neo4j.Driver) -> Generator[str, None, None]: # Generate a random prefix for category names to avoid conflicts between test runs - prefix = ''.join(random.choices(string.ascii_lowercase, k=8)) + prefix = "".join(random.choices(string.ascii_lowercase, k=8)) category_name = f"Category_{prefix}" - + # Create test data try: with driver.session() as session: @@ -47,13 +45,13 @@ def sample_data(driver: neo4j.Driver) -> Generator[str, None, None]: CREATE (p2)-[:BELONGS_TO]->(c) CREATE (p3)-[:BELONGS_TO]->(c) """, - category_name=category_name + category_name=category_name, ) except Neo4jError as e: pytest.fail(f"Failed to create test data: {e}") - + yield category_name - + # Clean up test data try: with driver.session() as session: @@ -62,7 +60,7 @@ def sample_data(driver: neo4j.Driver) -> Generator[str, None, None]: MATCH (p:Product)-[:BELONGS_TO]->(c:Category {name: $category_name}) DETACH DELETE p, c """, - category_name=category_name + category_name=category_name, ) except Neo4jError as e: pytest.fail(f"Failed to clean up test data: {e}") @@ -74,23 +72,24 @@ def test_cypher_retriever_basic_query(driver: neo4j.Driver, sample_data: str) -> driver=driver, query="MATCH (p:Product) WHERE p.price > $min_price RETURN p ORDER BY p.price", parameters={ - "min_price": { - "type": "number", - "description": "Minimum product price" - } - } + "min_price": {"type": "number", "description": "Minimum product price"} + }, ) - + # Execute the query result = retriever.search(parameters={"min_price": 10.0}) - + # Verify the results assert len(result.items) == 2 - assert "Product1" in result.items[0].content or "Product2" in result.items[0].content + assert ( + "Product1" in result.items[0].content or "Product2" in result.items[0].content + ) assert "cypher" in result.metadata -def test_cypher_retriever_multiple_parameters(driver: neo4j.Driver, sample_data: str) -> None: +def test_cypher_retriever_multiple_parameters( + driver: neo4j.Driver, sample_data: str +) -> None: """Test query with multiple parameters.""" retriever = CypherRetriever( driver=driver, @@ -101,35 +100,26 @@ def test_cypher_retriever_multiple_parameters(driver: neo4j.Driver, sample_data: RETURN p """, parameters={ - "min_price": { - "type": "number", - "description": "Minimum product price" - }, - "max_price": { - "type": "number", - "description": "Maximum product price" - }, - "min_stock": { - "type": "integer", - "description": "Minimum stock quantity" - } - } + "min_price": {"type": "number", "description": "Minimum product price"}, + "max_price": {"type": "number", "description": "Maximum product price"}, + "min_stock": {"type": "integer", "description": "Minimum stock quantity"}, + }, ) - + # Execute the query with parameters - result = retriever.search(parameters={ - "min_price": 5.0, - "max_price": 15.0, - "min_stock": 50 - }) - + result = retriever.search( + parameters={"min_price": 5.0, "max_price": 15.0, "min_stock": 50} + ) + # Verify the results assert len(result.items) == 2 assert any("Product1" in item.content for item in result.items) assert any("Product3" in item.content for item in result.items) -def test_cypher_retriever_optional_parameters(driver: neo4j.Driver, sample_data: str) -> None: +def test_cypher_retriever_optional_parameters( + driver: neo4j.Driver, sample_data: str +) -> None: """Test query with optional parameters.""" retriever = CypherRetriever( driver=driver, @@ -142,26 +132,28 @@ def test_cypher_retriever_optional_parameters(driver: neo4j.Driver, sample_data: "featured": { "type": "boolean", "description": "Filter for featured products", - "required": False + "required": False, } - } + }, ) - + # Execute the query with the optional parameter result_with_param = retriever.search(parameters={"featured": True}) - + # Verify the results with parameter assert len(result_with_param.items) == 2 assert all("featured: true" in item.content for item in result_with_param.items) - + # Execute the query without the optional parameter result_without_param = retriever.search(parameters={}) - + # Verify the results without parameter (should return all products) assert len(result_without_param.items) == 3 -def test_cypher_retriever_relationship_traversal(driver: neo4j.Driver, sample_data: str) -> None: +def test_cypher_retriever_relationship_traversal( + driver: neo4j.Driver, sample_data: str +) -> None: """Test query with relationship traversal.""" retriever = CypherRetriever( driver=driver, @@ -170,23 +162,23 @@ def test_cypher_retriever_relationship_traversal(driver: neo4j.Driver, sample_da RETURN p.name as product, p.price as price, c.name as category """, parameters={ - "category_name": { - "type": "string", - "description": "Category name" - } - } + "category_name": {"type": "string", "description": "Category name"} + }, ) - + # Execute the query result = retriever.search(parameters={"category_name": sample_data}) - + # Verify the results assert len(result.items) == 3 assert all(sample_data in item.content for item in result.items) -def test_cypher_retriever_custom_formatter(driver: neo4j.Driver, sample_data: str) -> None: +def test_cypher_retriever_custom_formatter( + driver: neo4j.Driver, sample_data: str +) -> None: """Test query with custom result formatter.""" + # Custom formatter that extracts product info in a structured format def product_formatter(record): product = record["p"] @@ -195,26 +187,26 @@ def product_formatter(record): metadata={ "price": product["price"], "stock": product["stock"], - "featured": product["featured"] - } + "featured": product["featured"], + }, ) - + retriever = CypherRetriever( driver=driver, query="MATCH (p:Product) RETURN p", parameters={}, - result_formatter=product_formatter + result_formatter=product_formatter, ) - + # Execute the query result = retriever.search(parameters={}) - + # Verify the results assert len(result.items) == 3 - + # Check custom formatting for item in result.items: assert " - $" in item.content assert "price" in item.metadata assert "stock" in item.metadata - assert "featured" in item.metadata \ No newline at end of file + assert "featured" in item.metadata diff --git a/tests/unit/retrievers/test_cypher.py b/tests/unit/retrievers/test_cypher.py index 4dd4a2eb7..edb67fb02 100644 --- a/tests/unit/retrievers/test_cypher.py +++ b/tests/unit/retrievers/test_cypher.py @@ -20,7 +20,10 @@ import neo4j from neo4j import Record -from neo4j_graphrag.exceptions import RetrieverInitializationError, SearchValidationError +from neo4j_graphrag.exceptions import ( + RetrieverInitializationError, + SearchValidationError, +) from neo4j_graphrag.retrievers.cypher import CypherRetriever from neo4j_graphrag.types import RetrieverResultItem @@ -29,18 +32,21 @@ class TestCypherRetriever(unittest.TestCase): @classmethod def setUpClass(cls): # Patch the Neo4jDriverModel.check_driver method to pass validation with MagicMock - cls.patcher1 = patch('neo4j_graphrag.types.Neo4jDriverModel.check_driver') + cls.patcher1 = patch("neo4j_graphrag.types.Neo4jDriverModel.check_driver") cls.mock_check_driver = cls.patcher1.start() cls.mock_check_driver.side_effect = lambda v: v - + # Patch the version check in the Retriever base class to avoid Neo4j version validation - cls.patcher2 = patch('neo4j_graphrag.retrievers.base.Retriever.VERIFY_NEO4J_VERSION', False) + cls.patcher2 = patch( + "neo4j_graphrag.retrievers.base.Retriever.VERIFY_NEO4J_VERSION", False + ) cls.patcher2.start() - + @classmethod def tearDownClass(cls): cls.patcher1.stop() cls.patcher2.stop() + def setUp(self): # Create a mock driver self.driver = MagicMock(spec=neo4j.Driver) @@ -99,7 +105,12 @@ def test_init_invalid_parameter_type(self): CypherRetriever( driver=self.driver, query=self.valid_query, - parameters={"movie_title": {"type": "invalid_type", "description": "Title of a movie"}}, + parameters={ + "movie_title": { + "type": "invalid_type", + "description": "Title of a movie", + } + }, ) def test_search_success(self): @@ -110,7 +121,7 @@ def test_search_success(self): parameters=self.valid_parameters, ) result = retriever.search(parameters={"movie_title": "The Matrix"}) - + # Assert driver.execute_query was called with the right parameters self.driver.execute_query.assert_called_once() assert result.items @@ -135,7 +146,9 @@ def test_search_unexpected_parameter(self): parameters=self.valid_parameters, ) with pytest.raises(SearchValidationError): - retriever.search(parameters={"movie_title": "The Matrix", "year": 1999}) # 'year' not defined + retriever.search( + parameters={"movie_title": "The Matrix", "year": 1999} + ) # 'year' not defined def test_search_type_mismatch(self): # Test search with parameter type mismatch @@ -145,12 +158,16 @@ def test_search_type_mismatch(self): parameters=self.valid_parameters, ) with pytest.raises(SearchValidationError): - retriever.search(parameters={"movie_title": 123}) # Integer, expected string + retriever.search( + parameters={"movie_title": 123} + ) # Integer, expected string def test_different_parameter_types(self): # Test with different parameter types - query = "MATCH (m:Movie) WHERE m.title = $title AND m.year = $year AND m.rating > $rating " \ - "AND m.is_available = $available AND m.genres IN $genres RETURN m" + query = ( + "MATCH (m:Movie) WHERE m.title = $title AND m.year = $year AND m.rating > $rating " + "AND m.is_available = $available AND m.genres IN $genres RETURN m" + ) parameters = { "title": {"type": "string", "description": "Movie title"}, "year": {"type": "integer", "description": "Release year"}, @@ -158,13 +175,13 @@ def test_different_parameter_types(self): "available": {"type": "boolean", "description": "Is the movie available"}, "genres": {"type": "array", "description": "List of genres"}, } - + retriever = CypherRetriever( driver=self.driver, query=query, parameters=parameters, ) - + # Valid parameters of different types result = retriever.search( parameters={ @@ -172,12 +189,12 @@ def test_different_parameter_types(self): "year": 1999, "rating": 8.5, "available": True, - "genres": ["Action", "Sci-Fi"] + "genres": ["Action", "Sci-Fi"], } ) - + assert result.items - + # Test integer type validation with pytest.raises(SearchValidationError): retriever.search( @@ -186,10 +203,10 @@ def test_different_parameter_types(self): "year": "1999", # String, expected integer "rating": 8.5, "available": True, - "genres": ["Action", "Sci-Fi"] + "genres": ["Action", "Sci-Fi"], } ) - + # Test number type validation with pytest.raises(SearchValidationError): retriever.search( @@ -198,10 +215,10 @@ def test_different_parameter_types(self): "year": 1999, "rating": "8.5", # String, expected number "available": True, - "genres": ["Action", "Sci-Fi"] + "genres": ["Action", "Sci-Fi"], } ) - + # Test boolean type validation with pytest.raises(SearchValidationError): retriever.search( @@ -210,10 +227,10 @@ def test_different_parameter_types(self): "year": 1999, "rating": 8.5, "available": "yes", # String, expected boolean - "genres": ["Action", "Sci-Fi"] + "genres": ["Action", "Sci-Fi"], } ) - + # Test array type validation with pytest.raises(SearchValidationError): retriever.search( @@ -222,7 +239,7 @@ def test_different_parameter_types(self): "year": 1999, "rating": 8.5, "available": True, - "genres": "Action, Sci-Fi" # String, expected array + "genres": "Action, Sci-Fi", # String, expected array } ) @@ -231,16 +248,16 @@ def test_custom_result_formatter(self): def custom_formatter(record): return RetrieverResultItem( content=f"Movie: {record['m']['title']}", - metadata={"score": record["score"]} + metadata={"score": record["score"]}, ) - + retriever = CypherRetriever( driver=self.driver, query=self.valid_query, parameters=self.valid_parameters, result_formatter=custom_formatter, ) - + result = retriever.search(parameters={"movie_title": "The Matrix"}) assert result.items[0].content == "Movie: Test Movie" assert result.items[0].metadata["score"] == 0.9 @@ -250,23 +267,27 @@ def test_optional_parameters(self): query = "MATCH (m:Movie {title: $title}) WHERE m.year = $year RETURN m" parameters = { "title": {"type": "string", "description": "Movie title", "required": True}, - "year": {"type": "integer", "description": "Release year", "required": False}, + "year": { + "type": "integer", + "description": "Release year", + "required": False, + }, } - + retriever = CypherRetriever( driver=self.driver, query=query, parameters=parameters, ) - + # Should succeed with only required parameters result = retriever.search(parameters={"title": "The Matrix"}) assert result.items - + # Should also succeed with optional parameters result = retriever.search(parameters={"title": "The Matrix", "year": 1999}) assert result.items if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() From 6fe041efbf22bc7b5cd8d490b42b76e1b521c123 Mon Sep 17 00:00:00 2001 From: Oskar Hane Date: Fri, 14 Mar 2025 16:44:04 +0100 Subject: [PATCH 3/4] Fix type errors and improve type annotations in CypherRetriever and tests --- CHANGELOG.md | 4 + examples/retrieve/cypher_retriever.py | 12 ++- src/neo4j_graphrag/retrievers/cypher.py | 137 ++++++++++++++++++------ tests/e2e/retrievers/test_cypher_e2e.py | 11 +- tests/unit/retrievers/test_cypher.py | 40 ++++--- 5 files changed, 145 insertions(+), 59 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 390b32e26..20933baf7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,10 @@ ## Next +### Added + +- Added CypherRetriever for executing parameterized Cypher queries with strong type validation. + ## 1.6.0 ### Added diff --git a/examples/retrieve/cypher_retriever.py b/examples/retrieve/cypher_retriever.py index 050c2ef44..3b45d05b7 100644 --- a/examples/retrieve/cypher_retriever.py +++ b/examples/retrieve/cypher_retriever.py @@ -20,6 +20,7 @@ """ import neo4j +from neo4j import Record from neo4j_graphrag.retrievers import CypherRetriever from neo4j_graphrag.types import RetrieverResultItem @@ -33,7 +34,7 @@ # Simple example: Find a movie by title -def find_movie_by_title(): +def find_movie_by_title() -> None: retriever = CypherRetriever( driver=driver, query="MATCH (m:Movie {title: $movie_title}) RETURN m", @@ -52,9 +53,9 @@ def find_movie_by_title(): # Advanced example: Find movies with multiple criteria -def find_movies_by_criteria(): +def find_movies_by_criteria() -> None: # Custom formatter to extract specific information - def movie_formatter(record): + def movie_formatter(record: Record) -> RetrieverResultItem: movie = record["m"] return RetrieverResultItem( content=f"{movie['title']} ({movie['released']})", @@ -124,7 +125,7 @@ def movie_formatter(record): # Example with relationship traversal -def find_actors_in_movie(): +def find_actors_in_movie() -> None: retriever = CypherRetriever( driver=driver, query=""" @@ -155,7 +156,8 @@ def find_actors_in_movie(): with driver.session() as session: # Check if data exists result = session.run("MATCH (m:Movie) RETURN count(m) as count") - count = result.single()["count"] + record = result.single() + count = record["count"] if record else 0 if count == 0: print("No movie data found. Creating sample data...") diff --git a/src/neo4j_graphrag/retrievers/cypher.py b/src/neo4j_graphrag/retrievers/cypher.py index 9b2fff005..369cf4766 100644 --- a/src/neo4j_graphrag/retrievers/cypher.py +++ b/src/neo4j_graphrag/retrievers/cypher.py @@ -16,9 +16,10 @@ import logging import re -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, List, Optional import neo4j +from pydantic_core import ErrorDetails from neo4j.exceptions import CypherSyntaxError from pydantic import ValidationError @@ -89,7 +90,7 @@ def __init__( self, driver: neo4j.Driver, query: str, - parameters: Dict[str, Dict], + parameters: Dict[str, Dict[str, Any]], result_formatter: Optional[ Callable[[neo4j.Record], RetrieverResultItem] ] = None, @@ -108,7 +109,12 @@ def __init__( ) except ValidationError as e: raise RetrieverInitializationError( - f"Invalid parameter definition for {param_name}: {e.errors()}" + [ErrorDetails( + loc=("parameters", param_name), + msg=f"Invalid parameter definition: {e.errors()}", + type="validation_error", + input=param_def + )] ) from e try: @@ -148,17 +154,28 @@ def _validate_cypher_query(self, query: str) -> None: """ # We can't fully validate the query without executing it, but we can check for basic syntax if not query.strip(): - raise RetrieverInitializationError("Query cannot be empty") + raise RetrieverInitializationError([ + ErrorDetails( + loc=("query",), + msg="Query cannot be empty", + type="value_error.empty", + input="" + ) + ]) # Check for presence of common Cypher keywords if not any( keyword in query.upper() for keyword in ["MATCH", "RETURN", "CREATE", "MERGE", "WITH"] ): - raise RetrieverInitializationError( - "Query does not appear to be valid Cypher. " - "It should contain at least one of: MATCH, RETURN, CREATE, MERGE, WITH" - ) + raise RetrieverInitializationError([ + ErrorDetails( + loc=("query",), + msg="Query does not appear to be valid Cypher. It should contain at least one of: MATCH, RETURN, CREATE, MERGE, WITH", + type="value_error.invalid_cypher", + input="" + ) + ]) def _validate_query_parameters( self, query: str, parameters: Dict[str, CypherParameterDefinition] @@ -180,9 +197,14 @@ def _validate_query_parameters( # Check that all parameters in the query are defined undefined_params = query_params - set(parameters.keys()) if undefined_params: - raise RetrieverInitializationError( - f"The following parameters are used in the query but not defined: {', '.join(undefined_params)}" - ) + raise RetrieverInitializationError([ + ErrorDetails( + loc=("parameters",), + msg=f"The following parameters are used in the query but not defined: {', '.join(undefined_params)}", + type="value_error.undefined_parameters", + input=undefined_params + ) + ]) def _validate_parameter_values(self, parameters: Dict[str, Any]) -> None: """ @@ -197,43 +219,80 @@ def _validate_parameter_values(self, parameters: Dict[str, Any]) -> None: # Check that all required parameters are provided for param_name, param_def in self.parameters.items(): if param_def.required and param_name not in parameters: - raise SearchValidationError( - f"Required parameter '{param_name}' is missing" - ) + raise SearchValidationError([ + ErrorDetails( + loc=("parameters", param_name), + msg=f"Required parameter '{param_name}' is missing", + type="value_error.missing", + input=None + ) + ]) # Validate the type of each parameter for param_name, param_value in parameters.items(): if param_name not in self.parameters: - raise SearchValidationError(f"Unexpected parameter: {param_name}") + raise SearchValidationError([ + ErrorDetails( + loc=("parameters", param_name), + msg=f"Unexpected parameter: {param_name}", + type="value_error.unexpected", + input=param_name + ) + ]) param_def = self.parameters[param_name] # Type validation if param_def.type == CypherParameterType.STRING: if not isinstance(param_value, str): - raise SearchValidationError( - f"Parameter '{param_name}' should be of type string, got {type(param_value).__name__}" - ) + raise SearchValidationError([ + ErrorDetails( + loc=("parameters", param_name), + msg=f"Parameter '{param_name}' should be of type string, got {type(param_value).__name__}", + type="type_error.string", + input=param_value + ) + ]) elif param_def.type == CypherParameterType.NUMBER: if not isinstance(param_value, (int, float)): - raise SearchValidationError( - f"Parameter '{param_name}' should be of type number, got {type(param_value).__name__}" - ) + raise SearchValidationError([ + ErrorDetails( + loc=("parameters", param_name), + msg=f"Parameter '{param_name}' should be of type number, got {type(param_value).__name__}", + type="type_error.number", + input=param_value + ) + ]) elif param_def.type == CypherParameterType.INTEGER: if not isinstance(param_value, int) or isinstance(param_value, bool): - raise SearchValidationError( - f"Parameter '{param_name}' should be of type integer, got {type(param_value).__name__}" - ) + raise SearchValidationError([ + ErrorDetails( + loc=("parameters", param_name), + msg=f"Parameter '{param_name}' should be of type integer, got {type(param_value).__name__}", + type="type_error.integer", + input=param_value + ) + ]) elif param_def.type == CypherParameterType.BOOLEAN: if not isinstance(param_value, bool): - raise SearchValidationError( - f"Parameter '{param_name}' should be of type boolean, got {type(param_value).__name__}" - ) + raise SearchValidationError([ + ErrorDetails( + loc=("parameters", param_name), + msg=f"Parameter '{param_name}' should be of type boolean, got {type(param_value).__name__}", + type="type_error.boolean", + input=param_value + ) + ]) elif param_def.type == CypherParameterType.ARRAY: if not isinstance(param_value, (list, tuple)): - raise SearchValidationError( - f"Parameter '{param_name}' should be of type array, got {type(param_value).__name__}" - ) + raise SearchValidationError([ + ErrorDetails( + loc=("parameters", param_name), + msg=f"Parameter '{param_name}' should be of type array, got {type(param_value).__name__}", + type="type_error.array", + input=param_value + ) + ]) def get_search_results(self, parameters: Dict[str, Any]) -> RawSearchResult: """ @@ -268,9 +327,23 @@ def get_search_results(self, parameters: Dict[str, Any]) -> RawSearchResult: routing_=neo4j.RoutingControl.READ, ) except CypherSyntaxError as e: - raise SearchValidationError(f"Cypher syntax error: {e.message}") from e + raise SearchValidationError([ + ErrorDetails( + loc=("query",), + msg=f"Cypher syntax error: {e.message}", + type="value_error.cypher_syntax", + input=self.query + ) + ]) from e except Exception as e: - raise SearchValidationError(f"Failed to execute query: {str(e)}") from e + raise SearchValidationError([ + ErrorDetails( + loc=("query",), + msg=f"Failed to execute query: {str(e)}", + type="execution_error", + input=self.query + ) + ]) from e return RawSearchResult( records=records, diff --git a/tests/e2e/retrievers/test_cypher_e2e.py b/tests/e2e/retrievers/test_cypher_e2e.py index 51eed554d..75282b633 100644 --- a/tests/e2e/retrievers/test_cypher_e2e.py +++ b/tests/e2e/retrievers/test_cypher_e2e.py @@ -18,6 +18,7 @@ from typing import Generator import neo4j +from neo4j import Record import pytest from neo4j.exceptions import Neo4jError @@ -84,7 +85,7 @@ def test_cypher_retriever_basic_query(driver: neo4j.Driver, sample_data: str) -> assert ( "Product1" in result.items[0].content or "Product2" in result.items[0].content ) - assert "cypher" in result.metadata + assert result.metadata is not None and "cypher" in result.metadata def test_cypher_retriever_multiple_parameters( @@ -180,7 +181,7 @@ def test_cypher_retriever_custom_formatter( """Test query with custom result formatter.""" # Custom formatter that extracts product info in a structured format - def product_formatter(record): + def product_formatter(record: Record) -> RetrieverResultItem: product = record["p"] return RetrieverResultItem( content=f"{product['name']} - ${product['price']}", @@ -207,6 +208,6 @@ def product_formatter(record): # Check custom formatting for item in result.items: assert " - $" in item.content - assert "price" in item.metadata - assert "stock" in item.metadata - assert "featured" in item.metadata + assert item.metadata is not None and "price" in item.metadata + assert item.metadata is not None and "stock" in item.metadata + assert item.metadata is not None and "featured" in item.metadata diff --git a/tests/unit/retrievers/test_cypher.py b/tests/unit/retrievers/test_cypher.py index edb67fb02..d1e1a155b 100644 --- a/tests/unit/retrievers/test_cypher.py +++ b/tests/unit/retrievers/test_cypher.py @@ -29,8 +29,13 @@ class TestCypherRetriever(unittest.TestCase): + # Define class attributes for mypy + patcher1: unittest.mock._patch[MagicMock] + patcher2: unittest.mock._patch[bool] + mock_check_driver: MagicMock + @classmethod - def setUpClass(cls): + def setUpClass(cls) -> None: # Patch the Neo4jDriverModel.check_driver method to pass validation with MagicMock cls.patcher1 = patch("neo4j_graphrag.types.Neo4jDriverModel.check_driver") cls.mock_check_driver = cls.patcher1.start() @@ -43,11 +48,11 @@ def setUpClass(cls): cls.patcher2.start() @classmethod - def tearDownClass(cls): + def tearDownClass(cls) -> None: cls.patcher1.stop() cls.patcher2.stop() - def setUp(self): + def setUp(self) -> None: # Create a mock driver self.driver = MagicMock(spec=neo4j.Driver) self.driver.execute_query.return_value = ( @@ -62,7 +67,7 @@ def setUp(self): "movie_title": {"type": "string", "description": "Title of a movie"} } - def test_init_success(self): + def test_init_success(self) -> None: # Test successful initialization retriever = CypherRetriever( driver=self.driver, @@ -72,7 +77,7 @@ def test_init_success(self): assert retriever.query == self.valid_query assert "movie_title" in retriever.parameters - def test_init_empty_query(self): + def test_init_empty_query(self) -> None: # Test initialization with empty query with pytest.raises(RetrieverInitializationError): CypherRetriever( @@ -81,7 +86,7 @@ def test_init_empty_query(self): parameters=self.valid_parameters, ) - def test_init_invalid_query(self): + def test_init_invalid_query(self) -> None: # Test initialization with invalid query with pytest.raises(RetrieverInitializationError): CypherRetriever( @@ -90,7 +95,7 @@ def test_init_invalid_query(self): parameters=self.valid_parameters, ) - def test_init_undefined_parameters(self): + def test_init_undefined_parameters(self) -> None: # Test initialization with undefined parameters in query with pytest.raises(RetrieverInitializationError): CypherRetriever( @@ -99,7 +104,7 @@ def test_init_undefined_parameters(self): parameters=self.valid_parameters, # Missing 'year' parameter ) - def test_init_invalid_parameter_type(self): + def test_init_invalid_parameter_type(self) -> None: # Test initialization with invalid parameter type with pytest.raises(RetrieverInitializationError): CypherRetriever( @@ -113,7 +118,7 @@ def test_init_invalid_parameter_type(self): }, ) - def test_search_success(self): + def test_search_success(self) -> None: # Test successful search retriever = CypherRetriever( driver=self.driver, @@ -128,7 +133,7 @@ def test_search_success(self): assert result.metadata and "cypher" in result.metadata assert result.metadata["cypher"] == self.valid_query - def test_search_missing_required_parameter(self): + def test_search_missing_required_parameter(self) -> None: # Test search with missing required parameter retriever = CypherRetriever( driver=self.driver, @@ -138,7 +143,7 @@ def test_search_missing_required_parameter(self): with pytest.raises(SearchValidationError): retriever.search(parameters={}) # Missing 'movie_title' - def test_search_unexpected_parameter(self): + def test_search_unexpected_parameter(self) -> None: # Test search with unexpected parameter retriever = CypherRetriever( driver=self.driver, @@ -150,7 +155,7 @@ def test_search_unexpected_parameter(self): parameters={"movie_title": "The Matrix", "year": 1999} ) # 'year' not defined - def test_search_type_mismatch(self): + def test_search_type_mismatch(self) -> None: # Test search with parameter type mismatch retriever = CypherRetriever( driver=self.driver, @@ -162,7 +167,7 @@ def test_search_type_mismatch(self): parameters={"movie_title": 123} ) # Integer, expected string - def test_different_parameter_types(self): + def test_different_parameter_types(self) -> None: # Test with different parameter types query = ( "MATCH (m:Movie) WHERE m.title = $title AND m.year = $year AND m.rating > $rating " @@ -243,9 +248,9 @@ def test_different_parameter_types(self): } ) - def test_custom_result_formatter(self): + def test_custom_result_formatter(self) -> None: # Test with custom result formatter - def custom_formatter(record): + def custom_formatter(record: Record) -> RetrieverResultItem: return RetrieverResultItem( content=f"Movie: {record['m']['title']}", metadata={"score": record["score"]}, @@ -260,9 +265,10 @@ def custom_formatter(record): result = retriever.search(parameters={"movie_title": "The Matrix"}) assert result.items[0].content == "Movie: Test Movie" - assert result.items[0].metadata["score"] == 0.9 + if result.items[0].metadata: + assert result.items[0].metadata.get("score") == 0.9 - def test_optional_parameters(self): + def test_optional_parameters(self) -> None: # Test with optional parameters query = "MATCH (m:Movie {title: $title}) WHERE m.year = $year RETURN m" parameters = { From a53db03926e228730feeb16ed441e1e93bacfb72 Mon Sep 17 00:00:00 2001 From: Oskar Hane Date: Fri, 14 Mar 2025 16:50:59 +0100 Subject: [PATCH 4/4] Fix ruff issues: remove unused import and apply formatting --- src/neo4j_graphrag/retrievers/cypher.py | 232 +++++++++++++----------- tests/unit/retrievers/test_cypher.py | 2 +- 2 files changed, 130 insertions(+), 104 deletions(-) diff --git a/src/neo4j_graphrag/retrievers/cypher.py b/src/neo4j_graphrag/retrievers/cypher.py index 369cf4766..0690c0e00 100644 --- a/src/neo4j_graphrag/retrievers/cypher.py +++ b/src/neo4j_graphrag/retrievers/cypher.py @@ -16,7 +16,7 @@ import logging import re -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, Optional import neo4j from pydantic_core import ErrorDetails @@ -109,12 +109,14 @@ def __init__( ) except ValidationError as e: raise RetrieverInitializationError( - [ErrorDetails( - loc=("parameters", param_name), - msg=f"Invalid parameter definition: {e.errors()}", - type="validation_error", - input=param_def - )] + [ + ErrorDetails( + loc=("parameters", param_name), + msg=f"Invalid parameter definition: {e.errors()}", + type="validation_error", + input=param_def, + ) + ] ) from e try: @@ -154,28 +156,32 @@ def _validate_cypher_query(self, query: str) -> None: """ # We can't fully validate the query without executing it, but we can check for basic syntax if not query.strip(): - raise RetrieverInitializationError([ - ErrorDetails( - loc=("query",), - msg="Query cannot be empty", - type="value_error.empty", - input="" - ) - ]) + raise RetrieverInitializationError( + [ + ErrorDetails( + loc=("query",), + msg="Query cannot be empty", + type="value_error.empty", + input="", + ) + ] + ) # Check for presence of common Cypher keywords if not any( keyword in query.upper() for keyword in ["MATCH", "RETURN", "CREATE", "MERGE", "WITH"] ): - raise RetrieverInitializationError([ - ErrorDetails( - loc=("query",), - msg="Query does not appear to be valid Cypher. It should contain at least one of: MATCH, RETURN, CREATE, MERGE, WITH", - type="value_error.invalid_cypher", - input="" - ) - ]) + raise RetrieverInitializationError( + [ + ErrorDetails( + loc=("query",), + msg="Query does not appear to be valid Cypher. It should contain at least one of: MATCH, RETURN, CREATE, MERGE, WITH", + type="value_error.invalid_cypher", + input="", + ) + ] + ) def _validate_query_parameters( self, query: str, parameters: Dict[str, CypherParameterDefinition] @@ -197,14 +203,16 @@ def _validate_query_parameters( # Check that all parameters in the query are defined undefined_params = query_params - set(parameters.keys()) if undefined_params: - raise RetrieverInitializationError([ - ErrorDetails( - loc=("parameters",), - msg=f"The following parameters are used in the query but not defined: {', '.join(undefined_params)}", - type="value_error.undefined_parameters", - input=undefined_params - ) - ]) + raise RetrieverInitializationError( + [ + ErrorDetails( + loc=("parameters",), + msg=f"The following parameters are used in the query but not defined: {', '.join(undefined_params)}", + type="value_error.undefined_parameters", + input=undefined_params, + ) + ] + ) def _validate_parameter_values(self, parameters: Dict[str, Any]) -> None: """ @@ -219,80 +227,94 @@ def _validate_parameter_values(self, parameters: Dict[str, Any]) -> None: # Check that all required parameters are provided for param_name, param_def in self.parameters.items(): if param_def.required and param_name not in parameters: - raise SearchValidationError([ - ErrorDetails( - loc=("parameters", param_name), - msg=f"Required parameter '{param_name}' is missing", - type="value_error.missing", - input=None - ) - ]) + raise SearchValidationError( + [ + ErrorDetails( + loc=("parameters", param_name), + msg=f"Required parameter '{param_name}' is missing", + type="value_error.missing", + input=None, + ) + ] + ) # Validate the type of each parameter for param_name, param_value in parameters.items(): if param_name not in self.parameters: - raise SearchValidationError([ - ErrorDetails( - loc=("parameters", param_name), - msg=f"Unexpected parameter: {param_name}", - type="value_error.unexpected", - input=param_name - ) - ]) + raise SearchValidationError( + [ + ErrorDetails( + loc=("parameters", param_name), + msg=f"Unexpected parameter: {param_name}", + type="value_error.unexpected", + input=param_name, + ) + ] + ) param_def = self.parameters[param_name] # Type validation if param_def.type == CypherParameterType.STRING: if not isinstance(param_value, str): - raise SearchValidationError([ - ErrorDetails( - loc=("parameters", param_name), - msg=f"Parameter '{param_name}' should be of type string, got {type(param_value).__name__}", - type="type_error.string", - input=param_value - ) - ]) + raise SearchValidationError( + [ + ErrorDetails( + loc=("parameters", param_name), + msg=f"Parameter '{param_name}' should be of type string, got {type(param_value).__name__}", + type="type_error.string", + input=param_value, + ) + ] + ) elif param_def.type == CypherParameterType.NUMBER: if not isinstance(param_value, (int, float)): - raise SearchValidationError([ - ErrorDetails( - loc=("parameters", param_name), - msg=f"Parameter '{param_name}' should be of type number, got {type(param_value).__name__}", - type="type_error.number", - input=param_value - ) - ]) + raise SearchValidationError( + [ + ErrorDetails( + loc=("parameters", param_name), + msg=f"Parameter '{param_name}' should be of type number, got {type(param_value).__name__}", + type="type_error.number", + input=param_value, + ) + ] + ) elif param_def.type == CypherParameterType.INTEGER: if not isinstance(param_value, int) or isinstance(param_value, bool): - raise SearchValidationError([ - ErrorDetails( - loc=("parameters", param_name), - msg=f"Parameter '{param_name}' should be of type integer, got {type(param_value).__name__}", - type="type_error.integer", - input=param_value - ) - ]) + raise SearchValidationError( + [ + ErrorDetails( + loc=("parameters", param_name), + msg=f"Parameter '{param_name}' should be of type integer, got {type(param_value).__name__}", + type="type_error.integer", + input=param_value, + ) + ] + ) elif param_def.type == CypherParameterType.BOOLEAN: if not isinstance(param_value, bool): - raise SearchValidationError([ - ErrorDetails( - loc=("parameters", param_name), - msg=f"Parameter '{param_name}' should be of type boolean, got {type(param_value).__name__}", - type="type_error.boolean", - input=param_value - ) - ]) + raise SearchValidationError( + [ + ErrorDetails( + loc=("parameters", param_name), + msg=f"Parameter '{param_name}' should be of type boolean, got {type(param_value).__name__}", + type="type_error.boolean", + input=param_value, + ) + ] + ) elif param_def.type == CypherParameterType.ARRAY: if not isinstance(param_value, (list, tuple)): - raise SearchValidationError([ - ErrorDetails( - loc=("parameters", param_name), - msg=f"Parameter '{param_name}' should be of type array, got {type(param_value).__name__}", - type="type_error.array", - input=param_value - ) - ]) + raise SearchValidationError( + [ + ErrorDetails( + loc=("parameters", param_name), + msg=f"Parameter '{param_name}' should be of type array, got {type(param_value).__name__}", + type="type_error.array", + input=param_value, + ) + ] + ) def get_search_results(self, parameters: Dict[str, Any]) -> RawSearchResult: """ @@ -327,23 +349,27 @@ def get_search_results(self, parameters: Dict[str, Any]) -> RawSearchResult: routing_=neo4j.RoutingControl.READ, ) except CypherSyntaxError as e: - raise SearchValidationError([ - ErrorDetails( - loc=("query",), - msg=f"Cypher syntax error: {e.message}", - type="value_error.cypher_syntax", - input=self.query - ) - ]) from e + raise SearchValidationError( + [ + ErrorDetails( + loc=("query",), + msg=f"Cypher syntax error: {e.message}", + type="value_error.cypher_syntax", + input=self.query, + ) + ] + ) from e except Exception as e: - raise SearchValidationError([ - ErrorDetails( - loc=("query",), - msg=f"Failed to execute query: {str(e)}", - type="execution_error", - input=self.query - ) - ]) from e + raise SearchValidationError( + [ + ErrorDetails( + loc=("query",), + msg=f"Failed to execute query: {str(e)}", + type="execution_error", + input=self.query, + ) + ] + ) from e return RawSearchResult( records=records, diff --git a/tests/unit/retrievers/test_cypher.py b/tests/unit/retrievers/test_cypher.py index d1e1a155b..cc1a1c86d 100644 --- a/tests/unit/retrievers/test_cypher.py +++ b/tests/unit/retrievers/test_cypher.py @@ -33,7 +33,7 @@ class TestCypherRetriever(unittest.TestCase): patcher1: unittest.mock._patch[MagicMock] patcher2: unittest.mock._patch[bool] mock_check_driver: MagicMock - + @classmethod def setUpClass(cls) -> None: # Patch the Neo4jDriverModel.check_driver method to pass validation with MagicMock