From 11423d90dcd01620a94cec80f76f40400d609256 Mon Sep 17 00:00:00 2001 From: estelle Date: Fri, 23 May 2025 12:57:01 +0200 Subject: [PATCH 01/25] Update schema definition --- .../experimental/components/schema.py | 53 ++++++++++++---- .../experimental/components/test_schema.py | 62 ++++++++++++++----- 2 files changed, 88 insertions(+), 27 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index b2007ccb4..6937e2b79 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -23,10 +23,10 @@ from pydantic import ( BaseModel, PrivateAttr, - ValidationError, model_validator, validate_call, ConfigDict, + ValidationError, ) from typing_extensions import Self @@ -67,6 +67,7 @@ class PropertyType(BaseModel): "ZONED_TIME", ] description: str = "" + required: bool = False model_config = ConfigDict( frozen=True, @@ -81,6 +82,7 @@ class NodeType(BaseModel): label: str description: str = "" properties: list[PropertyType] = [] + additional_properties: bool = True @model_validator(mode="before") @classmethod @@ -89,6 +91,16 @@ def validate_input_if_string(cls, data: EntityInputType) -> EntityInputType: return {"label": data} return data + @model_validator(mode="after") + def validate_additional_properties(self) -> Self: + if len(self.properties) == 0 and not self.additional_properties: + warnings.warn( + "Using `additional_properties=False` with no defined " + "properties will cause the model to be pruned during graph cleaning.", + UserWarning, + ) + return self + class RelationshipType(BaseModel): """ @@ -98,6 +110,7 @@ class RelationshipType(BaseModel): label: str description: str = "" properties: list[PropertyType] = [] + additional_properties: bool = True @model_validator(mode="before") @classmethod @@ -106,11 +119,25 @@ def validate_input_if_string(cls, data: RelationInputType) -> RelationInputType: return {"label": data} return data + @model_validator(mode="after") + def validate_additional_properties(self) -> Self: + if len(self.properties) == 0 and not self.additional_properties: + warnings.warn( + "Using `additional_properties=False` with no defined " + "properties will cause the model to be pruned during graph cleaning.", + UserWarning, + ) + return self + class GraphSchema(DataModel): node_types: Tuple[NodeType, ...] - relationship_types: Optional[Tuple[RelationshipType, ...]] = None - patterns: Optional[Tuple[Tuple[str, str, str], ...]] = None + relationship_types: Tuple[RelationshipType, ...] = tuple() + patterns: Tuple[Tuple[str, str, str], ...] = tuple() + + additional_node_types: bool = True + additional_relationship_types: bool = True + additional_patterns: bool = True _node_type_index: dict[str, NodeType] = PrivateAttr() _relationship_type_index: dict[str, RelationshipType] = PrivateAttr() @@ -128,26 +155,26 @@ def check_schema(self) -> Self: else {} ) - relationship_types = self.relationship_types or tuple() - patterns = self.patterns or tuple() + relationship_types = self.relationship_types + patterns = self.patterns if patterns: if not relationship_types: raise SchemaValidationError( - "Relations must also be provided when using a potential schema." + "Relationship types must also be provided when using patterns." ) for entity1, relation, entity2 in patterns: if entity1 not in self._node_type_index: raise SchemaValidationError( - f"Entity '{entity1}' is not defined in the provided entities." + f"Node type '{entity1}' is not defined in the provided node_types." ) if relation not in self._relationship_type_index: raise SchemaValidationError( - f"Relation '{relation}' is not defined in the provided relations." + f"Relationship type '{relation}' is not defined in the provided relationship_types." ) if entity2 not in self._node_type_index: - raise SchemaValidationError( - f"Entity '{entity2}' is not defined in the provided entities." + raise ValidationError( + f"Node type '{entity2}' is not defined in the provided node_types." ) return self @@ -303,12 +330,12 @@ def create_schema_model( return GraphSchema.model_validate( dict( node_types=node_types, - relationship_types=relationship_types, - patterns=patterns, + relationship_types=relationship_types or (), + patterns=patterns or (), ) ) except (ValidationError, SchemaValidationError) as e: - raise SchemaValidationError(e) + raise SchemaValidationError(e) from e @validate_call async def run( diff --git a/tests/unit/experimental/components/test_schema.py b/tests/unit/experimental/components/test_schema.py index 1aec51261..67edb2393 100644 --- a/tests/unit/experimental/components/test_schema.py +++ b/tests/unit/experimental/components/test_schema.py @@ -16,7 +16,7 @@ import json from typing import Tuple -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, patch import pytest @@ -38,6 +38,24 @@ from neo4j_graphrag.utils.file_handler import FileFormat +def test_node_type_raise_warning_if_misconfigured() -> None: + with pytest.warns(UserWarning): + NodeType( + label="test", + properties=[], + additional_properties=False, + ) + + +def test_relationship_type_raise_warning_if_misconfigured() -> None: + with pytest.warns(UserWarning): + RelationshipType( + label="test", + properties=[], + additional_properties=False, + ) + + @pytest.fixture def valid_node_types() -> tuple[NodeType, ...]: return ( @@ -46,8 +64,9 @@ def valid_node_types() -> tuple[NodeType, ...]: description="An individual human being.", properties=[ PropertyType(name="birth date", type="ZONED_DATETIME"), - PropertyType(name="name", type="STRING"), + PropertyType(name="name", type="STRING", required=True), ], + additional_properties=False, ), NodeType( label="ORGANIZATION", @@ -64,9 +83,10 @@ def valid_relationship_types() -> tuple[RelationshipType, ...]: label="EMPLOYED_BY", description="Indicates employment relationship.", properties=[ - PropertyType(name="start_time", type="LOCAL_DATETIME"), + PropertyType(name="start_time", type="LOCAL_DATETIME", required=True), PropertyType(name="end_time", type="LOCAL_DATETIME"), ], + additional_properties=False, ), RelationshipType( label="ORGANIZED_BY", @@ -122,13 +142,16 @@ def test_create_schema_model_valid_data( valid_relationship_types: Tuple[RelationshipType, ...], valid_patterns: Tuple[Tuple[str, str, str], ...], ) -> None: - schema_instance = schema_builder.create_schema_model( + schema = schema_builder.create_schema_model( list(valid_node_types), list(valid_relationship_types), list(valid_patterns) ) - assert schema_instance.node_types == valid_node_types - assert schema_instance.relationship_types == valid_relationship_types - assert schema_instance.patterns == valid_patterns + assert schema.node_types == valid_node_types + assert schema.relationship_types == valid_relationship_types + assert schema.patterns == valid_patterns + assert schema.additional_node_types is True + assert schema.additional_relationship_types is True + assert schema.additional_patterns is True @pytest.mark.asyncio @@ -138,13 +161,21 @@ async def test_run_method( valid_relationship_types: Tuple[RelationshipType, ...], valid_patterns: Tuple[Tuple[str, str, str], ...], ) -> None: - schema = await schema_builder.run( - list(valid_node_types), list(valid_relationship_types), list(valid_patterns) - ) + with patch.object( + schema_builder, + "create_schema_model", + return_value=GraphSchema(node_types=valid_node_types, relationship_types=valid_relationship_types, patterns=valid_patterns), + ): + schema = await schema_builder.run( + list(valid_node_types), list(valid_relationship_types), list(valid_patterns) + ) assert schema.node_types == valid_node_types assert schema.relationship_types == valid_relationship_types assert schema.patterns == valid_patterns + assert schema.additional_node_types is True + assert schema.additional_relationship_types is True + assert schema.additional_patterns is True def test_create_schema_model_invalid_entity( @@ -159,7 +190,7 @@ def test_create_schema_model_invalid_entity( list(valid_relationship_types), list(patterns_with_invalid_entity), ) - assert "Entity 'NON_EXISTENT_ENTITY' is not defined" in str( + assert "Node type 'NON_EXISTENT_ENTITY' is not defined" in str( exc_info.value ), "Should fail due to non-existent entity" @@ -176,7 +207,7 @@ def test_create_schema_model_invalid_relation( list(valid_relationship_types), list(patterns_with_invalid_relation), ) - assert "Relation 'NON_EXISTENT_RELATION' is not defined" in str( + assert "Relationship type 'NON_EXISTENT_RELATION' is not defined" in str( exc_info.value ), "Should fail due to non-existent relation" @@ -191,7 +222,7 @@ def test_create_schema_model_no_potential_schema( ) assert schema_instance.node_types == valid_node_types assert schema_instance.relationship_types == valid_relationship_types - assert schema_instance.patterns is None + assert schema_instance.patterns == () def test_create_schema_model_no_relations_or_potential_schema( @@ -206,14 +237,17 @@ def test_create_schema_model_no_relations_or_potential_schema( assert person is not None assert person.description == "An individual human being." assert len(person.properties) == 2 + assert person.additional_properties is False org = schema_instance.node_type_from_label("ORGANIZATION") assert org is not None assert org.description == "A structured group of people with a common purpose." + assert org.additional_properties is True age = schema_instance.node_type_from_label("AGE") assert age is not None assert age.description == "Age of a person in years." + assert age.additional_properties is True def test_create_schema_model_missing_relations( @@ -225,7 +259,7 @@ def test_create_schema_model_missing_relations( schema_builder.create_schema_model( node_types=valid_node_types, patterns=valid_patterns ) - assert "Relations must also be provided when using a potential schema." in str( + assert "Relationship types must also be provided when using patterns." in str( exc_info.value ), "Should fail due to missing relations" From ea91b6804c0bfa57952addd379b8e203a2360dd8 Mon Sep 17 00:00:00 2001 From: estelle Date: Fri, 23 May 2025 14:15:39 +0200 Subject: [PATCH 02/25] Add graph pruning component and tests (WIP) --- .../experimental/components/graph_pruning.py | 220 ++++++++++++++++++ .../components/test_graph_pruning.py | 155 ++++++++++++ 2 files changed, 375 insertions(+) create mode 100644 src/neo4j_graphrag/experimental/components/graph_pruning.py create mode 100644 tests/unit/experimental/components/test_graph_pruning.py diff --git a/src/neo4j_graphrag/experimental/components/graph_pruning.py b/src/neo4j_graphrag/experimental/components/graph_pruning.py new file mode 100644 index 000000000..9f6ac1c10 --- /dev/null +++ b/src/neo4j_graphrag/experimental/components/graph_pruning.py @@ -0,0 +1,220 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import Optional, Any + +from neo4j_graphrag.experimental.components.schema import ( + GraphSchema, + PropertyType, + NodeType, +) +from neo4j_graphrag.experimental.components.types import ( + Neo4jGraph, + Neo4jNode, + Neo4jRelationship, +) +from neo4j_graphrag.experimental.pipeline import Component + +logger = logging.getLogger(__name__) + + +class GraphPruning(Component): + async def run( + self, + graph: Neo4jGraph, + schema: Optional[GraphSchema] = None, + ) -> Neo4jGraph: + if schema is None: + return graph + return self._clean_graph(graph, schema) + + def _clean_graph( + self, + graph: Neo4jGraph, + schema: GraphSchema, + ) -> Neo4jGraph: + """ + Verify that the graph conforms to the provided schema. + + Remove invalid entities,relationships, and properties. + If an entity is removed, all of its relationships are also removed. + If no valid properties remain for an entity, remove that entity. + """ + # enforce nodes (remove invalid labels, strip invalid properties) + filtered_nodes = self._enforce_nodes(graph.nodes, schema) + if not filtered_nodes: + logger.warning( + "PRUNING: all nodes were pruned, resulting graph will be empty. Check logs for details." + ) + return Neo4jGraph() + + # enforce relationships (remove those referencing invalid nodes or with invalid + # types or with start/end nodes not conforming to the schema, and strip invalid + # properties) + filtered_rels = self._enforce_relationships( + graph.relationships, filtered_nodes, schema + ) + + return Neo4jGraph(nodes=filtered_nodes, relationships=filtered_rels) + + def _validate_node( + self, + node: Neo4jNode, + schema_entity: Optional[NodeType] = None, + additional_node_types: bool = True, + ) -> Optional[Neo4jNode]: + if not schema_entity: + # node type not declared in the schema + if additional_node_types: + # keep node as it is as we do not have any additional info + return node + # it's not in schema + return None + allowed_props = schema_entity.properties + filtered_props = self._enforce_properties( + node.properties, allowed_props, schema_entity.additional_properties + ) + if not filtered_props: + return None + return Neo4jNode( + id=node.id, + label=node.label, + properties=filtered_props, + embedding_properties=node.embedding_properties, + ) + + def _enforce_nodes( + self, extracted_nodes: list[Neo4jNode], schema: GraphSchema + ) -> list[Neo4jNode]: + """ + Filter extracted nodes to be conformant to the schema. + + Keep only those whose label is in schema + (unless schema has additional_node_types=True, default value) + For each valid node, validate properties. If a node is left without + properties, prune it. + """ + valid_nodes = [] + for node in extracted_nodes: + schema_entity = schema.node_type_from_label(node.label) + new_node = self._validate_node( + node, + schema_entity, + additional_node_types=schema.additional_node_types, + ) + if new_node: + valid_nodes.append(new_node) + return valid_nodes + + def _enforce_relationships( + self, + extracted_relationships: list[Neo4jRelationship], + filtered_nodes: list[Neo4jNode], + schema: GraphSchema, + ) -> list[Neo4jRelationship]: + """ + Filter extracted nodes to be conformant to the schema. + + Keep only those whose types are in schema, start/end node conform to schema, + and start/end nodes are in filtered nodes (i.e., kept after node enforcement). + For each valid relationship, filter out properties not present in the schema. + If a relationship direct is incorrect, invert it. + """ + + valid_rels = [] + valid_nodes = {node.id: node.label for node in filtered_nodes} + + patterns = schema.patterns + + for rel in extracted_relationships: + schema_relation = schema.relationship_type_from_label(rel.type) + if schema_relation is None: + if schema.additional_relationship_types: + valid_rels.append(rel) + else: + logger.debug(f"PRUNING:: {rel} as {rel.type} is not in the schema") + continue + + if ( + rel.start_node_id not in valid_nodes + or rel.end_node_id not in valid_nodes + ): + logger.debug( + f"PRUNING:: {rel} as one of {rel.start_node_id} or {rel.end_node_id} is not in the graph" + ) + continue + + start_label = valid_nodes[rel.start_node_id] + end_label = valid_nodes[rel.end_node_id] + + tuple_valid = True + if patterns: + tuple_valid = (start_label, rel.type, end_label) in patterns + reverse_tuple_valid = ( + end_label, + rel.type, + start_label, + ) in patterns + + if ( + not tuple_valid + and not reverse_tuple_valid + and not schema.additional_patterns + ): + logger.debug(f"PRUNING:: {rel} not in the allowed patterns") + continue + + allowed_props = schema_relation.properties + filtered_props = self._enforce_properties( + rel.properties, allowed_props, schema_relation.additional_properties + ) + + valid_rels.append( + Neo4jRelationship( + start_node_id=rel.start_node_id if tuple_valid else rel.end_node_id, + end_node_id=rel.end_node_id if tuple_valid else rel.start_node_id, + type=rel.type, + properties=filtered_props, + embedding_properties=rel.embedding_properties, + ) + ) + + return valid_rels + + def _enforce_properties( + self, + properties: dict[str, Any], + valid_properties: list[PropertyType], + additional_properties: bool, + ) -> dict[str, Any]: + """ + Filter properties. + - Keep only those that exist in schema (i.e., valid properties). + - Check that all required properties are present + """ + valid_prop_names = {prop.name for prop in valid_properties} + filtered_properties = { + key: value + for key, value in properties.items() + if key in valid_prop_names or additional_properties + } + required_prop_names = {prop.name for prop in valid_properties if prop.required} + for req_prop in required_prop_names: + if filtered_properties.get(req_prop) is None: + logger.info( + f"PRUNING:: {req_prop} is required but missing in {properties} - skipping node" + ) + return {} # node will be pruned + return filtered_properties diff --git a/tests/unit/experimental/components/test_graph_pruning.py b/tests/unit/experimental/components/test_graph_pruning.py new file mode 100644 index 000000000..f11f97737 --- /dev/null +++ b/tests/unit/experimental/components/test_graph_pruning.py @@ -0,0 +1,155 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any + +import pytest + +from neo4j_graphrag.experimental.components.graph_pruning import GraphPruning +from neo4j_graphrag.experimental.components.schema import NodeType, PropertyType +from neo4j_graphrag.experimental.components.types import Neo4jNode + + +@pytest.mark.parametrize( + "properties, valid_properties, additional_properties, expected_filtered_properties", + [ + ( + # no required, additional allowed + { + "name": "John Does", + "age": 25, + }, + [ + PropertyType( + name="name", + type="STRING", + ) + ], + True, + { + "name": "John Does", + "age": 25, + }, + ), + ( + # no required, additional not allowed + { + "name": "John Does", + "age": 25, + }, + [ + PropertyType( + name="name", + type="STRING", + ) + ], + False, + { + "name": "John Does", + }, + ), + ( + # required missing + { + "age": 25, + }, + [ + PropertyType( + name="name", + type="STRING", + required=True, + ) + ], + True, + {}, + ), + ], +) +def test_graph_pruning_enforce_properties( + properties: dict[str, Any], + valid_properties: list[PropertyType], + additional_properties: bool, + expected_filtered_properties: dict[str, Any], +) -> None: + prunner = GraphPruning() + filtered_properties = prunner._enforce_properties( + properties, valid_properties, additional_properties=additional_properties + ) + assert filtered_properties == expected_filtered_properties + + +@pytest.fixture(scope="module") +def node_type_no_properties() -> NodeType: + return NodeType(label="Person") + + +@pytest.fixture(scope="module") +def node_type_required_name() -> NodeType: + return NodeType( + label="Person", + properties=[ + PropertyType(name="name", type="STRING", required=True), + PropertyType(name="age", type="INTEGER"), + ], + ) + + +@pytest.mark.parametrize( + "node, entity, additional_node_types, expected_node", + [ + # all good, with default values + ( + Neo4jNode(id="1", label="Person", properties={"name": "John Doe"}), + "node_type_no_properties", + True, + Neo4jNode(id="1", label="Person", properties={"name": "John Doe"}), + ), + # properties empty (missing default) + ( + Neo4jNode(id="1", label="Person", properties={"age": 45}), + "node_type_required_name", + True, + None, + ), + # node label not is schema, additional not allowed + ( + Neo4jNode(id="1", label="Location", properties={"name": "New York"}), + None, + False, + None, + ), + # node label not is schema, additional allowed + ( + Neo4jNode(id="1", label="Location", properties={"name": "New York"}), + None, + True, + Neo4jNode(id="1", label="Location", properties={"name": "New York"}), + ), + ], +) +def test_graph_pruning_validate_node( + node: Neo4jNode, + entity: str, + additional_node_types: bool, + expected_node: Neo4jNode, + request: pytest.FixtureRequest, +) -> None: + e = request.getfixturevalue(entity) if entity else None + + prunner = GraphPruning() + result = prunner._validate_node(node, e, additional_node_types) + if expected_node is not None: + assert result == expected_node + else: + assert result is None From a95b3c85a26387d88752074f85ac269674f3c556 Mon Sep 17 00:00:00 2001 From: estelle Date: Fri, 23 May 2025 14:15:47 +0200 Subject: [PATCH 03/25] Cleaning --- .../components/entity_relation_extractor.py | 184 +------- .../experimental/components/types.py | 6 - .../template_pipeline/simple_kg_builder.py | 3 - .../experimental/pipeline/kg_builder.py | 4 - .../test_entity_relation_extractor.py | 431 ------------------ .../experimental/components/test_schema.py | 6 +- .../experimental/pipeline/test_kg_builder.py | 14 - 7 files changed, 11 insertions(+), 637 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py index 32d65276e..b1ed9bbb1 100644 --- a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py +++ b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py @@ -18,23 +18,20 @@ import enum import json import logging -from typing import Any, List, Optional, Union, cast, Dict +from typing import Any, List, Optional, Union, cast import json_repair from pydantic import ValidationError, validate_call from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.experimental.components.lexical_graph import LexicalGraphBuilder -from neo4j_graphrag.experimental.components.schema import GraphSchema, PropertyType +from neo4j_graphrag.experimental.components.schema import GraphSchema from neo4j_graphrag.experimental.components.types import ( DocumentInfo, LexicalGraphConfig, Neo4jGraph, - Neo4jNode, - Neo4jRelationship, TextChunk, TextChunks, - SchemaEnforcementMode, ) from neo4j_graphrag.experimental.pipeline.component import Component from neo4j_graphrag.experimental.pipeline.exceptions import InvalidJSONError @@ -169,7 +166,6 @@ class LLMEntityRelationExtractor(EntityRelationExtractor): llm (LLMInterface): The language model to use for extraction. prompt_template (ERExtractionTemplate | str): A custom prompt template to use for extraction. create_lexical_graph (bool): Whether to include the text chunks in the graph in addition to the extracted entities and relations. Defaults to True. - enforce_schema (SchemaEnforcementMode): Whether to validate or not the extracted entities/rels against the provided schema. Defaults to None. on_error (OnError): What to do when an error occurs during extraction. Defaults to raising an error. max_concurrency (int): The maximum number of concurrent tasks which can be used to make requests to the LLM. @@ -194,13 +190,11 @@ def __init__( llm: LLMInterface, prompt_template: ERExtractionTemplate | str = ERExtractionTemplate(), create_lexical_graph: bool = True, - enforce_schema: SchemaEnforcementMode = SchemaEnforcementMode.NONE, on_error: OnError = OnError.RAISE, max_concurrency: int = 5, ) -> None: super().__init__(on_error=on_error, create_lexical_graph=create_lexical_graph) self.llm = llm # with response_format={ "type": "json_object" }, - self.enforce_schema = enforce_schema self.max_concurrency = max_concurrency if isinstance(prompt_template, str): template = PromptTemplate(prompt_template, expected_inputs=[]) @@ -284,13 +278,13 @@ async def run_for_chunk( """Run extraction, validation and post processing for a single chunk""" async with sem: chunk_graph = await self.extract_for_chunk(schema, examples, chunk) - final_chunk_graph = self.validate_chunk(chunk_graph, schema) + # final_chunk_graph = self.validate_chunk(chunk_graph, schema) await self.post_process_chunk( - final_chunk_graph, + chunk_graph, chunk, lexical_graph_builder, ) - return final_chunk_graph + return chunk_graph @validate_call async def run( @@ -328,7 +322,7 @@ async def run( elif lexical_graph_config: lexical_graph_builder = LexicalGraphBuilder(config=lexical_graph_config) schema = schema or GraphSchema( - node_types=(), relationship_types=None, patterns=None + node_types=(), ) examples = examples or "" sem = asyncio.Semaphore(self.max_concurrency) @@ -346,169 +340,3 @@ async def run( graph = self.combine_chunk_graphs(lexical_graph, chunk_graphs) logger.debug(f"Extracted graph: {prettify(graph)}") return graph - - def validate_chunk( - self, chunk_graph: Neo4jGraph, schema: GraphSchema - ) -> Neo4jGraph: - """ - Perform validation after entity and relation extraction: - - Enforce schema if schema enforcement mode is on and schema is provided - """ - if self.enforce_schema != SchemaEnforcementMode.NONE: - if not schema or not schema.node_types: # schema is not provided - logger.warning( - "Schema enforcement is ON but the guiding schema is not provided." - ) - else: - # if enforcing_schema is on and schema is provided, clean the graph - return self._clean_graph(chunk_graph, schema) - return chunk_graph - - def _clean_graph( - self, - graph: Neo4jGraph, - schema: GraphSchema, - ) -> Neo4jGraph: - """ - Verify that the graph conforms to the provided schema. - - Remove invalid entities,relationships, and properties. - If an entity is removed, all of its relationships are also removed. - If no valid properties remain for an entity, remove that entity. - """ - # enforce nodes (remove invalid labels, strip invalid properties) - filtered_nodes = self._enforce_nodes(graph.nodes, schema) - - # enforce relationships (remove those referencing invalid nodes or with invalid - # types or with start/end nodes not conforming to the schema, and strip invalid - # properties) - filtered_rels = self._enforce_relationships( - graph.relationships, filtered_nodes, schema - ) - - return Neo4jGraph(nodes=filtered_nodes, relationships=filtered_rels) - - def _enforce_nodes( - self, extracted_nodes: List[Neo4jNode], schema: GraphSchema - ) -> List[Neo4jNode]: - """ - Filter extracted nodes to be conformant to the schema. - - Keep only those whose label is in schema. - For each valid node, filter out properties not present in the schema. - Remove a node if it ends up with no valid properties. - """ - if self.enforce_schema != SchemaEnforcementMode.STRICT: - return extracted_nodes - - valid_nodes = [] - - for node in extracted_nodes: - schema_entity = schema.node_type_from_label(node.label) - if not schema_entity: - continue - allowed_props = schema_entity.properties or [] - if allowed_props: - filtered_props = self._enforce_properties( - node.properties, allowed_props - ) - else: - filtered_props = node.properties - if filtered_props: - valid_nodes.append( - Neo4jNode( - id=node.id, - label=node.label, - properties=filtered_props, - embedding_properties=node.embedding_properties, - ) - ) - - return valid_nodes - - def _enforce_relationships( - self, - extracted_relationships: List[Neo4jRelationship], - filtered_nodes: List[Neo4jNode], - schema: GraphSchema, - ) -> List[Neo4jRelationship]: - """ - Filter extracted nodes to be conformant to the schema. - - Keep only those whose types are in schema, start/end node conform to schema, - and start/end nodes are in filtered nodes (i.e., kept after node enforcement). - For each valid relationship, filter out properties not present in the schema. - If a relationship direct is incorrect, invert it. - """ - if self.enforce_schema != SchemaEnforcementMode.STRICT: - return extracted_relationships - - if schema.relationship_types is None: - return extracted_relationships - - valid_rels = [] - - valid_nodes = {node.id: node.label for node in filtered_nodes} - - patterns = schema.patterns - - for rel in extracted_relationships: - schema_relation = schema.relationship_type_from_label(rel.type) - if not schema_relation: - logger.debug(f"PRUNING:: {rel} as {rel.type} is not in the schema") - continue - - if ( - rel.start_node_id not in valid_nodes - or rel.end_node_id not in valid_nodes - ): - logger.debug( - f"PRUNING:: {rel} as one of {rel.start_node_id} or {rel.end_node_id} is not in the graph" - ) - continue - - start_label = valid_nodes[rel.start_node_id] - end_label = valid_nodes[rel.end_node_id] - - tuple_valid = True - if patterns: - tuple_valid = (start_label, rel.type, end_label) in patterns - reverse_tuple_valid = ( - end_label, - rel.type, - start_label, - ) in patterns - - if not tuple_valid and not reverse_tuple_valid: - logger.debug(f"PRUNING:: {rel} not in the potential schema") - continue - - allowed_props = schema_relation.properties or [] - if allowed_props: - filtered_props = self._enforce_properties(rel.properties, allowed_props) - else: - filtered_props = rel.properties - - valid_rels.append( - Neo4jRelationship( - start_node_id=rel.start_node_id if tuple_valid else rel.end_node_id, - end_node_id=rel.end_node_id if tuple_valid else rel.start_node_id, - type=rel.type, - properties=filtered_props, - embedding_properties=rel.embedding_properties, - ) - ) - - return valid_rels - - def _enforce_properties( - self, properties: Dict[str, Any], valid_properties: List[PropertyType] - ) -> Dict[str, Any]: - """ - Filter properties. - Keep only those that exist in schema (i.e., valid properties). - """ - valid_prop_names = {prop.name for prop in valid_properties} - return { - key: value for key, value in properties.items() if key in valid_prop_names - } diff --git a/src/neo4j_graphrag/experimental/components/types.py b/src/neo4j_graphrag/experimental/components/types.py index b1ed46569..689e6b6ce 100644 --- a/src/neo4j_graphrag/experimental/components/types.py +++ b/src/neo4j_graphrag/experimental/components/types.py @@ -15,7 +15,6 @@ from __future__ import annotations import uuid -from enum import Enum from typing import Any, Dict, Optional from pydantic import BaseModel, Field, field_validator @@ -171,8 +170,3 @@ def lexical_graph_node_labels(self) -> tuple[str, ...]: class GraphResult(DataModel): graph: Neo4jGraph config: LexicalGraphConfig - - -class SchemaEnforcementMode(str, Enum): - NONE = "NONE" - STRICT = "STRICT" diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py index 4aee855bb..75e5fc842 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py @@ -54,7 +54,6 @@ ) from neo4j_graphrag.experimental.components.types import ( LexicalGraphConfig, - SchemaEnforcementMode, ) from neo4j_graphrag.experimental.pipeline.config.object_config import ComponentType from neo4j_graphrag.experimental.pipeline.config.template_pipeline.base import ( @@ -92,7 +91,6 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig): relations: Sequence[RelationInputType] = [] potential_schema: Optional[list[tuple[str, str, str]]] = None schema_: Optional[GraphSchema] = Field(default=None, alias="schema") - enforce_schema: SchemaEnforcementMode = SchemaEnforcementMode.NONE on_error: OnError = OnError.IGNORE prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate() perform_entity_resolution: bool = True @@ -247,7 +245,6 @@ def _get_extractor(self) -> EntityRelationExtractor: return LLMEntityRelationExtractor( llm=self.get_default_llm(), prompt_template=self.prompt_template, - enforce_schema=self.enforce_schema, on_error=self.on_error, ) diff --git a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py index b0231e50f..ba81b042a 100644 --- a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py @@ -28,7 +28,6 @@ from neo4j_graphrag.experimental.components.text_splitters.base import TextSplitter from neo4j_graphrag.experimental.components.types import ( LexicalGraphConfig, - SchemaEnforcementMode, ) from neo4j_graphrag.experimental.pipeline.config.object_config import ComponentType from neo4j_graphrag.experimental.pipeline.config.runner import PipelineRunner @@ -71,7 +70,6 @@ class SimpleKGPipeline: - dict: following the RelationshipType schema, ie with label, description and properties keys potential_schema (Optional[List[tuple]]): DEPRECATED. A list of potential schema relationships. - enforce_schema (str): Validation of the extracted entities/rels against the provided schema. Defaults to "NONE", where schema enforcement will be ignored even if the schema is provided. Possible values "None" or "STRICT". from_pdf (bool): Determines whether to include the PdfLoader in the pipeline. If True, expects `file_path` input in `run` methods. If False, expects `text` input in `run` methods. @@ -93,7 +91,6 @@ def __init__( relations: Optional[Sequence[RelationInputType]] = None, potential_schema: Optional[List[tuple[str, str, str]]] = None, schema: Optional[Union[GraphSchema, dict[str, list[Any]]]] = None, - enforce_schema: str = "NONE", from_pdf: bool = True, text_splitter: Optional[TextSplitter] = None, pdf_loader: Optional[DataLoader] = None, @@ -114,7 +111,6 @@ def __init__( relations=relations or [], potential_schema=potential_schema, schema=schema, - enforce_schema=SchemaEnforcementMode(enforce_schema), from_pdf=from_pdf, pdf_loader=ComponentType(pdf_loader) if pdf_loader else None, kg_writer=ComponentType(kg_writer) if kg_writer else None, diff --git a/tests/unit/experimental/components/test_entity_relation_extractor.py b/tests/unit/experimental/components/test_entity_relation_extractor.py index 64ac0d42d..f76ab5c9c 100644 --- a/tests/unit/experimental/components/test_entity_relation_extractor.py +++ b/tests/unit/experimental/components/test_entity_relation_extractor.py @@ -25,13 +25,11 @@ balance_curly_braces, fix_invalid_json, ) -from neo4j_graphrag.experimental.components.schema import GraphSchema from neo4j_graphrag.experimental.components.types import ( DocumentInfo, Neo4jGraph, TextChunk, TextChunks, - SchemaEnforcementMode, ) from neo4j_graphrag.experimental.pipeline.exceptions import InvalidJSONError from neo4j_graphrag.llm import LLMInterface, LLMResponse @@ -231,435 +229,6 @@ async def test_extractor_custom_prompt() -> None: llm.ainvoke.assert_called_once_with("this is my prompt") -@pytest.mark.asyncio -async def test_extractor_no_schema_enforcement() -> None: - llm = MagicMock(spec=LLMInterface) - llm.ainvoke.return_value = LLMResponse( - content='{"nodes":[{"id":"0","label":"Alien","properties":{"foo":"bar"}}],' - '"relationships":[]}' - ) - - extractor = LLMEntityRelationExtractor( - llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.NONE - ) - - schema = GraphSchema.model_validate( - { - "node_types": [ - { - "label": "Person", - "properties": [{"name": "name", "type": "STRING"}], - } - ], - "relationship_types": [], - "patterns": [], - }, - ) - - chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) - - result: Neo4jGraph = await extractor.run(chunks=chunks, schema=schema) - - assert len(result.nodes) == 1 - assert result.nodes[0].label == "Alien" - assert result.nodes[0].properties == {"chunk_index": 0, "foo": "bar"} - - -@pytest.mark.asyncio -async def test_extractor_schema_enforcement_when_no_schema_provided() -> None: - llm = MagicMock(spec=LLMInterface) - llm.ainvoke.return_value = LLMResponse( - content='{"nodes":[{"id":"0","label":"Alien","properties":{"foo":"bar"}}],' - '"relationships":[]}' - ) - - extractor = LLMEntityRelationExtractor( - llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT - ) - - chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) - - result: Neo4jGraph = await extractor.run(chunks=chunks) - - assert len(result.nodes) == 1 - assert result.nodes[0].label == "Alien" - assert result.nodes[0].properties == {"chunk_index": 0, "foo": "bar"} - - -@pytest.mark.asyncio -async def test_extractor_schema_enforcement_invalid_nodes() -> None: - llm = MagicMock(spec=LLMInterface) - llm.ainvoke.return_value = LLMResponse( - content='{"nodes":[{"id":"0","label":"Alien","properties":{"foo":"bar"}},' - '{"id":"1","label":"Person","properties":{"name":"Alice"}}],' - '"relationships":[]}' - ) - - extractor = LLMEntityRelationExtractor( - llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT - ) - - schema = GraphSchema.model_validate( - { - "node_types": [ - { - "label": "Person", - "properties": [{"name": "name", "type": "STRING"}], - } - ], - "relationship_types": [], - "patterns": [], - }, - ) - - chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) - - result: Neo4jGraph = await extractor.run(chunks=chunks, schema=schema) - - assert len(result.nodes) == 1 - assert result.nodes[0].label == "Person" - assert result.nodes[0].properties == {"chunk_index": 0, "name": "Alice"} - - -@pytest.mark.asyncio -async def test_extraction_schema_enforcement_invalid_node_properties() -> None: - llm = MagicMock(spec=LLMInterface) - llm.ainvoke.return_value = LLMResponse( - content='{"nodes":[{"id":"1","label":"Person","properties":' - '{"name":"Alice","age":30,"foo":"bar"}}],' - '"relationships":[]}' - ) - - extractor = LLMEntityRelationExtractor( - llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT - ) - - schema = GraphSchema.model_validate( - { - "node_types": [ - { - "label": "Person", - "properties": [ - {"name": "name", "type": "STRING"}, - {"name": "age", "type": "STRING"}, - ], - } - ], - "relationship_types": [], - "patterns": [], - }, - ) - - chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) - - result: Neo4jGraph = await extractor.run(chunks, schema=schema) - - # "foo" is removed - assert len(result.nodes) == 1 - assert len(result.nodes[0].properties) == 3 - assert "foo" not in result.nodes[0].properties - - -@pytest.mark.asyncio -async def test_extractor_schema_enforcement_valid_nodes_with_empty_props() -> None: - llm = MagicMock(spec=LLMInterface) - llm.ainvoke.return_value = LLMResponse( - content='{"nodes":[{"id":"1","label":"Person","properties":{"foo":"bar"}}],' - '"relationships":[]}' - ) - - extractor = LLMEntityRelationExtractor( - llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT - ) - - schema = GraphSchema.model_validate( - { - "node_types": [ - { - "label": "Person", - } - ], - } - ) - chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) - - result: Neo4jGraph = await extractor.run(chunks, schema=schema) - - assert len(result.nodes) == 1 - - -@pytest.mark.asyncio -async def test_extractor_schema_enforcement_invalid_relations_wrong_types() -> None: - llm = MagicMock(spec=LLMInterface) - llm.ainvoke.return_value = LLMResponse( - content='{"nodes":[{"id":"1","label":"Person","properties":' - '{"name":"Alice"}},{"id":"2","label":"Person","properties":' - '{"name":"Bob"}}],' - '"relationships":[{"start_node_id":"1","end_node_id":"2",' - '"type":"FRIENDS_WITH","properties":{}}]}' - ) - - extractor = LLMEntityRelationExtractor( - llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT - ) - - schema = GraphSchema.model_validate( - { - "node_types": [ - { - "label": "Person", - "properties": [ - {"name": "name", "type": "STRING"}, - {"name": "age", "type": "STRING"}, - ], - } - ], - "relationship_types": [{"label": "LIKES"}], - "patterns": [], - } - ) - - chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) - - result: Neo4jGraph = await extractor.run(chunks, schema=schema) - - assert len(result.nodes) == 2 - assert len(result.relationships) == 0 - - -@pytest.mark.asyncio -async def test_extractor_schema_enforcement_invalid_relations_wrong_start_node() -> ( - None -): - llm = MagicMock(spec=LLMInterface) - llm.ainvoke.return_value = LLMResponse( - content='{"nodes":[{"id":"1","label":"Person","properties":{"name":"Alice"}},' - '{"id":"2","label":"Person","properties":{"name":"Bob"}}, ' - '{"id":"3","label":"City","properties":{"name":"London"}}],' - '"relationships":[{"start_node_id":"1","end_node_id":"2",' - '"type":"LIVES_IN","properties":{}}]}' - ) - - extractor = LLMEntityRelationExtractor( - llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT - ) - - schema = GraphSchema.model_validate( - { - "node_types": [ - { - "label": "Person", - "properties": [{"name": "name", "type": "STRING"}], - }, - { - "label": "City", - "properties": [{"name": "name", "type": "STRING"}], - }, - ], - "relationship_types": [{"label": "LIVES_IN"}], - "patterns": [("Person", "LIVES_IN", "City")], - } - ) - - chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) - - result: Neo4jGraph = await extractor.run(chunks, schema=schema) - - assert len(result.nodes) == 3 - assert len(result.relationships) == 0 - - -@pytest.mark.asyncio -async def test_extractor_schema_enforcement_invalid_relation_properties() -> None: - llm = MagicMock(spec=LLMInterface) - llm.ainvoke.return_value = LLMResponse( - content='{"nodes":[{"id":"1","label":"Person","properties":{"name":"Alice"}},' - '{"id":"2","label":"Person","properties":{"name":"Bob"}}],' - '"relationships":[{"start_node_id":"1","end_node_id":"2",' - '"type":"LIKES","properties":{"strength":"high","foo":"bar"}}]}' - ) - - extractor = LLMEntityRelationExtractor( - llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT - ) - - schema = GraphSchema.model_validate( - { - "node_types": [ - { - "label": "Person", - "properties": [{"name": "name", "type": "STRING"}], - } - ], - "relationship_types": [ - { - "label": "LIKES", - "properties": [{"name": "strength", "type": "STRING"}], - } - ], - "patterns": [], - } - ) - - chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) - - result: Neo4jGraph = await extractor.run(chunks, schema=schema) - - assert len(result.nodes) == 2 - assert len(result.relationships) == 1 - rel = result.relationships[0] - assert "foo" not in rel.properties - assert rel.properties["strength"] == "high" - - -@pytest.mark.asyncio -async def test_extractor_schema_enforcement_removed_relation_start_end_nodes() -> None: - llm = MagicMock(spec=LLMInterface) - llm.ainvoke.return_value = LLMResponse( - content='{"nodes":[{"id":"1","label":"Alien","properties":{}},' - '{"id":"2","label":"Robot","properties":{}}],' - '"relationships":[{"start_node_id":"1","end_node_id":"2",' - '"type":"LIKES","properties":{}}]}' - ) - - extractor = LLMEntityRelationExtractor( - llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT - ) - - schema = GraphSchema.model_validate( - { - "node_types": [ - { - "label": "Person", - "properties": [{"name": "name", "type": "STRING"}], - } - ], - "relationship_types": [{"label": "LIKES"}], - "patterns": [("Person", "LIKES", "Person")], - } - ) - - chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) - - result: Neo4jGraph = await extractor.run(chunks, schema=schema) - - assert len(result.nodes) == 0 - assert len(result.relationships) == 0 - - -@pytest.mark.asyncio -async def test_extractor_schema_enforcement_inverted_relation_direction() -> None: - llm = MagicMock(spec=LLMInterface) - llm.ainvoke.return_value = LLMResponse( - content='{"nodes":[{"id":"1","label":"Person","properties":{"name":"Alice"}},' - '{"id":"2","label":"City","properties":{"name":"London"}}],' - '"relationships":[{"start_node_id":"2","end_node_id":"1",' - '"type":"LIVES_IN","properties":{}}]}' - ) - - extractor = LLMEntityRelationExtractor( - llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT - ) - - schema = GraphSchema.model_validate( - { - "node_types": [ - { - "label": "Person", - "properties": [{"name": "name", "type": "STRING"}], - }, - { - "label": "City", - "properties": [{"name": "name", "type": "STRING"}], - }, - ], - "relationship_types": [{"label": "LIVES_IN"}], - "patterns": [("Person", "LIVES_IN", "City")], - } - ) - - chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) - - result: Neo4jGraph = await extractor.run(chunks, schema=schema) - - assert len(result.nodes) == 2 - assert len(result.relationships) == 1 - assert result.relationships[0].start_node_id.split(":")[1] == "1" - assert result.relationships[0].end_node_id.split(":")[1] == "2" - - -@pytest.mark.asyncio -async def test_extractor_schema_enforcement_none_relationships_in_schema() -> None: - llm = MagicMock(spec=LLMInterface) - llm.ainvoke.return_value = LLMResponse( - content='{"nodes":[{"id":"1","label":"Person","properties":' - '{"name":"Alice"}},{"id":"2","label":"Person","properties":' - '{"name":"Bob"}}],' - '"relationships":[{"start_node_id":"1","end_node_id":"2",' - '"type":"FRIENDS_WITH","properties":{}}]}' - ) - - extractor = LLMEntityRelationExtractor( - llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT - ) - - schema = GraphSchema.model_validate( - dict( - node_types=[ - { - "label": "Person", - "properties": [{"name": "name", "type": "STRING"}], - } - ], - relationship_types=None, - patterns=None, - ) - ) - - chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) - - result: Neo4jGraph = await extractor.run(chunks, schema=schema) - - assert len(result.nodes) == 2 - assert len(result.relationships) == 1 - assert result.relationships[0].type == "FRIENDS_WITH" - - -@pytest.mark.asyncio -async def test_extractor_schema_enforcement_empty_relationships_in_schema() -> None: - llm = MagicMock(spec=LLMInterface) - llm.ainvoke.return_value = LLMResponse( - content='{"nodes":[{"id":"1","label":"Person","properties":' - '{"name":"Alice"}},{"id":"2","label":"Person","properties":' - '{"name":"Bob"}}],' - '"relationships":[{"start_node_id":"1","end_node_id":"2",' - '"type":"FRIENDS_WITH","properties":{}}]}' - ) - - extractor = LLMEntityRelationExtractor( - llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT - ) - - schema = GraphSchema.model_validate( - dict( - node_types=[ - { - "label": "Person", - "properties": [{"name": "name", "type": "STRING"}], - } - ], - relationship_types=[], - patterns=None, - ) - ) - - chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) - - result: Neo4jGraph = await extractor.run(chunks, schema=schema) - - assert len(result.relationships) == 0 - - def test_fix_invalid_json_empty_result() -> None: json_string = "invalid json" diff --git a/tests/unit/experimental/components/test_schema.py b/tests/unit/experimental/components/test_schema.py index 67edb2393..be960f826 100644 --- a/tests/unit/experimental/components/test_schema.py +++ b/tests/unit/experimental/components/test_schema.py @@ -164,7 +164,11 @@ async def test_run_method( with patch.object( schema_builder, "create_schema_model", - return_value=GraphSchema(node_types=valid_node_types, relationship_types=valid_relationship_types, patterns=valid_patterns), + return_value=GraphSchema( + node_types=valid_node_types, + relationship_types=valid_relationship_types, + patterns=valid_patterns, + ), ): schema = await schema_builder.run( list(valid_node_types), list(valid_relationship_types), list(valid_patterns) diff --git a/tests/unit/experimental/pipeline/test_kg_builder.py b/tests/unit/experimental/pipeline/test_kg_builder.py index b4ece857a..13d789cb2 100644 --- a/tests/unit/experimental/pipeline/test_kg_builder.py +++ b/tests/unit/experimental/pipeline/test_kg_builder.py @@ -149,20 +149,6 @@ def test_simple_kg_pipeline_on_error_invalid_value() -> None: ) -def test_simple_kg_pipeline_enforce_schema_invalid_value() -> None: - llm = MagicMock(spec=LLMInterface) - driver = MagicMock(spec=neo4j.Driver) - embedder = MagicMock(spec=Embedder) - - with pytest.raises(PipelineDefinitionError): - SimpleKGPipeline( - llm=llm, - driver=driver, - embedder=embedder, - enforce_schema="INVALID_VALUE", - ) - - @mock.patch( "neo4j_graphrag.experimental.components.kg_writer.get_version", return_value=((5, 23, 0), False, False), From 869129fcdb3d36e07a287c7358cc71e53dea75d1 Mon Sep 17 00:00:00 2001 From: estelle Date: Fri, 23 May 2025 14:21:31 +0200 Subject: [PATCH 04/25] Add pruner to SimpleKGPipeline --- .../experimental/components/graph_pruning.py | 3 +++ .../template_pipeline/simple_kg_builder.py | 17 ++++++++++++++++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/neo4j_graphrag/experimental/components/graph_pruning.py b/src/neo4j_graphrag/experimental/components/graph_pruning.py index 9f6ac1c10..8d861e60a 100644 --- a/src/neo4j_graphrag/experimental/components/graph_pruning.py +++ b/src/neo4j_graphrag/experimental/components/graph_pruning.py @@ -15,6 +15,8 @@ import logging from typing import Optional, Any +from pydantic import validate_call + from neo4j_graphrag.experimental.components.schema import ( GraphSchema, PropertyType, @@ -31,6 +33,7 @@ class GraphPruning(Component): + @validate_call async def run( self, graph: Neo4jGraph, diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py index 75e5fc842..34217433c 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py @@ -35,6 +35,7 @@ LLMEntityRelationExtractor, OnError, ) +from neo4j_graphrag.experimental.components.graph_pruning import GraphPruning from neo4j_graphrag.experimental.components.kg_writer import KGWriter, Neo4jWriter from neo4j_graphrag.experimental.components.pdf_loader import PdfLoader from neo4j_graphrag.experimental.components.resolver import ( @@ -78,6 +79,7 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig): "chunk_embedder", "schema", "extractor", + "pruner", "writer", "resolver", ] @@ -248,6 +250,9 @@ def _get_extractor(self) -> EntityRelationExtractor: on_error=self.on_error, ) + def _get_pruner(self) -> GraphPruning: + return GraphPruning() + def _get_writer(self) -> KGWriter: if self.kg_writer: return self.kg_writer.parse(self._global_data) # type: ignore @@ -329,9 +334,19 @@ def _get_connections(self) -> list[ConnectionDefinition]: connections.append( ConnectionDefinition( start="extractor", - end="writer", + end="pruner", input_config={ "graph": "extractor", + "schema": "schema", + }, + ) + ) + connections.append( + ConnectionDefinition( + start="pruner", + end="writer", + input_config={ + "graph": "pruner", }, ) ) From fd3cdceed012219948bbf7602ee8902d10bdd74c Mon Sep 17 00:00:00 2001 From: estelle Date: Fri, 23 May 2025 16:45:07 +0200 Subject: [PATCH 05/25] Add test for relationship enforcement --- .../experimental/components/graph_pruning.py | 113 ++++++------- .../components/test_graph_pruning.py | 150 +++++++++++++++++- .../test_simple_kg_builder.py | 15 +- 3 files changed, 217 insertions(+), 61 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/graph_pruning.py b/src/neo4j_graphrag/experimental/components/graph_pruning.py index 8d861e60a..b1c825666 100644 --- a/src/neo4j_graphrag/experimental/components/graph_pruning.py +++ b/src/neo4j_graphrag/experimental/components/graph_pruning.py @@ -21,6 +21,7 @@ GraphSchema, PropertyType, NodeType, + RelationshipType, ) from neo4j_graphrag.experimental.components.types import ( Neo4jGraph, @@ -121,6 +122,56 @@ def _enforce_nodes( valid_nodes.append(new_node) return valid_nodes + def _validate_relationship( + self, + rel: Neo4jRelationship, + valid_nodes: dict[str, str], + relationship_type: Optional[RelationshipType], + additional_relationship_types: bool, + patterns: tuple[tuple[str, str, str], ...], + additional_patterns: bool, + ) -> Optional[Neo4jRelationship]: + if relationship_type is None: + if additional_relationship_types: + return rel + else: + logger.debug(f"PRUNING:: {rel} as {rel.type} is not in the schema") + return None + + if rel.start_node_id not in valid_nodes or rel.end_node_id not in valid_nodes: + logger.debug( + f"PRUNING:: {rel} as one of {rel.start_node_id} or {rel.end_node_id} is not in the graph" + ) + return None + + start_label = valid_nodes[rel.start_node_id] + end_label = valid_nodes[rel.end_node_id] + tuple_valid = True + reverse_tuple_valid = False + if patterns: + tuple_valid = (start_label, rel.type, end_label) in patterns + reverse_tuple_valid = ( + end_label, + rel.type, + start_label, + ) in patterns + + if not tuple_valid and not reverse_tuple_valid and not additional_patterns: + logger.debug(f"PRUNING:: {rel} not in the allowed patterns") + return None + + allowed_props = relationship_type.properties + filtered_props = self._enforce_properties( + rel.properties, allowed_props, relationship_type.additional_properties + ) + return Neo4jRelationship( + start_node_id=rel.end_node_id if reverse_tuple_valid else rel.start_node_id, + end_node_id=rel.start_node_id if reverse_tuple_valid else rel.end_node_id, + type=rel.type, + properties=filtered_props, + embedding_properties=rel.embedding_properties, + ) + def _enforce_relationships( self, extracted_relationships: list[Neo4jRelationship], @@ -138,62 +189,18 @@ def _enforce_relationships( valid_rels = [] valid_nodes = {node.id: node.label for node in filtered_nodes} - - patterns = schema.patterns - for rel in extracted_relationships: schema_relation = schema.relationship_type_from_label(rel.type) - if schema_relation is None: - if schema.additional_relationship_types: - valid_rels.append(rel) - else: - logger.debug(f"PRUNING:: {rel} as {rel.type} is not in the schema") - continue - - if ( - rel.start_node_id not in valid_nodes - or rel.end_node_id not in valid_nodes - ): - logger.debug( - f"PRUNING:: {rel} as one of {rel.start_node_id} or {rel.end_node_id} is not in the graph" - ) - continue - - start_label = valid_nodes[rel.start_node_id] - end_label = valid_nodes[rel.end_node_id] - - tuple_valid = True - if patterns: - tuple_valid = (start_label, rel.type, end_label) in patterns - reverse_tuple_valid = ( - end_label, - rel.type, - start_label, - ) in patterns - - if ( - not tuple_valid - and not reverse_tuple_valid - and not schema.additional_patterns - ): - logger.debug(f"PRUNING:: {rel} not in the allowed patterns") - continue - - allowed_props = schema_relation.properties - filtered_props = self._enforce_properties( - rel.properties, allowed_props, schema_relation.additional_properties + new_rel = self._validate_relationship( + rel, + valid_nodes, + schema_relation, + schema.additional_relationship_types, + schema.patterns, + schema.additional_patterns, ) - - valid_rels.append( - Neo4jRelationship( - start_node_id=rel.start_node_id if tuple_valid else rel.end_node_id, - end_node_id=rel.end_node_id if tuple_valid else rel.start_node_id, - type=rel.type, - properties=filtered_props, - embedding_properties=rel.embedding_properties, - ) - ) - + if new_rel: + valid_rels.append(new_rel) return valid_rels def _enforce_properties( diff --git a/tests/unit/experimental/components/test_graph_pruning.py b/tests/unit/experimental/components/test_graph_pruning.py index f11f97737..c86653842 100644 --- a/tests/unit/experimental/components/test_graph_pruning.py +++ b/tests/unit/experimental/components/test_graph_pruning.py @@ -17,8 +17,12 @@ import pytest from neo4j_graphrag.experimental.components.graph_pruning import GraphPruning -from neo4j_graphrag.experimental.components.schema import NodeType, PropertyType -from neo4j_graphrag.experimental.components.types import Neo4jNode +from neo4j_graphrag.experimental.components.schema import ( + NodeType, + PropertyType, + RelationshipType, +) +from neo4j_graphrag.experimental.components.types import Neo4jNode, Neo4jRelationship @pytest.mark.parametrize( @@ -153,3 +157,145 @@ def test_graph_pruning_validate_node( assert result == expected_node else: assert result is None + + +@pytest.fixture +def neo4j_relationship() -> Neo4jRelationship: + return Neo4jRelationship( + start_node_id="1", + end_node_id="2", + type="REL", + properties={}, + ) + + +@pytest.fixture +def neo4j_reversed_relationship(neo4j_relationship: Neo4jRelationship) -> Neo4jRelationship: + return Neo4jRelationship( + start_node_id=neo4j_relationship.end_node_id, + end_node_id=neo4j_relationship.start_node_id, + type=neo4j_relationship.type, + properties=neo4j_relationship.properties, + ) + + +@pytest.mark.parametrize( + "relationship, valid_nodes, relationship_type, additional_relationship_types, patterns, additional_patterns, expected_relationship", + [ + # all good + ( + "neo4j_relationship", + { + "1": "Person", + "2": "Location", + }, + RelationshipType( + label="REL", + ), + True, + (("Person", "REL", "Location"),), + True, + "neo4j_relationship", + ), + # reverse relationship + ( + "neo4j_reversed_relationship", + { + "1": "Person", + "2": "Location", + }, + RelationshipType( + label="REL", + ), + True, + (("Person", "REL", "Location"),), + True, + "neo4j_relationship", + ), + # invalid type addition allowed + ( + "neo4j_relationship", + { + "1": "Person", + "2": "Location", + }, + None, + True, + (("Person", "REL", "Location"),), + True, + "neo4j_relationship", + ), + # invalid_type_addition_not_allowed + ( + "neo4j_relationship", + { + "1": "Person", + "2": "Location", + }, + None, + False, + (("Person", "REL", "Location"),), + True, + None, + ), + # invalid pattern, addition allowed + ( + "neo4j_relationship", + { + "1": "Person", + "2": "Location", + }, + RelationshipType( + label="REL", + ), + True, + (("Person", "REL", "Person"),), + True, + "neo4j_relationship", + ), + # invalid pattern, addition not allowed + ( + "neo4j_relationship", + { + "1": "Person", + "2": "Location", + }, + RelationshipType( + label="REL", + ), + True, + (("Person", "REL", "Person"),), + False, + None, + ), + ], +) +def test_graph_pruning_validate_relationship( + relationship: str, + valid_nodes: dict[str, str], + relationship_type: RelationshipType, + additional_relationship_types: bool, + patterns: tuple[tuple[str, str, str], ...], + additional_patterns: bool, + expected_relationship: str | None, + request: pytest.FixtureRequest, +) -> None: + relationship_obj = request.getfixturevalue(relationship) + expected_relationship_obj = ( + request.getfixturevalue(expected_relationship) + if expected_relationship + else None + ) + + pruner = GraphPruning() + assert ( + pruner._validate_relationship( + relationship_obj, + valid_nodes, + relationship_type, + additional_relationship_types, + patterns, + additional_patterns, + ) + == expected_relationship_obj + ) diff --git a/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py b/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py index 73bb50eeb..6560fda41 100644 --- a/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py +++ b/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py @@ -222,14 +222,15 @@ def test_simple_kg_pipeline_config_connections_from_pdf() -> None: perform_entity_resolution=False, ) connections = config._get_connections() - assert len(connections) == 6 + assert len(connections) == 7 expected_connections = [ ("pdf_loader", "splitter"), ("pdf_loader", "schema"), ("schema", "extractor"), ("splitter", "chunk_embedder"), ("chunk_embedder", "extractor"), - ("extractor", "writer"), + ("extractor", "pruner"), + ("pruner", "writer"), ] for actual, expected in zip(connections, expected_connections): assert (actual.start, actual.end) == expected @@ -241,12 +242,13 @@ def test_simple_kg_pipeline_config_connections_from_text() -> None: perform_entity_resolution=False, ) connections = config._get_connections() - assert len(connections) == 4 + assert len(connections) == 5 expected_connections = [ ("schema", "extractor"), ("splitter", "chunk_embedder"), ("chunk_embedder", "extractor"), - ("extractor", "writer"), + ("extractor", "pruner"), + ("pruner", "writer"), ] for actual, expected in zip(connections, expected_connections): assert (actual.start, actual.end) == expected @@ -258,14 +260,15 @@ def test_simple_kg_pipeline_config_connections_with_er() -> None: perform_entity_resolution=True, ) connections = config._get_connections() - assert len(connections) == 7 + assert len(connections) == 8 expected_connections = [ ("pdf_loader", "splitter"), ("pdf_loader", "schema"), ("schema", "extractor"), ("splitter", "chunk_embedder"), ("chunk_embedder", "extractor"), - ("extractor", "writer"), + ("extractor", "pruner"), + ("pruner", "writer"), ("writer", "resolver"), ] for actual, expected in zip(connections, expected_connections): From aeb3c8181c25a825af7137b8bc7771654d176aa4 Mon Sep 17 00:00:00 2001 From: estelle Date: Fri, 23 May 2025 16:51:17 +0200 Subject: [PATCH 06/25] Change return model to have some stats about pruned objects --- .../experimental/components/graph_pruning.py | 26 +++++++++++++++---- .../template_pipeline/simple_kg_builder.py | 2 +- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/graph_pruning.py b/src/neo4j_graphrag/experimental/components/graph_pruning.py index b1c825666..82f2d36d3 100644 --- a/src/neo4j_graphrag/experimental/components/graph_pruning.py +++ b/src/neo4j_graphrag/experimental/components/graph_pruning.py @@ -28,21 +28,37 @@ Neo4jNode, Neo4jRelationship, ) -from neo4j_graphrag.experimental.pipeline import Component +from neo4j_graphrag.experimental.pipeline import Component, DataModel logger = logging.getLogger(__name__) +class GraphPruningResult(DataModel): + graph: Neo4jGraph + metadata: dict[str, Any] = {} + + class GraphPruning(Component): @validate_call async def run( self, graph: Neo4jGraph, schema: Optional[GraphSchema] = None, - ) -> Neo4jGraph: - if schema is None: - return graph - return self._clean_graph(graph, schema) + ) -> GraphPruningResult: + if schema is not None: + new_graph = self._clean_graph(graph, schema) + else: + new_graph = graph + return GraphPruningResult( + graph=new_graph, + metadata={ + "stats": { + "pruned_node_count": len(graph.nodes) - len(new_graph.nodes), + "pruned_relationship_count": len(graph.relationships) + - len(new_graph.relationships), + } + }, + ) def _clean_graph( self, diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py index 34217433c..15823f679 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py @@ -346,7 +346,7 @@ def _get_connections(self) -> list[ConnectionDefinition]: start="pruner", end="writer", input_config={ - "graph": "pruner", + "graph": "pruner.graph", }, ) ) From 53b0cdf334da131764cba662eaff692d27adf005 Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 26 May 2025 11:07:42 +0200 Subject: [PATCH 07/25] We need to filter out relationships if start/end node is not valid in all cases (additional_relationship_types or not) --- .../experimental/components/graph_pruning.py | 45 +++--- .../components/test_graph_pruning.py | 131 +++++++++++++++--- 2 files changed, 139 insertions(+), 37 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/graph_pruning.py b/src/neo4j_graphrag/experimental/components/graph_pruning.py index 82f2d36d3..068c955d8 100644 --- a/src/neo4j_graphrag/experimental/components/graph_pruning.py +++ b/src/neo4j_graphrag/experimental/components/graph_pruning.py @@ -147,39 +147,52 @@ def _validate_relationship( patterns: tuple[tuple[str, str, str], ...], additional_patterns: bool, ) -> Optional[Neo4jRelationship]: - if relationship_type is None: - if additional_relationship_types: - return rel - else: - logger.debug(f"PRUNING:: {rel} as {rel.type} is not in the schema") - return None - + # validate start/end node IDs are valid nodes if rel.start_node_id not in valid_nodes or rel.end_node_id not in valid_nodes: logger.debug( - f"PRUNING:: {rel} as one of {rel.start_node_id} or {rel.end_node_id} is not in the graph" + f"PRUNING:: {rel} as one of {rel.start_node_id} or {rel.end_node_id} is not a valid node" ) return None + # validate relationship type + if relationship_type is None: + if not additional_relationship_types: + logger.debug( + f"PRUNING:: {rel} as {rel.type} is not in the schema and `additional_relationship_types` is False" + ) + return None + + # validate pattern start_label = valid_nodes[rel.start_node_id] end_label = valid_nodes[rel.end_node_id] tuple_valid = True reverse_tuple_valid = False if patterns: tuple_valid = (start_label, rel.type, end_label) in patterns + # try to reverse relationship only if initial order is not valid reverse_tuple_valid = ( - end_label, - rel.type, - start_label, - ) in patterns + not tuple_valid + and ( + end_label, + rel.type, + start_label, + ) + in patterns + ) if not tuple_valid and not reverse_tuple_valid and not additional_patterns: logger.debug(f"PRUNING:: {rel} not in the allowed patterns") return None - allowed_props = relationship_type.properties - filtered_props = self._enforce_properties( - rel.properties, allowed_props, relationship_type.additional_properties - ) + # filter properties if we can + if relationship_type is not None: + allowed_props = relationship_type.properties + filtered_props = self._enforce_properties( + rel.properties, allowed_props, relationship_type.additional_properties + ) + else: + filtered_props = rel.properties + return Neo4jRelationship( start_node_id=rel.end_node_id if reverse_tuple_valid else rel.start_node_id, end_node_id=rel.start_node_id if reverse_tuple_valid else rel.end_node_id, diff --git a/tests/unit/experimental/components/test_graph_pruning.py b/tests/unit/experimental/components/test_graph_pruning.py index c86653842..c1acb889b 100644 --- a/tests/unit/experimental/components/test_graph_pruning.py +++ b/tests/unit/experimental/components/test_graph_pruning.py @@ -13,16 +13,25 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any +from unittest.mock import patch, Mock import pytest -from neo4j_graphrag.experimental.components.graph_pruning import GraphPruning +from neo4j_graphrag.experimental.components.graph_pruning import ( + GraphPruning, + GraphPruningResult, +) from neo4j_graphrag.experimental.components.schema import ( NodeType, PropertyType, RelationshipType, + GraphSchema, +) +from neo4j_graphrag.experimental.components.types import ( + Neo4jNode, + Neo4jRelationship, + Neo4jGraph, ) -from neo4j_graphrag.experimental.components.types import Neo4jNode, Neo4jRelationship @pytest.mark.parametrize( @@ -170,7 +179,9 @@ def neo4j_relationship() -> Neo4jRelationship: @pytest.fixture -def neo4j_reversed_relationship(neo4j_relationship: Neo4jRelationship) -> Neo4jRelationship: +def neo4j_reversed_relationship( + neo4j_relationship: Neo4jRelationship, +) -> Neo4jRelationship: return Neo4jRelationship( start_node_id=neo4j_relationship.end_node_id, end_node_id=neo4j_relationship.start_node_id, @@ -184,7 +195,22 @@ def neo4j_reversed_relationship(neo4j_relationship: Neo4jRelationship) -> Neo4jR [ # all good ( - "neo4j_relationship", + "neo4j_relationship", # relationship, + { # valid_nodes + "1": "Person", + "2": "Location", + }, + RelationshipType( # relationship_type + label="REL", + ), + True, # additional_relationship_types + (("Person", "REL", "Location"),), # patterns + True, # additional_patterns + "neo4j_relationship", # expected_relationship + ), + # reverse relationship + ( + "neo4j_reversed_relationship", { "1": "Person", "2": "Location", @@ -192,25 +218,25 @@ def neo4j_reversed_relationship(neo4j_relationship: Neo4jRelationship) -> Neo4jR RelationshipType( label="REL", ), - True, + True, # additional_relationship_types (("Person", "REL", "Location"),), - True, + True, # additional_patterns "neo4j_relationship", ), - # reverse relationship + # invalid start node ID ( "neo4j_reversed_relationship", { - "1": "Person", + "10": "Person", "2": "Location", }, RelationshipType( label="REL", ), - True, + True, # additional_relationship_types (("Person", "REL", "Location"),), - True, - "neo4j_relationship", + True, # additional_patterns + None, ), # invalid type addition allowed ( @@ -219,11 +245,23 @@ def neo4j_reversed_relationship(neo4j_relationship: Neo4jRelationship) -> Neo4jR "1": "Person", "2": "Location", }, - None, - True, + None, # relationship_type + True, # additional_relationship_types (("Person", "REL", "Location"),), - True, + True, # additional_patterns + "neo4j_relationship", + ), + # invalid type addition allowed but invalid node ID + ( "neo4j_relationship", + { + "1": "Person", + }, + None, # relationship_type + True, # additional_relationship_types + (("Person", "REL", "Location"),), + True, # additional_patterns + None, ), # invalid_type_addition_not_allowed ( @@ -232,10 +270,10 @@ def neo4j_reversed_relationship(neo4j_relationship: Neo4jRelationship) -> Neo4jR "1": "Person", "2": "Location", }, - None, - False, + None, # relationship_type + False, # additional_relationship_types (("Person", "REL", "Location"),), - True, + True, # additional_patterns None, ), # invalid pattern, addition allowed @@ -248,9 +286,9 @@ def neo4j_reversed_relationship(neo4j_relationship: Neo4jRelationship) -> Neo4jR RelationshipType( label="REL", ), - True, + True, # additional_relationship_types (("Person", "REL", "Person"),), - True, + True, # additional_patterns "neo4j_relationship", ), # invalid pattern, addition not allowed @@ -263,9 +301,9 @@ def neo4j_reversed_relationship(neo4j_relationship: Neo4jRelationship) -> Neo4jR RelationshipType( label="REL", ), - True, + True, # additional_relationship_types (("Person", "REL", "Person"),), - False, + False, # additional_patterns None, ), ], @@ -299,3 +337,54 @@ def test_graph_pruning_validate_relationship( ) == expected_relationship_obj ) + + +@patch("neo4j_graphrag.experimental.components.graph_pruning.GraphPruning._clean_graph") +@pytest.mark.asyncio +async def test_graph_pruning_run_happy_path( + mock_clean_graph: Mock, node_type_required_name +) -> None: + initial_graph = Neo4jGraph( + nodes=[Neo4jNode(id="1", label="Person"), Neo4jNode(id="2", label="Location")], + ) + schema = GraphSchema(node_types=(node_type_required_name,)) + cleaned_graph = Neo4jGraph(nodes=[Neo4jNode(id="1", label="Person")]) + mock_clean_graph.return_value = cleaned_graph + pruner = GraphPruning() + pruner_result = await pruner.run( + graph=initial_graph, + schema=schema, + ) + assert isinstance(pruner_result, GraphPruningResult) + assert pruner_result.graph == cleaned_graph + mock_clean_graph.assert_called_once_with(initial_graph, schema) + + +@pytest.mark.asyncio +async def test_graph_pruning_run_no_schema() -> None: + initial_graph = Neo4jGraph(nodes=[Neo4jNode(id="1", label="Person")]) + pruner = GraphPruning() + pruner_result = await pruner.run( + graph=initial_graph, + schema=None, + ) + assert isinstance(pruner_result, GraphPruningResult) + assert pruner_result.graph == initial_graph + + +@patch( + "neo4j_graphrag.experimental.components.graph_pruning.GraphPruning._enforce_nodes" +) +def test_graph_pruning_clean_graph( + mock_enforce_nodes: Mock, +) -> None: + mock_enforce_nodes.return_value = [] + initial_graph = Neo4jGraph(nodes=[Neo4jNode(id="1", label="Person")]) + schema = GraphSchema(node_types=()) + pruner = GraphPruning() + cleaned_graph = pruner._clean_graph(initial_graph, schema) + assert cleaned_graph == Neo4jGraph() + mock_enforce_nodes.assert_called_once_with( + [Neo4jNode(id="1", label="Person")], + schema, + ) From fc1d65a7f07ca4a50c456e996c75423a9ccdcf9f Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 26 May 2025 11:27:16 +0200 Subject: [PATCH 08/25] Do not filter based on patterns if relationship type not in schema and additional_relationship_types is allowed --- src/neo4j_graphrag/experimental/components/graph_pruning.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/graph_pruning.py b/src/neo4j_graphrag/experimental/components/graph_pruning.py index 068c955d8..89624634e 100644 --- a/src/neo4j_graphrag/experimental/components/graph_pruning.py +++ b/src/neo4j_graphrag/experimental/components/graph_pruning.py @@ -163,11 +163,11 @@ def _validate_relationship( return None # validate pattern - start_label = valid_nodes[rel.start_node_id] - end_label = valid_nodes[rel.end_node_id] tuple_valid = True reverse_tuple_valid = False - if patterns: + if patterns and relationship_type: + start_label = valid_nodes[rel.start_node_id] + end_label = valid_nodes[rel.end_node_id] tuple_valid = (start_label, rel.type, end_label) in patterns # try to reverse relationship only if initial order is not valid reverse_tuple_valid = ( From df22c7f50b384847e92f20dce80a8ea524547296 Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 26 May 2025 11:27:25 +0200 Subject: [PATCH 09/25] Raise proper error type --- src/neo4j_graphrag/experimental/components/schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index 6937e2b79..75cdc67b8 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -173,7 +173,7 @@ def check_schema(self) -> Self: f"Relationship type '{relation}' is not defined in the provided relationship_types." ) if entity2 not in self._node_type_index: - raise ValidationError( + raise ValueError( f"Node type '{entity2}' is not defined in the provided node_types." ) From c9beb953e88291fc9e8610e633a7797bb31e759e Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 26 May 2025 11:28:51 +0200 Subject: [PATCH 10/25] Ruff/mypy --- .../pipeline/config/template_pipeline/simple_kg_builder.py | 4 +++- tests/unit/experimental/components/test_graph_pruning.py | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py index 15823f679..e1d3af5a0 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py @@ -223,7 +223,9 @@ def _process_schema_with_precedence( if self.relations is not None else None ) - patterns = tuple(self.potential_schema) if self.potential_schema else None + patterns = ( + tuple(self.potential_schema) if self.potential_schema else tuple() + ) return node_types, relationship_types, patterns diff --git a/tests/unit/experimental/components/test_graph_pruning.py b/tests/unit/experimental/components/test_graph_pruning.py index c1acb889b..61e6f692e 100644 --- a/tests/unit/experimental/components/test_graph_pruning.py +++ b/tests/unit/experimental/components/test_graph_pruning.py @@ -342,7 +342,8 @@ def test_graph_pruning_validate_relationship( @patch("neo4j_graphrag.experimental.components.graph_pruning.GraphPruning._clean_graph") @pytest.mark.asyncio async def test_graph_pruning_run_happy_path( - mock_clean_graph: Mock, node_type_required_name + mock_clean_graph: Mock, + node_type_required_name: NodeType, ) -> None: initial_graph = Neo4jGraph( nodes=[Neo4jNode(id="1", label="Person"), Neo4jNode(id="2", label="Location")], From 81023d215ff4365147ee76d407b0fd92f6129333 Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 26 May 2025 12:02:01 +0200 Subject: [PATCH 11/25] Add e2e test for graph pruning component --- .../test_graph_pruning_component_e2e.py | 527 ++++++++++++++++++ 1 file changed, 527 insertions(+) create mode 100644 tests/e2e/experimental/test_graph_pruning_component_e2e.py diff --git a/tests/e2e/experimental/test_graph_pruning_component_e2e.py b/tests/e2e/experimental/test_graph_pruning_component_e2e.py new file mode 100644 index 000000000..27a7aa3f3 --- /dev/null +++ b/tests/e2e/experimental/test_graph_pruning_component_e2e.py @@ -0,0 +1,527 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any + +import pytest + +from neo4j_graphrag.experimental.components.graph_pruning import GraphPruning +from neo4j_graphrag.experimental.components.schema import GraphSchema +from neo4j_graphrag.experimental.components.types import ( + Neo4jGraph, + Neo4jNode, + Neo4jRelationship, +) + + +@pytest.fixture +def extracted_graph() -> Neo4jGraph: + """This is the graph to be pruned in all the below tests, + using different schema configuration. + """ + return Neo4jGraph( + nodes=[ + Neo4jNode( + id="1", + label="Person", + properties={ + "name": "John Doe", + }, + ), + Neo4jNode( + id="2", + label="Person", + properties={ + "height": 180, + }, + ), + Neo4jNode( + id="3", + label="Person", + properties={ + "name": "Jane Doe", + "weight": 90, + }, + ), + Neo4jNode( + id="10", + label="Organization", + properties={ + "name": "Azerty Inc.", + "created": 1999, + }, + ), + ], + relationships=[ + Neo4jRelationship( + start_node_id="1", + end_node_id="2", + type="KNOWS", + properties={"firstMetIn": 2025}, + ), + Neo4jRelationship( + start_node_id="1", + end_node_id="3", + type="KNOWS", + ), + Neo4jRelationship( + start_node_id="1", + end_node_id="2", + type="MANAGES", + ), + Neo4jRelationship( + start_node_id="1", + end_node_id="10", + type="MANAGES", + ), + Neo4jRelationship( + start_node_id="1", + end_node_id="10", + type="WORKS_FOR", + ), + ], + ) + + +async def _test( + extracted_graph: Neo4jGraph, schema_dict: dict[str, Any], expected_graph: Neo4jGraph +) -> None: + schema = GraphSchema.model_validate(schema_dict) + pruner = GraphPruning() + res = await pruner.run(extracted_graph, schema) + assert res.graph == expected_graph + + +@pytest.mark.asyncio +async def test_graph_pruning_loose(extracted_graph: Neo4jGraph) -> None: + """Loose schema: + - no required properties + - all additional* allowed + + => we keep everything from the extracted graph + """ + schema_dict = { + "node_types": [ + { + "label": "Person", + "properties": [ + {"name": "name", "type": "STRING"}, + {"name": "height", "type": "INTEGER"}, + ], + } + ], + "relationship_types": [ + { + "label": "KNOWS", + } + ], + "patterns": [ + ("Person", "KNOWS", "Person"), + ], + } + await _test(extracted_graph, schema_dict, extracted_graph) + + +@pytest.mark.asyncio +async def test_graph_pruning_missing_required_property( + extracted_graph: Neo4jGraph, +) -> None: + """Person node type has a required 'name' property: + - extracted nodes without this property are pruned + - any relationship tied to this node is also pruned + """ + schema_dict = { + "node_types": [ + { + "label": "Person", + "properties": [ + { + "name": "name", + "type": "STRING", + "required": True, + }, + {"name": "height", "type": "INTEGER"}, + ], + } + ], + "relationship_types": [ + { + "label": "KNOWS", + } + ], + "patterns": [ + ("Person", "KNOWS", "Person"), + ], + } + filtered_graph = Neo4jGraph( + nodes=[ + Neo4jNode( + id="1", + label="Person", + properties={ + "name": "John Doe", + }, + ), + # do not have the required "name" property + # Neo4jNode( + # id="2", + # label="Person", + # properties={ + # "height": 180, + # } + # ), + Neo4jNode( + id="3", + label="Person", + properties={ + "name": "Jane Doe", + "weight": 90, + }, + ), + Neo4jNode( + id="10", + label="Organization", + properties={ + "name": "Azerty Inc.", + "created": 1999, + }, + ), + ], + relationships=[ + # node "2" was pruned + # Neo4jRelationship( + # start_node_id="1", + # end_node_id="2", + # type="KNOWS", + # properties={"firstMetIn": 2025}, + # ), + Neo4jRelationship( + start_node_id="1", + end_node_id="3", + type="KNOWS", + ), + # node "2" was pruned + # Neo4jRelationship( + # start_node_id="1", + # end_node_id="2", + # type="MANAGES", + # ), + Neo4jRelationship( + start_node_id="1", + end_node_id="10", + type="MANAGES", + ), + Neo4jRelationship( + start_node_id="1", + end_node_id="10", + type="WORKS_FOR", + ), + ], + ) + await _test(extracted_graph, schema_dict, filtered_graph) + + +@pytest.mark.asyncio +async def test_graph_pruning_strict_properties_and_node_types( + extracted_graph: Neo4jGraph, +) -> None: + """Additional properties on Person nodes are not allowed. + Additional node types are not allowed. + + => we prune "Organization" nodes (not in schema) + and the "weight" property that was extracted for some persons. + """ + schema_dict = { + "node_types": [ + { + "label": "Person", + "properties": [ + { + "name": "name", + "type": "STRING", + }, + {"name": "height", "type": "INTEGER"}, + ], + "additional_properties": False, + } + ], + "relationship_types": [ + { + "label": "KNOWS", + } + ], + "patterns": [ + ("Person", "KNOWS", "Person"), + ], + "additional_node_types": False, + } + filtered_graph = Neo4jGraph( + nodes=[ + Neo4jNode( + id="1", + label="Person", + properties={ + "name": "John Doe", + }, + ), + Neo4jNode( + id="2", + label="Person", + properties={ + "height": 180, + }, + ), + Neo4jNode( + id="3", + label="Person", + properties={ + "name": "Jane Doe", + # weight not in listed properties + # "weight": 90, + }, + ), + # label "Organization" not in schema + # Neo4jNode( + # id="10", + # label="Organization", + # properties={ + # "name": "Azerty Inc.", + # "created": 1999, + # } + # ), + ], + relationships=[ + Neo4jRelationship( + start_node_id="1", + end_node_id="2", + type="KNOWS", + properties={"firstMetIn": 2025}, + ), + Neo4jRelationship( + start_node_id="1", + end_node_id="3", + type="KNOWS", + ), + Neo4jRelationship( + start_node_id="1", + end_node_id="2", + type="MANAGES", + ), + # node "10" was pruned (label not allowed) + # Neo4jRelationship( + # start_node_id="1", + # end_node_id="10", + # type="MANAGES", + # ), + # Neo4jRelationship( + # start_node_id="1", + # end_node_id="10", + # type="WORKS_FOR", + # ) + ], + ) + await _test(extracted_graph, schema_dict, filtered_graph) + + +@pytest.mark.asyncio +async def test_graph_pruning_strict_relationship_types(extracted_graph: Neo4jGraph): + """Additional relationship types not allowed + + => we prune all MANAGES and WORKS_FOR extracted relationships + """ + schema_dict = { + "node_types": [ + { + "label": "Person", + "properties": [ + { + "name": "name", + "type": "STRING", + }, + {"name": "height", "type": "INTEGER"}, + ], + } + ], + "relationship_types": [ + { + "label": "KNOWS", + } + ], + "patterns": [ + ("Person", "KNOWS", "Person"), + ], + "additional_relationship_types": False, + } + filtered_graph = Neo4jGraph( + nodes=[ + Neo4jNode( + id="1", + label="Person", + properties={ + "name": "John Doe", + }, + ), + Neo4jNode( + id="2", + label="Person", + properties={ + "height": 180, + }, + ), + Neo4jNode( + id="3", + label="Person", + properties={ + "name": "Jane Doe", + "weight": 90, + }, + ), + Neo4jNode( + id="10", + label="Organization", + properties={ + "name": "Azerty Inc.", + "created": 1999, + }, + ), + ], + relationships=[ + Neo4jRelationship( + start_node_id="1", + end_node_id="2", + type="KNOWS", + properties={"firstMetIn": 2025}, + ), + Neo4jRelationship( + start_node_id="1", + end_node_id="3", + type="KNOWS", + ), + # MANAGES not in allowed relationship types + # Neo4jRelationship( + # start_node_id="1", + # end_node_id="2", + # type="MANAGES", + # ), + # WORKS_FOR not in allowed relationship types + # Neo4jRelationship( + # start_node_id="1", + # end_node_id="10", + # type="WORKS_FOR", + # ) + ], + ) + await _test(extracted_graph, schema_dict, filtered_graph) + + +@pytest.mark.asyncio +async def test_graph_pruning_strict_patterns(extracted_graph: Neo4jGraph): + """Additional patterns not allowed: + + - MANAGES: it's a known relationship type but without any pattern, it's pruned + - WORKS_FOR: it's not a known relationship type, and additional_relationship_types is allowed + so we keep it. + """ + # - no additional patterns allowed + schema_dict = { + "node_types": [ + { + "label": "Person", + "properties": [ + { + "name": "name", + "type": "STRING", + }, + {"name": "height", "type": "INTEGER"}, + ], + }, + { + "label": "Organization", + }, + ], + "relationship_types": [ + { + "label": "KNOWS", + }, + { + "label": "MANAGES", + }, + ], + "patterns": ( + ("Person", "KNOWS", "Person"), + ("Person", "KNOWS", "Organization"), + ), + "additional_patterns": False, + } + filtered_graph = Neo4jGraph( + nodes=[ + Neo4jNode( + id="1", + label="Person", + properties={ + "name": "John Doe", + }, + ), + Neo4jNode( + id="2", + label="Person", + properties={ + "height": 180, + }, + ), + Neo4jNode( + id="3", + label="Person", + properties={ + "name": "Jane Doe", + "weight": 90, + }, + ), + Neo4jNode( + id="10", + label="Organization", + properties={ + "name": "Azerty Inc.", + "created": 1999, + }, + ), + ], + relationships=[ + Neo4jRelationship( + start_node_id="1", + end_node_id="2", + type="KNOWS", + properties={"firstMetIn": 2025}, + ), + Neo4jRelationship( + start_node_id="1", + end_node_id="3", + type="KNOWS", + ), + # invalid pattern (person, manages, person) + # Neo4jRelationship( + # start_node_id="1", + # end_node_id="2", + # type="MANAGES", + # ), + # not a valid pattern but WORKS_FOR + # not in relationship types, so we keep it + Neo4jRelationship( + start_node_id="1", + end_node_id="10", + type="WORKS_FOR", + ), + ], + ) + await _test(extracted_graph, schema_dict, filtered_graph) From 80bd0372802d7f809bf5a7bff556c1d35c90d6b9 Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 26 May 2025 12:06:30 +0200 Subject: [PATCH 12/25] Mypy --- tests/e2e/experimental/test_graph_pruning_component_e2e.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/e2e/experimental/test_graph_pruning_component_e2e.py b/tests/e2e/experimental/test_graph_pruning_component_e2e.py index 27a7aa3f3..74a9a84ae 100644 --- a/tests/e2e/experimental/test_graph_pruning_component_e2e.py +++ b/tests/e2e/experimental/test_graph_pruning_component_e2e.py @@ -335,7 +335,9 @@ async def test_graph_pruning_strict_properties_and_node_types( @pytest.mark.asyncio -async def test_graph_pruning_strict_relationship_types(extracted_graph: Neo4jGraph): +async def test_graph_pruning_strict_relationship_types( + extracted_graph: Neo4jGraph, +) -> None: """Additional relationship types not allowed => we prune all MANAGES and WORKS_FOR extracted relationships @@ -426,7 +428,7 @@ async def test_graph_pruning_strict_relationship_types(extracted_graph: Neo4jGra @pytest.mark.asyncio -async def test_graph_pruning_strict_patterns(extracted_graph: Neo4jGraph): +async def test_graph_pruning_strict_patterns(extracted_graph: Neo4jGraph) -> None: """Additional patterns not allowed: - MANAGES: it's a known relationship type but without any pattern, it's pruned From 6c666c0bc249f6bafed0326edc829af6606ef4b9 Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 26 May 2025 12:24:01 +0200 Subject: [PATCH 13/25] Update changelog and doc --- CHANGELOG.md | 2 +- docs/source/user_guide_kg_builder.rst | 73 +++++++++++++++------------ 2 files changed, 41 insertions(+), 34 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a8bad9486..f4bd19157 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,7 +15,7 @@ #### Strict mode -- Strict mode in `SimpleKGPipeline`: now properties and relationships are pruned only if they are defined in the input schema. +- Strict mode in `SimpleKGPipeline`: the `enforce_schema` option is removed and replaced by a schema-driven pruning. #### Schema definition diff --git a/docs/source/user_guide_kg_builder.rst b/docs/source/user_guide_kg_builder.rst index 6af5e8cb3..250231ce6 100644 --- a/docs/source/user_guide_kg_builder.rst +++ b/docs/source/user_guide_kg_builder.rst @@ -145,7 +145,6 @@ They are also accessible via the `SimpleKGPipeline` interface. # ... prompt_template="", lexical_graph_config=my_config, - enforce_schema="STRICT" on_error="RAISE", # ... ) @@ -878,38 +877,6 @@ It can be used in this way: The LLM to use can be customized, the only constraint is that it obeys the :ref:`LLMInterface `. -Schema Enforcement Behaviour ----------------------------- -.. _schema-enforcement-behaviour: - -By default, even if a schema is provided to guide the LLM in the entity and relation extraction, the LLM response is not validated against that schema. -This behaviour can be changed by using the `enforce_schema` flag in the `LLMEntityRelationExtractor` constructor: - -.. code:: python - - from neo4j_graphrag.experimental.components.entity_relation_extractor import LLMEntityRelationExtractor - from neo4j_graphrag.experimental.components.types import SchemaEnforcementMode - - extractor = LLMEntityRelationExtractor( - # ... - enforce_schema=SchemaEnforcementMode.STRICT, - ) - -In this scenario, any extracted node/relation/property that is not part of the provided schema will be pruned. -Any relation whose start node or end node does not conform to the provided tuple in `potential_schema` will be pruned. -If a relation start/end nodes are valid but the direction is incorrect, the latter will be inverted. -If a node is left with no properties, it will be also pruned. - -.. note:: - - If the input schema lacks a certain type of information, pruning is skipped. - For example, if an entity is defined only by a label and has no properties, - property pruning is not performed and all properties returned by the LLM are kept. - - -.. warning:: - - Note that if the schema enforcement mode is on but the schema is not provided, no schema enforcement will be applied. Error Behaviour --------------- @@ -1017,6 +984,46 @@ If more customization is needed, it is possible to subclass the `EntityRelationE See :ref:`entityrelationextractor`. +Schema Guidance and Graph Filtering +=================================== + +The provided schema serves as a guiding structure for the language model during graph construction. However, it does not impose strict constraints on the model's output. As a result, the model may generate additional node labels, relationship types, or properties that are not explicitly defined in the schema. + +By default, all extracted elements — including nodes, relationships, and properties — are retained in the constructed graph. This behavior can be configured using the following schema options: +(see :ref:`graphschema`) + + +Configuration Options +--------------------- + +- **Required Properties** + Required properties may be specified at the node or relationship type level. Any extracted node or relationship missing one or more of its required properties will be pruned from the graph. + +- **Additional Properties** *(default: False)* + This node- or relationship-level option determines whether extra properties not listed in the schema should be retained. + - If set to ``False`` (default), all extracted properties are retained. + - If set to ``True``, only the properties defined in the schema are preserved; all others are removed. + +- **Additional Node Types** *(default: True)* + This schema-level option specifies whether node types not defined in the schema are included in the graph. + - If set to ``True`` (default), such node types are retained. + - If set to ``False``, nodes with undefined types are removed. + +- **Additional Relationship Types** *(default: True)* + This schema-level option specifies whether relationship types not defined in the schema are included in the graph. + - If set to ``True`` (default), such relationships are retained. + - If set to ``False``, relationships with undefined types are removed. + +- **Additional Patterns** *(default: True)* + This schema-level option determines whether relationship patterns not explicitly listed in the schema are allowed. + - If set to ``True`` (default), all patterns are retained. + - If set to ``False``, only patterns defined in the schema are kept. + +.. note:: + + If ``additional_patterns`` is set to ``False`` but ``additional_relationships`` is ``True``, extra relationships are still retained as long as they are part of patterns included in the schema. + + .. _kg-writer-section: Knowledge Graph Writer From bddcd71934c5651af2d0466ccad2c1875539d9cc Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 26 May 2025 12:24:07 +0200 Subject: [PATCH 14/25] Mypy --- tests/unit/experimental/components/test_graph_pruning.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/experimental/components/test_graph_pruning.py b/tests/unit/experimental/components/test_graph_pruning.py index 61e6f692e..a6eedb4d5 100644 --- a/tests/unit/experimental/components/test_graph_pruning.py +++ b/tests/unit/experimental/components/test_graph_pruning.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations from typing import Any from unittest.mock import patch, Mock From 71893bd78ca48ca704315b483d6715bc26bf7d9b Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 26 May 2025 14:01:29 +0200 Subject: [PATCH 15/25] ChatGPT was wrong --- docs/source/user_guide_kg_builder.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/user_guide_kg_builder.rst b/docs/source/user_guide_kg_builder.rst index 250231ce6..2795a78cb 100644 --- a/docs/source/user_guide_kg_builder.rst +++ b/docs/source/user_guide_kg_builder.rst @@ -1021,7 +1021,7 @@ Configuration Options .. note:: - If ``additional_patterns`` is set to ``False`` but ``additional_relationships`` is ``True``, extra relationships are still retained as long as they are part of patterns included in the schema. + If ``additional_patterns`` is set to ``False`` but ``additional_relationships`` is ``True``, extra relationships are still retained even if not part of a pattern listed in the schema. .. _kg-writer-section: From 7d70aefe48c8393e972e681f0a629a8bb32c28a4 Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 27 May 2025 11:35:01 +0200 Subject: [PATCH 16/25] Change edge case behaviour --- docs/source/user_guide_kg_builder.rst | 32 +++---- .../experimental/components/schema.py | 28 +++++- .../test_graph_pruning_component_e2e.py | 93 ------------------- .../experimental/components/test_schema.py | 44 ++++++++- 4 files changed, 79 insertions(+), 118 deletions(-) diff --git a/docs/source/user_guide_kg_builder.rst b/docs/source/user_guide_kg_builder.rst index 2795a78cb..c5c2096db 100644 --- a/docs/source/user_guide_kg_builder.rst +++ b/docs/source/user_guide_kg_builder.rst @@ -74,9 +74,8 @@ Graph Schema ------------ It is possible to guide the LLM by supplying a list of node and relationship types, -and instructions on how to connect them (patterns). However, note that the extracted graph -may not fully adhere to these guidelines unless schema enforcement is enabled -(see :ref:`Schema Enforcement Behaviour`). Node and relationship types can be represented +and instructions on how to connect them (patterns). +Node and relationship types can be represented as either simple strings (for their labels) or dictionaries. If using a dictionary, it must include a label key and can optionally include description and properties keys, as shown below: @@ -90,7 +89,7 @@ as shown below: # such as a description: {"label": "House", "description": "Family the person belongs to"}, # or a list of properties the LLM will try to attach to the entity: - {"label": "Planet", "properties": [{"name": "weather", "type": "STRING"}]}, + {"label": "Planet", "properties": [{"name": "name", "type": "STRING", "required": True}, {"name": "weather", "type": "STRING"}]}, ] # same thing for relationships: RELATIONSHIP_TYPES = [ @@ -124,7 +123,8 @@ This schema information can be provided to the `SimpleKGBuilder` as demonstrated schema={ "node_types": NODE_TYPES, "relationship_types": RELATIONSHIP_TYPES, - "patterns": PATTERNS + "patterns": PATTERNS, + "additional_node_types": False, }, # ... ) @@ -1001,27 +1001,27 @@ Configuration Options - **Additional Properties** *(default: False)* This node- or relationship-level option determines whether extra properties not listed in the schema should be retained. - - If set to ``False`` (default), all extracted properties are retained. - - If set to ``True``, only the properties defined in the schema are preserved; all others are removed. + + - If set to ``False`` (default), all extracted properties are retained. + - If set to ``True``, only the properties defined in the schema are preserved; all others are removed. - **Additional Node Types** *(default: True)* This schema-level option specifies whether node types not defined in the schema are included in the graph. - - If set to ``True`` (default), such node types are retained. - - If set to ``False``, nodes with undefined types are removed. + + - If set to ``True`` (default), such node types are retained. + - If set to ``False``, nodes with undefined types are removed. - **Additional Relationship Types** *(default: True)* This schema-level option specifies whether relationship types not defined in the schema are included in the graph. - - If set to ``True`` (default), such relationships are retained. - - If set to ``False``, relationships with undefined types are removed. + + - If set to ``True`` (default), such relationships are retained. + - If set to ``False``, relationships with undefined types are removed. - **Additional Patterns** *(default: True)* This schema-level option determines whether relationship patterns not explicitly listed in the schema are allowed. - - If set to ``True`` (default), all patterns are retained. - - If set to ``False``, only patterns defined in the schema are kept. - -.. note:: - If ``additional_patterns`` is set to ``False`` but ``additional_relationships`` is ``True``, extra relationships are still retained even if not part of a pattern listed in the schema. + - If set to ``True`` (default), all patterns are retained. + - If set to ``False``, only patterns defined in the schema are kept. **Note** `additional_relationship_types` must also be `False`. .. _kg-writer-section: diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index 75cdc67b8..fd950db13 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -94,10 +94,9 @@ def validate_input_if_string(cls, data: EntityInputType) -> EntityInputType: @model_validator(mode="after") def validate_additional_properties(self) -> Self: if len(self.properties) == 0 and not self.additional_properties: - warnings.warn( + raise ValueError( "Using `additional_properties=False` with no defined " "properties will cause the model to be pruned during graph cleaning.", - UserWarning, ) return self @@ -122,15 +121,26 @@ def validate_input_if_string(cls, data: RelationInputType) -> RelationInputType: @model_validator(mode="after") def validate_additional_properties(self) -> Self: if len(self.properties) == 0 and not self.additional_properties: - warnings.warn( + raise ValueError( "Using `additional_properties=False` with no defined " "properties will cause the model to be pruned during graph cleaning.", - UserWarning, ) return self class GraphSchema(DataModel): + """This model represents the expected + node and relationship types in the graph. + + It is used both for guiding the LLM in the entity and relation + extraction component, and for cleaning the extracted graph in a + post-processing step. + + .. warning:: + + This model is immutable. + """ + node_types: Tuple[NodeType, ...] relationship_types: Tuple[RelationshipType, ...] = tuple() patterns: Tuple[Tuple[str, str, str], ...] = tuple() @@ -147,7 +157,7 @@ class GraphSchema(DataModel): ) @model_validator(mode="after") - def check_schema(self) -> Self: + def validate_patterns_against_node_and_rel_types(self) -> Self: self._node_type_index = {node.label: node for node in self.node_types} self._relationship_type_index = ( {r.label: r for r in self.relationship_types} @@ -179,6 +189,14 @@ def check_schema(self) -> Self: return self + @model_validator(mode="after") + def validate_additional_parameters(self) -> Self: + if self.additional_patterns and not self.additional_relationship_types: + raise ValueError( + "`additional_relationship_types` must be set to True when using `additional_patterns=True`" + ) + return self + def node_type_from_label(self, label: str) -> Optional[NodeType]: return self._node_type_index.get(label) diff --git a/tests/e2e/experimental/test_graph_pruning_component_e2e.py b/tests/e2e/experimental/test_graph_pruning_component_e2e.py index 74a9a84ae..73804770e 100644 --- a/tests/e2e/experimental/test_graph_pruning_component_e2e.py +++ b/tests/e2e/experimental/test_graph_pruning_component_e2e.py @@ -334,99 +334,6 @@ async def test_graph_pruning_strict_properties_and_node_types( await _test(extracted_graph, schema_dict, filtered_graph) -@pytest.mark.asyncio -async def test_graph_pruning_strict_relationship_types( - extracted_graph: Neo4jGraph, -) -> None: - """Additional relationship types not allowed - - => we prune all MANAGES and WORKS_FOR extracted relationships - """ - schema_dict = { - "node_types": [ - { - "label": "Person", - "properties": [ - { - "name": "name", - "type": "STRING", - }, - {"name": "height", "type": "INTEGER"}, - ], - } - ], - "relationship_types": [ - { - "label": "KNOWS", - } - ], - "patterns": [ - ("Person", "KNOWS", "Person"), - ], - "additional_relationship_types": False, - } - filtered_graph = Neo4jGraph( - nodes=[ - Neo4jNode( - id="1", - label="Person", - properties={ - "name": "John Doe", - }, - ), - Neo4jNode( - id="2", - label="Person", - properties={ - "height": 180, - }, - ), - Neo4jNode( - id="3", - label="Person", - properties={ - "name": "Jane Doe", - "weight": 90, - }, - ), - Neo4jNode( - id="10", - label="Organization", - properties={ - "name": "Azerty Inc.", - "created": 1999, - }, - ), - ], - relationships=[ - Neo4jRelationship( - start_node_id="1", - end_node_id="2", - type="KNOWS", - properties={"firstMetIn": 2025}, - ), - Neo4jRelationship( - start_node_id="1", - end_node_id="3", - type="KNOWS", - ), - # MANAGES not in allowed relationship types - # Neo4jRelationship( - # start_node_id="1", - # end_node_id="2", - # type="MANAGES", - # ), - # WORKS_FOR not in allowed relationship types - # Neo4jRelationship( - # start_node_id="1", - # end_node_id="10", - # type="WORKS_FOR", - # ) - ], - ) - await _test(extracted_graph, schema_dict, filtered_graph) - - @pytest.mark.asyncio async def test_graph_pruning_strict_patterns(extracted_graph: Neo4jGraph) -> None: """Additional patterns not allowed: diff --git a/tests/unit/experimental/components/test_schema.py b/tests/unit/experimental/components/test_schema.py index be960f826..10a8de137 100644 --- a/tests/unit/experimental/components/test_schema.py +++ b/tests/unit/experimental/components/test_schema.py @@ -19,6 +19,7 @@ from unittest.mock import AsyncMock, patch import pytest +from pydantic import ValidationError from neo4j_graphrag.exceptions import SchemaValidationError, SchemaExtractionError from neo4j_graphrag.experimental.components.schema import ( @@ -38,8 +39,8 @@ from neo4j_graphrag.utils.file_handler import FileFormat -def test_node_type_raise_warning_if_misconfigured() -> None: - with pytest.warns(UserWarning): +def test_node_type_raise_error_if_misconfigured() -> None: + with pytest.raises(ValidationError): NodeType( label="test", properties=[], @@ -47,8 +48,8 @@ def test_node_type_raise_warning_if_misconfigured() -> None: ) -def test_relationship_type_raise_warning_if_misconfigured() -> None: - with pytest.warns(UserWarning): +def test_relationship_type_raise_error_if_misconfigured() -> None: + with pytest.raises(ValidationError): RelationshipType( label="test", properties=[], @@ -56,6 +57,41 @@ def test_relationship_type_raise_warning_if_misconfigured() -> None: ) +def test_schema_additional_parameter_validation() -> None: + """Additional relationship types not allowed, but additional patterns allowed + + => raise Exception + """ + schema_dict = { + "node_types": [ + { + "label": "Person", + "properties": [ + { + "name": "name", + "type": "STRING", + }, + {"name": "height", "type": "INTEGER"}, + ], + } + ], + "relationship_types": [ + { + "label": "KNOWS", + } + ], + "patterns": [ + ("Person", "KNOWS", "Person"), + ], + "additional_relationship_types": False, + } + with pytest.raises( + ValidationError, + match="`additional_relationship_types` must be set to True when using `additional_patterns=True`", + ): + GraphSchema.model_validate(schema_dict) + + @pytest.fixture def valid_node_types() -> tuple[NodeType, ...]: return ( From da4bc45bb1346d71426334c806148a782d397b57 Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 9 Jun 2025 13:38:32 +0200 Subject: [PATCH 17/25] Fix doc --- docs/source/user_guide_kg_builder.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/user_guide_kg_builder.rst b/docs/source/user_guide_kg_builder.rst index c5c2096db..cb256caf7 100644 --- a/docs/source/user_guide_kg_builder.rst +++ b/docs/source/user_guide_kg_builder.rst @@ -999,11 +999,11 @@ Configuration Options - **Required Properties** Required properties may be specified at the node or relationship type level. Any extracted node or relationship missing one or more of its required properties will be pruned from the graph. -- **Additional Properties** *(default: False)* +- **Additional Properties** *(default: True)* This node- or relationship-level option determines whether extra properties not listed in the schema should be retained. - - If set to ``False`` (default), all extracted properties are retained. - - If set to ``True``, only the properties defined in the schema are preserved; all others are removed. + - If set to ``True`` (default), all extracted properties are retained. + - If set to ``False``, only the properties defined in the schema are preserved; all others are removed. - **Additional Node Types** *(default: True)* This schema-level option specifies whether node types not defined in the schema are included in the graph. From f7b024cdf6a56fec98420a9597e51eea5c78298b Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 9 Jun 2025 13:50:21 +0200 Subject: [PATCH 18/25] Update doc --- docs/source/user_guide_kg_builder.rst | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/docs/source/user_guide_kg_builder.rst b/docs/source/user_guide_kg_builder.rst index cb256caf7..76dd2bac5 100644 --- a/docs/source/user_guide_kg_builder.rst +++ b/docs/source/user_guide_kg_builder.rst @@ -73,7 +73,8 @@ Customizing the SimpleKGPipeline Graph Schema ------------ -It is possible to guide the LLM by supplying a list of node and relationship types, +It is possible to guide the LLM by supplying a list of node and relationship types ( +with, optionally, a list of their expected properties) and instructions on how to connect them (patterns). Node and relationship types can be represented as either simple strings (for their labels) or dictionaries. If using a dictionary, @@ -1005,6 +1006,12 @@ Configuration Options - If set to ``True`` (default), all extracted properties are retained. - If set to ``False``, only the properties defined in the schema are preserved; all others are removed. + +.. note:: Node pruning + + If, after property pruning using the above rule, a node is left without any property, it is removed from the graph. + + - **Additional Node Types** *(default: True)* This schema-level option specifies whether node types not defined in the schema are included in the graph. From a0d3e0a02e40352610798133b25d034f53f11443 Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 9 Jun 2025 13:55:52 +0200 Subject: [PATCH 19/25] Fix condition --- .../experimental/components/schema.py | 7 +++++-- .../test_graph_pruning_component_e2e.py | 14 +++++++------- tests/unit/experimental/components/test_schema.py | 4 ++-- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index fd950db13..8f686c298 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -191,9 +191,12 @@ def validate_patterns_against_node_and_rel_types(self) -> Self: @model_validator(mode="after") def validate_additional_parameters(self) -> Self: - if self.additional_patterns and not self.additional_relationship_types: + if ( + self.additional_patterns is False + and self.additional_relationship_types is True + ): raise ValueError( - "`additional_relationship_types` must be set to True when using `additional_patterns=True`" + "`additional_relationship_types` must be set to False when using `additional_patterns=False`" ) return self diff --git a/tests/e2e/experimental/test_graph_pruning_component_e2e.py b/tests/e2e/experimental/test_graph_pruning_component_e2e.py index 73804770e..333e74163 100644 --- a/tests/e2e/experimental/test_graph_pruning_component_e2e.py +++ b/tests/e2e/experimental/test_graph_pruning_component_e2e.py @@ -371,6 +371,7 @@ async def test_graph_pruning_strict_patterns(extracted_graph: Neo4jGraph) -> Non ("Person", "KNOWS", "Person"), ("Person", "KNOWS", "Organization"), ), + "additional_relationship_types": False, "additional_patterns": False, } filtered_graph = Neo4jGraph( @@ -424,13 +425,12 @@ async def test_graph_pruning_strict_patterns(extracted_graph: Neo4jGraph) -> Non # end_node_id="2", # type="MANAGES", # ), - # not a valid pattern but WORKS_FOR - # not in relationship types, so we keep it - Neo4jRelationship( - start_node_id="1", - end_node_id="10", - type="WORKS_FOR", - ), + # invalid pattern (person, works for, person) + # Neo4jRelationship( + # start_node_id="1", + # end_node_id="10", + # type="WORKS_FOR", + # ), ], ) await _test(extracted_graph, schema_dict, filtered_graph) diff --git a/tests/unit/experimental/components/test_schema.py b/tests/unit/experimental/components/test_schema.py index 10a8de137..cfc987af7 100644 --- a/tests/unit/experimental/components/test_schema.py +++ b/tests/unit/experimental/components/test_schema.py @@ -83,11 +83,11 @@ def test_schema_additional_parameter_validation() -> None: "patterns": [ ("Person", "KNOWS", "Person"), ], - "additional_relationship_types": False, + "additional_patterns": False, } with pytest.raises( ValidationError, - match="`additional_relationship_types` must be set to True when using `additional_patterns=True`", + match="`additional_relationship_types` must be set to False when using `additional_patterns=False`", ): GraphSchema.model_validate(schema_dict) From 91ed08852098b7eaed199db1495e59d58b0d3e0f Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 9 Jun 2025 13:58:55 +0200 Subject: [PATCH 20/25] Remove incomplete comments --- .../experimental/components/graph_pruning.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/graph_pruning.py b/src/neo4j_graphrag/experimental/components/graph_pruning.py index 89624634e..229f65f93 100644 --- a/src/neo4j_graphrag/experimental/components/graph_pruning.py +++ b/src/neo4j_graphrag/experimental/components/graph_pruning.py @@ -72,7 +72,6 @@ def _clean_graph( If an entity is removed, all of its relationships are also removed. If no valid properties remain for an entity, remove that entity. """ - # enforce nodes (remove invalid labels, strip invalid properties) filtered_nodes = self._enforce_nodes(graph.nodes, schema) if not filtered_nodes: logger.warning( @@ -80,9 +79,6 @@ def _clean_graph( ) return Neo4jGraph() - # enforce relationships (remove those referencing invalid nodes or with invalid - # types or with start/end nodes not conforming to the schema, and strip invalid - # properties) filtered_rels = self._enforce_relationships( graph.relationships, filtered_nodes, schema ) @@ -213,7 +209,8 @@ def _enforce_relationships( Keep only those whose types are in schema, start/end node conform to schema, and start/end nodes are in filtered nodes (i.e., kept after node enforcement). For each valid relationship, filter out properties not present in the schema. - If a relationship direct is incorrect, invert it. + + If a relationship direction is incorrect, invert it. """ valid_rels = [] From 27c7ba6de3c9654bc20d4f3cbae6145c8f5d4be3 Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 10 Jun 2025 14:53:04 +0200 Subject: [PATCH 21/25] More pruning stats --- examples/README.md | 1 + .../components/pruners/graph_pruner.py | 136 ++++++++++ .../experimental/components/graph_pruning.py | 237 +++++++++++++++--- .../experimental/components/types.py | 8 + .../components/test_graph_pruning.py | 37 ++- 5 files changed, 361 insertions(+), 58 deletions(-) create mode 100644 examples/customize/build_graph/components/pruners/graph_pruner.py diff --git a/examples/README.md b/examples/README.md index fa8bb945e..774739b32 100644 --- a/examples/README.md +++ b/examples/README.md @@ -128,6 +128,7 @@ are listed in [the last section of this file](#customize). - [LLM-based](./customize/build_graph/components/extractors/llm_entity_relation_extractor.py) - [LLM-based with custom prompt](./customize/build_graph/components/extractors/llm_entity_relation_extractor_with_custom_prompt.py) - [Custom](./customize/build_graph/components/extractors/custom_extractor.py) +- [Graph Pruner](./customize/build_graph/components/pruners/graph_pruner.py) - Knowledge Graph Writer: - [Neo4j writer](./customize/build_graph/components/writers/neo4j_writer.py) - [Custom](./customize/build_graph/components/writers/custom_writer.py) diff --git a/examples/customize/build_graph/components/pruners/graph_pruner.py b/examples/customize/build_graph/components/pruners/graph_pruner.py new file mode 100644 index 000000000..adf8694a1 --- /dev/null +++ b/examples/customize/build_graph/components/pruners/graph_pruner.py @@ -0,0 +1,136 @@ +"""This example demonstrates how to use the GraphPruner component.""" + +import asyncio + +from neo4j_graphrag.experimental.components.graph_pruning import GraphPruning +from neo4j_graphrag.experimental.components.schema import ( + GraphSchema, + NodeType, + PropertyType, + RelationshipType, +) +from neo4j_graphrag.experimental.components.types import ( + Neo4jGraph, + Neo4jNode, + Neo4jRelationship, +) + +graph = Neo4jGraph( + nodes=[ + Neo4jNode( + id="Person/John", + label="Person", + properties={ + "firstName": "John", + "lastName": "Doe", + "occupation": "employee", + }, + ), + Neo4jNode( + id="Person/Jane", + label="Person", + properties={ + "firstName": "Jane", + }, + ), + Neo4jNode( + id="Person/Jack", + label="Person", + properties={"firstName": "Jack", "lastName": "Dae"}, + ), + Neo4jNode( + id="Organization/Corp1", + label="Organization", + properties={"name": "CorpA"}, + ), + ], + relationships=[ + Neo4jRelationship( + start_node_id="Person/John", + end_node_id="Person/Jack", + type="KNOWS", + ), + Neo4jRelationship( + start_node_id="Organization/CorpA", + end_node_id="Person/Jack", + type="WORKS_FOR", + ), + Neo4jRelationship( + start_node_id="Person/John", + end_node_id="Person/Jack", + type="PARENT_OF", + ), + ], +) + +schema = GraphSchema( + node_types=( + NodeType( + label="Person", + properties=[ + PropertyType(name="firstName", type="STRING", required=True), + PropertyType(name="lastName", type="STRING", required=True), + PropertyType(name="age", type="INTEGER"), + ], + additional_properties=False, + ), + NodeType( + label="Organization", + properties=[ + PropertyType(name="name", type="STRING", required=True), + PropertyType(name="address", type="STRING"), + ], + ), + ), + relationship_types=( + RelationshipType( + label="WORKS_FOR", + properties=[PropertyType(name="since", type="LOCAL_DATETIME")], + ), + RelationshipType( + label="KNOWS", + ), + ), + patterns=( + ("Person", "KNOWS", "Person"), + ("Person", "WORKS_FOR", "Organization"), + ), + additional_node_types=False, + additional_relationship_types=False, + additional_patterns=False, +) + + +async def main() -> None: + pruner = GraphPruning() + res = await pruner.run(graph, schema) + print("=" * 20, "FINAL CLEANED GRAPH:", "=" * 20) + print(res.graph) + print("=" * 20, "PRUNED ITEM:", "=" * 20) + print(res.pruning_stats) + print("-" * 10, "PRUNED NODES:") + for node in res.pruning_stats.pruned_nodes: + print( + node.item.label, + "with properties", + node.item.properties, + "pruned because", + node.pruned_reason, + node.metadata, + ) + print("-" * 10, "PRUNED RELATIONSHIPS:") + for rel in res.pruning_stats.pruned_relationships: + print(rel.item.type, "pruned because", rel.pruned_reason) + print("-" * 10, "PRUNED PROPERTIES:") + for prop in res.pruning_stats.pruned_properties: + print( + prop.item, + "from node label", + prop.label, + "pruned because", + prop.pruned_reason, + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/neo4j_graphrag/experimental/components/graph_pruning.py b/src/neo4j_graphrag/experimental/components/graph_pruning.py index 229f65f93..77fdb907e 100644 --- a/src/neo4j_graphrag/experimental/components/graph_pruning.py +++ b/src/neo4j_graphrag/experimental/components/graph_pruning.py @@ -12,10 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import enum import logging -from typing import Optional, Any +from typing import Optional, Any, TypeVar, Generic, Union -from pydantic import validate_call +from pydantic import validate_call, BaseModel from neo4j_graphrag.experimental.components.schema import ( GraphSchema, @@ -33,9 +34,99 @@ logger = logging.getLogger(__name__) +class PruningReason(str, enum.Enum): + NOT_IN_SCHEMA = "NOT_IN_SCHEMA" + MISSING_REQUIRED_PROPERTY = "MISSING_REQUIRED_PROPERTY" + NO_PROPERTY_LEFT = "NO_PROPERTY_LEFT" + INVALID_START_OR_END_NODE = "INVALID_START_OR_END_NODE" + INVALID_PATTERN = "INVALID_PATTERN" + + +ItemType = TypeVar("ItemType") + + +class PrunedItem(BaseModel, Generic[ItemType]): + label: str + item: ItemType + pruned_reason: PruningReason + metadata: dict[str, Any] = {} + + +class PruningStats(BaseModel): + pruned_nodes: list[PrunedItem[Neo4jNode]] = [] + pruned_relationships: list[PrunedItem[Neo4jRelationship]] = [] + pruned_properties: list[PrunedItem[str]] = [] + + @property + def number_of_pruned_nodes(self) -> int: + return len(self.pruned_nodes) + + @property + def number_of_pruned_relationships(self) -> int: + return len(self.pruned_relationships) + + @property + def number_of_pruned_properties(self) -> int: + return len(self.pruned_properties) + + def __str__(self): + return ( + f"PruningStats: nodes: {self.number_of_pruned_nodes}, " + f"relationships: {self.number_of_pruned_relationships}, " + f"properties: {self.number_of_pruned_properties}" + ) + + def add_pruned_node( + self, node: Neo4jNode, reason: PruningReason, **kwargs: Any + ) -> None: + self.pruned_nodes.append( + PrunedItem( + label=node.label, item=node, pruned_reason=reason, metadata=kwargs + ) + ) + + def add_pruned_relationship( + self, relationship: Neo4jRelationship, reason: PruningReason, **kwargs: Any + ) -> None: + self.pruned_relationships.append( + PrunedItem( + label=relationship.type, + item=relationship, + pruned_reason=reason, + metadata=kwargs, + ) + ) + + def add_pruned_property( + self, prop: str, label: str, reason: PruningReason, **kwargs: Any + ) -> None: + self.pruned_properties.append( + PrunedItem(label=label, item=prop, pruned_reason=reason, metadata=kwargs) + ) + + def add_pruned_item( + self, + item: Union[Neo4jNode, Neo4jRelationship], + reason: PruningReason, + **kwargs: Any, + ) -> None: + if isinstance(item, Neo4jNode): + self.add_pruned_node( + item, + reason=reason, + **kwargs, + ) + else: + self.add_pruned_relationship( + item, + reason=reason, + **kwargs, + ) + + class GraphPruningResult(DataModel): graph: Neo4jGraph - metadata: dict[str, Any] = {} + pruning_stats: PruningStats class GraphPruning(Component): @@ -46,25 +137,20 @@ async def run( schema: Optional[GraphSchema] = None, ) -> GraphPruningResult: if schema is not None: - new_graph = self._clean_graph(graph, schema) + new_graph, pruning_stats = self._clean_graph(graph, schema) else: new_graph = graph + pruning_stats = PruningStats() return GraphPruningResult( graph=new_graph, - metadata={ - "stats": { - "pruned_node_count": len(graph.nodes) - len(new_graph.nodes), - "pruned_relationship_count": len(graph.relationships) - - len(new_graph.relationships), - } - }, + pruning_stats=pruning_stats, ) def _clean_graph( self, graph: Neo4jGraph, schema: GraphSchema, - ) -> Neo4jGraph: + ) -> tuple[Neo4jGraph, PruningStats]: """ Verify that the graph conforms to the provided schema. @@ -72,22 +158,34 @@ def _clean_graph( If an entity is removed, all of its relationships are also removed. If no valid properties remain for an entity, remove that entity. """ - filtered_nodes = self._enforce_nodes(graph.nodes, schema) + pruning_stats = PruningStats() + filtered_nodes = self._enforce_nodes( + graph.nodes, + schema, + pruning_stats, + ) if not filtered_nodes: logger.warning( "PRUNING: all nodes were pruned, resulting graph will be empty. Check logs for details." ) - return Neo4jGraph() + return Neo4jGraph(), pruning_stats filtered_rels = self._enforce_relationships( - graph.relationships, filtered_nodes, schema + graph.relationships, + filtered_nodes, + schema, + pruning_stats, ) - return Neo4jGraph(nodes=filtered_nodes, relationships=filtered_rels) + return ( + Neo4jGraph(nodes=filtered_nodes, relationships=filtered_rels), + pruning_stats, + ) def _validate_node( self, node: Neo4jNode, + pruning_stats: PruningStats, schema_entity: Optional[NodeType] = None, additional_node_types: bool = True, ) -> Optional[Neo4jNode]: @@ -97,10 +195,13 @@ def _validate_node( # keep node as it is as we do not have any additional info return node # it's not in schema + pruning_stats.add_pruned_node(node, reason=PruningReason.NOT_IN_SCHEMA) return None - allowed_props = schema_entity.properties filtered_props = self._enforce_properties( - node.properties, allowed_props, schema_entity.additional_properties + node, + schema_entity, + pruning_stats, + prune_empty=True, ) if not filtered_props: return None @@ -112,7 +213,10 @@ def _validate_node( ) def _enforce_nodes( - self, extracted_nodes: list[Neo4jNode], schema: GraphSchema + self, + extracted_nodes: list[Neo4jNode], + schema: GraphSchema, + pruning_stats: PruningStats, ) -> list[Neo4jNode]: """ Filter extracted nodes to be conformant to the schema. @@ -127,6 +231,7 @@ def _enforce_nodes( schema_entity = schema.node_type_from_label(node.label) new_node = self._validate_node( node, + pruning_stats, schema_entity, additional_node_types=schema.additional_node_types, ) @@ -138,6 +243,7 @@ def _validate_relationship( self, rel: Neo4jRelationship, valid_nodes: dict[str, str], + pruning_stats: PruningStats, relationship_type: Optional[RelationshipType], additional_relationship_types: bool, patterns: tuple[tuple[str, str, str], ...], @@ -148,6 +254,9 @@ def _validate_relationship( logger.debug( f"PRUNING:: {rel} as one of {rel.start_node_id} or {rel.end_node_id} is not a valid node" ) + pruning_stats.add_pruned_relationship( + rel, reason=PruningReason.INVALID_START_OR_END_NODE + ) return None # validate relationship type @@ -156,6 +265,9 @@ def _validate_relationship( logger.debug( f"PRUNING:: {rel} as {rel.type} is not in the schema and `additional_relationship_types` is False" ) + pruning_stats.add_pruned_relationship( + rel, reason=PruningReason.NOT_IN_SCHEMA + ) return None # validate pattern @@ -178,13 +290,18 @@ def _validate_relationship( if not tuple_valid and not reverse_tuple_valid and not additional_patterns: logger.debug(f"PRUNING:: {rel} not in the allowed patterns") + pruning_stats.add_pruned_relationship( + rel, reason=PruningReason.INVALID_PATTERN + ) return None # filter properties if we can if relationship_type is not None: - allowed_props = relationship_type.properties filtered_props = self._enforce_properties( - rel.properties, allowed_props, relationship_type.additional_properties + rel, + relationship_type, + pruning_stats, + prune_empty=False, ) else: filtered_props = rel.properties @@ -202,6 +319,7 @@ def _enforce_relationships( extracted_relationships: list[Neo4jRelationship], filtered_nodes: list[Neo4jNode], schema: GraphSchema, + pruning_stats: PruningStats, ) -> list[Neo4jRelationship]: """ Filter extracted nodes to be conformant to the schema. @@ -220,6 +338,7 @@ def _enforce_relationships( new_rel = self._validate_relationship( rel, valid_nodes, + pruning_stats, schema_relation, schema.additional_relationship_types, schema.patterns, @@ -230,27 +349,73 @@ def _enforce_relationships( return valid_rels def _enforce_properties( + self, + item: Union[Neo4jNode, Neo4jRelationship], + schema_item: Union[NodeType, RelationshipType], + pruning_stats: PruningStats, + prune_empty: bool = False, + ) -> dict[str, Any]: + """ + Enforce properties: + - Filter out those that are not in schema (i.e., valid properties) if allowed properties is False. + - Check that all required properties are present and not null. + """ + filtered_properties = self._filter_properties( + item.properties, + schema_item.properties, + schema_item.additional_properties, + item.token, # label or type + pruning_stats, + ) + if not filtered_properties and prune_empty: + pruning_stats.add_pruned_item(item, reason=PruningReason.NO_PROPERTY_LEFT) + return filtered_properties + missing_required_properties = self._check_required_properties( + filtered_properties, + valid_properties=schema_item.properties, + ) + if missing_required_properties: + pruning_stats.add_pruned_item( + item, + reason=PruningReason.MISSING_REQUIRED_PROPERTY, + missing_required_properties=missing_required_properties, + ) + return {} + return filtered_properties + + def _filter_properties( self, properties: dict[str, Any], valid_properties: list[PropertyType], additional_properties: bool, + node_label: str, + pruning_stats: PruningStats, ) -> dict[str, Any]: - """ - Filter properties. - - Keep only those that exist in schema (i.e., valid properties). - - Check that all required properties are present - """ + """Filters out properties not in schema if additional_properties is False""" + if additional_properties: + # we do not need to filter any property, just return the initial properties + return properties valid_prop_names = {prop.name for prop in valid_properties} - filtered_properties = { - key: value - for key, value in properties.items() - if key in valid_prop_names or additional_properties - } + filtered_properties = {} + for prop_name, prop_value in properties.items(): + if prop_name not in valid_prop_names: + pruning_stats.add_pruned_property( + prop_name, + node_label, + reason=PruningReason.NOT_IN_SCHEMA, + value=prop_value, + ) + continue + filtered_properties[prop_name] = prop_value + return filtered_properties + + def _check_required_properties( + self, filtered_properties: dict[str, Any], valid_properties: list[PropertyType] + ) -> list[str]: + """Returns the list of missing required properties, if any.""" required_prop_names = {prop.name for prop in valid_properties if prop.required} + missing_required_properties = [] for req_prop in required_prop_names: if filtered_properties.get(req_prop) is None: - logger.info( - f"PRUNING:: {req_prop} is required but missing in {properties} - skipping node" - ) - return {} # node will be pruned - return filtered_properties + missing_required_properties.append(req_prop) + return missing_required_properties diff --git a/src/neo4j_graphrag/experimental/components/types.py b/src/neo4j_graphrag/experimental/components/types.py index 689e6b6ce..4d271c242 100644 --- a/src/neo4j_graphrag/experimental/components/types.py +++ b/src/neo4j_graphrag/experimental/components/types.py @@ -99,6 +99,10 @@ def check_for_id_properties( raise TypeError("'id' as a property name is not allowed") return v + @property + def token(self) -> str: + return self.label + class Neo4jRelationship(BaseModel): """Represents a Neo4j relationship. @@ -117,6 +121,10 @@ class Neo4jRelationship(BaseModel): properties: dict[str, Any] = {} embedding_properties: Optional[dict[str, list[float]]] = None + @property + def token(self) -> str: + return self.type + class Neo4jGraph(DataModel): """Represents a Neo4j graph. diff --git a/tests/unit/experimental/components/test_graph_pruning.py b/tests/unit/experimental/components/test_graph_pruning.py index a6eedb4d5..970e329d4 100644 --- a/tests/unit/experimental/components/test_graph_pruning.py +++ b/tests/unit/experimental/components/test_graph_pruning.py @@ -14,13 +14,14 @@ # limitations under the License. from __future__ import annotations from typing import Any -from unittest.mock import patch, Mock +from unittest.mock import patch, Mock, ANY import pytest from neo4j_graphrag.experimental.components.graph_pruning import ( GraphPruning, GraphPruningResult, + PruningStats, ) from neo4j_graphrag.experimental.components.schema import ( NodeType, @@ -73,32 +74,21 @@ "name": "John Does", }, ), - ( - # required missing - { - "age": 25, - }, - [ - PropertyType( - name="name", - type="STRING", - required=True, - ) - ], - True, - {}, - ), ], ) -def test_graph_pruning_enforce_properties( +def test_graph_pruning_filter_properties( properties: dict[str, Any], valid_properties: list[PropertyType], additional_properties: bool, expected_filtered_properties: dict[str, Any], ) -> None: prunner = GraphPruning() - filtered_properties = prunner._enforce_properties( - properties, valid_properties, additional_properties=additional_properties + filtered_properties = prunner._filter_properties( + properties, + valid_properties, + additional_properties=additional_properties, + node_label="Label", + pruning_stats=PruningStats(), ) assert filtered_properties == expected_filtered_properties @@ -162,7 +152,7 @@ def test_graph_pruning_validate_node( e = request.getfixturevalue(entity) if entity else None prunner = GraphPruning() - result = prunner._validate_node(node, e, additional_node_types) + result = prunner._validate_node(node, PruningStats(), e, additional_node_types) if expected_node is not None: assert result == expected_node else: @@ -331,6 +321,7 @@ def test_graph_pruning_validate_relationship( pruner._validate_relationship( relationship_obj, valid_nodes, + PruningStats(), relationship_type, additional_relationship_types, patterns, @@ -351,7 +342,7 @@ async def test_graph_pruning_run_happy_path( ) schema = GraphSchema(node_types=(node_type_required_name,)) cleaned_graph = Neo4jGraph(nodes=[Neo4jNode(id="1", label="Person")]) - mock_clean_graph.return_value = cleaned_graph + mock_clean_graph.return_value = (cleaned_graph, PruningStats()) pruner = GraphPruning() pruner_result = await pruner.run( graph=initial_graph, @@ -384,9 +375,11 @@ def test_graph_pruning_clean_graph( initial_graph = Neo4jGraph(nodes=[Neo4jNode(id="1", label="Person")]) schema = GraphSchema(node_types=()) pruner = GraphPruning() - cleaned_graph = pruner._clean_graph(initial_graph, schema) + cleaned_graph, pruning_stats = pruner._clean_graph(initial_graph, schema) assert cleaned_graph == Neo4jGraph() + assert isinstance(pruning_stats, PruningStats) mock_enforce_nodes.assert_called_once_with( [Neo4jNode(id="1", label="Person")], schema, + ANY, ) From dd52c44ae081b8d646bf53f52483a7f3c56baad6 Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 10 Jun 2025 14:53:30 +0200 Subject: [PATCH 22/25] Typo --- tests/unit/experimental/components/test_graph_pruning.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/unit/experimental/components/test_graph_pruning.py b/tests/unit/experimental/components/test_graph_pruning.py index 970e329d4..8c2d490d1 100644 --- a/tests/unit/experimental/components/test_graph_pruning.py +++ b/tests/unit/experimental/components/test_graph_pruning.py @@ -82,8 +82,8 @@ def test_graph_pruning_filter_properties( additional_properties: bool, expected_filtered_properties: dict[str, Any], ) -> None: - prunner = GraphPruning() - filtered_properties = prunner._filter_properties( + pruner = GraphPruning() + filtered_properties = pruner._filter_properties( properties, valid_properties, additional_properties=additional_properties, @@ -151,8 +151,8 @@ def test_graph_pruning_validate_node( ) -> None: e = request.getfixturevalue(entity) if entity else None - prunner = GraphPruning() - result = prunner._validate_node(node, PruningStats(), e, additional_node_types) + pruner = GraphPruning() + result = pruner._validate_node(node, PruningStats(), e, additional_node_types) if expected_node is not None: assert result == expected_node else: From 50da0e437310038ad5a086edd3cf5e6625317b3d Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 10 Jun 2025 15:49:51 +0200 Subject: [PATCH 23/25] Remove default value for consistency --- src/neo4j_graphrag/experimental/components/graph_pruning.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/graph_pruning.py b/src/neo4j_graphrag/experimental/components/graph_pruning.py index 77fdb907e..9e01f8df3 100644 --- a/src/neo4j_graphrag/experimental/components/graph_pruning.py +++ b/src/neo4j_graphrag/experimental/components/graph_pruning.py @@ -186,8 +186,8 @@ def _validate_node( self, node: Neo4jNode, pruning_stats: PruningStats, - schema_entity: Optional[NodeType] = None, - additional_node_types: bool = True, + schema_entity: Optional[NodeType], + additional_node_types: bool, ) -> Optional[Neo4jNode]: if not schema_entity: # node type not declared in the schema From ea1e29e77e3ae55f99bf1729b396d03bbbc15c43 Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 10 Jun 2025 17:40:32 +0200 Subject: [PATCH 24/25] Add a section to the doc --- docs/source/user_guide_kg_builder.rst | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/docs/source/user_guide_kg_builder.rst b/docs/source/user_guide_kg_builder.rst index 76dd2bac5..186a82099 100644 --- a/docs/source/user_guide_kg_builder.rst +++ b/docs/source/user_guide_kg_builder.rst @@ -1031,6 +1031,18 @@ Configuration Options - If set to ``False``, only patterns defined in the schema are kept. **Note** `additional_relationship_types` must also be `False`. + +Enforcement rules +_________________ + +In addition to the user-defined configuration options described above, +the `GraphPruning` component performs the following cleanup operations: + +- Nodes with missing required properties are pruned. +- Nodes with no remaining properties are pruned. +- Relationships with invalid source or target nodes (i.e., nodes no longer present in the graph) are pruned. +- Relationships with incorrect direction have their direction corrected. + .. _kg-writer-section: Knowledge Graph Writer From bc501d943ed5f20d4e9626f5171b22fffe15b6fc Mon Sep 17 00:00:00 2001 From: estelle Date: Wed, 11 Jun 2025 10:19:08 +0200 Subject: [PATCH 25/25] Mypy checks --- src/neo4j_graphrag/experimental/components/graph_pruning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/neo4j_graphrag/experimental/components/graph_pruning.py b/src/neo4j_graphrag/experimental/components/graph_pruning.py index 9e01f8df3..8c61810ec 100644 --- a/src/neo4j_graphrag/experimental/components/graph_pruning.py +++ b/src/neo4j_graphrag/experimental/components/graph_pruning.py @@ -69,7 +69,7 @@ def number_of_pruned_relationships(self) -> int: def number_of_pruned_properties(self) -> int: return len(self.pruned_properties) - def __str__(self): + def __str__(self) -> str: return ( f"PruningStats: nodes: {self.number_of_pruned_nodes}, " f"relationships: {self.number_of_pruned_relationships}, "