diff --git a/CHANGELOG.md b/CHANGELOG.md index cbd83dedd..c5eb72654 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,12 @@ - The `SchemaProperty` model has been renamed `PropertyType`. - `SchemaConfig` has been removed in favor of `GraphSchema` (used in the `SchemaBuilder` and `EntityRelationExtractor` classes). `entities`, `relations` and `potential_schema` fields have also been renamed `node_types`, `relationship_types` and `patterns` respectively. +#### Other + +- The reserved `id` property on `__KGBuilder__` nodes is removed. +- The `chunk_index` property on `__Entity__` nodes is removed. Use the `FROM_CHUNK` relationship instead. +- The `__entity__id` index is not used anymore and can be dropped from the database (it has been replaced by `__entity__tmp_internal_id`). + ## 1.7.0 diff --git a/docs/source/user_guide_kg_builder.rst b/docs/source/user_guide_kg_builder.rst index f810d9ed2..366f285c9 100644 --- a/docs/source/user_guide_kg_builder.rst +++ b/docs/source/user_guide_kg_builder.rst @@ -1048,8 +1048,10 @@ _________________ In addition to the user-defined configuration options described above, the `GraphPruning` component performs the following cleanup operations: +- Nodes with empty label or ID are pruned. - Nodes with missing required properties are pruned. - Nodes with no remaining properties are pruned. +- Relationships with empty type are pruned. - Relationships with invalid source or target nodes (i.e., nodes no longer present in the graph) are pruned. - Relationships with incorrect direction have their direction corrected. @@ -1078,6 +1080,11 @@ to a Neo4j database: Adjust the batch_size parameter of `Neo4jWriter` to optimize insert performance. This parameter controls the number of nodes or relationships inserted per batch, with a default value of 1000. +.. note:: Index + + In order to improve the ingestion performances, an index called `__entity__tmp_internal_id` is automatically added to the database. + + See :ref:`neo4jgraph`. diff --git a/examples/customize/build_graph/components/pruners/graph_pruner.py b/examples/customize/build_graph/components/pruners/graph_pruner.py index 7273188ad..c40924f41 100644 --- a/examples/customize/build_graph/components/pruners/graph_pruner.py +++ b/examples/customize/build_graph/components/pruners/graph_pruner.py @@ -1,4 +1,4 @@ -"""This example demonstrates how to use the GraphPruner component.""" +"""This example demonstrates how to use the GraphPruning component.""" import asyncio diff --git a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py index b1ed9bbb1..9ee61fed8 100644 --- a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py +++ b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py @@ -146,12 +146,11 @@ def update_ids( """Make node IDs unique across chunks, document and pipeline runs by prefixing them with a unique prefix. """ - prefix = f"{chunk.chunk_id}" + prefix = chunk.chunk_id for node in graph.nodes: node.id = f"{prefix}:{node.id}" if node.properties is None: node.properties = {} - node.properties.update({"chunk_index": chunk.index}) for rel in graph.relationships: rel.start_node_id = f"{prefix}:{rel.start_node_id}" rel.end_node_id = f"{prefix}:{rel.end_node_id}" diff --git a/src/neo4j_graphrag/experimental/components/graph_pruning.py b/src/neo4j_graphrag/experimental/components/graph_pruning.py index d49f48ce0..c8bf647f3 100644 --- a/src/neo4j_graphrag/experimental/components/graph_pruning.py +++ b/src/neo4j_graphrag/experimental/components/graph_pruning.py @@ -41,6 +41,7 @@ class PruningReason(str, enum.Enum): NO_PROPERTY_LEFT = "NO_PROPERTY_LEFT" INVALID_START_OR_END_NODE = "INVALID_START_OR_END_NODE" INVALID_PATTERN = "INVALID_PATTERN" + MISSING_LABEL = "MISSING_LABEL" ItemType = TypeVar("ItemType") @@ -198,6 +199,17 @@ def _validate_node( schema_entity: Optional[NodeType], additional_node_types: bool, ) -> Optional[Neo4jNode]: + if not node.label: + pruning_stats.add_pruned_node(node, reason=PruningReason.MISSING_LABEL) + return None + if not node.id: + pruning_stats.add_pruned_node( + node, + reason=PruningReason.MISSING_REQUIRED_PROPERTY, + missing_required_properties=["id"], + details="The node was extracted without a valid ID.", + ) + return None if not schema_entity: # node type not declared in the schema if additional_node_types: @@ -262,6 +274,11 @@ def _validate_relationship( patterns: tuple[tuple[str, str, str], ...], additional_patterns: bool, ) -> Optional[Neo4jRelationship]: + if not rel.type: + pruning_stats.add_pruned_relationship( + rel, reason=PruningReason.MISSING_LABEL + ) + return None # validate start/end node IDs are valid nodes if rel.start_node_id not in valid_nodes or rel.end_node_id not in valid_nodes: logger.debug( diff --git a/src/neo4j_graphrag/experimental/components/kg_writer.py b/src/neo4j_graphrag/experimental/components/kg_writer.py index 4a859b36e..1f46e5ab3 100644 --- a/src/neo4j_graphrag/experimental/components/kg_writer.py +++ b/src/neo4j_graphrag/experimental/components/kg_writer.py @@ -29,10 +29,9 @@ ) from neo4j_graphrag.experimental.pipeline.component import Component, DataModel from neo4j_graphrag.neo4j_queries import ( - UPSERT_NODE_QUERY, - UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE, - UPSERT_RELATIONSHIP_QUERY, - UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE, + upsert_node_query, + upsert_relationship_query, + db_cleaning_query, ) from neo4j_graphrag.utils.version_utils import ( get_version, @@ -117,20 +116,19 @@ def __init__( driver: neo4j.Driver, neo4j_database: Optional[str] = None, batch_size: int = 1000, + clean_db: bool = True, ): self.driver = driver_config.override_user_agent(driver) self.neo4j_database = neo4j_database self.batch_size = batch_size + self._clean_db = clean_db version_tuple, _, _ = get_version(self.driver, self.neo4j_database) self.is_version_5_23_or_above = is_version_5_23_or_above(version_tuple) def _db_setup(self) -> None: - # create index on __KGBuilder__.id - # used when creating the relationships - self.driver.execute_query( - "CREATE INDEX __entity__id IF NOT EXISTS FOR (n:__KGBuilder__) ON (n.id)", - database_=self.neo4j_database, - ) + self.driver.execute_query(""" + CREATE INDEX __entity__tmp_internal_id IF NOT EXISTS FOR (n:__KGBuilder__) ON (n.__tmp_internal_id) + """) @staticmethod def _nodes_to_rows( @@ -149,44 +147,51 @@ def _nodes_to_rows( def _upsert_nodes( self, nodes: list[Neo4jNode], lexical_graph_config: LexicalGraphConfig ) -> None: - """Upserts a single node into the Neo4j database." + """Upserts a batch of nodes into the Neo4j database. Args: nodes (list[Neo4jNode]): The nodes batch to upsert into the database. """ parameters = {"rows": self._nodes_to_rows(nodes, lexical_graph_config)} - if self.is_version_5_23_or_above: - self.driver.execute_query( - UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE, - parameters_=parameters, - database_=self.neo4j_database, - ) - else: - self.driver.execute_query( - UPSERT_NODE_QUERY, - parameters_=parameters, - database_=self.neo4j_database, - ) + query = upsert_node_query( + support_variable_scope_clause=self.is_version_5_23_or_above + ) + self.driver.execute_query( + query, + parameters_=parameters, + database_=self.neo4j_database, + ) + return None + + @staticmethod + def _relationships_to_rows( + relationships: list[Neo4jRelationship], + ) -> list[dict[str, Any]]: + return [relationship.model_dump() for relationship in relationships] def _upsert_relationships(self, rels: list[Neo4jRelationship]) -> None: - """Upserts a single relationship into the Neo4j database. + """Upserts a batch of relationships into the Neo4j database. Args: rels (list[Neo4jRelationship]): The relationships batch to upsert into the database. """ - parameters = {"rows": [rel.model_dump() for rel in rels]} - if self.is_version_5_23_or_above: - self.driver.execute_query( - UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE, - parameters_=parameters, - database_=self.neo4j_database, - ) - else: - self.driver.execute_query( - UPSERT_RELATIONSHIP_QUERY, - parameters_=parameters, - database_=self.neo4j_database, - ) + parameters = {"rows": self._relationships_to_rows(rels)} + query = upsert_relationship_query( + support_variable_scope_clause=self.is_version_5_23_or_above + ) + self.driver.execute_query( + query, + parameters_=parameters, + database_=self.neo4j_database, + ) + + def _db_cleaning(self) -> None: + query = db_cleaning_query( + support_variable_scope_clause=self.is_version_5_23_or_above, + batch_size=self.batch_size, + ) + with self.driver.session() as session: + session.run(query) @validate_call async def run( @@ -209,6 +214,9 @@ async def run( for batch in batched(graph.relationships, self.batch_size): self._upsert_relationships(batch) + if self._clean_db: + self._db_cleaning() + return KGWriterModel( status="SUCCESS", metadata={ diff --git a/src/neo4j_graphrag/experimental/components/types.py b/src/neo4j_graphrag/experimental/components/types.py index 3c07d401c..363767ef3 100644 --- a/src/neo4j_graphrag/experimental/components/types.py +++ b/src/neo4j_graphrag/experimental/components/types.py @@ -14,14 +14,18 @@ # limitations under the License. from __future__ import annotations +import logging import uuid from typing import Any, Dict, Optional -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field from neo4j_graphrag.experimental.pipeline.component import DataModel +logger = logging.getLogger(__name__) + + class DocumentInfo(DataModel): """A document loaded by a DataLoader. @@ -79,7 +83,7 @@ class Neo4jNode(BaseModel): """Represents a Neo4j node. Attributes: - id (str): The element ID of the node. + id (str): The ID of the node. This ID is used to refer to the node for relationship creation. label (str): The label of the node. properties (dict[str, Any]): A dictionary of properties attached to the node. embedding_properties (Optional[dict[str, list[float]]]): A list of embedding properties attached to the node. @@ -90,15 +94,6 @@ class Neo4jNode(BaseModel): properties: dict[str, Any] = {} embedding_properties: Optional[dict[str, list[float]]] = None - @field_validator("properties", "embedding_properties") - @classmethod - def check_for_id_properties( - cls, v: Optional[dict[str, Any]] - ) -> Optional[dict[str, Any]]: - if v and "id" in v.keys(): - raise TypeError("'id' as a property name is not allowed") - return v - @property def token(self) -> str: return self.label diff --git a/src/neo4j_graphrag/neo4j_queries.py b/src/neo4j_graphrag/neo4j_queries.py index 4e6e4ef31..a0e9b419e 100644 --- a/src/neo4j_graphrag/neo4j_queries.py +++ b/src/neo4j_graphrag/neo4j_queries.py @@ -52,61 +52,90 @@ "YIELD node, score" ) -UPSERT_NODE_QUERY = ( - "UNWIND $rows AS row " - "CREATE (n:__KGBuilder__ {id: row.id}) " - "SET n += row.properties " - "WITH n, row CALL apoc.create.addLabels(n, row.labels) YIELD node " - "WITH node as n, row CALL { " - "WITH n, row WITH n, row WHERE row.embedding_properties IS NOT NULL " - "UNWIND keys(row.embedding_properties) as emb " - "CALL db.create.setNodeVectorProperty(n, emb, row.embedding_properties[emb]) " - "RETURN count(*) as nbEmb " - "} " - "RETURN elementId(n)" -) -UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE = ( - "UNWIND $rows AS row " - "CREATE (n:__KGBuilder__ {id: row.id}) " - "SET n += row.properties " - "WITH n, row CALL apoc.create.addLabels(n, row.labels) YIELD node " - "WITH node as n, row CALL (n, row) { " - "WITH n, row WITH n, row WHERE row.embedding_properties IS NOT NULL " - "UNWIND keys(row.embedding_properties) as emb " - "CALL db.create.setNodeVectorProperty(n, emb, row.embedding_properties[emb]) " - "RETURN count(*) as nbEmb " - "} " - "RETURN elementId(n)" -) +def _call_subquery_syntax( + support_variable_scope_clause: bool, variable_list: list[str] +) -> str: + """A helper function to return the CALL subquery syntax: + - Either CALL { WITH + - or CALL (variables) { + """ + variables = ",".join(variable_list) + if support_variable_scope_clause: + return f"CALL ({variables}) {{ " + if variables: + return f"CALL {{ WITH {variables} " + return "CALL { " + + +def upsert_node_query(support_variable_scope_clause: bool) -> str: + """Build the Cypher query to upsert a batch of nodes: + - Create the new node + - Set its label(s) and properties + - Set its embedding properties if any + - Return the node elementId + """ + call_prefix = _call_subquery_syntax( + support_variable_scope_clause, variable_list=["n", "row"] + ) + return ( + "UNWIND $rows AS row " + "CREATE (n:__KGBuilder__ {__tmp_internal_id: row.id}) " + "SET n += row.properties " + "WITH n, row CALL apoc.create.addLabels(n, row.labels) YIELD node " + "WITH node as n, row " + f"{call_prefix} " + "WITH n, row WHERE row.embedding_properties IS NOT NULL " + "UNWIND keys(row.embedding_properties) as emb " + "CALL db.create.setNodeVectorProperty(n, emb, row.embedding_properties[emb]) " + "RETURN count(*) as nbEmb " + "} " + "RETURN elementId(n) as element_id" + ) -UPSERT_RELATIONSHIP_QUERY = ( - "UNWIND $rows as row " - "MATCH (start:__KGBuilder__ {id: row.start_node_id}) " - "MATCH (end:__KGBuilder__ {id: row.end_node_id}) " - "WITH start, end, row " - "CALL apoc.merge.relationship(start, row.type, {}, row.properties, end, row.properties) YIELD rel " - "WITH rel, row CALL { " - "WITH rel, row WITH rel, row WHERE row.embedding_properties IS NOT NULL " - "UNWIND keys(row.embedding_properties) as emb " - "CALL db.create.setRelationshipVectorProperty(rel, emb, row.embedding_properties[emb]) " - "} " - "RETURN elementId(rel)" -) -UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE = ( - "UNWIND $rows as row " - "MATCH (start:__KGBuilder__ {id: row.start_node_id}) " - "MATCH (end:__KGBuilder__ {id: row.end_node_id}) " - "WITH start, end, row " - "CALL apoc.merge.relationship(start, row.type, {}, row.properties, end, row.properties) YIELD rel " - "WITH rel, row CALL (rel, row) { " - "WITH rel, row WITH rel, row WHERE row.embedding_properties IS NOT NULL " - "UNWIND keys(row.embedding_properties) as emb " - "CALL db.create.setRelationshipVectorProperty(rel, emb, row.embedding_properties[emb]) " - "} " - "RETURN elementId(rel)" -) +def upsert_relationship_query(support_variable_scope_clause: bool) -> str: + """Build the Cypher query to upsert a batch of relationships: + - Create the new relationship: + only one relationship of a specific type is allowed between the same two nodes + - Set its properties + - Set its embedding properties if any + - Return the node elementId + """ + call_prefix = _call_subquery_syntax( + support_variable_scope_clause, variable_list=["rel", "row"] + ) + return ( + "UNWIND $rows as row " + "MATCH (start:__KGBuilder__ {__tmp_internal_id: row.start_node_id}), " + " (end:__KGBuilder__ {__tmp_internal_id: row.end_node_id}) " + "WITH start, end, row " + "CALL apoc.merge.relationship(start, row.type, {}, row.properties, end, row.properties) YIELD rel " + "WITH rel, row " + f"{call_prefix} " + "WITH rel, row WHERE row.embedding_properties IS NOT NULL " + "UNWIND keys(row.embedding_properties) as emb " + "CALL db.create.setRelationshipVectorProperty(rel, emb, row.embedding_properties[emb]) " + "} " + "RETURN elementId(rel)" + ) + + +def db_cleaning_query(support_variable_scope_clause: bool, batch_size: int) -> str: + """Removes the temporary __tmp_internal_id property from all nodes.""" + call_prefix = _call_subquery_syntax( + support_variable_scope_clause, variable_list=["n"] + ) + return ( + "MATCH (n:__KGBuilder__) " + "WHERE n.__tmp_internal_id IS NOT NULL " + f"{call_prefix} " + " SET n.__tmp_internal_id = NULL " + "} " + f"IN TRANSACTIONS OF {batch_size} ROWS " + "ON ERROR CONTINUE" + ) + # Deprecated, remove along with upsert_vector UPSERT_VECTOR_ON_NODE_QUERY = ( @@ -150,13 +179,15 @@ def _get_hybrid_query(neo4j_version_is_5_23_or_above: bool) -> str: Construct a cypher query for hybrid search. Args: - neo4j_version_is_5_23_or_above (bool): Whether or not the Neo4j version is 5.23 or above; + neo4j_version_is_5_23_or_above (bool): Whether the Neo4j version is 5.23 or above; determines which call syntax is used. Returns: str: The constructed Cypher query string. """ - call_prefix = "CALL () { " if neo4j_version_is_5_23_or_above else "CALL { " + call_prefix = _call_subquery_syntax( + neo4j_version_is_5_23_or_above, variable_list=[] + ) query_body = ( f"{NODE_VECTOR_INDEX_QUERY} " "WITH collect({node:node, score:score}) AS nodes, max(score) AS vector_index_max_score " diff --git a/tests/e2e/docker-compose.yml b/tests/e2e/docker-compose.yml index 6ceda8af7..52c9f4db4 100644 --- a/tests/e2e/docker-compose.yml +++ b/tests/e2e/docker-compose.yml @@ -32,8 +32,9 @@ services: - 7474:7474 environment: NEO4J_AUTH: neo4j/password - NEO4J_ACCEPT_LICENSE_AGREEMENT: "eval" + NEO4J_ACCEPT_LICENSE_AGREEMENT: "yes" NEO4J_PLUGINS: "[\"apoc\"]" + NEO4J_server_memory_heap_max__size: 6G qdrant: image: qdrant/qdrant ports: diff --git a/tests/e2e/experimental/test_kg_writer_component_e2e.py b/tests/e2e/experimental/test_kg_writer_component_e2e.py index 75f8e5578..38aab3f1d 100644 --- a/tests/e2e/experimental/test_kg_writer_component_e2e.py +++ b/tests/e2e/experimental/test_kg_writer_component_e2e.py @@ -30,13 +30,13 @@ async def test_kg_writer(driver: neo4j.Driver) -> None: start_node = Neo4jNode( id="1", label="MyLabel", - properties={"chunk": 1}, + properties={"id": "abc"}, embedding_properties={"vectorProperty": [1.0, 2.0, 3.0]}, ) end_node = Neo4jNode( id="2", label="MyLabel", - properties={}, + properties={"id": "def"}, embedding_properties=None, ) relationship = Neo4jRelationship( @@ -45,7 +45,7 @@ async def test_kg_writer(driver: neo4j.Driver) -> None: node_with_two_embeddings = Neo4jNode( id="3", label="MyLabel", - properties={"chunk": 1}, + properties={"id": "ghi"}, embedding_properties={ "vectorProperty": [1.0, 2.0, 3.0], "otherVectorProperty": [10.0, 20.0, 30.0], @@ -61,7 +61,7 @@ async def test_kg_writer(driver: neo4j.Driver) -> None: assert res.status == "SUCCESS" query = """ - MATCH (a:MyLabel {id: '1'})-[r:MY_RELATIONSHIP]->(b:MyLabel {id: '2'}) + MATCH (a:MyLabel {id: 'abc'})-[r:MY_RELATIONSHIP]->(b:MyLabel {id: 'def'}) RETURN a, r, b """ record = driver.execute_query(query).records[0] @@ -69,7 +69,7 @@ async def test_kg_writer(driver: neo4j.Driver) -> None: node_a = record["a"] assert start_node.label in list(node_a.labels) - assert start_node.id == str(node_a.get("id")) + assert start_node.properties.get("id") == str(node_a.get("id")) for key, val in start_node.properties.items(): assert key in node_a.keys() assert val == node_a.get(key) @@ -80,18 +80,18 @@ async def test_kg_writer(driver: neo4j.Driver) -> None: node_b = record["b"] assert end_node.label in list(node_b.labels) - assert end_node.id == str(node_b.get("id")) + assert end_node.properties.get("id") == str(node_b.get("id")) for key, val in end_node.properties.items(): assert key in node_b.keys() assert val == node_b.get(key) rel = record["r"] assert rel.type == relationship.type - assert relationship.start_node_id == rel.start_node.get("id") - assert relationship.end_node_id == rel.end_node.get("id") + assert rel.start_node.get("id") == start_node.properties.get("id") + assert rel.end_node.get("id") == end_node.properties.get("id") query = """ - MATCH (c:MyLabel {id: '3'}) + MATCH (c:MyLabel {id: 'ghi'}) RETURN c """ records = driver.execute_query(query).records diff --git a/tests/unit/experimental/components/test_entity_relation_extractor.py b/tests/unit/experimental/components/test_entity_relation_extractor.py index f76ab5c9c..a93bf0bcf 100644 --- a/tests/unit/experimental/components/test_entity_relation_extractor.py +++ b/tests/unit/experimental/components/test_entity_relation_extractor.py @@ -109,7 +109,6 @@ async def test_extractor_happy_path_non_empty_result() -> None: entity = result.nodes[2] assert entity.id == f"{chunk_entity.id}:0" assert entity.label == "Person" - assert entity.properties == {"chunk_index": 0} assert len(result.relationships) == 2 assert result.relationships[0].type == "FROM_DOCUMENT" assert result.relationships[0].start_node_id == f"{chunk_entity.id}" @@ -213,7 +212,6 @@ async def test_extractor_llm_badly_formatted_json_gets_fixed() -> None: assert len(res.nodes) == 1 assert res.nodes[0].label == "Person" - assert res.nodes[0].properties == {"chunk_index": 0} assert res.nodes[0].embedding_properties is None assert res.relationships == [] diff --git a/tests/unit/experimental/components/test_graph_pruning.py b/tests/unit/experimental/components/test_graph_pruning.py index 04730f2e4..c4c779cc4 100644 --- a/tests/unit/experimental/components/test_graph_pruning.py +++ b/tests/unit/experimental/components/test_graph_pruning.py @@ -148,6 +148,20 @@ def node_type_required_name() -> NodeType: True, Neo4jNode(id="1", label="Location", properties={"name": "New York"}), ), + # node label not valid + ( + Neo4jNode(id="1", label="", properties={"name": "New York"}), + "node_type_required_name", + True, + None, + ), + # node ID not valid + ( + Neo4jNode(id="", label="Location", properties={"name": "New York"}), + "node_type_required_name", + True, + None, + ), ], ) def test_graph_pruning_validate_node( @@ -193,6 +207,16 @@ def neo4j_relationship() -> Neo4jRelationship: ) +@pytest.fixture +def neo4j_relationship_invalid_type() -> Neo4jRelationship: + return Neo4jRelationship( + start_node_id="1", + end_node_id="2", + type="", + properties={}, + ) + + @pytest.fixture def neo4j_reversed_relationship( neo4j_relationship: Neo4jRelationship, @@ -253,7 +277,7 @@ def neo4j_reversed_relationship( True, # additional_patterns None, ), - # invalid type addition allowed + # invalid type, addition allowed ( "neo4j_relationship", { @@ -266,7 +290,7 @@ def neo4j_reversed_relationship( True, # additional_patterns "neo4j_relationship", ), - # invalid type addition allowed but invalid node ID + # invalid type, addition allowed but invalid node ID ( "neo4j_relationship", { @@ -278,7 +302,7 @@ def neo4j_reversed_relationship( True, # additional_patterns None, ), - # invalid_type_addition_not_allowed + # invalid type, addition not allowed ( "neo4j_relationship", { @@ -321,6 +345,21 @@ def neo4j_reversed_relationship( False, # additional_patterns None, ), + # invalid extracted type + ( + "neo4j_relationship_invalid_type", # relationship, + { # valid_nodes + "1": "Person", + "2": "Location", + }, + RelationshipType( # relationship_type + label="REL", + ), + True, # additional_relationship_types + (("Person", "REL", "Location"),), # patterns + True, # additional_patterns + None, # expected_relationship + ), ], ) def test_graph_pruning_validate_relationship( diff --git a/tests/unit/experimental/components/test_kg_writer.py b/tests/unit/experimental/components/test_kg_writer.py index f149a264f..5473078af 100644 --- a/tests/unit/experimental/components/test_kg_writer.py +++ b/tests/unit/experimental/components/test_kg_writer.py @@ -26,10 +26,8 @@ Neo4jRelationship, ) from neo4j_graphrag.neo4j_queries import ( - UPSERT_NODE_QUERY, - UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE, - UPSERT_RELATIONSHIP_QUERY, - UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE, + upsert_node_query, + upsert_relationship_query, ) @@ -56,11 +54,16 @@ def test_batched() -> None: return_value=None, ) def test_upsert_nodes(_: Mock, driver: MagicMock) -> None: + driver.execute_query.return_value = ( + [{"element_id": "#1"}], + None, + None, + ) neo4j_writer = Neo4jWriter(driver=driver) node = Neo4jNode(id="1", label="Label", properties={"key": "value"}) neo4j_writer._upsert_nodes(nodes=[node], lexical_graph_config=LexicalGraphConfig()) driver.execute_query.assert_called_once_with( - UPSERT_NODE_QUERY, + upsert_node_query(False), parameters_={ "rows": [ { @@ -88,6 +91,11 @@ def test_upsert_nodes_with_embedding( _: Mock, driver: MagicMock, ) -> None: + driver.execute_query.return_value = ( + [{"element_id": "#1"}], + None, + None, + ) neo4j_writer = Neo4jWriter(driver=driver) node = Neo4jNode( id="1", @@ -95,10 +103,9 @@ def test_upsert_nodes_with_embedding( properties={"key": "value"}, embedding_properties={"embeddingProp": [1.0, 2.0, 3.0]}, ) - driver.execute_query.return_value.records = [{"elementId(n)": 1}] neo4j_writer._upsert_nodes(nodes=[node], lexical_graph_config=LexicalGraphConfig()) driver.execute_query.assert_any_call( - UPSERT_NODE_QUERY, + upsert_node_query(False), parameters_={ "rows": [ { @@ -130,7 +137,9 @@ def test_upsert_relationship(_: Mock, driver: MagicMock) -> None: type="RELATIONSHIP", properties={"key": "value"}, ) - neo4j_writer._upsert_relationships(rels=[rel]) + neo4j_writer._upsert_relationships( + rels=[rel], + ) parameters = { "rows": [ { @@ -143,7 +152,7 @@ def test_upsert_relationship(_: Mock, driver: MagicMock) -> None: ] } driver.execute_query.assert_called_once_with( - UPSERT_RELATIONSHIP_QUERY, + upsert_relationship_query(False), parameters_=parameters, database_=None, ) @@ -167,7 +176,9 @@ def test_upsert_relationship_with_embedding(_: Mock, driver: MagicMock) -> None: embedding_properties={"embeddingProp": [1.0, 2.0, 3.0]}, ) driver.execute_query.return_value.records = [{"elementId(r)": "rel_elem_id"}] - neo4j_writer._upsert_relationships(rels=[rel]) + neo4j_writer._upsert_relationships( + rels=[rel], + ) parameters = { "rows": [ { @@ -180,7 +191,7 @@ def test_upsert_relationship_with_embedding(_: Mock, driver: MagicMock) -> None: ] } driver.execute_query.assert_any_call( - UPSERT_RELATIONSHIP_QUERY, + upsert_relationship_query(False), parameters_=parameters, database_=None, ) @@ -196,13 +207,21 @@ def test_upsert_relationship_with_embedding(_: Mock, driver: MagicMock) -> None: return_value=None, ) async def test_run(_: Mock, driver: MagicMock) -> None: + driver.execute_query.return_value = ( + [ + {"element_id": "#1"}, + {"element_id": "#2"}, + ], + None, + None, + ) neo4j_writer = Neo4jWriter(driver=driver) node = Neo4jNode(id="1", label="Label") rel = Neo4jRelationship(start_node_id="1", end_node_id="2", type="RELATIONSHIP") graph = Neo4jGraph(nodes=[node], relationships=[rel]) await neo4j_writer.run(graph=graph) driver.execute_query.assert_any_call( - UPSERT_NODE_QUERY, + upsert_node_query(False), parameters_={ "rows": [ { @@ -228,7 +247,7 @@ async def test_run(_: Mock, driver: MagicMock) -> None: ] } driver.execute_query.assert_any_call( - UPSERT_RELATIONSHIP_QUERY, + upsert_relationship_query(False), parameters_=parameters_, database_=None, ) @@ -242,7 +261,14 @@ async def test_run(_: Mock, driver: MagicMock) -> None: async def test_run_is_version_below_5_23(_: Mock) -> None: driver = MagicMock() driver.execute_query = Mock( - return_value=([{"versions": ["5.22.0"], "edition": "enterprise"}], None, None) + side_effect=( + # get_version + ([{"versions": ["5.22.0"], "edition": "enterpise"}], None, None), + # upsert nodes + ([{"_internal_id": "1", "element_id": "#1"}], None, None), + # upsert relationships + (None, None, None), + ) ) neo4j_writer = Neo4jWriter(driver=driver) @@ -253,7 +279,7 @@ async def test_run_is_version_below_5_23(_: Mock) -> None: await neo4j_writer.run(graph=graph) driver.execute_query.assert_any_call( - UPSERT_NODE_QUERY, + upsert_node_query(False), parameters_={ "rows": [ { @@ -279,7 +305,7 @@ async def test_run_is_version_below_5_23(_: Mock) -> None: ] } driver.execute_query.assert_any_call( - UPSERT_RELATIONSHIP_QUERY, + upsert_relationship_query(False), parameters_=parameters_, database_=None, ) @@ -293,7 +319,14 @@ async def test_run_is_version_below_5_23(_: Mock) -> None: async def test_run_is_version_5_23_or_above(_: Mock) -> None: driver = MagicMock() driver.execute_query = Mock( - return_value=([{"versions": ["5.23.0"], "edition": "enterpise"}], None, None) + side_effect=( + # get_version + ([{"versions": ["5.23.0"], "edition": "enterpise"}], None, None), + # upsert nodes + ([{"element_id": "#1"}], None, None), + # upsert relationships + (None, None, None), + ) ) neo4j_writer = Neo4jWriter(driver=driver) @@ -305,7 +338,7 @@ async def test_run_is_version_5_23_or_above(_: Mock) -> None: await neo4j_writer.run(graph=graph) driver.execute_query.assert_any_call( - UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE, + upsert_node_query(True), parameters_={ "rows": [ { @@ -331,7 +364,7 @@ async def test_run_is_version_5_23_or_above(_: Mock) -> None: ] } driver.execute_query.assert_any_call( - UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE, + upsert_relationship_query(True), parameters_=parameters_, database_=None, ) diff --git a/tests/unit/experimental/components/test_types.py b/tests/unit/experimental/components/test_types.py deleted file mode 100644 index 4e6d7766b..000000000 --- a/tests/unit/experimental/components/test_types.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright (c) "Neo4j" -# Neo4j Sweden AB [https://neo4j.com] -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# https://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import pytest -from neo4j_graphrag.experimental.components.types import Neo4jNode - - -def test_neo4j_node_invalid_property() -> None: - with pytest.raises(TypeError) as excinfo: - Neo4jNode(id="0", label="Label", properties={"id": "1"}) - assert "'id' as a property name is not allowed" in str(excinfo) - - -def test_neo4j_node_invalid_embedding_property() -> None: - with pytest.raises(TypeError) as excinfo: - Neo4jNode(id="0", label="Label", embedding_properties={"id": [1.0, 2.0, 3.0]}) - assert "'id' as a property name is not allowed" in str(excinfo)