diff --git a/CHANGELOG.md b/CHANGELOG.md index a8bad9486..f4bd19157 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,7 +15,7 @@ #### Strict mode -- Strict mode in `SimpleKGPipeline`: now properties and relationships are pruned only if they are defined in the input schema. +- Strict mode in `SimpleKGPipeline`: the `enforce_schema` option is removed and replaced by a schema-driven pruning. #### Schema definition diff --git a/docs/source/user_guide_kg_builder.rst b/docs/source/user_guide_kg_builder.rst index 6af5e8cb3..186a82099 100644 --- a/docs/source/user_guide_kg_builder.rst +++ b/docs/source/user_guide_kg_builder.rst @@ -73,10 +73,10 @@ Customizing the SimpleKGPipeline Graph Schema ------------ -It is possible to guide the LLM by supplying a list of node and relationship types, -and instructions on how to connect them (patterns). However, note that the extracted graph -may not fully adhere to these guidelines unless schema enforcement is enabled -(see :ref:`Schema Enforcement Behaviour`). Node and relationship types can be represented +It is possible to guide the LLM by supplying a list of node and relationship types ( +with, optionally, a list of their expected properties) +and instructions on how to connect them (patterns). +Node and relationship types can be represented as either simple strings (for their labels) or dictionaries. If using a dictionary, it must include a label key and can optionally include description and properties keys, as shown below: @@ -90,7 +90,7 @@ as shown below: # such as a description: {"label": "House", "description": "Family the person belongs to"}, # or a list of properties the LLM will try to attach to the entity: - {"label": "Planet", "properties": [{"name": "weather", "type": "STRING"}]}, + {"label": "Planet", "properties": [{"name": "name", "type": "STRING", "required": True}, {"name": "weather", "type": "STRING"}]}, ] # same thing for relationships: RELATIONSHIP_TYPES = [ @@ -124,7 +124,8 @@ This schema information can be provided to the `SimpleKGBuilder` as demonstrated schema={ "node_types": NODE_TYPES, "relationship_types": RELATIONSHIP_TYPES, - "patterns": PATTERNS + "patterns": PATTERNS, + "additional_node_types": False, }, # ... ) @@ -145,7 +146,6 @@ They are also accessible via the `SimpleKGPipeline` interface. # ... prompt_template="", lexical_graph_config=my_config, - enforce_schema="STRICT" on_error="RAISE", # ... ) @@ -878,38 +878,6 @@ It can be used in this way: The LLM to use can be customized, the only constraint is that it obeys the :ref:`LLMInterface `. -Schema Enforcement Behaviour ----------------------------- -.. _schema-enforcement-behaviour: - -By default, even if a schema is provided to guide the LLM in the entity and relation extraction, the LLM response is not validated against that schema. -This behaviour can be changed by using the `enforce_schema` flag in the `LLMEntityRelationExtractor` constructor: - -.. code:: python - - from neo4j_graphrag.experimental.components.entity_relation_extractor import LLMEntityRelationExtractor - from neo4j_graphrag.experimental.components.types import SchemaEnforcementMode - - extractor = LLMEntityRelationExtractor( - # ... - enforce_schema=SchemaEnforcementMode.STRICT, - ) - -In this scenario, any extracted node/relation/property that is not part of the provided schema will be pruned. -Any relation whose start node or end node does not conform to the provided tuple in `potential_schema` will be pruned. -If a relation start/end nodes are valid but the direction is incorrect, the latter will be inverted. -If a node is left with no properties, it will be also pruned. - -.. note:: - - If the input schema lacks a certain type of information, pruning is skipped. - For example, if an entity is defined only by a label and has no properties, - property pruning is not performed and all properties returned by the LLM are kept. - - -.. warning:: - - Note that if the schema enforcement mode is on but the schema is not provided, no schema enforcement will be applied. Error Behaviour --------------- @@ -1017,6 +985,64 @@ If more customization is needed, it is possible to subclass the `EntityRelationE See :ref:`entityrelationextractor`. +Schema Guidance and Graph Filtering +=================================== + +The provided schema serves as a guiding structure for the language model during graph construction. However, it does not impose strict constraints on the model's output. As a result, the model may generate additional node labels, relationship types, or properties that are not explicitly defined in the schema. + +By default, all extracted elements — including nodes, relationships, and properties — are retained in the constructed graph. This behavior can be configured using the following schema options: +(see :ref:`graphschema`) + + +Configuration Options +--------------------- + +- **Required Properties** + Required properties may be specified at the node or relationship type level. Any extracted node or relationship missing one or more of its required properties will be pruned from the graph. + +- **Additional Properties** *(default: True)* + This node- or relationship-level option determines whether extra properties not listed in the schema should be retained. + + - If set to ``True`` (default), all extracted properties are retained. + - If set to ``False``, only the properties defined in the schema are preserved; all others are removed. + + +.. note:: Node pruning + + If, after property pruning using the above rule, a node is left without any property, it is removed from the graph. + + +- **Additional Node Types** *(default: True)* + This schema-level option specifies whether node types not defined in the schema are included in the graph. + + - If set to ``True`` (default), such node types are retained. + - If set to ``False``, nodes with undefined types are removed. + +- **Additional Relationship Types** *(default: True)* + This schema-level option specifies whether relationship types not defined in the schema are included in the graph. + + - If set to ``True`` (default), such relationships are retained. + - If set to ``False``, relationships with undefined types are removed. + +- **Additional Patterns** *(default: True)* + This schema-level option determines whether relationship patterns not explicitly listed in the schema are allowed. + + - If set to ``True`` (default), all patterns are retained. + - If set to ``False``, only patterns defined in the schema are kept. **Note** `additional_relationship_types` must also be `False`. + + + +Enforcement rules +_________________ + +In addition to the user-defined configuration options described above, +the `GraphPruning` component performs the following cleanup operations: + +- Nodes with missing required properties are pruned. +- Nodes with no remaining properties are pruned. +- Relationships with invalid source or target nodes (i.e., nodes no longer present in the graph) are pruned. +- Relationships with incorrect direction have their direction corrected. + .. _kg-writer-section: Knowledge Graph Writer diff --git a/examples/README.md b/examples/README.md index fa8bb945e..774739b32 100644 --- a/examples/README.md +++ b/examples/README.md @@ -128,6 +128,7 @@ are listed in [the last section of this file](#customize). - [LLM-based](./customize/build_graph/components/extractors/llm_entity_relation_extractor.py) - [LLM-based with custom prompt](./customize/build_graph/components/extractors/llm_entity_relation_extractor_with_custom_prompt.py) - [Custom](./customize/build_graph/components/extractors/custom_extractor.py) +- [Graph Pruner](./customize/build_graph/components/pruners/graph_pruner.py) - Knowledge Graph Writer: - [Neo4j writer](./customize/build_graph/components/writers/neo4j_writer.py) - [Custom](./customize/build_graph/components/writers/custom_writer.py) diff --git a/examples/customize/build_graph/components/pruners/graph_pruner.py b/examples/customize/build_graph/components/pruners/graph_pruner.py new file mode 100644 index 000000000..adf8694a1 --- /dev/null +++ b/examples/customize/build_graph/components/pruners/graph_pruner.py @@ -0,0 +1,136 @@ +"""This example demonstrates how to use the GraphPruner component.""" + +import asyncio + +from neo4j_graphrag.experimental.components.graph_pruning import GraphPruning +from neo4j_graphrag.experimental.components.schema import ( + GraphSchema, + NodeType, + PropertyType, + RelationshipType, +) +from neo4j_graphrag.experimental.components.types import ( + Neo4jGraph, + Neo4jNode, + Neo4jRelationship, +) + +graph = Neo4jGraph( + nodes=[ + Neo4jNode( + id="Person/John", + label="Person", + properties={ + "firstName": "John", + "lastName": "Doe", + "occupation": "employee", + }, + ), + Neo4jNode( + id="Person/Jane", + label="Person", + properties={ + "firstName": "Jane", + }, + ), + Neo4jNode( + id="Person/Jack", + label="Person", + properties={"firstName": "Jack", "lastName": "Dae"}, + ), + Neo4jNode( + id="Organization/Corp1", + label="Organization", + properties={"name": "CorpA"}, + ), + ], + relationships=[ + Neo4jRelationship( + start_node_id="Person/John", + end_node_id="Person/Jack", + type="KNOWS", + ), + Neo4jRelationship( + start_node_id="Organization/CorpA", + end_node_id="Person/Jack", + type="WORKS_FOR", + ), + Neo4jRelationship( + start_node_id="Person/John", + end_node_id="Person/Jack", + type="PARENT_OF", + ), + ], +) + +schema = GraphSchema( + node_types=( + NodeType( + label="Person", + properties=[ + PropertyType(name="firstName", type="STRING", required=True), + PropertyType(name="lastName", type="STRING", required=True), + PropertyType(name="age", type="INTEGER"), + ], + additional_properties=False, + ), + NodeType( + label="Organization", + properties=[ + PropertyType(name="name", type="STRING", required=True), + PropertyType(name="address", type="STRING"), + ], + ), + ), + relationship_types=( + RelationshipType( + label="WORKS_FOR", + properties=[PropertyType(name="since", type="LOCAL_DATETIME")], + ), + RelationshipType( + label="KNOWS", + ), + ), + patterns=( + ("Person", "KNOWS", "Person"), + ("Person", "WORKS_FOR", "Organization"), + ), + additional_node_types=False, + additional_relationship_types=False, + additional_patterns=False, +) + + +async def main() -> None: + pruner = GraphPruning() + res = await pruner.run(graph, schema) + print("=" * 20, "FINAL CLEANED GRAPH:", "=" * 20) + print(res.graph) + print("=" * 20, "PRUNED ITEM:", "=" * 20) + print(res.pruning_stats) + print("-" * 10, "PRUNED NODES:") + for node in res.pruning_stats.pruned_nodes: + print( + node.item.label, + "with properties", + node.item.properties, + "pruned because", + node.pruned_reason, + node.metadata, + ) + print("-" * 10, "PRUNED RELATIONSHIPS:") + for rel in res.pruning_stats.pruned_relationships: + print(rel.item.type, "pruned because", rel.pruned_reason) + print("-" * 10, "PRUNED PROPERTIES:") + for prop in res.pruning_stats.pruned_properties: + print( + prop.item, + "from node label", + prop.label, + "pruned because", + prop.pruned_reason, + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py index 32d65276e..b1ed9bbb1 100644 --- a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py +++ b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py @@ -18,23 +18,20 @@ import enum import json import logging -from typing import Any, List, Optional, Union, cast, Dict +from typing import Any, List, Optional, Union, cast import json_repair from pydantic import ValidationError, validate_call from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.experimental.components.lexical_graph import LexicalGraphBuilder -from neo4j_graphrag.experimental.components.schema import GraphSchema, PropertyType +from neo4j_graphrag.experimental.components.schema import GraphSchema from neo4j_graphrag.experimental.components.types import ( DocumentInfo, LexicalGraphConfig, Neo4jGraph, - Neo4jNode, - Neo4jRelationship, TextChunk, TextChunks, - SchemaEnforcementMode, ) from neo4j_graphrag.experimental.pipeline.component import Component from neo4j_graphrag.experimental.pipeline.exceptions import InvalidJSONError @@ -169,7 +166,6 @@ class LLMEntityRelationExtractor(EntityRelationExtractor): llm (LLMInterface): The language model to use for extraction. prompt_template (ERExtractionTemplate | str): A custom prompt template to use for extraction. create_lexical_graph (bool): Whether to include the text chunks in the graph in addition to the extracted entities and relations. Defaults to True. - enforce_schema (SchemaEnforcementMode): Whether to validate or not the extracted entities/rels against the provided schema. Defaults to None. on_error (OnError): What to do when an error occurs during extraction. Defaults to raising an error. max_concurrency (int): The maximum number of concurrent tasks which can be used to make requests to the LLM. @@ -194,13 +190,11 @@ def __init__( llm: LLMInterface, prompt_template: ERExtractionTemplate | str = ERExtractionTemplate(), create_lexical_graph: bool = True, - enforce_schema: SchemaEnforcementMode = SchemaEnforcementMode.NONE, on_error: OnError = OnError.RAISE, max_concurrency: int = 5, ) -> None: super().__init__(on_error=on_error, create_lexical_graph=create_lexical_graph) self.llm = llm # with response_format={ "type": "json_object" }, - self.enforce_schema = enforce_schema self.max_concurrency = max_concurrency if isinstance(prompt_template, str): template = PromptTemplate(prompt_template, expected_inputs=[]) @@ -284,13 +278,13 @@ async def run_for_chunk( """Run extraction, validation and post processing for a single chunk""" async with sem: chunk_graph = await self.extract_for_chunk(schema, examples, chunk) - final_chunk_graph = self.validate_chunk(chunk_graph, schema) + # final_chunk_graph = self.validate_chunk(chunk_graph, schema) await self.post_process_chunk( - final_chunk_graph, + chunk_graph, chunk, lexical_graph_builder, ) - return final_chunk_graph + return chunk_graph @validate_call async def run( @@ -328,7 +322,7 @@ async def run( elif lexical_graph_config: lexical_graph_builder = LexicalGraphBuilder(config=lexical_graph_config) schema = schema or GraphSchema( - node_types=(), relationship_types=None, patterns=None + node_types=(), ) examples = examples or "" sem = asyncio.Semaphore(self.max_concurrency) @@ -346,169 +340,3 @@ async def run( graph = self.combine_chunk_graphs(lexical_graph, chunk_graphs) logger.debug(f"Extracted graph: {prettify(graph)}") return graph - - def validate_chunk( - self, chunk_graph: Neo4jGraph, schema: GraphSchema - ) -> Neo4jGraph: - """ - Perform validation after entity and relation extraction: - - Enforce schema if schema enforcement mode is on and schema is provided - """ - if self.enforce_schema != SchemaEnforcementMode.NONE: - if not schema or not schema.node_types: # schema is not provided - logger.warning( - "Schema enforcement is ON but the guiding schema is not provided." - ) - else: - # if enforcing_schema is on and schema is provided, clean the graph - return self._clean_graph(chunk_graph, schema) - return chunk_graph - - def _clean_graph( - self, - graph: Neo4jGraph, - schema: GraphSchema, - ) -> Neo4jGraph: - """ - Verify that the graph conforms to the provided schema. - - Remove invalid entities,relationships, and properties. - If an entity is removed, all of its relationships are also removed. - If no valid properties remain for an entity, remove that entity. - """ - # enforce nodes (remove invalid labels, strip invalid properties) - filtered_nodes = self._enforce_nodes(graph.nodes, schema) - - # enforce relationships (remove those referencing invalid nodes or with invalid - # types or with start/end nodes not conforming to the schema, and strip invalid - # properties) - filtered_rels = self._enforce_relationships( - graph.relationships, filtered_nodes, schema - ) - - return Neo4jGraph(nodes=filtered_nodes, relationships=filtered_rels) - - def _enforce_nodes( - self, extracted_nodes: List[Neo4jNode], schema: GraphSchema - ) -> List[Neo4jNode]: - """ - Filter extracted nodes to be conformant to the schema. - - Keep only those whose label is in schema. - For each valid node, filter out properties not present in the schema. - Remove a node if it ends up with no valid properties. - """ - if self.enforce_schema != SchemaEnforcementMode.STRICT: - return extracted_nodes - - valid_nodes = [] - - for node in extracted_nodes: - schema_entity = schema.node_type_from_label(node.label) - if not schema_entity: - continue - allowed_props = schema_entity.properties or [] - if allowed_props: - filtered_props = self._enforce_properties( - node.properties, allowed_props - ) - else: - filtered_props = node.properties - if filtered_props: - valid_nodes.append( - Neo4jNode( - id=node.id, - label=node.label, - properties=filtered_props, - embedding_properties=node.embedding_properties, - ) - ) - - return valid_nodes - - def _enforce_relationships( - self, - extracted_relationships: List[Neo4jRelationship], - filtered_nodes: List[Neo4jNode], - schema: GraphSchema, - ) -> List[Neo4jRelationship]: - """ - Filter extracted nodes to be conformant to the schema. - - Keep only those whose types are in schema, start/end node conform to schema, - and start/end nodes are in filtered nodes (i.e., kept after node enforcement). - For each valid relationship, filter out properties not present in the schema. - If a relationship direct is incorrect, invert it. - """ - if self.enforce_schema != SchemaEnforcementMode.STRICT: - return extracted_relationships - - if schema.relationship_types is None: - return extracted_relationships - - valid_rels = [] - - valid_nodes = {node.id: node.label for node in filtered_nodes} - - patterns = schema.patterns - - for rel in extracted_relationships: - schema_relation = schema.relationship_type_from_label(rel.type) - if not schema_relation: - logger.debug(f"PRUNING:: {rel} as {rel.type} is not in the schema") - continue - - if ( - rel.start_node_id not in valid_nodes - or rel.end_node_id not in valid_nodes - ): - logger.debug( - f"PRUNING:: {rel} as one of {rel.start_node_id} or {rel.end_node_id} is not in the graph" - ) - continue - - start_label = valid_nodes[rel.start_node_id] - end_label = valid_nodes[rel.end_node_id] - - tuple_valid = True - if patterns: - tuple_valid = (start_label, rel.type, end_label) in patterns - reverse_tuple_valid = ( - end_label, - rel.type, - start_label, - ) in patterns - - if not tuple_valid and not reverse_tuple_valid: - logger.debug(f"PRUNING:: {rel} not in the potential schema") - continue - - allowed_props = schema_relation.properties or [] - if allowed_props: - filtered_props = self._enforce_properties(rel.properties, allowed_props) - else: - filtered_props = rel.properties - - valid_rels.append( - Neo4jRelationship( - start_node_id=rel.start_node_id if tuple_valid else rel.end_node_id, - end_node_id=rel.end_node_id if tuple_valid else rel.start_node_id, - type=rel.type, - properties=filtered_props, - embedding_properties=rel.embedding_properties, - ) - ) - - return valid_rels - - def _enforce_properties( - self, properties: Dict[str, Any], valid_properties: List[PropertyType] - ) -> Dict[str, Any]: - """ - Filter properties. - Keep only those that exist in schema (i.e., valid properties). - """ - valid_prop_names = {prop.name for prop in valid_properties} - return { - key: value for key, value in properties.items() if key in valid_prop_names - } diff --git a/src/neo4j_graphrag/experimental/components/graph_pruning.py b/src/neo4j_graphrag/experimental/components/graph_pruning.py new file mode 100644 index 000000000..8c61810ec --- /dev/null +++ b/src/neo4j_graphrag/experimental/components/graph_pruning.py @@ -0,0 +1,421 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import enum +import logging +from typing import Optional, Any, TypeVar, Generic, Union + +from pydantic import validate_call, BaseModel + +from neo4j_graphrag.experimental.components.schema import ( + GraphSchema, + PropertyType, + NodeType, + RelationshipType, +) +from neo4j_graphrag.experimental.components.types import ( + Neo4jGraph, + Neo4jNode, + Neo4jRelationship, +) +from neo4j_graphrag.experimental.pipeline import Component, DataModel + +logger = logging.getLogger(__name__) + + +class PruningReason(str, enum.Enum): + NOT_IN_SCHEMA = "NOT_IN_SCHEMA" + MISSING_REQUIRED_PROPERTY = "MISSING_REQUIRED_PROPERTY" + NO_PROPERTY_LEFT = "NO_PROPERTY_LEFT" + INVALID_START_OR_END_NODE = "INVALID_START_OR_END_NODE" + INVALID_PATTERN = "INVALID_PATTERN" + + +ItemType = TypeVar("ItemType") + + +class PrunedItem(BaseModel, Generic[ItemType]): + label: str + item: ItemType + pruned_reason: PruningReason + metadata: dict[str, Any] = {} + + +class PruningStats(BaseModel): + pruned_nodes: list[PrunedItem[Neo4jNode]] = [] + pruned_relationships: list[PrunedItem[Neo4jRelationship]] = [] + pruned_properties: list[PrunedItem[str]] = [] + + @property + def number_of_pruned_nodes(self) -> int: + return len(self.pruned_nodes) + + @property + def number_of_pruned_relationships(self) -> int: + return len(self.pruned_relationships) + + @property + def number_of_pruned_properties(self) -> int: + return len(self.pruned_properties) + + def __str__(self) -> str: + return ( + f"PruningStats: nodes: {self.number_of_pruned_nodes}, " + f"relationships: {self.number_of_pruned_relationships}, " + f"properties: {self.number_of_pruned_properties}" + ) + + def add_pruned_node( + self, node: Neo4jNode, reason: PruningReason, **kwargs: Any + ) -> None: + self.pruned_nodes.append( + PrunedItem( + label=node.label, item=node, pruned_reason=reason, metadata=kwargs + ) + ) + + def add_pruned_relationship( + self, relationship: Neo4jRelationship, reason: PruningReason, **kwargs: Any + ) -> None: + self.pruned_relationships.append( + PrunedItem( + label=relationship.type, + item=relationship, + pruned_reason=reason, + metadata=kwargs, + ) + ) + + def add_pruned_property( + self, prop: str, label: str, reason: PruningReason, **kwargs: Any + ) -> None: + self.pruned_properties.append( + PrunedItem(label=label, item=prop, pruned_reason=reason, metadata=kwargs) + ) + + def add_pruned_item( + self, + item: Union[Neo4jNode, Neo4jRelationship], + reason: PruningReason, + **kwargs: Any, + ) -> None: + if isinstance(item, Neo4jNode): + self.add_pruned_node( + item, + reason=reason, + **kwargs, + ) + else: + self.add_pruned_relationship( + item, + reason=reason, + **kwargs, + ) + + +class GraphPruningResult(DataModel): + graph: Neo4jGraph + pruning_stats: PruningStats + + +class GraphPruning(Component): + @validate_call + async def run( + self, + graph: Neo4jGraph, + schema: Optional[GraphSchema] = None, + ) -> GraphPruningResult: + if schema is not None: + new_graph, pruning_stats = self._clean_graph(graph, schema) + else: + new_graph = graph + pruning_stats = PruningStats() + return GraphPruningResult( + graph=new_graph, + pruning_stats=pruning_stats, + ) + + def _clean_graph( + self, + graph: Neo4jGraph, + schema: GraphSchema, + ) -> tuple[Neo4jGraph, PruningStats]: + """ + Verify that the graph conforms to the provided schema. + + Remove invalid entities,relationships, and properties. + If an entity is removed, all of its relationships are also removed. + If no valid properties remain for an entity, remove that entity. + """ + pruning_stats = PruningStats() + filtered_nodes = self._enforce_nodes( + graph.nodes, + schema, + pruning_stats, + ) + if not filtered_nodes: + logger.warning( + "PRUNING: all nodes were pruned, resulting graph will be empty. Check logs for details." + ) + return Neo4jGraph(), pruning_stats + + filtered_rels = self._enforce_relationships( + graph.relationships, + filtered_nodes, + schema, + pruning_stats, + ) + + return ( + Neo4jGraph(nodes=filtered_nodes, relationships=filtered_rels), + pruning_stats, + ) + + def _validate_node( + self, + node: Neo4jNode, + pruning_stats: PruningStats, + schema_entity: Optional[NodeType], + additional_node_types: bool, + ) -> Optional[Neo4jNode]: + if not schema_entity: + # node type not declared in the schema + if additional_node_types: + # keep node as it is as we do not have any additional info + return node + # it's not in schema + pruning_stats.add_pruned_node(node, reason=PruningReason.NOT_IN_SCHEMA) + return None + filtered_props = self._enforce_properties( + node, + schema_entity, + pruning_stats, + prune_empty=True, + ) + if not filtered_props: + return None + return Neo4jNode( + id=node.id, + label=node.label, + properties=filtered_props, + embedding_properties=node.embedding_properties, + ) + + def _enforce_nodes( + self, + extracted_nodes: list[Neo4jNode], + schema: GraphSchema, + pruning_stats: PruningStats, + ) -> list[Neo4jNode]: + """ + Filter extracted nodes to be conformant to the schema. + + Keep only those whose label is in schema + (unless schema has additional_node_types=True, default value) + For each valid node, validate properties. If a node is left without + properties, prune it. + """ + valid_nodes = [] + for node in extracted_nodes: + schema_entity = schema.node_type_from_label(node.label) + new_node = self._validate_node( + node, + pruning_stats, + schema_entity, + additional_node_types=schema.additional_node_types, + ) + if new_node: + valid_nodes.append(new_node) + return valid_nodes + + def _validate_relationship( + self, + rel: Neo4jRelationship, + valid_nodes: dict[str, str], + pruning_stats: PruningStats, + relationship_type: Optional[RelationshipType], + additional_relationship_types: bool, + patterns: tuple[tuple[str, str, str], ...], + additional_patterns: bool, + ) -> Optional[Neo4jRelationship]: + # validate start/end node IDs are valid nodes + if rel.start_node_id not in valid_nodes or rel.end_node_id not in valid_nodes: + logger.debug( + f"PRUNING:: {rel} as one of {rel.start_node_id} or {rel.end_node_id} is not a valid node" + ) + pruning_stats.add_pruned_relationship( + rel, reason=PruningReason.INVALID_START_OR_END_NODE + ) + return None + + # validate relationship type + if relationship_type is None: + if not additional_relationship_types: + logger.debug( + f"PRUNING:: {rel} as {rel.type} is not in the schema and `additional_relationship_types` is False" + ) + pruning_stats.add_pruned_relationship( + rel, reason=PruningReason.NOT_IN_SCHEMA + ) + return None + + # validate pattern + tuple_valid = True + reverse_tuple_valid = False + if patterns and relationship_type: + start_label = valid_nodes[rel.start_node_id] + end_label = valid_nodes[rel.end_node_id] + tuple_valid = (start_label, rel.type, end_label) in patterns + # try to reverse relationship only if initial order is not valid + reverse_tuple_valid = ( + not tuple_valid + and ( + end_label, + rel.type, + start_label, + ) + in patterns + ) + + if not tuple_valid and not reverse_tuple_valid and not additional_patterns: + logger.debug(f"PRUNING:: {rel} not in the allowed patterns") + pruning_stats.add_pruned_relationship( + rel, reason=PruningReason.INVALID_PATTERN + ) + return None + + # filter properties if we can + if relationship_type is not None: + filtered_props = self._enforce_properties( + rel, + relationship_type, + pruning_stats, + prune_empty=False, + ) + else: + filtered_props = rel.properties + + return Neo4jRelationship( + start_node_id=rel.end_node_id if reverse_tuple_valid else rel.start_node_id, + end_node_id=rel.start_node_id if reverse_tuple_valid else rel.end_node_id, + type=rel.type, + properties=filtered_props, + embedding_properties=rel.embedding_properties, + ) + + def _enforce_relationships( + self, + extracted_relationships: list[Neo4jRelationship], + filtered_nodes: list[Neo4jNode], + schema: GraphSchema, + pruning_stats: PruningStats, + ) -> list[Neo4jRelationship]: + """ + Filter extracted nodes to be conformant to the schema. + + Keep only those whose types are in schema, start/end node conform to schema, + and start/end nodes are in filtered nodes (i.e., kept after node enforcement). + For each valid relationship, filter out properties not present in the schema. + + If a relationship direction is incorrect, invert it. + """ + + valid_rels = [] + valid_nodes = {node.id: node.label for node in filtered_nodes} + for rel in extracted_relationships: + schema_relation = schema.relationship_type_from_label(rel.type) + new_rel = self._validate_relationship( + rel, + valid_nodes, + pruning_stats, + schema_relation, + schema.additional_relationship_types, + schema.patterns, + schema.additional_patterns, + ) + if new_rel: + valid_rels.append(new_rel) + return valid_rels + + def _enforce_properties( + self, + item: Union[Neo4jNode, Neo4jRelationship], + schema_item: Union[NodeType, RelationshipType], + pruning_stats: PruningStats, + prune_empty: bool = False, + ) -> dict[str, Any]: + """ + Enforce properties: + - Filter out those that are not in schema (i.e., valid properties) if allowed properties is False. + - Check that all required properties are present and not null. + """ + filtered_properties = self._filter_properties( + item.properties, + schema_item.properties, + schema_item.additional_properties, + item.token, # label or type + pruning_stats, + ) + if not filtered_properties and prune_empty: + pruning_stats.add_pruned_item(item, reason=PruningReason.NO_PROPERTY_LEFT) + return filtered_properties + missing_required_properties = self._check_required_properties( + filtered_properties, + valid_properties=schema_item.properties, + ) + if missing_required_properties: + pruning_stats.add_pruned_item( + item, + reason=PruningReason.MISSING_REQUIRED_PROPERTY, + missing_required_properties=missing_required_properties, + ) + return {} + return filtered_properties + + def _filter_properties( + self, + properties: dict[str, Any], + valid_properties: list[PropertyType], + additional_properties: bool, + node_label: str, + pruning_stats: PruningStats, + ) -> dict[str, Any]: + """Filters out properties not in schema if additional_properties is False""" + if additional_properties: + # we do not need to filter any property, just return the initial properties + return properties + valid_prop_names = {prop.name for prop in valid_properties} + filtered_properties = {} + for prop_name, prop_value in properties.items(): + if prop_name not in valid_prop_names: + pruning_stats.add_pruned_property( + prop_name, + node_label, + reason=PruningReason.NOT_IN_SCHEMA, + value=prop_value, + ) + continue + filtered_properties[prop_name] = prop_value + return filtered_properties + + def _check_required_properties( + self, filtered_properties: dict[str, Any], valid_properties: list[PropertyType] + ) -> list[str]: + """Returns the list of missing required properties, if any.""" + required_prop_names = {prop.name for prop in valid_properties if prop.required} + missing_required_properties = [] + for req_prop in required_prop_names: + if filtered_properties.get(req_prop) is None: + missing_required_properties.append(req_prop) + return missing_required_properties diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index b2007ccb4..8f686c298 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -23,10 +23,10 @@ from pydantic import ( BaseModel, PrivateAttr, - ValidationError, model_validator, validate_call, ConfigDict, + ValidationError, ) from typing_extensions import Self @@ -67,6 +67,7 @@ class PropertyType(BaseModel): "ZONED_TIME", ] description: str = "" + required: bool = False model_config = ConfigDict( frozen=True, @@ -81,6 +82,7 @@ class NodeType(BaseModel): label: str description: str = "" properties: list[PropertyType] = [] + additional_properties: bool = True @model_validator(mode="before") @classmethod @@ -89,6 +91,15 @@ def validate_input_if_string(cls, data: EntityInputType) -> EntityInputType: return {"label": data} return data + @model_validator(mode="after") + def validate_additional_properties(self) -> Self: + if len(self.properties) == 0 and not self.additional_properties: + raise ValueError( + "Using `additional_properties=False` with no defined " + "properties will cause the model to be pruned during graph cleaning.", + ) + return self + class RelationshipType(BaseModel): """ @@ -98,6 +109,7 @@ class RelationshipType(BaseModel): label: str description: str = "" properties: list[PropertyType] = [] + additional_properties: bool = True @model_validator(mode="before") @classmethod @@ -106,11 +118,36 @@ def validate_input_if_string(cls, data: RelationInputType) -> RelationInputType: return {"label": data} return data + @model_validator(mode="after") + def validate_additional_properties(self) -> Self: + if len(self.properties) == 0 and not self.additional_properties: + raise ValueError( + "Using `additional_properties=False` with no defined " + "properties will cause the model to be pruned during graph cleaning.", + ) + return self + class GraphSchema(DataModel): + """This model represents the expected + node and relationship types in the graph. + + It is used both for guiding the LLM in the entity and relation + extraction component, and for cleaning the extracted graph in a + post-processing step. + + .. warning:: + + This model is immutable. + """ + node_types: Tuple[NodeType, ...] - relationship_types: Optional[Tuple[RelationshipType, ...]] = None - patterns: Optional[Tuple[Tuple[str, str, str], ...]] = None + relationship_types: Tuple[RelationshipType, ...] = tuple() + patterns: Tuple[Tuple[str, str, str], ...] = tuple() + + additional_node_types: bool = True + additional_relationship_types: bool = True + additional_patterns: bool = True _node_type_index: dict[str, NodeType] = PrivateAttr() _relationship_type_index: dict[str, RelationshipType] = PrivateAttr() @@ -120,7 +157,7 @@ class GraphSchema(DataModel): ) @model_validator(mode="after") - def check_schema(self) -> Self: + def validate_patterns_against_node_and_rel_types(self) -> Self: self._node_type_index = {node.label: node for node in self.node_types} self._relationship_type_index = ( {r.label: r for r in self.relationship_types} @@ -128,30 +165,41 @@ def check_schema(self) -> Self: else {} ) - relationship_types = self.relationship_types or tuple() - patterns = self.patterns or tuple() + relationship_types = self.relationship_types + patterns = self.patterns if patterns: if not relationship_types: raise SchemaValidationError( - "Relations must also be provided when using a potential schema." + "Relationship types must also be provided when using patterns." ) for entity1, relation, entity2 in patterns: if entity1 not in self._node_type_index: raise SchemaValidationError( - f"Entity '{entity1}' is not defined in the provided entities." + f"Node type '{entity1}' is not defined in the provided node_types." ) if relation not in self._relationship_type_index: raise SchemaValidationError( - f"Relation '{relation}' is not defined in the provided relations." + f"Relationship type '{relation}' is not defined in the provided relationship_types." ) if entity2 not in self._node_type_index: - raise SchemaValidationError( - f"Entity '{entity2}' is not defined in the provided entities." + raise ValueError( + f"Node type '{entity2}' is not defined in the provided node_types." ) return self + @model_validator(mode="after") + def validate_additional_parameters(self) -> Self: + if ( + self.additional_patterns is False + and self.additional_relationship_types is True + ): + raise ValueError( + "`additional_relationship_types` must be set to False when using `additional_patterns=False`" + ) + return self + def node_type_from_label(self, label: str) -> Optional[NodeType]: return self._node_type_index.get(label) @@ -303,12 +351,12 @@ def create_schema_model( return GraphSchema.model_validate( dict( node_types=node_types, - relationship_types=relationship_types, - patterns=patterns, + relationship_types=relationship_types or (), + patterns=patterns or (), ) ) except (ValidationError, SchemaValidationError) as e: - raise SchemaValidationError(e) + raise SchemaValidationError(e) from e @validate_call async def run( diff --git a/src/neo4j_graphrag/experimental/components/types.py b/src/neo4j_graphrag/experimental/components/types.py index b1ed46569..4d271c242 100644 --- a/src/neo4j_graphrag/experimental/components/types.py +++ b/src/neo4j_graphrag/experimental/components/types.py @@ -15,7 +15,6 @@ from __future__ import annotations import uuid -from enum import Enum from typing import Any, Dict, Optional from pydantic import BaseModel, Field, field_validator @@ -100,6 +99,10 @@ def check_for_id_properties( raise TypeError("'id' as a property name is not allowed") return v + @property + def token(self) -> str: + return self.label + class Neo4jRelationship(BaseModel): """Represents a Neo4j relationship. @@ -118,6 +121,10 @@ class Neo4jRelationship(BaseModel): properties: dict[str, Any] = {} embedding_properties: Optional[dict[str, list[float]]] = None + @property + def token(self) -> str: + return self.type + class Neo4jGraph(DataModel): """Represents a Neo4j graph. @@ -171,8 +178,3 @@ def lexical_graph_node_labels(self) -> tuple[str, ...]: class GraphResult(DataModel): graph: Neo4jGraph config: LexicalGraphConfig - - -class SchemaEnforcementMode(str, Enum): - NONE = "NONE" - STRICT = "STRICT" diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py index 4aee855bb..e1d3af5a0 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py @@ -35,6 +35,7 @@ LLMEntityRelationExtractor, OnError, ) +from neo4j_graphrag.experimental.components.graph_pruning import GraphPruning from neo4j_graphrag.experimental.components.kg_writer import KGWriter, Neo4jWriter from neo4j_graphrag.experimental.components.pdf_loader import PdfLoader from neo4j_graphrag.experimental.components.resolver import ( @@ -54,7 +55,6 @@ ) from neo4j_graphrag.experimental.components.types import ( LexicalGraphConfig, - SchemaEnforcementMode, ) from neo4j_graphrag.experimental.pipeline.config.object_config import ComponentType from neo4j_graphrag.experimental.pipeline.config.template_pipeline.base import ( @@ -79,6 +79,7 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig): "chunk_embedder", "schema", "extractor", + "pruner", "writer", "resolver", ] @@ -92,7 +93,6 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig): relations: Sequence[RelationInputType] = [] potential_schema: Optional[list[tuple[str, str, str]]] = None schema_: Optional[GraphSchema] = Field(default=None, alias="schema") - enforce_schema: SchemaEnforcementMode = SchemaEnforcementMode.NONE on_error: OnError = OnError.IGNORE prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate() perform_entity_resolution: bool = True @@ -223,7 +223,9 @@ def _process_schema_with_precedence( if self.relations is not None else None ) - patterns = tuple(self.potential_schema) if self.potential_schema else None + patterns = ( + tuple(self.potential_schema) if self.potential_schema else tuple() + ) return node_types, relationship_types, patterns @@ -247,10 +249,12 @@ def _get_extractor(self) -> EntityRelationExtractor: return LLMEntityRelationExtractor( llm=self.get_default_llm(), prompt_template=self.prompt_template, - enforce_schema=self.enforce_schema, on_error=self.on_error, ) + def _get_pruner(self) -> GraphPruning: + return GraphPruning() + def _get_writer(self) -> KGWriter: if self.kg_writer: return self.kg_writer.parse(self._global_data) # type: ignore @@ -332,9 +336,19 @@ def _get_connections(self) -> list[ConnectionDefinition]: connections.append( ConnectionDefinition( start="extractor", - end="writer", + end="pruner", input_config={ "graph": "extractor", + "schema": "schema", + }, + ) + ) + connections.append( + ConnectionDefinition( + start="pruner", + end="writer", + input_config={ + "graph": "pruner.graph", }, ) ) diff --git a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py index b0231e50f..ba81b042a 100644 --- a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py @@ -28,7 +28,6 @@ from neo4j_graphrag.experimental.components.text_splitters.base import TextSplitter from neo4j_graphrag.experimental.components.types import ( LexicalGraphConfig, - SchemaEnforcementMode, ) from neo4j_graphrag.experimental.pipeline.config.object_config import ComponentType from neo4j_graphrag.experimental.pipeline.config.runner import PipelineRunner @@ -71,7 +70,6 @@ class SimpleKGPipeline: - dict: following the RelationshipType schema, ie with label, description and properties keys potential_schema (Optional[List[tuple]]): DEPRECATED. A list of potential schema relationships. - enforce_schema (str): Validation of the extracted entities/rels against the provided schema. Defaults to "NONE", where schema enforcement will be ignored even if the schema is provided. Possible values "None" or "STRICT". from_pdf (bool): Determines whether to include the PdfLoader in the pipeline. If True, expects `file_path` input in `run` methods. If False, expects `text` input in `run` methods. @@ -93,7 +91,6 @@ def __init__( relations: Optional[Sequence[RelationInputType]] = None, potential_schema: Optional[List[tuple[str, str, str]]] = None, schema: Optional[Union[GraphSchema, dict[str, list[Any]]]] = None, - enforce_schema: str = "NONE", from_pdf: bool = True, text_splitter: Optional[TextSplitter] = None, pdf_loader: Optional[DataLoader] = None, @@ -114,7 +111,6 @@ def __init__( relations=relations or [], potential_schema=potential_schema, schema=schema, - enforce_schema=SchemaEnforcementMode(enforce_schema), from_pdf=from_pdf, pdf_loader=ComponentType(pdf_loader) if pdf_loader else None, kg_writer=ComponentType(kg_writer) if kg_writer else None, diff --git a/tests/e2e/experimental/test_graph_pruning_component_e2e.py b/tests/e2e/experimental/test_graph_pruning_component_e2e.py new file mode 100644 index 000000000..333e74163 --- /dev/null +++ b/tests/e2e/experimental/test_graph_pruning_component_e2e.py @@ -0,0 +1,436 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any + +import pytest + +from neo4j_graphrag.experimental.components.graph_pruning import GraphPruning +from neo4j_graphrag.experimental.components.schema import GraphSchema +from neo4j_graphrag.experimental.components.types import ( + Neo4jGraph, + Neo4jNode, + Neo4jRelationship, +) + + +@pytest.fixture +def extracted_graph() -> Neo4jGraph: + """This is the graph to be pruned in all the below tests, + using different schema configuration. + """ + return Neo4jGraph( + nodes=[ + Neo4jNode( + id="1", + label="Person", + properties={ + "name": "John Doe", + }, + ), + Neo4jNode( + id="2", + label="Person", + properties={ + "height": 180, + }, + ), + Neo4jNode( + id="3", + label="Person", + properties={ + "name": "Jane Doe", + "weight": 90, + }, + ), + Neo4jNode( + id="10", + label="Organization", + properties={ + "name": "Azerty Inc.", + "created": 1999, + }, + ), + ], + relationships=[ + Neo4jRelationship( + start_node_id="1", + end_node_id="2", + type="KNOWS", + properties={"firstMetIn": 2025}, + ), + Neo4jRelationship( + start_node_id="1", + end_node_id="3", + type="KNOWS", + ), + Neo4jRelationship( + start_node_id="1", + end_node_id="2", + type="MANAGES", + ), + Neo4jRelationship( + start_node_id="1", + end_node_id="10", + type="MANAGES", + ), + Neo4jRelationship( + start_node_id="1", + end_node_id="10", + type="WORKS_FOR", + ), + ], + ) + + +async def _test( + extracted_graph: Neo4jGraph, schema_dict: dict[str, Any], expected_graph: Neo4jGraph +) -> None: + schema = GraphSchema.model_validate(schema_dict) + pruner = GraphPruning() + res = await pruner.run(extracted_graph, schema) + assert res.graph == expected_graph + + +@pytest.mark.asyncio +async def test_graph_pruning_loose(extracted_graph: Neo4jGraph) -> None: + """Loose schema: + - no required properties + - all additional* allowed + + => we keep everything from the extracted graph + """ + schema_dict = { + "node_types": [ + { + "label": "Person", + "properties": [ + {"name": "name", "type": "STRING"}, + {"name": "height", "type": "INTEGER"}, + ], + } + ], + "relationship_types": [ + { + "label": "KNOWS", + } + ], + "patterns": [ + ("Person", "KNOWS", "Person"), + ], + } + await _test(extracted_graph, schema_dict, extracted_graph) + + +@pytest.mark.asyncio +async def test_graph_pruning_missing_required_property( + extracted_graph: Neo4jGraph, +) -> None: + """Person node type has a required 'name' property: + - extracted nodes without this property are pruned + - any relationship tied to this node is also pruned + """ + schema_dict = { + "node_types": [ + { + "label": "Person", + "properties": [ + { + "name": "name", + "type": "STRING", + "required": True, + }, + {"name": "height", "type": "INTEGER"}, + ], + } + ], + "relationship_types": [ + { + "label": "KNOWS", + } + ], + "patterns": [ + ("Person", "KNOWS", "Person"), + ], + } + filtered_graph = Neo4jGraph( + nodes=[ + Neo4jNode( + id="1", + label="Person", + properties={ + "name": "John Doe", + }, + ), + # do not have the required "name" property + # Neo4jNode( + # id="2", + # label="Person", + # properties={ + # "height": 180, + # } + # ), + Neo4jNode( + id="3", + label="Person", + properties={ + "name": "Jane Doe", + "weight": 90, + }, + ), + Neo4jNode( + id="10", + label="Organization", + properties={ + "name": "Azerty Inc.", + "created": 1999, + }, + ), + ], + relationships=[ + # node "2" was pruned + # Neo4jRelationship( + # start_node_id="1", + # end_node_id="2", + # type="KNOWS", + # properties={"firstMetIn": 2025}, + # ), + Neo4jRelationship( + start_node_id="1", + end_node_id="3", + type="KNOWS", + ), + # node "2" was pruned + # Neo4jRelationship( + # start_node_id="1", + # end_node_id="2", + # type="MANAGES", + # ), + Neo4jRelationship( + start_node_id="1", + end_node_id="10", + type="MANAGES", + ), + Neo4jRelationship( + start_node_id="1", + end_node_id="10", + type="WORKS_FOR", + ), + ], + ) + await _test(extracted_graph, schema_dict, filtered_graph) + + +@pytest.mark.asyncio +async def test_graph_pruning_strict_properties_and_node_types( + extracted_graph: Neo4jGraph, +) -> None: + """Additional properties on Person nodes are not allowed. + Additional node types are not allowed. + + => we prune "Organization" nodes (not in schema) + and the "weight" property that was extracted for some persons. + """ + schema_dict = { + "node_types": [ + { + "label": "Person", + "properties": [ + { + "name": "name", + "type": "STRING", + }, + {"name": "height", "type": "INTEGER"}, + ], + "additional_properties": False, + } + ], + "relationship_types": [ + { + "label": "KNOWS", + } + ], + "patterns": [ + ("Person", "KNOWS", "Person"), + ], + "additional_node_types": False, + } + filtered_graph = Neo4jGraph( + nodes=[ + Neo4jNode( + id="1", + label="Person", + properties={ + "name": "John Doe", + }, + ), + Neo4jNode( + id="2", + label="Person", + properties={ + "height": 180, + }, + ), + Neo4jNode( + id="3", + label="Person", + properties={ + "name": "Jane Doe", + # weight not in listed properties + # "weight": 90, + }, + ), + # label "Organization" not in schema + # Neo4jNode( + # id="10", + # label="Organization", + # properties={ + # "name": "Azerty Inc.", + # "created": 1999, + # } + # ), + ], + relationships=[ + Neo4jRelationship( + start_node_id="1", + end_node_id="2", + type="KNOWS", + properties={"firstMetIn": 2025}, + ), + Neo4jRelationship( + start_node_id="1", + end_node_id="3", + type="KNOWS", + ), + Neo4jRelationship( + start_node_id="1", + end_node_id="2", + type="MANAGES", + ), + # node "10" was pruned (label not allowed) + # Neo4jRelationship( + # start_node_id="1", + # end_node_id="10", + # type="MANAGES", + # ), + # Neo4jRelationship( + # start_node_id="1", + # end_node_id="10", + # type="WORKS_FOR", + # ) + ], + ) + await _test(extracted_graph, schema_dict, filtered_graph) + + +@pytest.mark.asyncio +async def test_graph_pruning_strict_patterns(extracted_graph: Neo4jGraph) -> None: + """Additional patterns not allowed: + + - MANAGES: it's a known relationship type but without any pattern, it's pruned + - WORKS_FOR: it's not a known relationship type, and additional_relationship_types is allowed + so we keep it. + """ + # - no additional patterns allowed + schema_dict = { + "node_types": [ + { + "label": "Person", + "properties": [ + { + "name": "name", + "type": "STRING", + }, + {"name": "height", "type": "INTEGER"}, + ], + }, + { + "label": "Organization", + }, + ], + "relationship_types": [ + { + "label": "KNOWS", + }, + { + "label": "MANAGES", + }, + ], + "patterns": ( + ("Person", "KNOWS", "Person"), + ("Person", "KNOWS", "Organization"), + ), + "additional_relationship_types": False, + "additional_patterns": False, + } + filtered_graph = Neo4jGraph( + nodes=[ + Neo4jNode( + id="1", + label="Person", + properties={ + "name": "John Doe", + }, + ), + Neo4jNode( + id="2", + label="Person", + properties={ + "height": 180, + }, + ), + Neo4jNode( + id="3", + label="Person", + properties={ + "name": "Jane Doe", + "weight": 90, + }, + ), + Neo4jNode( + id="10", + label="Organization", + properties={ + "name": "Azerty Inc.", + "created": 1999, + }, + ), + ], + relationships=[ + Neo4jRelationship( + start_node_id="1", + end_node_id="2", + type="KNOWS", + properties={"firstMetIn": 2025}, + ), + Neo4jRelationship( + start_node_id="1", + end_node_id="3", + type="KNOWS", + ), + # invalid pattern (person, manages, person) + # Neo4jRelationship( + # start_node_id="1", + # end_node_id="2", + # type="MANAGES", + # ), + # invalid pattern (person, works for, person) + # Neo4jRelationship( + # start_node_id="1", + # end_node_id="10", + # type="WORKS_FOR", + # ), + ], + ) + await _test(extracted_graph, schema_dict, filtered_graph) diff --git a/tests/unit/experimental/components/test_entity_relation_extractor.py b/tests/unit/experimental/components/test_entity_relation_extractor.py index 64ac0d42d..f76ab5c9c 100644 --- a/tests/unit/experimental/components/test_entity_relation_extractor.py +++ b/tests/unit/experimental/components/test_entity_relation_extractor.py @@ -25,13 +25,11 @@ balance_curly_braces, fix_invalid_json, ) -from neo4j_graphrag.experimental.components.schema import GraphSchema from neo4j_graphrag.experimental.components.types import ( DocumentInfo, Neo4jGraph, TextChunk, TextChunks, - SchemaEnforcementMode, ) from neo4j_graphrag.experimental.pipeline.exceptions import InvalidJSONError from neo4j_graphrag.llm import LLMInterface, LLMResponse @@ -231,435 +229,6 @@ async def test_extractor_custom_prompt() -> None: llm.ainvoke.assert_called_once_with("this is my prompt") -@pytest.mark.asyncio -async def test_extractor_no_schema_enforcement() -> None: - llm = MagicMock(spec=LLMInterface) - llm.ainvoke.return_value = LLMResponse( - content='{"nodes":[{"id":"0","label":"Alien","properties":{"foo":"bar"}}],' - '"relationships":[]}' - ) - - extractor = LLMEntityRelationExtractor( - llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.NONE - ) - - schema = GraphSchema.model_validate( - { - "node_types": [ - { - "label": "Person", - "properties": [{"name": "name", "type": "STRING"}], - } - ], - "relationship_types": [], - "patterns": [], - }, - ) - - chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) - - result: Neo4jGraph = await extractor.run(chunks=chunks, schema=schema) - - assert len(result.nodes) == 1 - assert result.nodes[0].label == "Alien" - assert result.nodes[0].properties == {"chunk_index": 0, "foo": "bar"} - - -@pytest.mark.asyncio -async def test_extractor_schema_enforcement_when_no_schema_provided() -> None: - llm = MagicMock(spec=LLMInterface) - llm.ainvoke.return_value = LLMResponse( - content='{"nodes":[{"id":"0","label":"Alien","properties":{"foo":"bar"}}],' - '"relationships":[]}' - ) - - extractor = LLMEntityRelationExtractor( - llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT - ) - - chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) - - result: Neo4jGraph = await extractor.run(chunks=chunks) - - assert len(result.nodes) == 1 - assert result.nodes[0].label == "Alien" - assert result.nodes[0].properties == {"chunk_index": 0, "foo": "bar"} - - -@pytest.mark.asyncio -async def test_extractor_schema_enforcement_invalid_nodes() -> None: - llm = MagicMock(spec=LLMInterface) - llm.ainvoke.return_value = LLMResponse( - content='{"nodes":[{"id":"0","label":"Alien","properties":{"foo":"bar"}},' - '{"id":"1","label":"Person","properties":{"name":"Alice"}}],' - '"relationships":[]}' - ) - - extractor = LLMEntityRelationExtractor( - llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT - ) - - schema = GraphSchema.model_validate( - { - "node_types": [ - { - "label": "Person", - "properties": [{"name": "name", "type": "STRING"}], - } - ], - "relationship_types": [], - "patterns": [], - }, - ) - - chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) - - result: Neo4jGraph = await extractor.run(chunks=chunks, schema=schema) - - assert len(result.nodes) == 1 - assert result.nodes[0].label == "Person" - assert result.nodes[0].properties == {"chunk_index": 0, "name": "Alice"} - - -@pytest.mark.asyncio -async def test_extraction_schema_enforcement_invalid_node_properties() -> None: - llm = MagicMock(spec=LLMInterface) - llm.ainvoke.return_value = LLMResponse( - content='{"nodes":[{"id":"1","label":"Person","properties":' - '{"name":"Alice","age":30,"foo":"bar"}}],' - '"relationships":[]}' - ) - - extractor = LLMEntityRelationExtractor( - llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT - ) - - schema = GraphSchema.model_validate( - { - "node_types": [ - { - "label": "Person", - "properties": [ - {"name": "name", "type": "STRING"}, - {"name": "age", "type": "STRING"}, - ], - } - ], - "relationship_types": [], - "patterns": [], - }, - ) - - chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) - - result: Neo4jGraph = await extractor.run(chunks, schema=schema) - - # "foo" is removed - assert len(result.nodes) == 1 - assert len(result.nodes[0].properties) == 3 - assert "foo" not in result.nodes[0].properties - - -@pytest.mark.asyncio -async def test_extractor_schema_enforcement_valid_nodes_with_empty_props() -> None: - llm = MagicMock(spec=LLMInterface) - llm.ainvoke.return_value = LLMResponse( - content='{"nodes":[{"id":"1","label":"Person","properties":{"foo":"bar"}}],' - '"relationships":[]}' - ) - - extractor = LLMEntityRelationExtractor( - llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT - ) - - schema = GraphSchema.model_validate( - { - "node_types": [ - { - "label": "Person", - } - ], - } - ) - chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) - - result: Neo4jGraph = await extractor.run(chunks, schema=schema) - - assert len(result.nodes) == 1 - - -@pytest.mark.asyncio -async def test_extractor_schema_enforcement_invalid_relations_wrong_types() -> None: - llm = MagicMock(spec=LLMInterface) - llm.ainvoke.return_value = LLMResponse( - content='{"nodes":[{"id":"1","label":"Person","properties":' - '{"name":"Alice"}},{"id":"2","label":"Person","properties":' - '{"name":"Bob"}}],' - '"relationships":[{"start_node_id":"1","end_node_id":"2",' - '"type":"FRIENDS_WITH","properties":{}}]}' - ) - - extractor = LLMEntityRelationExtractor( - llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT - ) - - schema = GraphSchema.model_validate( - { - "node_types": [ - { - "label": "Person", - "properties": [ - {"name": "name", "type": "STRING"}, - {"name": "age", "type": "STRING"}, - ], - } - ], - "relationship_types": [{"label": "LIKES"}], - "patterns": [], - } - ) - - chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) - - result: Neo4jGraph = await extractor.run(chunks, schema=schema) - - assert len(result.nodes) == 2 - assert len(result.relationships) == 0 - - -@pytest.mark.asyncio -async def test_extractor_schema_enforcement_invalid_relations_wrong_start_node() -> ( - None -): - llm = MagicMock(spec=LLMInterface) - llm.ainvoke.return_value = LLMResponse( - content='{"nodes":[{"id":"1","label":"Person","properties":{"name":"Alice"}},' - '{"id":"2","label":"Person","properties":{"name":"Bob"}}, ' - '{"id":"3","label":"City","properties":{"name":"London"}}],' - '"relationships":[{"start_node_id":"1","end_node_id":"2",' - '"type":"LIVES_IN","properties":{}}]}' - ) - - extractor = LLMEntityRelationExtractor( - llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT - ) - - schema = GraphSchema.model_validate( - { - "node_types": [ - { - "label": "Person", - "properties": [{"name": "name", "type": "STRING"}], - }, - { - "label": "City", - "properties": [{"name": "name", "type": "STRING"}], - }, - ], - "relationship_types": [{"label": "LIVES_IN"}], - "patterns": [("Person", "LIVES_IN", "City")], - } - ) - - chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) - - result: Neo4jGraph = await extractor.run(chunks, schema=schema) - - assert len(result.nodes) == 3 - assert len(result.relationships) == 0 - - -@pytest.mark.asyncio -async def test_extractor_schema_enforcement_invalid_relation_properties() -> None: - llm = MagicMock(spec=LLMInterface) - llm.ainvoke.return_value = LLMResponse( - content='{"nodes":[{"id":"1","label":"Person","properties":{"name":"Alice"}},' - '{"id":"2","label":"Person","properties":{"name":"Bob"}}],' - '"relationships":[{"start_node_id":"1","end_node_id":"2",' - '"type":"LIKES","properties":{"strength":"high","foo":"bar"}}]}' - ) - - extractor = LLMEntityRelationExtractor( - llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT - ) - - schema = GraphSchema.model_validate( - { - "node_types": [ - { - "label": "Person", - "properties": [{"name": "name", "type": "STRING"}], - } - ], - "relationship_types": [ - { - "label": "LIKES", - "properties": [{"name": "strength", "type": "STRING"}], - } - ], - "patterns": [], - } - ) - - chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) - - result: Neo4jGraph = await extractor.run(chunks, schema=schema) - - assert len(result.nodes) == 2 - assert len(result.relationships) == 1 - rel = result.relationships[0] - assert "foo" not in rel.properties - assert rel.properties["strength"] == "high" - - -@pytest.mark.asyncio -async def test_extractor_schema_enforcement_removed_relation_start_end_nodes() -> None: - llm = MagicMock(spec=LLMInterface) - llm.ainvoke.return_value = LLMResponse( - content='{"nodes":[{"id":"1","label":"Alien","properties":{}},' - '{"id":"2","label":"Robot","properties":{}}],' - '"relationships":[{"start_node_id":"1","end_node_id":"2",' - '"type":"LIKES","properties":{}}]}' - ) - - extractor = LLMEntityRelationExtractor( - llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT - ) - - schema = GraphSchema.model_validate( - { - "node_types": [ - { - "label": "Person", - "properties": [{"name": "name", "type": "STRING"}], - } - ], - "relationship_types": [{"label": "LIKES"}], - "patterns": [("Person", "LIKES", "Person")], - } - ) - - chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) - - result: Neo4jGraph = await extractor.run(chunks, schema=schema) - - assert len(result.nodes) == 0 - assert len(result.relationships) == 0 - - -@pytest.mark.asyncio -async def test_extractor_schema_enforcement_inverted_relation_direction() -> None: - llm = MagicMock(spec=LLMInterface) - llm.ainvoke.return_value = LLMResponse( - content='{"nodes":[{"id":"1","label":"Person","properties":{"name":"Alice"}},' - '{"id":"2","label":"City","properties":{"name":"London"}}],' - '"relationships":[{"start_node_id":"2","end_node_id":"1",' - '"type":"LIVES_IN","properties":{}}]}' - ) - - extractor = LLMEntityRelationExtractor( - llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT - ) - - schema = GraphSchema.model_validate( - { - "node_types": [ - { - "label": "Person", - "properties": [{"name": "name", "type": "STRING"}], - }, - { - "label": "City", - "properties": [{"name": "name", "type": "STRING"}], - }, - ], - "relationship_types": [{"label": "LIVES_IN"}], - "patterns": [("Person", "LIVES_IN", "City")], - } - ) - - chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) - - result: Neo4jGraph = await extractor.run(chunks, schema=schema) - - assert len(result.nodes) == 2 - assert len(result.relationships) == 1 - assert result.relationships[0].start_node_id.split(":")[1] == "1" - assert result.relationships[0].end_node_id.split(":")[1] == "2" - - -@pytest.mark.asyncio -async def test_extractor_schema_enforcement_none_relationships_in_schema() -> None: - llm = MagicMock(spec=LLMInterface) - llm.ainvoke.return_value = LLMResponse( - content='{"nodes":[{"id":"1","label":"Person","properties":' - '{"name":"Alice"}},{"id":"2","label":"Person","properties":' - '{"name":"Bob"}}],' - '"relationships":[{"start_node_id":"1","end_node_id":"2",' - '"type":"FRIENDS_WITH","properties":{}}]}' - ) - - extractor = LLMEntityRelationExtractor( - llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT - ) - - schema = GraphSchema.model_validate( - dict( - node_types=[ - { - "label": "Person", - "properties": [{"name": "name", "type": "STRING"}], - } - ], - relationship_types=None, - patterns=None, - ) - ) - - chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) - - result: Neo4jGraph = await extractor.run(chunks, schema=schema) - - assert len(result.nodes) == 2 - assert len(result.relationships) == 1 - assert result.relationships[0].type == "FRIENDS_WITH" - - -@pytest.mark.asyncio -async def test_extractor_schema_enforcement_empty_relationships_in_schema() -> None: - llm = MagicMock(spec=LLMInterface) - llm.ainvoke.return_value = LLMResponse( - content='{"nodes":[{"id":"1","label":"Person","properties":' - '{"name":"Alice"}},{"id":"2","label":"Person","properties":' - '{"name":"Bob"}}],' - '"relationships":[{"start_node_id":"1","end_node_id":"2",' - '"type":"FRIENDS_WITH","properties":{}}]}' - ) - - extractor = LLMEntityRelationExtractor( - llm=llm, create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT - ) - - schema = GraphSchema.model_validate( - dict( - node_types=[ - { - "label": "Person", - "properties": [{"name": "name", "type": "STRING"}], - } - ], - relationship_types=[], - patterns=None, - ) - ) - - chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) - - result: Neo4jGraph = await extractor.run(chunks, schema=schema) - - assert len(result.relationships) == 0 - - def test_fix_invalid_json_empty_result() -> None: json_string = "invalid json" diff --git a/tests/unit/experimental/components/test_graph_pruning.py b/tests/unit/experimental/components/test_graph_pruning.py new file mode 100644 index 000000000..8c2d490d1 --- /dev/null +++ b/tests/unit/experimental/components/test_graph_pruning.py @@ -0,0 +1,385 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations +from typing import Any +from unittest.mock import patch, Mock, ANY + +import pytest + +from neo4j_graphrag.experimental.components.graph_pruning import ( + GraphPruning, + GraphPruningResult, + PruningStats, +) +from neo4j_graphrag.experimental.components.schema import ( + NodeType, + PropertyType, + RelationshipType, + GraphSchema, +) +from neo4j_graphrag.experimental.components.types import ( + Neo4jNode, + Neo4jRelationship, + Neo4jGraph, +) + + +@pytest.mark.parametrize( + "properties, valid_properties, additional_properties, expected_filtered_properties", + [ + ( + # no required, additional allowed + { + "name": "John Does", + "age": 25, + }, + [ + PropertyType( + name="name", + type="STRING", + ) + ], + True, + { + "name": "John Does", + "age": 25, + }, + ), + ( + # no required, additional not allowed + { + "name": "John Does", + "age": 25, + }, + [ + PropertyType( + name="name", + type="STRING", + ) + ], + False, + { + "name": "John Does", + }, + ), + ], +) +def test_graph_pruning_filter_properties( + properties: dict[str, Any], + valid_properties: list[PropertyType], + additional_properties: bool, + expected_filtered_properties: dict[str, Any], +) -> None: + pruner = GraphPruning() + filtered_properties = pruner._filter_properties( + properties, + valid_properties, + additional_properties=additional_properties, + node_label="Label", + pruning_stats=PruningStats(), + ) + assert filtered_properties == expected_filtered_properties + + +@pytest.fixture(scope="module") +def node_type_no_properties() -> NodeType: + return NodeType(label="Person") + + +@pytest.fixture(scope="module") +def node_type_required_name() -> NodeType: + return NodeType( + label="Person", + properties=[ + PropertyType(name="name", type="STRING", required=True), + PropertyType(name="age", type="INTEGER"), + ], + ) + + +@pytest.mark.parametrize( + "node, entity, additional_node_types, expected_node", + [ + # all good, with default values + ( + Neo4jNode(id="1", label="Person", properties={"name": "John Doe"}), + "node_type_no_properties", + True, + Neo4jNode(id="1", label="Person", properties={"name": "John Doe"}), + ), + # properties empty (missing default) + ( + Neo4jNode(id="1", label="Person", properties={"age": 45}), + "node_type_required_name", + True, + None, + ), + # node label not is schema, additional not allowed + ( + Neo4jNode(id="1", label="Location", properties={"name": "New York"}), + None, + False, + None, + ), + # node label not is schema, additional allowed + ( + Neo4jNode(id="1", label="Location", properties={"name": "New York"}), + None, + True, + Neo4jNode(id="1", label="Location", properties={"name": "New York"}), + ), + ], +) +def test_graph_pruning_validate_node( + node: Neo4jNode, + entity: str, + additional_node_types: bool, + expected_node: Neo4jNode, + request: pytest.FixtureRequest, +) -> None: + e = request.getfixturevalue(entity) if entity else None + + pruner = GraphPruning() + result = pruner._validate_node(node, PruningStats(), e, additional_node_types) + if expected_node is not None: + assert result == expected_node + else: + assert result is None + + +@pytest.fixture +def neo4j_relationship() -> Neo4jRelationship: + return Neo4jRelationship( + start_node_id="1", + end_node_id="2", + type="REL", + properties={}, + ) + + +@pytest.fixture +def neo4j_reversed_relationship( + neo4j_relationship: Neo4jRelationship, +) -> Neo4jRelationship: + return Neo4jRelationship( + start_node_id=neo4j_relationship.end_node_id, + end_node_id=neo4j_relationship.start_node_id, + type=neo4j_relationship.type, + properties=neo4j_relationship.properties, + ) + + +@pytest.mark.parametrize( + "relationship, valid_nodes, relationship_type, additional_relationship_types, patterns, additional_patterns, expected_relationship", + [ + # all good + ( + "neo4j_relationship", # relationship, + { # valid_nodes + "1": "Person", + "2": "Location", + }, + RelationshipType( # relationship_type + label="REL", + ), + True, # additional_relationship_types + (("Person", "REL", "Location"),), # patterns + True, # additional_patterns + "neo4j_relationship", # expected_relationship + ), + # reverse relationship + ( + "neo4j_reversed_relationship", + { + "1": "Person", + "2": "Location", + }, + RelationshipType( + label="REL", + ), + True, # additional_relationship_types + (("Person", "REL", "Location"),), + True, # additional_patterns + "neo4j_relationship", + ), + # invalid start node ID + ( + "neo4j_reversed_relationship", + { + "10": "Person", + "2": "Location", + }, + RelationshipType( + label="REL", + ), + True, # additional_relationship_types + (("Person", "REL", "Location"),), + True, # additional_patterns + None, + ), + # invalid type addition allowed + ( + "neo4j_relationship", + { + "1": "Person", + "2": "Location", + }, + None, # relationship_type + True, # additional_relationship_types + (("Person", "REL", "Location"),), + True, # additional_patterns + "neo4j_relationship", + ), + # invalid type addition allowed but invalid node ID + ( + "neo4j_relationship", + { + "1": "Person", + }, + None, # relationship_type + True, # additional_relationship_types + (("Person", "REL", "Location"),), + True, # additional_patterns + None, + ), + # invalid_type_addition_not_allowed + ( + "neo4j_relationship", + { + "1": "Person", + "2": "Location", + }, + None, # relationship_type + False, # additional_relationship_types + (("Person", "REL", "Location"),), + True, # additional_patterns + None, + ), + # invalid pattern, addition allowed + ( + "neo4j_relationship", + { + "1": "Person", + "2": "Location", + }, + RelationshipType( + label="REL", + ), + True, # additional_relationship_types + (("Person", "REL", "Person"),), + True, # additional_patterns + "neo4j_relationship", + ), + # invalid pattern, addition not allowed + ( + "neo4j_relationship", + { + "1": "Person", + "2": "Location", + }, + RelationshipType( + label="REL", + ), + True, # additional_relationship_types + (("Person", "REL", "Person"),), + False, # additional_patterns + None, + ), + ], +) +def test_graph_pruning_validate_relationship( + relationship: str, + valid_nodes: dict[str, str], + relationship_type: RelationshipType, + additional_relationship_types: bool, + patterns: tuple[tuple[str, str, str], ...], + additional_patterns: bool, + expected_relationship: str | None, + request: pytest.FixtureRequest, +) -> None: + relationship_obj = request.getfixturevalue(relationship) + expected_relationship_obj = ( + request.getfixturevalue(expected_relationship) + if expected_relationship + else None + ) + + pruner = GraphPruning() + assert ( + pruner._validate_relationship( + relationship_obj, + valid_nodes, + PruningStats(), + relationship_type, + additional_relationship_types, + patterns, + additional_patterns, + ) + == expected_relationship_obj + ) + + +@patch("neo4j_graphrag.experimental.components.graph_pruning.GraphPruning._clean_graph") +@pytest.mark.asyncio +async def test_graph_pruning_run_happy_path( + mock_clean_graph: Mock, + node_type_required_name: NodeType, +) -> None: + initial_graph = Neo4jGraph( + nodes=[Neo4jNode(id="1", label="Person"), Neo4jNode(id="2", label="Location")], + ) + schema = GraphSchema(node_types=(node_type_required_name,)) + cleaned_graph = Neo4jGraph(nodes=[Neo4jNode(id="1", label="Person")]) + mock_clean_graph.return_value = (cleaned_graph, PruningStats()) + pruner = GraphPruning() + pruner_result = await pruner.run( + graph=initial_graph, + schema=schema, + ) + assert isinstance(pruner_result, GraphPruningResult) + assert pruner_result.graph == cleaned_graph + mock_clean_graph.assert_called_once_with(initial_graph, schema) + + +@pytest.mark.asyncio +async def test_graph_pruning_run_no_schema() -> None: + initial_graph = Neo4jGraph(nodes=[Neo4jNode(id="1", label="Person")]) + pruner = GraphPruning() + pruner_result = await pruner.run( + graph=initial_graph, + schema=None, + ) + assert isinstance(pruner_result, GraphPruningResult) + assert pruner_result.graph == initial_graph + + +@patch( + "neo4j_graphrag.experimental.components.graph_pruning.GraphPruning._enforce_nodes" +) +def test_graph_pruning_clean_graph( + mock_enforce_nodes: Mock, +) -> None: + mock_enforce_nodes.return_value = [] + initial_graph = Neo4jGraph(nodes=[Neo4jNode(id="1", label="Person")]) + schema = GraphSchema(node_types=()) + pruner = GraphPruning() + cleaned_graph, pruning_stats = pruner._clean_graph(initial_graph, schema) + assert cleaned_graph == Neo4jGraph() + assert isinstance(pruning_stats, PruningStats) + mock_enforce_nodes.assert_called_once_with( + [Neo4jNode(id="1", label="Person")], + schema, + ANY, + ) diff --git a/tests/unit/experimental/components/test_schema.py b/tests/unit/experimental/components/test_schema.py index 1aec51261..cfc987af7 100644 --- a/tests/unit/experimental/components/test_schema.py +++ b/tests/unit/experimental/components/test_schema.py @@ -16,9 +16,10 @@ import json from typing import Tuple -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, patch import pytest +from pydantic import ValidationError from neo4j_graphrag.exceptions import SchemaValidationError, SchemaExtractionError from neo4j_graphrag.experimental.components.schema import ( @@ -38,6 +39,59 @@ from neo4j_graphrag.utils.file_handler import FileFormat +def test_node_type_raise_error_if_misconfigured() -> None: + with pytest.raises(ValidationError): + NodeType( + label="test", + properties=[], + additional_properties=False, + ) + + +def test_relationship_type_raise_error_if_misconfigured() -> None: + with pytest.raises(ValidationError): + RelationshipType( + label="test", + properties=[], + additional_properties=False, + ) + + +def test_schema_additional_parameter_validation() -> None: + """Additional relationship types not allowed, but additional patterns allowed + + => raise Exception + """ + schema_dict = { + "node_types": [ + { + "label": "Person", + "properties": [ + { + "name": "name", + "type": "STRING", + }, + {"name": "height", "type": "INTEGER"}, + ], + } + ], + "relationship_types": [ + { + "label": "KNOWS", + } + ], + "patterns": [ + ("Person", "KNOWS", "Person"), + ], + "additional_patterns": False, + } + with pytest.raises( + ValidationError, + match="`additional_relationship_types` must be set to False when using `additional_patterns=False`", + ): + GraphSchema.model_validate(schema_dict) + + @pytest.fixture def valid_node_types() -> tuple[NodeType, ...]: return ( @@ -46,8 +100,9 @@ def valid_node_types() -> tuple[NodeType, ...]: description="An individual human being.", properties=[ PropertyType(name="birth date", type="ZONED_DATETIME"), - PropertyType(name="name", type="STRING"), + PropertyType(name="name", type="STRING", required=True), ], + additional_properties=False, ), NodeType( label="ORGANIZATION", @@ -64,9 +119,10 @@ def valid_relationship_types() -> tuple[RelationshipType, ...]: label="EMPLOYED_BY", description="Indicates employment relationship.", properties=[ - PropertyType(name="start_time", type="LOCAL_DATETIME"), + PropertyType(name="start_time", type="LOCAL_DATETIME", required=True), PropertyType(name="end_time", type="LOCAL_DATETIME"), ], + additional_properties=False, ), RelationshipType( label="ORGANIZED_BY", @@ -122,13 +178,16 @@ def test_create_schema_model_valid_data( valid_relationship_types: Tuple[RelationshipType, ...], valid_patterns: Tuple[Tuple[str, str, str], ...], ) -> None: - schema_instance = schema_builder.create_schema_model( + schema = schema_builder.create_schema_model( list(valid_node_types), list(valid_relationship_types), list(valid_patterns) ) - assert schema_instance.node_types == valid_node_types - assert schema_instance.relationship_types == valid_relationship_types - assert schema_instance.patterns == valid_patterns + assert schema.node_types == valid_node_types + assert schema.relationship_types == valid_relationship_types + assert schema.patterns == valid_patterns + assert schema.additional_node_types is True + assert schema.additional_relationship_types is True + assert schema.additional_patterns is True @pytest.mark.asyncio @@ -138,13 +197,25 @@ async def test_run_method( valid_relationship_types: Tuple[RelationshipType, ...], valid_patterns: Tuple[Tuple[str, str, str], ...], ) -> None: - schema = await schema_builder.run( - list(valid_node_types), list(valid_relationship_types), list(valid_patterns) - ) + with patch.object( + schema_builder, + "create_schema_model", + return_value=GraphSchema( + node_types=valid_node_types, + relationship_types=valid_relationship_types, + patterns=valid_patterns, + ), + ): + schema = await schema_builder.run( + list(valid_node_types), list(valid_relationship_types), list(valid_patterns) + ) assert schema.node_types == valid_node_types assert schema.relationship_types == valid_relationship_types assert schema.patterns == valid_patterns + assert schema.additional_node_types is True + assert schema.additional_relationship_types is True + assert schema.additional_patterns is True def test_create_schema_model_invalid_entity( @@ -159,7 +230,7 @@ def test_create_schema_model_invalid_entity( list(valid_relationship_types), list(patterns_with_invalid_entity), ) - assert "Entity 'NON_EXISTENT_ENTITY' is not defined" in str( + assert "Node type 'NON_EXISTENT_ENTITY' is not defined" in str( exc_info.value ), "Should fail due to non-existent entity" @@ -176,7 +247,7 @@ def test_create_schema_model_invalid_relation( list(valid_relationship_types), list(patterns_with_invalid_relation), ) - assert "Relation 'NON_EXISTENT_RELATION' is not defined" in str( + assert "Relationship type 'NON_EXISTENT_RELATION' is not defined" in str( exc_info.value ), "Should fail due to non-existent relation" @@ -191,7 +262,7 @@ def test_create_schema_model_no_potential_schema( ) assert schema_instance.node_types == valid_node_types assert schema_instance.relationship_types == valid_relationship_types - assert schema_instance.patterns is None + assert schema_instance.patterns == () def test_create_schema_model_no_relations_or_potential_schema( @@ -206,14 +277,17 @@ def test_create_schema_model_no_relations_or_potential_schema( assert person is not None assert person.description == "An individual human being." assert len(person.properties) == 2 + assert person.additional_properties is False org = schema_instance.node_type_from_label("ORGANIZATION") assert org is not None assert org.description == "A structured group of people with a common purpose." + assert org.additional_properties is True age = schema_instance.node_type_from_label("AGE") assert age is not None assert age.description == "Age of a person in years." + assert age.additional_properties is True def test_create_schema_model_missing_relations( @@ -225,7 +299,7 @@ def test_create_schema_model_missing_relations( schema_builder.create_schema_model( node_types=valid_node_types, patterns=valid_patterns ) - assert "Relations must also be provided when using a potential schema." in str( + assert "Relationship types must also be provided when using patterns." in str( exc_info.value ), "Should fail due to missing relations" diff --git a/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py b/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py index 73bb50eeb..6560fda41 100644 --- a/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py +++ b/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py @@ -222,14 +222,15 @@ def test_simple_kg_pipeline_config_connections_from_pdf() -> None: perform_entity_resolution=False, ) connections = config._get_connections() - assert len(connections) == 6 + assert len(connections) == 7 expected_connections = [ ("pdf_loader", "splitter"), ("pdf_loader", "schema"), ("schema", "extractor"), ("splitter", "chunk_embedder"), ("chunk_embedder", "extractor"), - ("extractor", "writer"), + ("extractor", "pruner"), + ("pruner", "writer"), ] for actual, expected in zip(connections, expected_connections): assert (actual.start, actual.end) == expected @@ -241,12 +242,13 @@ def test_simple_kg_pipeline_config_connections_from_text() -> None: perform_entity_resolution=False, ) connections = config._get_connections() - assert len(connections) == 4 + assert len(connections) == 5 expected_connections = [ ("schema", "extractor"), ("splitter", "chunk_embedder"), ("chunk_embedder", "extractor"), - ("extractor", "writer"), + ("extractor", "pruner"), + ("pruner", "writer"), ] for actual, expected in zip(connections, expected_connections): assert (actual.start, actual.end) == expected @@ -258,14 +260,15 @@ def test_simple_kg_pipeline_config_connections_with_er() -> None: perform_entity_resolution=True, ) connections = config._get_connections() - assert len(connections) == 7 + assert len(connections) == 8 expected_connections = [ ("pdf_loader", "splitter"), ("pdf_loader", "schema"), ("schema", "extractor"), ("splitter", "chunk_embedder"), ("chunk_embedder", "extractor"), - ("extractor", "writer"), + ("extractor", "pruner"), + ("pruner", "writer"), ("writer", "resolver"), ] for actual, expected in zip(connections, expected_connections): diff --git a/tests/unit/experimental/pipeline/test_kg_builder.py b/tests/unit/experimental/pipeline/test_kg_builder.py index b4ece857a..13d789cb2 100644 --- a/tests/unit/experimental/pipeline/test_kg_builder.py +++ b/tests/unit/experimental/pipeline/test_kg_builder.py @@ -149,20 +149,6 @@ def test_simple_kg_pipeline_on_error_invalid_value() -> None: ) -def test_simple_kg_pipeline_enforce_schema_invalid_value() -> None: - llm = MagicMock(spec=LLMInterface) - driver = MagicMock(spec=neo4j.Driver) - embedder = MagicMock(spec=Embedder) - - with pytest.raises(PipelineDefinitionError): - SimpleKGPipeline( - llm=llm, - driver=driver, - embedder=embedder, - enforce_schema="INVALID_VALUE", - ) - - @mock.patch( "neo4j_graphrag.experimental.components.kg_writer.get_version", return_value=((5, 23, 0), False, False),