diff --git a/docs/source/user_guide_kg_builder.rst b/docs/source/user_guide_kg_builder.rst index 0a299a0a3..f810d9ed2 100644 --- a/docs/source/user_guide_kg_builder.rst +++ b/docs/source/user_guide_kg_builder.rst @@ -998,31 +998,40 @@ By default, all extracted elements — including nodes, relationships, and prope Configuration Options --------------------- -- **Required Properties** +- **Required Properties** (default: ``False``) Required properties may be specified at the node or relationship type level. Any extracted node or relationship missing one or more of its required properties will be pruned from the graph. -- **Additional Properties** *(default: True)* +- **Additional Properties** This node- or relationship-level option determines whether extra properties not listed in the schema should be retained. - - If set to ``True`` (default), all extracted properties are retained. + - If set to ``True``, all extracted properties are retained. - If set to ``False``, only the properties defined in the schema are preserved; all others are removed. +.. note:: Default behavior + + By default, this flag is set to ``False`` if at least one property is defined, ``True`` otherwise. + + The same rule applies for `additional_node_types`, `additional_relationship_types` and `additional_patterns` described below. + +.. warning:: + + Defining a node or relationship types with no properties and `additional_properties_allowed=False` will raise a ValidationError. .. note:: Node pruning If, after property pruning using the above rule, a node is left without any property, it is removed from the graph. -- **Additional Node Types** *(default: True)* +- **Additional Node Types** This schema-level option specifies whether node types not defined in the schema are included in the graph. - - If set to ``True`` (default), such node types are retained. + - If set to ``True``, such node types are retained. - If set to ``False``, nodes with undefined types are removed. -- **Additional Relationship Types** *(default: True)* +- **Additional Relationship Types** This schema-level option specifies whether relationship types not defined in the schema are included in the graph. - - If set to ``True`` (default), such relationships are retained. + - If set to ``True``, such relationships are retained. - If set to ``False``, relationships with undefined types are removed. - **Additional Patterns** *(default: True)* diff --git a/examples/customize/build_graph/components/pruners/graph_pruner.py b/examples/customize/build_graph/components/pruners/graph_pruner.py index adf8694a1..7273188ad 100644 --- a/examples/customize/build_graph/components/pruners/graph_pruner.py +++ b/examples/customize/build_graph/components/pruners/graph_pruner.py @@ -41,7 +41,7 @@ Neo4jNode( id="Organization/Corp1", label="Organization", - properties={"name": "CorpA"}, + properties={"name": "Corp1"}, ), ], relationships=[ @@ -51,7 +51,7 @@ type="KNOWS", ), Neo4jRelationship( - start_node_id="Organization/CorpA", + start_node_id="Organization/Corp2", end_node_id="Person/Jack", type="WORKS_FOR", ), @@ -80,12 +80,14 @@ PropertyType(name="name", type="STRING", required=True), PropertyType(name="address", type="STRING"), ], + additional_properties=True, ), ), relationship_types=( RelationshipType( label="WORKS_FOR", properties=[PropertyType(name="since", type="LOCAL_DATETIME")], + additional_properties=True, ), RelationshipType( label="KNOWS", diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index 7e00eb586..6357ac56a 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -17,7 +17,7 @@ import json import logging import warnings -from typing import Any, Dict, List, Literal, Optional, Tuple, Union, Sequence +from typing import Any, Dict, List, Literal, Optional, Tuple, Union, Sequence, Callable from pathlib import Path from pydantic import ( @@ -27,6 +27,7 @@ validate_call, ConfigDict, ValidationError, + Field, ) from typing_extensions import Self @@ -74,6 +75,13 @@ class PropertyType(BaseModel): ) +def default_additional_item(key: str) -> Callable[[dict[str, Any]], bool]: + def wrapper(validated_data: dict[str, Any]) -> bool: + return len(validated_data.get(key, [])) == 0 + + return wrapper + + class NodeType(BaseModel): """ Represents a possible node in the graph. @@ -82,7 +90,9 @@ class NodeType(BaseModel): label: str description: str = "" properties: list[PropertyType] = [] - additional_properties: bool = True + additional_properties: bool = Field( + default_factory=default_additional_item("properties") + ) @model_validator(mode="before") @classmethod @@ -96,7 +106,8 @@ def validate_additional_properties(self) -> Self: if len(self.properties) == 0 and not self.additional_properties: raise ValueError( "Using `additional_properties=False` with no defined " - "properties will cause the model to be pruned during graph cleaning.", + "properties will cause the model to be pruned during graph cleaning. " + f"Define some properties or remove this NodeType: {self}" ) return self @@ -109,7 +120,9 @@ class RelationshipType(BaseModel): label: str description: str = "" properties: list[PropertyType] = [] - additional_properties: bool = True + additional_properties: bool = Field( + default_factory=default_additional_item("properties") + ) @model_validator(mode="before") @classmethod @@ -123,7 +136,8 @@ def validate_additional_properties(self) -> Self: if len(self.properties) == 0 and not self.additional_properties: raise ValueError( "Using `additional_properties=False` with no defined " - "properties will cause the model to be pruned during graph cleaning.", + "properties will cause the model to be pruned during graph cleaning. " + f"Define some properties or remove this RelationshipType: {self}" ) return self @@ -145,9 +159,15 @@ class GraphSchema(DataModel): relationship_types: Tuple[RelationshipType, ...] = tuple() patterns: Tuple[Tuple[str, str, str], ...] = tuple() - additional_node_types: bool = True - additional_relationship_types: bool = True - additional_patterns: bool = True + additional_node_types: bool = Field( + default_factory=default_additional_item("node_types") + ) + additional_relationship_types: bool = Field( + default_factory=default_additional_item("relationship_types") + ) + additional_patterns: bool = Field( + default_factory=default_additional_item("patterns") + ) _node_type_index: dict[str, NodeType] = PrivateAttr() _relationship_type_index: dict[str, RelationshipType] = PrivateAttr() diff --git a/tests/e2e/experimental/test_graph_pruning_component_e2e.py b/tests/e2e/experimental/test_graph_pruning_component_e2e.py index 333e74163..85a215793 100644 --- a/tests/e2e/experimental/test_graph_pruning_component_e2e.py +++ b/tests/e2e/experimental/test_graph_pruning_component_e2e.py @@ -119,6 +119,7 @@ async def test_graph_pruning_loose(extracted_graph: Neo4jGraph) -> None: {"name": "name", "type": "STRING"}, {"name": "height", "type": "INTEGER"}, ], + "additional_properties": True, } ], "relationship_types": [ @@ -129,6 +130,9 @@ async def test_graph_pruning_loose(extracted_graph: Neo4jGraph) -> None: "patterns": [ ("Person", "KNOWS", "Person"), ], + "additional_node_types": True, + "additional_relationship_types": True, + "additional_patterns": True, } await _test(extracted_graph, schema_dict, extracted_graph) @@ -153,6 +157,7 @@ async def test_graph_pruning_missing_required_property( }, {"name": "height", "type": "INTEGER"}, ], + "additional_properties": True, } ], "relationship_types": [ @@ -163,6 +168,9 @@ async def test_graph_pruning_missing_required_property( "patterns": [ ("Person", "KNOWS", "Person"), ], + "additional_node_types": True, + "additional_relationship_types": True, + "additional_patterns": True, } filtered_graph = Neo4jGraph( nodes=[ @@ -253,7 +261,7 @@ async def test_graph_pruning_strict_properties_and_node_types( }, {"name": "height", "type": "INTEGER"}, ], - "additional_properties": False, + # "additional_properties": False, # default value } ], "relationship_types": [ @@ -264,7 +272,9 @@ async def test_graph_pruning_strict_properties_and_node_types( "patterns": [ ("Person", "KNOWS", "Person"), ], - "additional_node_types": False, + # "additional_node_types": False, # default value + "additional_relationship_types": True, + "additional_patterns": True, } filtered_graph = Neo4jGraph( nodes=[ @@ -354,6 +364,7 @@ async def test_graph_pruning_strict_patterns(extracted_graph: Neo4jGraph) -> Non }, {"name": "height", "type": "INTEGER"}, ], + "additional_properties": True, }, { "label": "Organization", @@ -371,6 +382,7 @@ async def test_graph_pruning_strict_patterns(extracted_graph: Neo4jGraph) -> Non ("Person", "KNOWS", "Person"), ("Person", "KNOWS", "Organization"), ), + "additional_node_types": True, "additional_relationship_types": False, "additional_patterns": False, } diff --git a/tests/e2e/experimental/test_simplekgpipeline_e2e.py b/tests/e2e/experimental/test_simplekgpipeline_e2e.py index be1c04cf9..fb42ef40e 100644 --- a/tests/e2e/experimental/test_simplekgpipeline_e2e.py +++ b/tests/e2e/experimental/test_simplekgpipeline_e2e.py @@ -96,13 +96,13 @@ async def test_pipeline_builder_happy_path_legacy_schema( ] # Instantiate Entity and Relation objects - entities = ["PERSON", "ORGANIZATION", "HORCRUX", "LOCATION"] + entities = ["Person", "Organization", "Horcrux", "Location"] relations = ["SITUATED_AT", "INTERACTS", "OWNS", "LED_BY"] potential_schema = [ - ("PERSON", "SITUATED_AT", "LOCATION"), - ("PERSON", "INTERACTS", "PERSON"), - ("PERSON", "OWNS", "HORCRUX"), - ("ORGANIZATION", "LED_BY", "PERSON"), + ("Person", "SITUATED_AT", "Location"), + ("Person", "INTERACTS", "Person"), + ("Person", "OWNS", "Horcrux"), + ("Organization", "LED_BY", "Person"), ] # Additional arguments diff --git a/tests/unit/experimental/components/test_schema.py b/tests/unit/experimental/components/test_schema.py index f6c53016e..79e7712d9 100644 --- a/tests/unit/experimental/components/test_schema.py +++ b/tests/unit/experimental/components/test_schema.py @@ -15,7 +15,7 @@ from __future__ import annotations import json -from typing import Tuple +from typing import Tuple, Any from unittest.mock import AsyncMock, patch import pytest @@ -46,14 +46,36 @@ def test_node_type_initialization_from_string() -> None: assert node_type.properties == [] -def test_node_type_raise_error_if_misconfigured() -> None: +def test_node_type_additional_properties_default() -> None: + # default behavior: + node_type = NodeType.model_validate({"label": "Label"}) + assert node_type.additional_properties is True + node_type = NodeType.model_validate({"label": "Label", "properties": []}) + assert node_type.additional_properties is True + node_type = NodeType.model_validate( + {"label": "Label", "properties": [{"name": "name", "type": "STRING"}]} + ) + assert node_type.additional_properties is False + + # manually changing the default value + # impossible cases: no properties and no additional with pytest.raises(ValidationError): - NodeType( - label="test", - properties=[], - additional_properties=False, + NodeType.model_validate({"label": "Label", "additional_properties": False}) + with pytest.raises(ValidationError): + NodeType.model_validate( + {"label": "Label", "properties": [], "additional_properties": False} ) + # working case: properties and additional allowed + node_type = NodeType.model_validate( + { + "label": "Label", + "properties": [{"name": "name", "type": "STRING"}], + "additional_properties": True, + } + ) + assert node_type.additional_properties is True + def test_relationship_type_initialization_from_string() -> None: relationship_type = RelationshipType.model_validate("REL") @@ -62,14 +84,84 @@ def test_relationship_type_initialization_from_string() -> None: assert relationship_type.properties == [] -def test_relationship_type_raise_error_if_misconfigured() -> None: +def test_relationship_type_additional_properties_default() -> None: + relationship_type = RelationshipType.model_validate({"label": "REL"}) + assert relationship_type.additional_properties is True + relationship_type = RelationshipType.model_validate( + {"label": "REL", "properties": []} + ) + assert relationship_type.additional_properties is True + relationship_type = RelationshipType.model_validate( + {"label": "REL", "properties": [{"name": "name", "type": "STRING"}]} + ) + assert relationship_type.additional_properties is False + + # manually changing the default value + # impossible cases: no properties and no additional with pytest.raises(ValidationError): - RelationshipType( - label="test", - properties=[], - additional_properties=False, + RelationshipType.model_validate( + {"label": "REL", "additional_properties": False} + ) + with pytest.raises(ValidationError): + RelationshipType.model_validate( + {"label": "REL", "properties": [], "additional_properties": False} ) + # working case: properties and additional allowed + relationship_type = RelationshipType.model_validate( + { + "label": "REL", + "properties": [{"name": "name", "type": "STRING"}], + "additional_properties": True, + } + ) + assert relationship_type.additional_properties is True + + +def test_schema_additional_node_types_default() -> None: + schema_dict: dict[str, Any] = { + "node_types": [], + } + schema = GraphSchema.model_validate(schema_dict) + assert schema.additional_node_types is True + + schema_dict = { + "node_types": ["Person"], + } + schema = GraphSchema.model_validate(schema_dict) + assert schema.additional_node_types is False + + +def test_schema_additional_relationship_types_default() -> None: + schema_dict: dict[str, Any] = { + "node_types": [], + } + schema = GraphSchema.model_validate(schema_dict) + assert schema.additional_relationship_types is True + + schema_dict = { + "node_types": [], + "relationship_types": ["REL"], + } + schema = GraphSchema.model_validate(schema_dict) + assert schema.additional_relationship_types is False + + +def test_schema_additional_patterns_default() -> None: + schema_dict: dict[str, Any] = { + "node_types": [], + } + schema = GraphSchema.model_validate(schema_dict) + assert schema.additional_patterns is True + + schema_dict = { + "node_types": ["Person"], + "relationship_types": ["REL"], + "patterns": [("Person", "REL", "Person")], + } + schema = GraphSchema.model_validate(schema_dict) + assert schema.additional_patterns is False + def test_schema_additional_parameter_validation() -> None: """Additional relationship types not allowed, but additional patterns allowed @@ -97,6 +189,7 @@ def test_schema_additional_parameter_validation() -> None: "patterns": [ ("Person", "KNOWS", "Person"), ], + "additional_relationship_types": True, "additional_patterns": False, } with pytest.raises( @@ -199,9 +292,9 @@ def test_create_schema_model_valid_data( assert schema.node_types == valid_node_types assert schema.relationship_types == valid_relationship_types assert schema.patterns == valid_patterns - assert schema.additional_node_types is True - assert schema.additional_relationship_types is True - assert schema.additional_patterns is True + assert schema.additional_node_types is False + assert schema.additional_relationship_types is False + assert schema.additional_patterns is False @pytest.mark.asyncio @@ -227,9 +320,9 @@ async def test_run_method( assert schema.node_types == valid_node_types assert schema.relationship_types == valid_relationship_types assert schema.patterns == valid_patterns - assert schema.additional_node_types is True - assert schema.additional_relationship_types is True - assert schema.additional_patterns is True + assert schema.additional_node_types is False + assert schema.additional_relationship_types is False + assert schema.additional_patterns is False def test_create_schema_model_invalid_entity(