From a72c8879eedb7bea13299f3e98a214ae315ab2c5 Mon Sep 17 00:00:00 2001 From: estelle Date: Thu, 19 Jun 2025 17:25:15 +0200 Subject: [PATCH 1/9] Allow an ID property by renaming our internal ID to _kg_builder_id --- CHANGELOG.md | 4 +++ .../experimental/components/types.py | 17 ++++------- src/neo4j_graphrag/neo4j_queries.py | 12 ++++---- .../test_kg_writer_component_e2e.py | 12 ++++---- .../experimental/components/test_types.py | 28 ------------------- 5 files changed, 22 insertions(+), 51 deletions(-) delete mode 100644 tests/unit/experimental/components/test_types.py diff --git a/CHANGELOG.md b/CHANGELOG.md index cbd83dedd..2d0336ce8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,10 @@ - The `SchemaProperty` model has been renamed `PropertyType`. - `SchemaConfig` has been removed in favor of `GraphSchema` (used in the `SchemaBuilder` and `EntityRelationExtractor` classes). `entities`, `relations` and `potential_schema` fields have also been renamed `node_types`, `relationship_types` and `patterns` respectively. +#### Other + +- The node internal `id` property that's used to create relationships between nodes has been renamed to `__kg_builder_id`. This releases the `id` name for domain-specific meaningful `id` property. + ## 1.7.0 diff --git a/src/neo4j_graphrag/experimental/components/types.py b/src/neo4j_graphrag/experimental/components/types.py index 3c07d401c..f761f8e6e 100644 --- a/src/neo4j_graphrag/experimental/components/types.py +++ b/src/neo4j_graphrag/experimental/components/types.py @@ -14,14 +14,18 @@ # limitations under the License. from __future__ import annotations +import logging import uuid from typing import Any, Dict, Optional -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field from neo4j_graphrag.experimental.pipeline.component import DataModel +logger = logging.getLogger(__name__) + + class DocumentInfo(DataModel): """A document loaded by a DataLoader. @@ -79,7 +83,7 @@ class Neo4jNode(BaseModel): """Represents a Neo4j node. Attributes: - id (str): The element ID of the node. + id (str): The element ID of the node. This ID is used to refer to the node for relationship creation. label (str): The label of the node. properties (dict[str, Any]): A dictionary of properties attached to the node. embedding_properties (Optional[dict[str, list[float]]]): A list of embedding properties attached to the node. @@ -90,15 +94,6 @@ class Neo4jNode(BaseModel): properties: dict[str, Any] = {} embedding_properties: Optional[dict[str, list[float]]] = None - @field_validator("properties", "embedding_properties") - @classmethod - def check_for_id_properties( - cls, v: Optional[dict[str, Any]] - ) -> Optional[dict[str, Any]]: - if v and "id" in v.keys(): - raise TypeError("'id' as a property name is not allowed") - return v - @property def token(self) -> str: return self.label diff --git a/src/neo4j_graphrag/neo4j_queries.py b/src/neo4j_graphrag/neo4j_queries.py index 4e6e4ef31..7421eeb0a 100644 --- a/src/neo4j_graphrag/neo4j_queries.py +++ b/src/neo4j_graphrag/neo4j_queries.py @@ -54,7 +54,7 @@ UPSERT_NODE_QUERY = ( "UNWIND $rows AS row " - "CREATE (n:__KGBuilder__ {id: row.id}) " + "CREATE (n:__KGBuilder__ {__kg_builder_id: row.id}) " "SET n += row.properties " "WITH n, row CALL apoc.create.addLabels(n, row.labels) YIELD node " "WITH node as n, row CALL { " @@ -68,7 +68,7 @@ UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE = ( "UNWIND $rows AS row " - "CREATE (n:__KGBuilder__ {id: row.id}) " + "CREATE (n:__KGBuilder__ {__kg_builder_id: row.id}) " "SET n += row.properties " "WITH n, row CALL apoc.create.addLabels(n, row.labels) YIELD node " "WITH node as n, row CALL (n, row) { " @@ -82,8 +82,8 @@ UPSERT_RELATIONSHIP_QUERY = ( "UNWIND $rows as row " - "MATCH (start:__KGBuilder__ {id: row.start_node_id}) " - "MATCH (end:__KGBuilder__ {id: row.end_node_id}) " + "MATCH (start:__KGBuilder__ {__kg_builder_id: row.start_node_id}) " + "MATCH (end:__KGBuilder__ {__kg_builder_id: row.end_node_id}) " "WITH start, end, row " "CALL apoc.merge.relationship(start, row.type, {}, row.properties, end, row.properties) YIELD rel " "WITH rel, row CALL { " @@ -96,8 +96,8 @@ UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE = ( "UNWIND $rows as row " - "MATCH (start:__KGBuilder__ {id: row.start_node_id}) " - "MATCH (end:__KGBuilder__ {id: row.end_node_id}) " + "MATCH (start:__KGBuilder__ {__kg_builder_id: row.start_node_id}) " + "MATCH (end:__KGBuilder__ {__kg_builder_id: row.end_node_id}) " "WITH start, end, row " "CALL apoc.merge.relationship(start, row.type, {}, row.properties, end, row.properties) YIELD rel " "WITH rel, row CALL (rel, row) { " diff --git a/tests/e2e/experimental/test_kg_writer_component_e2e.py b/tests/e2e/experimental/test_kg_writer_component_e2e.py index 75f8e5578..5653ca6c7 100644 --- a/tests/e2e/experimental/test_kg_writer_component_e2e.py +++ b/tests/e2e/experimental/test_kg_writer_component_e2e.py @@ -61,7 +61,7 @@ async def test_kg_writer(driver: neo4j.Driver) -> None: assert res.status == "SUCCESS" query = """ - MATCH (a:MyLabel {id: '1'})-[r:MY_RELATIONSHIP]->(b:MyLabel {id: '2'}) + MATCH (a:MyLabel {__kg_builder_id: '1'})-[r:MY_RELATIONSHIP]->(b:MyLabel {__kg_builder_id: '2'}) RETURN a, r, b """ record = driver.execute_query(query).records[0] @@ -69,7 +69,7 @@ async def test_kg_writer(driver: neo4j.Driver) -> None: node_a = record["a"] assert start_node.label in list(node_a.labels) - assert start_node.id == str(node_a.get("id")) + assert start_node.id == str(node_a.get("__kg_builder_id")) for key, val in start_node.properties.items(): assert key in node_a.keys() assert val == node_a.get(key) @@ -80,18 +80,18 @@ async def test_kg_writer(driver: neo4j.Driver) -> None: node_b = record["b"] assert end_node.label in list(node_b.labels) - assert end_node.id == str(node_b.get("id")) + assert end_node.id == str(node_b.get("__kg_builder_id")) for key, val in end_node.properties.items(): assert key in node_b.keys() assert val == node_b.get(key) rel = record["r"] assert rel.type == relationship.type - assert relationship.start_node_id == rel.start_node.get("id") - assert relationship.end_node_id == rel.end_node.get("id") + assert relationship.start_node_id == rel.start_node.get("__kg_builder_id") + assert relationship.end_node_id == rel.end_node.get("__kg_builder_id") query = """ - MATCH (c:MyLabel {id: '3'}) + MATCH (c:MyLabel {__kg_builder_id: '3'}) RETURN c """ records = driver.execute_query(query).records diff --git a/tests/unit/experimental/components/test_types.py b/tests/unit/experimental/components/test_types.py deleted file mode 100644 index 4e6d7766b..000000000 --- a/tests/unit/experimental/components/test_types.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright (c) "Neo4j" -# Neo4j Sweden AB [https://neo4j.com] -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# https://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import pytest -from neo4j_graphrag.experimental.components.types import Neo4jNode - - -def test_neo4j_node_invalid_property() -> None: - with pytest.raises(TypeError) as excinfo: - Neo4jNode(id="0", label="Label", properties={"id": "1"}) - assert "'id' as a property name is not allowed" in str(excinfo) - - -def test_neo4j_node_invalid_embedding_property() -> None: - with pytest.raises(TypeError) as excinfo: - Neo4jNode(id="0", label="Label", embedding_properties={"id": [1.0, 2.0, 3.0]}) - assert "'id' as a property name is not allowed" in str(excinfo) From 6159c6e34dc14b4a9a74eb8b55a70806f3d0668c Mon Sep 17 00:00:00 2001 From: estelle Date: Thu, 19 Jun 2025 17:37:22 +0200 Subject: [PATCH 2/9] Prune nodes with invalid label/id and relationships with invalid type before inserting (APOC raises errors on them) --- docs/source/user_guide_kg_builder.rst | 2 + .../components/pruners/graph_pruner.py | 2 +- .../experimental/components/graph_pruning.py | 17 +++++++ .../components/test_graph_pruning.py | 45 +++++++++++++++++-- 4 files changed, 62 insertions(+), 4 deletions(-) diff --git a/docs/source/user_guide_kg_builder.rst b/docs/source/user_guide_kg_builder.rst index f810d9ed2..148216e50 100644 --- a/docs/source/user_guide_kg_builder.rst +++ b/docs/source/user_guide_kg_builder.rst @@ -1048,8 +1048,10 @@ _________________ In addition to the user-defined configuration options described above, the `GraphPruning` component performs the following cleanup operations: +- Nodes with invalid label or ID are pruned. - Nodes with missing required properties are pruned. - Nodes with no remaining properties are pruned. +- Relationships with invalid type are pruned. - Relationships with invalid source or target nodes (i.e., nodes no longer present in the graph) are pruned. - Relationships with incorrect direction have their direction corrected. diff --git a/examples/customize/build_graph/components/pruners/graph_pruner.py b/examples/customize/build_graph/components/pruners/graph_pruner.py index 7273188ad..c40924f41 100644 --- a/examples/customize/build_graph/components/pruners/graph_pruner.py +++ b/examples/customize/build_graph/components/pruners/graph_pruner.py @@ -1,4 +1,4 @@ -"""This example demonstrates how to use the GraphPruner component.""" +"""This example demonstrates how to use the GraphPruning component.""" import asyncio diff --git a/src/neo4j_graphrag/experimental/components/graph_pruning.py b/src/neo4j_graphrag/experimental/components/graph_pruning.py index d49f48ce0..c8bf647f3 100644 --- a/src/neo4j_graphrag/experimental/components/graph_pruning.py +++ b/src/neo4j_graphrag/experimental/components/graph_pruning.py @@ -41,6 +41,7 @@ class PruningReason(str, enum.Enum): NO_PROPERTY_LEFT = "NO_PROPERTY_LEFT" INVALID_START_OR_END_NODE = "INVALID_START_OR_END_NODE" INVALID_PATTERN = "INVALID_PATTERN" + MISSING_LABEL = "MISSING_LABEL" ItemType = TypeVar("ItemType") @@ -198,6 +199,17 @@ def _validate_node( schema_entity: Optional[NodeType], additional_node_types: bool, ) -> Optional[Neo4jNode]: + if not node.label: + pruning_stats.add_pruned_node(node, reason=PruningReason.MISSING_LABEL) + return None + if not node.id: + pruning_stats.add_pruned_node( + node, + reason=PruningReason.MISSING_REQUIRED_PROPERTY, + missing_required_properties=["id"], + details="The node was extracted without a valid ID.", + ) + return None if not schema_entity: # node type not declared in the schema if additional_node_types: @@ -262,6 +274,11 @@ def _validate_relationship( patterns: tuple[tuple[str, str, str], ...], additional_patterns: bool, ) -> Optional[Neo4jRelationship]: + if not rel.type: + pruning_stats.add_pruned_relationship( + rel, reason=PruningReason.MISSING_LABEL + ) + return None # validate start/end node IDs are valid nodes if rel.start_node_id not in valid_nodes or rel.end_node_id not in valid_nodes: logger.debug( diff --git a/tests/unit/experimental/components/test_graph_pruning.py b/tests/unit/experimental/components/test_graph_pruning.py index 04730f2e4..c4c779cc4 100644 --- a/tests/unit/experimental/components/test_graph_pruning.py +++ b/tests/unit/experimental/components/test_graph_pruning.py @@ -148,6 +148,20 @@ def node_type_required_name() -> NodeType: True, Neo4jNode(id="1", label="Location", properties={"name": "New York"}), ), + # node label not valid + ( + Neo4jNode(id="1", label="", properties={"name": "New York"}), + "node_type_required_name", + True, + None, + ), + # node ID not valid + ( + Neo4jNode(id="", label="Location", properties={"name": "New York"}), + "node_type_required_name", + True, + None, + ), ], ) def test_graph_pruning_validate_node( @@ -193,6 +207,16 @@ def neo4j_relationship() -> Neo4jRelationship: ) +@pytest.fixture +def neo4j_relationship_invalid_type() -> Neo4jRelationship: + return Neo4jRelationship( + start_node_id="1", + end_node_id="2", + type="", + properties={}, + ) + + @pytest.fixture def neo4j_reversed_relationship( neo4j_relationship: Neo4jRelationship, @@ -253,7 +277,7 @@ def neo4j_reversed_relationship( True, # additional_patterns None, ), - # invalid type addition allowed + # invalid type, addition allowed ( "neo4j_relationship", { @@ -266,7 +290,7 @@ def neo4j_reversed_relationship( True, # additional_patterns "neo4j_relationship", ), - # invalid type addition allowed but invalid node ID + # invalid type, addition allowed but invalid node ID ( "neo4j_relationship", { @@ -278,7 +302,7 @@ def neo4j_reversed_relationship( True, # additional_patterns None, ), - # invalid_type_addition_not_allowed + # invalid type, addition not allowed ( "neo4j_relationship", { @@ -321,6 +345,21 @@ def neo4j_reversed_relationship( False, # additional_patterns None, ), + # invalid extracted type + ( + "neo4j_relationship_invalid_type", # relationship, + { # valid_nodes + "1": "Person", + "2": "Location", + }, + RelationshipType( # relationship_type + label="REL", + ), + True, # additional_relationship_types + (("Person", "REL", "Location"),), # patterns + True, # additional_patterns + None, # expected_relationship + ), ], ) def test_graph_pruning_validate_relationship( From 984a3cd0eafcaab505ee1f527cd3a3e2ce2442c2 Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 23 Jun 2025 16:05:38 +0200 Subject: [PATCH 3/9] Do not pollute the graph with internal/temp properties --- CHANGELOG.md | 4 +- .../components/entity_relation_extractor.py | 3 +- .../experimental/components/kg_writer.py | 86 +++++++++++-------- src/neo4j_graphrag/neo4j_queries.py | 16 ++-- .../test_entity_relation_extractor.py | 2 - .../experimental/components/test_kg_writer.py | 62 +++++++++++-- 6 files changed, 118 insertions(+), 55 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2d0336ce8..4be0b1bb7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,8 +28,8 @@ #### Other -- The node internal `id` property that's used to create relationships between nodes has been renamed to `__kg_builder_id`. This releases the `id` name for domain-specific meaningful `id` property. - +- The `id` property on `__KG_Builder__` nodes is removed. +- The `chunk_index` property on `__Entity__` nodes is removed. Use the `FROM_CHUNK` relationship instead. ## 1.7.0 diff --git a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py index b1ed9bbb1..9ee61fed8 100644 --- a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py +++ b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py @@ -146,12 +146,11 @@ def update_ids( """Make node IDs unique across chunks, document and pipeline runs by prefixing them with a unique prefix. """ - prefix = f"{chunk.chunk_id}" + prefix = chunk.chunk_id for node in graph.nodes: node.id = f"{prefix}:{node.id}" if node.properties is None: node.properties = {} - node.properties.update({"chunk_index": chunk.index}) for rel in graph.relationships: rel.start_node_id = f"{prefix}:{rel.start_node_id}" rel.end_node_id = f"{prefix}:{rel.end_node_id}" diff --git a/src/neo4j_graphrag/experimental/components/kg_writer.py b/src/neo4j_graphrag/experimental/components/kg_writer.py index 4a859b36e..fcd8a62ba 100644 --- a/src/neo4j_graphrag/experimental/components/kg_writer.py +++ b/src/neo4j_graphrag/experimental/components/kg_writer.py @@ -125,12 +125,8 @@ def __init__( self.is_version_5_23_or_above = is_version_5_23_or_above(version_tuple) def _db_setup(self) -> None: - # create index on __KGBuilder__.id - # used when creating the relationships - self.driver.execute_query( - "CREATE INDEX __entity__id IF NOT EXISTS FOR (n:__KGBuilder__) ON (n.id)", - database_=self.neo4j_database, - ) + # not used for now + pass @staticmethod def _nodes_to_rows( @@ -148,45 +144,62 @@ def _nodes_to_rows( def _upsert_nodes( self, nodes: list[Neo4jNode], lexical_graph_config: LexicalGraphConfig - ) -> None: + ) -> dict[str, str]: """Upserts a single node into the Neo4j database." Args: nodes (list[Neo4jNode]): The nodes batch to upsert into the database. """ parameters = {"rows": self._nodes_to_rows(nodes, lexical_graph_config)} - if self.is_version_5_23_or_above: - self.driver.execute_query( - UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE, - parameters_=parameters, - database_=self.neo4j_database, - ) - else: - self.driver.execute_query( - UPSERT_NODE_QUERY, - parameters_=parameters, - database_=self.neo4j_database, - ) + query = ( + UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE + if self.is_version_5_23_or_above + else UPSERT_NODE_QUERY + ) + records, _, _ = self.driver.execute_query( + query, + parameters_=parameters, + database_=self.neo4j_database, + ) + print("RECORDS", records) + return {r["_internal_id"]: r["element_id"] for r in records} - def _upsert_relationships(self, rels: list[Neo4jRelationship]) -> None: + @staticmethod + def _relationships_to_rows( + relationships: list[Neo4jRelationship], node_id_mapping: dict[str, str] + ) -> list[dict[str, Any]]: + return [ + { + **relationship.model_dump(), + "start_node_element_id": node_id_mapping.get( + relationship.start_node_id, "" + ), + "end_node_element_id": node_id_mapping.get( + relationship.end_node_id, "" + ), + } + for relationship in relationships + ] + + def _upsert_relationships( + self, rels: list[Neo4jRelationship], node_id_mapping: dict[str, str] + ) -> None: """Upserts a single relationship into the Neo4j database. Args: rels (list[Neo4jRelationship]): The relationships batch to upsert into the database. """ - parameters = {"rows": [rel.model_dump() for rel in rels]} - if self.is_version_5_23_or_above: - self.driver.execute_query( - UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE, - parameters_=parameters, - database_=self.neo4j_database, - ) - else: - self.driver.execute_query( - UPSERT_RELATIONSHIP_QUERY, - parameters_=parameters, - database_=self.neo4j_database, - ) + parameters = {"rows": self._relationships_to_rows(rels, node_id_mapping)} + query = ( + UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE + if self.is_version_5_23_or_above + else UPSERT_RELATIONSHIP_QUERY + ) + self.driver.execute_query( + query, + parameters_=parameters, + database_=self.neo4j_database, + ) @validate_call async def run( @@ -203,11 +216,14 @@ async def run( try: self._db_setup() + node_id_mapping = {} + for batch in batched(graph.nodes, self.batch_size): - self._upsert_nodes(batch, lexical_graph_config) + batch_mapping = self._upsert_nodes(batch, lexical_graph_config) + node_id_mapping.update(batch_mapping) for batch in batched(graph.relationships, self.batch_size): - self._upsert_relationships(batch) + self._upsert_relationships(batch, node_id_mapping) return KGWriterModel( status="SUCCESS", diff --git a/src/neo4j_graphrag/neo4j_queries.py b/src/neo4j_graphrag/neo4j_queries.py index 7421eeb0a..4863462f3 100644 --- a/src/neo4j_graphrag/neo4j_queries.py +++ b/src/neo4j_graphrag/neo4j_queries.py @@ -54,7 +54,7 @@ UPSERT_NODE_QUERY = ( "UNWIND $rows AS row " - "CREATE (n:__KGBuilder__ {__kg_builder_id: row.id}) " + "CREATE (n:__KGBuilder__) " "SET n += row.properties " "WITH n, row CALL apoc.create.addLabels(n, row.labels) YIELD node " "WITH node as n, row CALL { " @@ -63,12 +63,12 @@ "CALL db.create.setNodeVectorProperty(n, emb, row.embedding_properties[emb]) " "RETURN count(*) as nbEmb " "} " - "RETURN elementId(n)" + "RETURN row.id as _internal_id, elementId(n) as element_id" ) UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE = ( "UNWIND $rows AS row " - "CREATE (n:__KGBuilder__ {__kg_builder_id: row.id}) " + "CREATE (n:__KGBuilder__) " "SET n += row.properties " "WITH n, row CALL apoc.create.addLabels(n, row.labels) YIELD node " "WITH node as n, row CALL (n, row) { " @@ -77,13 +77,13 @@ "CALL db.create.setNodeVectorProperty(n, emb, row.embedding_properties[emb]) " "RETURN count(*) as nbEmb " "} " - "RETURN elementId(n)" + "RETURN row.id as _internal_id, elementId(n) as element_id" ) UPSERT_RELATIONSHIP_QUERY = ( "UNWIND $rows as row " - "MATCH (start:__KGBuilder__ {__kg_builder_id: row.start_node_id}) " - "MATCH (end:__KGBuilder__ {__kg_builder_id: row.end_node_id}) " + "MATCH (start:__KGBuilder__), (end:__KGBuilder__) " + "WHERE elementId(start) = row.start_node_element_id AND elementId(end) = row.end_node_element_id " "WITH start, end, row " "CALL apoc.merge.relationship(start, row.type, {}, row.properties, end, row.properties) YIELD rel " "WITH rel, row CALL { " @@ -96,8 +96,8 @@ UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE = ( "UNWIND $rows as row " - "MATCH (start:__KGBuilder__ {__kg_builder_id: row.start_node_id}) " - "MATCH (end:__KGBuilder__ {__kg_builder_id: row.end_node_id}) " + "MATCH (start:__KGBuilder__), (end:__KGBuilder__) " + "WHERE elementId(start) = row.start_node_element_id AND elementId(end) = row.end_node_element_id " "WITH start, end, row " "CALL apoc.merge.relationship(start, row.type, {}, row.properties, end, row.properties) YIELD rel " "WITH rel, row CALL (rel, row) { " diff --git a/tests/unit/experimental/components/test_entity_relation_extractor.py b/tests/unit/experimental/components/test_entity_relation_extractor.py index f76ab5c9c..a93bf0bcf 100644 --- a/tests/unit/experimental/components/test_entity_relation_extractor.py +++ b/tests/unit/experimental/components/test_entity_relation_extractor.py @@ -109,7 +109,6 @@ async def test_extractor_happy_path_non_empty_result() -> None: entity = result.nodes[2] assert entity.id == f"{chunk_entity.id}:0" assert entity.label == "Person" - assert entity.properties == {"chunk_index": 0} assert len(result.relationships) == 2 assert result.relationships[0].type == "FROM_DOCUMENT" assert result.relationships[0].start_node_id == f"{chunk_entity.id}" @@ -213,7 +212,6 @@ async def test_extractor_llm_badly_formatted_json_gets_fixed() -> None: assert len(res.nodes) == 1 assert res.nodes[0].label == "Person" - assert res.nodes[0].properties == {"chunk_index": 0} assert res.nodes[0].embedding_properties is None assert res.relationships == [] diff --git a/tests/unit/experimental/components/test_kg_writer.py b/tests/unit/experimental/components/test_kg_writer.py index f149a264f..768b1a543 100644 --- a/tests/unit/experimental/components/test_kg_writer.py +++ b/tests/unit/experimental/components/test_kg_writer.py @@ -56,9 +56,17 @@ def test_batched() -> None: return_value=None, ) def test_upsert_nodes(_: Mock, driver: MagicMock) -> None: + driver.execute_query.return_value = ( + [{"_internal_id": "1", "element_id": "#1"}], + None, + None, + ) neo4j_writer = Neo4jWriter(driver=driver) node = Neo4jNode(id="1", label="Label", properties={"key": "value"}) - neo4j_writer._upsert_nodes(nodes=[node], lexical_graph_config=LexicalGraphConfig()) + result = neo4j_writer._upsert_nodes( + nodes=[node], lexical_graph_config=LexicalGraphConfig() + ) + assert result == {"1": "#1"} driver.execute_query.assert_called_once_with( UPSERT_NODE_QUERY, parameters_={ @@ -88,6 +96,11 @@ def test_upsert_nodes_with_embedding( _: Mock, driver: MagicMock, ) -> None: + driver.execute_query.return_value = ( + [{"_internal_id": "1", "element_id": "#1"}], + None, + None, + ) neo4j_writer = Neo4jWriter(driver=driver) node = Neo4jNode( id="1", @@ -95,7 +108,6 @@ def test_upsert_nodes_with_embedding( properties={"key": "value"}, embedding_properties={"embeddingProp": [1.0, 2.0, 3.0]}, ) - driver.execute_query.return_value.records = [{"elementId(n)": 1}] neo4j_writer._upsert_nodes(nodes=[node], lexical_graph_config=LexicalGraphConfig()) driver.execute_query.assert_any_call( UPSERT_NODE_QUERY, @@ -130,7 +142,9 @@ def test_upsert_relationship(_: Mock, driver: MagicMock) -> None: type="RELATIONSHIP", properties={"key": "value"}, ) - neo4j_writer._upsert_relationships(rels=[rel]) + neo4j_writer._upsert_relationships( + rels=[rel], node_id_mapping={"1": "#1", "2": "#2"} + ) parameters = { "rows": [ { @@ -139,6 +153,8 @@ def test_upsert_relationship(_: Mock, driver: MagicMock) -> None: "end_node_id": "2", "properties": {"key": "value"}, "embedding_properties": None, + "start_node_element_id": "#1", + "end_node_element_id": "#2", } ] } @@ -167,7 +183,9 @@ def test_upsert_relationship_with_embedding(_: Mock, driver: MagicMock) -> None: embedding_properties={"embeddingProp": [1.0, 2.0, 3.0]}, ) driver.execute_query.return_value.records = [{"elementId(r)": "rel_elem_id"}] - neo4j_writer._upsert_relationships(rels=[rel]) + neo4j_writer._upsert_relationships( + rels=[rel], node_id_mapping={"1": "#1", "2": "#2"} + ) parameters = { "rows": [ { @@ -176,6 +194,8 @@ def test_upsert_relationship_with_embedding(_: Mock, driver: MagicMock) -> None: "end_node_id": "2", "properties": {"key": "value"}, "embedding_properties": {"embeddingProp": [1.0, 2.0, 3.0]}, + "start_node_element_id": "#1", + "end_node_element_id": "#2", } ] } @@ -196,6 +216,14 @@ def test_upsert_relationship_with_embedding(_: Mock, driver: MagicMock) -> None: return_value=None, ) async def test_run(_: Mock, driver: MagicMock) -> None: + driver.execute_query.return_value = ( + [ + {"_internal_id": "1", "element_id": "#1"}, + {"_internal_id": "2", "element_id": "#2"}, + ], + None, + None, + ) neo4j_writer = Neo4jWriter(driver=driver) node = Neo4jNode(id="1", label="Label") rel = Neo4jRelationship(start_node_id="1", end_node_id="2", type="RELATIONSHIP") @@ -224,6 +252,8 @@ async def test_run(_: Mock, driver: MagicMock) -> None: "end_node_id": "2", "properties": {}, "embedding_properties": None, + "start_node_element_id": "#1", + "end_node_element_id": "#2", } ] } @@ -242,7 +272,14 @@ async def test_run(_: Mock, driver: MagicMock) -> None: async def test_run_is_version_below_5_23(_: Mock) -> None: driver = MagicMock() driver.execute_query = Mock( - return_value=([{"versions": ["5.22.0"], "edition": "enterprise"}], None, None) + side_effect=( + # get_version + ([{"versions": ["5.22.0"], "edition": "enterpise"}], None, None), + # upsert nodes + ([{"_internal_id": "1", "element_id": "#1"}], None, None), + # upsert relationships + (None, None, None), + ) ) neo4j_writer = Neo4jWriter(driver=driver) @@ -252,6 +289,8 @@ async def test_run_is_version_below_5_23(_: Mock) -> None: graph = Neo4jGraph(nodes=[node], relationships=[rel]) await neo4j_writer.run(graph=graph) + print(driver.execute_query.call_args_list) + driver.execute_query.assert_any_call( UPSERT_NODE_QUERY, parameters_={ @@ -275,6 +314,8 @@ async def test_run_is_version_below_5_23(_: Mock) -> None: "end_node_id": "2", "properties": {}, "embedding_properties": None, + "start_node_element_id": "#1", + "end_node_element_id": "", } ] } @@ -293,7 +334,14 @@ async def test_run_is_version_below_5_23(_: Mock) -> None: async def test_run_is_version_5_23_or_above(_: Mock) -> None: driver = MagicMock() driver.execute_query = Mock( - return_value=([{"versions": ["5.23.0"], "edition": "enterpise"}], None, None) + side_effect=( + # get_version + ([{"versions": ["5.23.0"], "edition": "enterpise"}], None, None), + # upsert nodes + ([{"_internal_id": "1", "element_id": "#1"}], None, None), + # upsert relationships + (None, None, None), + ) ) neo4j_writer = Neo4jWriter(driver=driver) @@ -327,6 +375,8 @@ async def test_run_is_version_5_23_or_above(_: Mock) -> None: "end_node_id": "2", "properties": {}, "embedding_properties": None, + "start_node_element_id": "#1", + "end_node_element_id": "", } ] } From d4dc91f60514eb24e28561adbf338ddfb35646cc Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 23 Jun 2025 18:14:50 +0200 Subject: [PATCH 4/9] Fix e2e --- .../experimental/components/kg_writer.py | 1 - .../experimental/test_kg_writer_component_e2e.py | 16 +++++++--------- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/kg_writer.py b/src/neo4j_graphrag/experimental/components/kg_writer.py index fcd8a62ba..40040f21a 100644 --- a/src/neo4j_graphrag/experimental/components/kg_writer.py +++ b/src/neo4j_graphrag/experimental/components/kg_writer.py @@ -161,7 +161,6 @@ def _upsert_nodes( parameters_=parameters, database_=self.neo4j_database, ) - print("RECORDS", records) return {r["_internal_id"]: r["element_id"] for r in records} @staticmethod diff --git a/tests/e2e/experimental/test_kg_writer_component_e2e.py b/tests/e2e/experimental/test_kg_writer_component_e2e.py index 5653ca6c7..8876219fa 100644 --- a/tests/e2e/experimental/test_kg_writer_component_e2e.py +++ b/tests/e2e/experimental/test_kg_writer_component_e2e.py @@ -30,13 +30,13 @@ async def test_kg_writer(driver: neo4j.Driver) -> None: start_node = Neo4jNode( id="1", label="MyLabel", - properties={"chunk": 1}, + properties={"id": "1"}, embedding_properties={"vectorProperty": [1.0, 2.0, 3.0]}, ) end_node = Neo4jNode( id="2", label="MyLabel", - properties={}, + properties={"id": "2"}, embedding_properties=None, ) relationship = Neo4jRelationship( @@ -45,7 +45,7 @@ async def test_kg_writer(driver: neo4j.Driver) -> None: node_with_two_embeddings = Neo4jNode( id="3", label="MyLabel", - properties={"chunk": 1}, + properties={"id": "3"}, embedding_properties={ "vectorProperty": [1.0, 2.0, 3.0], "otherVectorProperty": [10.0, 20.0, 30.0], @@ -61,7 +61,7 @@ async def test_kg_writer(driver: neo4j.Driver) -> None: assert res.status == "SUCCESS" query = """ - MATCH (a:MyLabel {__kg_builder_id: '1'})-[r:MY_RELATIONSHIP]->(b:MyLabel {__kg_builder_id: '2'}) + MATCH (a:MyLabel {id: '1'})-[r:MY_RELATIONSHIP]->(b:MyLabel {id: '2'}) RETURN a, r, b """ record = driver.execute_query(query).records[0] @@ -69,7 +69,7 @@ async def test_kg_writer(driver: neo4j.Driver) -> None: node_a = record["a"] assert start_node.label in list(node_a.labels) - assert start_node.id == str(node_a.get("__kg_builder_id")) + assert start_node.id == str(node_a.get("id")) for key, val in start_node.properties.items(): assert key in node_a.keys() assert val == node_a.get(key) @@ -80,18 +80,16 @@ async def test_kg_writer(driver: neo4j.Driver) -> None: node_b = record["b"] assert end_node.label in list(node_b.labels) - assert end_node.id == str(node_b.get("__kg_builder_id")) + assert end_node.id == str(node_b.get("id")) for key, val in end_node.properties.items(): assert key in node_b.keys() assert val == node_b.get(key) rel = record["r"] assert rel.type == relationship.type - assert relationship.start_node_id == rel.start_node.get("__kg_builder_id") - assert relationship.end_node_id == rel.end_node.get("__kg_builder_id") query = """ - MATCH (c:MyLabel {__kg_builder_id: '3'}) + MATCH (c:MyLabel {id: '3'}) RETURN c """ records = driver.execute_query(query).records From 040f8968d0ad25b621d2a892beeafed160045941 Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 24 Jun 2025 09:51:18 +0200 Subject: [PATCH 5/9] Improve e2e test --- .../e2e/experimental/test_kg_writer_component_e2e.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/e2e/experimental/test_kg_writer_component_e2e.py b/tests/e2e/experimental/test_kg_writer_component_e2e.py index 8876219fa..7ea02eae0 100644 --- a/tests/e2e/experimental/test_kg_writer_component_e2e.py +++ b/tests/e2e/experimental/test_kg_writer_component_e2e.py @@ -30,13 +30,13 @@ async def test_kg_writer(driver: neo4j.Driver) -> None: start_node = Neo4jNode( id="1", label="MyLabel", - properties={"id": "1"}, + properties={"id": "abc"}, embedding_properties={"vectorProperty": [1.0, 2.0, 3.0]}, ) end_node = Neo4jNode( id="2", label="MyLabel", - properties={"id": "2"}, + properties={"id": "def"}, embedding_properties=None, ) relationship = Neo4jRelationship( @@ -45,7 +45,7 @@ async def test_kg_writer(driver: neo4j.Driver) -> None: node_with_two_embeddings = Neo4jNode( id="3", label="MyLabel", - properties={"id": "3"}, + properties={"id": "ghi"}, embedding_properties={ "vectorProperty": [1.0, 2.0, 3.0], "otherVectorProperty": [10.0, 20.0, 30.0], @@ -61,7 +61,7 @@ async def test_kg_writer(driver: neo4j.Driver) -> None: assert res.status == "SUCCESS" query = """ - MATCH (a:MyLabel {id: '1'})-[r:MY_RELATIONSHIP]->(b:MyLabel {id: '2'}) + MATCH (a:MyLabel {id: 'abc'})-[r:MY_RELATIONSHIP]->(b:MyLabel {id: 'def'}) RETURN a, r, b """ record = driver.execute_query(query).records[0] @@ -87,9 +87,11 @@ async def test_kg_writer(driver: neo4j.Driver) -> None: rel = record["r"] assert rel.type == relationship.type + assert rel.start_node.get("id") == start_node.properties.get("id") + assert rel.end_node.get("id") == end_node.properties.get("id") query = """ - MATCH (c:MyLabel {id: '3'}) + MATCH (c:MyLabel {id: 'ghi'}) RETURN c """ records = driver.execute_query(query).records From f40cdd6c450841149aee2dbca546cbe6a9fa8690 Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 24 Jun 2025 10:01:40 +0200 Subject: [PATCH 6/9] Improve e2e test --- tests/e2e/experimental/test_kg_writer_component_e2e.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/e2e/experimental/test_kg_writer_component_e2e.py b/tests/e2e/experimental/test_kg_writer_component_e2e.py index 7ea02eae0..38aab3f1d 100644 --- a/tests/e2e/experimental/test_kg_writer_component_e2e.py +++ b/tests/e2e/experimental/test_kg_writer_component_e2e.py @@ -69,7 +69,7 @@ async def test_kg_writer(driver: neo4j.Driver) -> None: node_a = record["a"] assert start_node.label in list(node_a.labels) - assert start_node.id == str(node_a.get("id")) + assert start_node.properties.get("id") == str(node_a.get("id")) for key, val in start_node.properties.items(): assert key in node_a.keys() assert val == node_a.get(key) @@ -80,7 +80,7 @@ async def test_kg_writer(driver: neo4j.Driver) -> None: node_b = record["b"] assert end_node.label in list(node_b.labels) - assert end_node.id == str(node_b.get("id")) + assert end_node.properties.get("id") == str(node_b.get("id")) for key, val in end_node.properties.items(): assert key in node_b.keys() assert val == node_b.get(key) From a5587a7693301fa8013e4c8ba4b5f876f440a910 Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 24 Jun 2025 16:41:32 +0200 Subject: [PATCH 7/9] Rename property, clean db --- CHANGELOG.md | 2 +- .../experimental/components/kg_writer.py | 77 +++++----- src/neo4j_graphrag/neo4j_queries.py | 138 +++++++++++------- tests/e2e/docker-compose.yml | 3 +- .../experimental/components/test_kg_writer.py | 55 +++---- 5 files changed, 142 insertions(+), 133 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4be0b1bb7..0d88d8cf2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,7 +28,7 @@ #### Other -- The `id` property on `__KG_Builder__` nodes is removed. +- The reserved `id` property on `__KGBuilder__` nodes is removed. - The `chunk_index` property on `__Entity__` nodes is removed. Use the `FROM_CHUNK` relationship instead. ## 1.7.0 diff --git a/src/neo4j_graphrag/experimental/components/kg_writer.py b/src/neo4j_graphrag/experimental/components/kg_writer.py index 40040f21a..1f46e5ab3 100644 --- a/src/neo4j_graphrag/experimental/components/kg_writer.py +++ b/src/neo4j_graphrag/experimental/components/kg_writer.py @@ -29,10 +29,9 @@ ) from neo4j_graphrag.experimental.pipeline.component import Component, DataModel from neo4j_graphrag.neo4j_queries import ( - UPSERT_NODE_QUERY, - UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE, - UPSERT_RELATIONSHIP_QUERY, - UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE, + upsert_node_query, + upsert_relationship_query, + db_cleaning_query, ) from neo4j_graphrag.utils.version_utils import ( get_version, @@ -117,16 +116,19 @@ def __init__( driver: neo4j.Driver, neo4j_database: Optional[str] = None, batch_size: int = 1000, + clean_db: bool = True, ): self.driver = driver_config.override_user_agent(driver) self.neo4j_database = neo4j_database self.batch_size = batch_size + self._clean_db = clean_db version_tuple, _, _ = get_version(self.driver, self.neo4j_database) self.is_version_5_23_or_above = is_version_5_23_or_above(version_tuple) def _db_setup(self) -> None: - # not used for now - pass + self.driver.execute_query(""" + CREATE INDEX __entity__tmp_internal_id IF NOT EXISTS FOR (n:__KGBuilder__) ON (n.__tmp_internal_id) + """) @staticmethod def _nodes_to_rows( @@ -144,55 +146,38 @@ def _nodes_to_rows( def _upsert_nodes( self, nodes: list[Neo4jNode], lexical_graph_config: LexicalGraphConfig - ) -> dict[str, str]: - """Upserts a single node into the Neo4j database." + ) -> None: + """Upserts a batch of nodes into the Neo4j database. Args: nodes (list[Neo4jNode]): The nodes batch to upsert into the database. """ parameters = {"rows": self._nodes_to_rows(nodes, lexical_graph_config)} - query = ( - UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE - if self.is_version_5_23_or_above - else UPSERT_NODE_QUERY + query = upsert_node_query( + support_variable_scope_clause=self.is_version_5_23_or_above ) - records, _, _ = self.driver.execute_query( + self.driver.execute_query( query, parameters_=parameters, database_=self.neo4j_database, ) - return {r["_internal_id"]: r["element_id"] for r in records} + return None @staticmethod def _relationships_to_rows( - relationships: list[Neo4jRelationship], node_id_mapping: dict[str, str] + relationships: list[Neo4jRelationship], ) -> list[dict[str, Any]]: - return [ - { - **relationship.model_dump(), - "start_node_element_id": node_id_mapping.get( - relationship.start_node_id, "" - ), - "end_node_element_id": node_id_mapping.get( - relationship.end_node_id, "" - ), - } - for relationship in relationships - ] - - def _upsert_relationships( - self, rels: list[Neo4jRelationship], node_id_mapping: dict[str, str] - ) -> None: - """Upserts a single relationship into the Neo4j database. + return [relationship.model_dump() for relationship in relationships] + + def _upsert_relationships(self, rels: list[Neo4jRelationship]) -> None: + """Upserts a batch of relationships into the Neo4j database. Args: rels (list[Neo4jRelationship]): The relationships batch to upsert into the database. """ - parameters = {"rows": self._relationships_to_rows(rels, node_id_mapping)} - query = ( - UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE - if self.is_version_5_23_or_above - else UPSERT_RELATIONSHIP_QUERY + parameters = {"rows": self._relationships_to_rows(rels)} + query = upsert_relationship_query( + support_variable_scope_clause=self.is_version_5_23_or_above ) self.driver.execute_query( query, @@ -200,6 +185,14 @@ def _upsert_relationships( database_=self.neo4j_database, ) + def _db_cleaning(self) -> None: + query = db_cleaning_query( + support_variable_scope_clause=self.is_version_5_23_or_above, + batch_size=self.batch_size, + ) + with self.driver.session() as session: + session.run(query) + @validate_call async def run( self, @@ -215,14 +208,14 @@ async def run( try: self._db_setup() - node_id_mapping = {} - for batch in batched(graph.nodes, self.batch_size): - batch_mapping = self._upsert_nodes(batch, lexical_graph_config) - node_id_mapping.update(batch_mapping) + self._upsert_nodes(batch, lexical_graph_config) for batch in batched(graph.relationships, self.batch_size): - self._upsert_relationships(batch, node_id_mapping) + self._upsert_relationships(batch) + + if self._clean_db: + self._db_cleaning() return KGWriterModel( status="SUCCESS", diff --git a/src/neo4j_graphrag/neo4j_queries.py b/src/neo4j_graphrag/neo4j_queries.py index 4863462f3..9bf3663ca 100644 --- a/src/neo4j_graphrag/neo4j_queries.py +++ b/src/neo4j_graphrag/neo4j_queries.py @@ -52,61 +52,89 @@ "YIELD node, score" ) -UPSERT_NODE_QUERY = ( - "UNWIND $rows AS row " - "CREATE (n:__KGBuilder__) " - "SET n += row.properties " - "WITH n, row CALL apoc.create.addLabels(n, row.labels) YIELD node " - "WITH node as n, row CALL { " - "WITH n, row WITH n, row WHERE row.embedding_properties IS NOT NULL " - "UNWIND keys(row.embedding_properties) as emb " - "CALL db.create.setNodeVectorProperty(n, emb, row.embedding_properties[emb]) " - "RETURN count(*) as nbEmb " - "} " - "RETURN row.id as _internal_id, elementId(n) as element_id" -) -UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE = ( - "UNWIND $rows AS row " - "CREATE (n:__KGBuilder__) " - "SET n += row.properties " - "WITH n, row CALL apoc.create.addLabels(n, row.labels) YIELD node " - "WITH node as n, row CALL (n, row) { " - "WITH n, row WITH n, row WHERE row.embedding_properties IS NOT NULL " - "UNWIND keys(row.embedding_properties) as emb " - "CALL db.create.setNodeVectorProperty(n, emb, row.embedding_properties[emb]) " - "RETURN count(*) as nbEmb " - "} " - "RETURN row.id as _internal_id, elementId(n) as element_id" -) +def _call_subquery_syntax( + support_variable_scope_clause: bool, variable_list: list[str] +) -> str: + """A helper function to return the CALL subquery syntax: + - Either CALL { WITH + - or CALL (variables) { + """ + variables = ",".join(variable_list) + if support_variable_scope_clause: + return f"CALL ({variables}) {{ " + if variables: + return f"CALL {{ WITH {variables} " + return "CALL { " + + +def upsert_node_query(support_variable_scope_clause: bool) -> str: + """Build the Cypher query to upsert a batch of nodes: + - Create the new node + - Set its label(s) and properties + - Set its embedding properties if any + - Return the node elementId + """ + call_prefix = _call_subquery_syntax( + support_variable_scope_clause, variable_list=["n", "row"] + ) + return ( + "UNWIND $rows AS row " + "CREATE (n:__KGBuilder__ {__tmp_internal_id: row.id}) " + "SET n += row.properties " + "WITH n, row CALL apoc.create.addLabels(n, row.labels) YIELD node " + "WITH node as n, row " + f"{call_prefix} " + "WITH n, row WHERE row.embedding_properties IS NOT NULL " + "UNWIND keys(row.embedding_properties) as emb " + "CALL db.create.setNodeVectorProperty(n, emb, row.embedding_properties[emb]) " + "RETURN count(*) as nbEmb " + "} " + "RETURN elementId(n) as element_id" + ) -UPSERT_RELATIONSHIP_QUERY = ( - "UNWIND $rows as row " - "MATCH (start:__KGBuilder__), (end:__KGBuilder__) " - "WHERE elementId(start) = row.start_node_element_id AND elementId(end) = row.end_node_element_id " - "WITH start, end, row " - "CALL apoc.merge.relationship(start, row.type, {}, row.properties, end, row.properties) YIELD rel " - "WITH rel, row CALL { " - "WITH rel, row WITH rel, row WHERE row.embedding_properties IS NOT NULL " - "UNWIND keys(row.embedding_properties) as emb " - "CALL db.create.setRelationshipVectorProperty(rel, emb, row.embedding_properties[emb]) " - "} " - "RETURN elementId(rel)" -) -UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE = ( - "UNWIND $rows as row " - "MATCH (start:__KGBuilder__), (end:__KGBuilder__) " - "WHERE elementId(start) = row.start_node_element_id AND elementId(end) = row.end_node_element_id " - "WITH start, end, row " - "CALL apoc.merge.relationship(start, row.type, {}, row.properties, end, row.properties) YIELD rel " - "WITH rel, row CALL (rel, row) { " - "WITH rel, row WITH rel, row WHERE row.embedding_properties IS NOT NULL " - "UNWIND keys(row.embedding_properties) as emb " - "CALL db.create.setRelationshipVectorProperty(rel, emb, row.embedding_properties[emb]) " - "} " - "RETURN elementId(rel)" -) +def upsert_relationship_query(support_variable_scope_clause: bool) -> str: + """Build the Cypher query to upsert a batch of relationships: + - Create the new relationship: + only one relationship of a specific type is allowed between the same two nodes + - Set its properties + - Set its embedding properties if any + - Return the node elementId + """ + call_prefix = _call_subquery_syntax( + support_variable_scope_clause, variable_list=["rel", "row"] + ) + return ( + "UNWIND $rows as row " + "MATCH (start:__KGBuilder__ {__tmp_internal_id: row.start_node_id}), " + " (end:__KGBuilder__ {__tmp_internal_id: row.end_node_id}) " + "WITH start, end, row " + "CALL apoc.merge.relationship(start, row.type, {}, row.properties, end, row.properties) YIELD rel " + "WITH rel, row " + f"{call_prefix} " + "WITH rel, row WHERE row.embedding_properties IS NOT NULL " + "UNWIND keys(row.embedding_properties) as emb " + "CALL db.create.setRelationshipVectorProperty(rel, emb, row.embedding_properties[emb]) " + "} " + "RETURN elementId(rel)" + ) + + +def db_cleaning_query(support_variable_scope_clause: bool, batch_size: int) -> str: + """Removes the temporary __tmp_internal_id property from all nodes.""" + call_prefix = _call_subquery_syntax( + support_variable_scope_clause, variable_list=["n"] + ) + return ( + "MATCH (n:__KGBuilder__) " + "WHERE n.__tmp_internal_id IS NOT NULL " + f"{call_prefix} " + " SET n.__tmp_internal_id = NULL " + "} " + f"IN TRANSACTIONS OF {batch_size} ROWS" + ) + # Deprecated, remove along with upsert_vector UPSERT_VECTOR_ON_NODE_QUERY = ( @@ -150,13 +178,15 @@ def _get_hybrid_query(neo4j_version_is_5_23_or_above: bool) -> str: Construct a cypher query for hybrid search. Args: - neo4j_version_is_5_23_or_above (bool): Whether or not the Neo4j version is 5.23 or above; + neo4j_version_is_5_23_or_above (bool): Whether the Neo4j version is 5.23 or above; determines which call syntax is used. Returns: str: The constructed Cypher query string. """ - call_prefix = "CALL () { " if neo4j_version_is_5_23_or_above else "CALL { " + call_prefix = _call_subquery_syntax( + neo4j_version_is_5_23_or_above, variable_list=[] + ) query_body = ( f"{NODE_VECTOR_INDEX_QUERY} " "WITH collect({node:node, score:score}) AS nodes, max(score) AS vector_index_max_score " diff --git a/tests/e2e/docker-compose.yml b/tests/e2e/docker-compose.yml index 6ceda8af7..52c9f4db4 100644 --- a/tests/e2e/docker-compose.yml +++ b/tests/e2e/docker-compose.yml @@ -32,8 +32,9 @@ services: - 7474:7474 environment: NEO4J_AUTH: neo4j/password - NEO4J_ACCEPT_LICENSE_AGREEMENT: "eval" + NEO4J_ACCEPT_LICENSE_AGREEMENT: "yes" NEO4J_PLUGINS: "[\"apoc\"]" + NEO4J_server_memory_heap_max__size: 6G qdrant: image: qdrant/qdrant ports: diff --git a/tests/unit/experimental/components/test_kg_writer.py b/tests/unit/experimental/components/test_kg_writer.py index 768b1a543..24b8fe206 100644 --- a/tests/unit/experimental/components/test_kg_writer.py +++ b/tests/unit/experimental/components/test_kg_writer.py @@ -26,10 +26,8 @@ Neo4jRelationship, ) from neo4j_graphrag.neo4j_queries import ( - UPSERT_NODE_QUERY, - UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE, - UPSERT_RELATIONSHIP_QUERY, - UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE, + upsert_node_query, + upsert_relationship_query, ) @@ -57,18 +55,17 @@ def test_batched() -> None: ) def test_upsert_nodes(_: Mock, driver: MagicMock) -> None: driver.execute_query.return_value = ( - [{"_internal_id": "1", "element_id": "#1"}], + [{"element_id": "#1"}], None, None, ) neo4j_writer = Neo4jWriter(driver=driver) node = Neo4jNode(id="1", label="Label", properties={"key": "value"}) - result = neo4j_writer._upsert_nodes( + neo4j_writer._upsert_nodes( nodes=[node], lexical_graph_config=LexicalGraphConfig() ) - assert result == {"1": "#1"} driver.execute_query.assert_called_once_with( - UPSERT_NODE_QUERY, + upsert_node_query(False), parameters_={ "rows": [ { @@ -97,7 +94,7 @@ def test_upsert_nodes_with_embedding( driver: MagicMock, ) -> None: driver.execute_query.return_value = ( - [{"_internal_id": "1", "element_id": "#1"}], + [{"element_id": "#1"}], None, None, ) @@ -110,7 +107,7 @@ def test_upsert_nodes_with_embedding( ) neo4j_writer._upsert_nodes(nodes=[node], lexical_graph_config=LexicalGraphConfig()) driver.execute_query.assert_any_call( - UPSERT_NODE_QUERY, + upsert_node_query(False), parameters_={ "rows": [ { @@ -143,7 +140,7 @@ def test_upsert_relationship(_: Mock, driver: MagicMock) -> None: properties={"key": "value"}, ) neo4j_writer._upsert_relationships( - rels=[rel], node_id_mapping={"1": "#1", "2": "#2"} + rels=[rel], ) parameters = { "rows": [ @@ -153,13 +150,11 @@ def test_upsert_relationship(_: Mock, driver: MagicMock) -> None: "end_node_id": "2", "properties": {"key": "value"}, "embedding_properties": None, - "start_node_element_id": "#1", - "end_node_element_id": "#2", } ] } driver.execute_query.assert_called_once_with( - UPSERT_RELATIONSHIP_QUERY, + upsert_relationship_query(False), parameters_=parameters, database_=None, ) @@ -184,7 +179,7 @@ def test_upsert_relationship_with_embedding(_: Mock, driver: MagicMock) -> None: ) driver.execute_query.return_value.records = [{"elementId(r)": "rel_elem_id"}] neo4j_writer._upsert_relationships( - rels=[rel], node_id_mapping={"1": "#1", "2": "#2"} + rels=[rel], ) parameters = { "rows": [ @@ -194,13 +189,11 @@ def test_upsert_relationship_with_embedding(_: Mock, driver: MagicMock) -> None: "end_node_id": "2", "properties": {"key": "value"}, "embedding_properties": {"embeddingProp": [1.0, 2.0, 3.0]}, - "start_node_element_id": "#1", - "end_node_element_id": "#2", } ] } driver.execute_query.assert_any_call( - UPSERT_RELATIONSHIP_QUERY, + upsert_relationship_query(False), parameters_=parameters, database_=None, ) @@ -218,8 +211,8 @@ def test_upsert_relationship_with_embedding(_: Mock, driver: MagicMock) -> None: async def test_run(_: Mock, driver: MagicMock) -> None: driver.execute_query.return_value = ( [ - {"_internal_id": "1", "element_id": "#1"}, - {"_internal_id": "2", "element_id": "#2"}, + {"element_id": "#1"}, + {"element_id": "#2"}, ], None, None, @@ -230,7 +223,7 @@ async def test_run(_: Mock, driver: MagicMock) -> None: graph = Neo4jGraph(nodes=[node], relationships=[rel]) await neo4j_writer.run(graph=graph) driver.execute_query.assert_any_call( - UPSERT_NODE_QUERY, + upsert_node_query(False), parameters_={ "rows": [ { @@ -252,13 +245,11 @@ async def test_run(_: Mock, driver: MagicMock) -> None: "end_node_id": "2", "properties": {}, "embedding_properties": None, - "start_node_element_id": "#1", - "end_node_element_id": "#2", } ] } driver.execute_query.assert_any_call( - UPSERT_RELATIONSHIP_QUERY, + upsert_relationship_query(False), parameters_=parameters_, database_=None, ) @@ -289,10 +280,8 @@ async def test_run_is_version_below_5_23(_: Mock) -> None: graph = Neo4jGraph(nodes=[node], relationships=[rel]) await neo4j_writer.run(graph=graph) - print(driver.execute_query.call_args_list) - driver.execute_query.assert_any_call( - UPSERT_NODE_QUERY, + upsert_node_query(False), parameters_={ "rows": [ { @@ -314,13 +303,11 @@ async def test_run_is_version_below_5_23(_: Mock) -> None: "end_node_id": "2", "properties": {}, "embedding_properties": None, - "start_node_element_id": "#1", - "end_node_element_id": "", } ] } driver.execute_query.assert_any_call( - UPSERT_RELATIONSHIP_QUERY, + upsert_relationship_query(False), parameters_=parameters_, database_=None, ) @@ -338,7 +325,7 @@ async def test_run_is_version_5_23_or_above(_: Mock) -> None: # get_version ([{"versions": ["5.23.0"], "edition": "enterpise"}], None, None), # upsert nodes - ([{"_internal_id": "1", "element_id": "#1"}], None, None), + ([{"element_id": "#1"}], None, None), # upsert relationships (None, None, None), ) @@ -353,7 +340,7 @@ async def test_run_is_version_5_23_or_above(_: Mock) -> None: await neo4j_writer.run(graph=graph) driver.execute_query.assert_any_call( - UPSERT_NODE_QUERY_VARIABLE_SCOPE_CLAUSE, + upsert_node_query(True), parameters_={ "rows": [ { @@ -375,13 +362,11 @@ async def test_run_is_version_5_23_or_above(_: Mock) -> None: "end_node_id": "2", "properties": {}, "embedding_properties": None, - "start_node_element_id": "#1", - "end_node_element_id": "", } ] } driver.execute_query.assert_any_call( - UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE, + upsert_relationship_query(True), parameters_=parameters_, database_=None, ) From aa7435dcc895796cd214cc762b57fcb6d76705c3 Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 24 Jun 2025 16:59:33 +0200 Subject: [PATCH 8/9] Doc update --- CHANGELOG.md | 2 ++ docs/source/user_guide_kg_builder.rst | 5 +++++ src/neo4j_graphrag/experimental/components/types.py | 2 +- src/neo4j_graphrag/neo4j_queries.py | 3 ++- tests/e2e/experimental/test_kg_writer_component_e2e.py | 2 ++ tests/unit/experimental/components/test_kg_writer.py | 4 +--- 6 files changed, 13 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0d88d8cf2..c5eb72654 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,8 @@ - The reserved `id` property on `__KGBuilder__` nodes is removed. - The `chunk_index` property on `__Entity__` nodes is removed. Use the `FROM_CHUNK` relationship instead. +- The `__entity__id` index is not used anymore and can be dropped from the database (it has been replaced by `__entity__tmp_internal_id`). + ## 1.7.0 diff --git a/docs/source/user_guide_kg_builder.rst b/docs/source/user_guide_kg_builder.rst index 148216e50..c31255070 100644 --- a/docs/source/user_guide_kg_builder.rst +++ b/docs/source/user_guide_kg_builder.rst @@ -1080,6 +1080,11 @@ to a Neo4j database: Adjust the batch_size parameter of `Neo4jWriter` to optimize insert performance. This parameter controls the number of nodes or relationships inserted per batch, with a default value of 1000. +.. note:: Index + + In order to improve the ingestion performances, a index called `__entity__tmp_internal_id` is automatically added to the database. + + See :ref:`neo4jgraph`. diff --git a/src/neo4j_graphrag/experimental/components/types.py b/src/neo4j_graphrag/experimental/components/types.py index f761f8e6e..363767ef3 100644 --- a/src/neo4j_graphrag/experimental/components/types.py +++ b/src/neo4j_graphrag/experimental/components/types.py @@ -83,7 +83,7 @@ class Neo4jNode(BaseModel): """Represents a Neo4j node. Attributes: - id (str): The element ID of the node. This ID is used to refer to the node for relationship creation. + id (str): The ID of the node. This ID is used to refer to the node for relationship creation. label (str): The label of the node. properties (dict[str, Any]): A dictionary of properties attached to the node. embedding_properties (Optional[dict[str, list[float]]]): A list of embedding properties attached to the node. diff --git a/src/neo4j_graphrag/neo4j_queries.py b/src/neo4j_graphrag/neo4j_queries.py index 9bf3663ca..a0e9b419e 100644 --- a/src/neo4j_graphrag/neo4j_queries.py +++ b/src/neo4j_graphrag/neo4j_queries.py @@ -132,7 +132,8 @@ def db_cleaning_query(support_variable_scope_clause: bool, batch_size: int) -> s f"{call_prefix} " " SET n.__tmp_internal_id = NULL " "} " - f"IN TRANSACTIONS OF {batch_size} ROWS" + f"IN TRANSACTIONS OF {batch_size} ROWS " + "ON ERROR CONTINUE" ) diff --git a/tests/e2e/experimental/test_kg_writer_component_e2e.py b/tests/e2e/experimental/test_kg_writer_component_e2e.py index 38aab3f1d..ccba2ae69 100644 --- a/tests/e2e/experimental/test_kg_writer_component_e2e.py +++ b/tests/e2e/experimental/test_kg_writer_component_e2e.py @@ -87,6 +87,8 @@ async def test_kg_writer(driver: neo4j.Driver) -> None: rel = record["r"] assert rel.type == relationship.type + assert relationship.start_node_id == rel.start_node.get("id") + assert relationship.end_node_id == rel.end_node.get("id") assert rel.start_node.get("id") == start_node.properties.get("id") assert rel.end_node.get("id") == end_node.properties.get("id") diff --git a/tests/unit/experimental/components/test_kg_writer.py b/tests/unit/experimental/components/test_kg_writer.py index 24b8fe206..5473078af 100644 --- a/tests/unit/experimental/components/test_kg_writer.py +++ b/tests/unit/experimental/components/test_kg_writer.py @@ -61,9 +61,7 @@ def test_upsert_nodes(_: Mock, driver: MagicMock) -> None: ) neo4j_writer = Neo4jWriter(driver=driver) node = Neo4jNode(id="1", label="Label", properties={"key": "value"}) - neo4j_writer._upsert_nodes( - nodes=[node], lexical_graph_config=LexicalGraphConfig() - ) + neo4j_writer._upsert_nodes(nodes=[node], lexical_graph_config=LexicalGraphConfig()) driver.execute_query.assert_called_once_with( upsert_node_query(False), parameters_={ From e19c4349291baec1e5770961a6aa08842e064eb1 Mon Sep 17 00:00:00 2001 From: estelle Date: Thu, 26 Jun 2025 11:24:37 +0200 Subject: [PATCH 9/9] Review comments + e2e test --- docs/source/user_guide_kg_builder.rst | 6 +++--- tests/e2e/experimental/test_kg_writer_component_e2e.py | 2 -- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/docs/source/user_guide_kg_builder.rst b/docs/source/user_guide_kg_builder.rst index c31255070..366f285c9 100644 --- a/docs/source/user_guide_kg_builder.rst +++ b/docs/source/user_guide_kg_builder.rst @@ -1048,10 +1048,10 @@ _________________ In addition to the user-defined configuration options described above, the `GraphPruning` component performs the following cleanup operations: -- Nodes with invalid label or ID are pruned. +- Nodes with empty label or ID are pruned. - Nodes with missing required properties are pruned. - Nodes with no remaining properties are pruned. -- Relationships with invalid type are pruned. +- Relationships with empty type are pruned. - Relationships with invalid source or target nodes (i.e., nodes no longer present in the graph) are pruned. - Relationships with incorrect direction have their direction corrected. @@ -1082,7 +1082,7 @@ This parameter controls the number of nodes or relationships inserted per batch, .. note:: Index - In order to improve the ingestion performances, a index called `__entity__tmp_internal_id` is automatically added to the database. + In order to improve the ingestion performances, an index called `__entity__tmp_internal_id` is automatically added to the database. See :ref:`neo4jgraph`. diff --git a/tests/e2e/experimental/test_kg_writer_component_e2e.py b/tests/e2e/experimental/test_kg_writer_component_e2e.py index ccba2ae69..38aab3f1d 100644 --- a/tests/e2e/experimental/test_kg_writer_component_e2e.py +++ b/tests/e2e/experimental/test_kg_writer_component_e2e.py @@ -87,8 +87,6 @@ async def test_kg_writer(driver: neo4j.Driver) -> None: rel = record["r"] assert rel.type == relationship.type - assert relationship.start_node_id == rel.start_node.get("id") - assert relationship.end_node_id == rel.end_node.get("id") assert rel.start_node.get("id") == start_node.properties.get("id") assert rel.end_node.get("id") == end_node.properties.get("id")