Skip to content

Switch default values for additional_* flags #369

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions docs/source/user_guide_kg_builder.rst
Original file line number Diff line number Diff line change
Expand Up @@ -998,31 +998,40 @@ By default, all extracted elements — including nodes, relationships, and prope
Configuration Options
---------------------

- **Required Properties**
- **Required Properties** (default: ``False``)
Required properties may be specified at the node or relationship type level. Any extracted node or relationship missing one or more of its required properties will be pruned from the graph.

- **Additional Properties** *(default: True)*
- **Additional Properties**
This node- or relationship-level option determines whether extra properties not listed in the schema should be retained.

- If set to ``True`` (default), all extracted properties are retained.
- If set to ``True``, all extracted properties are retained.
- If set to ``False``, only the properties defined in the schema are preserved; all others are removed.

.. note:: Default behavior

By default, this flag is set to ``False`` if at least one property is defined, ``True`` otherwise.

The same rule applies for `additional_node_types`, `additional_relationship_types` and `additional_patterns` described below.

.. warning::

Defining a node or relationship types with no properties and `additional_properties_allowed=False` will raise a ValidationError.

.. note:: Node pruning

If, after property pruning using the above rule, a node is left without any property, it is removed from the graph.


- **Additional Node Types** *(default: True)*
- **Additional Node Types**
This schema-level option specifies whether node types not defined in the schema are included in the graph.

- If set to ``True`` (default), such node types are retained.
- If set to ``True``, such node types are retained.
- If set to ``False``, nodes with undefined types are removed.

- **Additional Relationship Types** *(default: True)*
- **Additional Relationship Types**
This schema-level option specifies whether relationship types not defined in the schema are included in the graph.

- If set to ``True`` (default), such relationships are retained.
- If set to ``True``, such relationships are retained.
- If set to ``False``, relationships with undefined types are removed.

- **Additional Patterns** *(default: True)*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
Neo4jNode(
id="Organization/Corp1",
label="Organization",
properties={"name": "CorpA"},
properties={"name": "Corp1"},
),
],
relationships=[
Expand All @@ -51,7 +51,7 @@
type="KNOWS",
),
Neo4jRelationship(
start_node_id="Organization/CorpA",
start_node_id="Organization/Corp2",
end_node_id="Person/Jack",
type="WORKS_FOR",
),
Expand Down Expand Up @@ -80,12 +80,14 @@
PropertyType(name="name", type="STRING", required=True),
PropertyType(name="address", type="STRING"),
],
additional_properties=True,
),
),
relationship_types=(
RelationshipType(
label="WORKS_FOR",
properties=[PropertyType(name="since", type="LOCAL_DATETIME")],
additional_properties=True,
),
RelationshipType(
label="KNOWS",
Expand Down
36 changes: 28 additions & 8 deletions src/neo4j_graphrag/experimental/components/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import json
import logging
import warnings
from typing import Any, Dict, List, Literal, Optional, Tuple, Union, Sequence
from typing import Any, Dict, List, Literal, Optional, Tuple, Union, Sequence, Callable
from pathlib import Path

from pydantic import (
Expand All @@ -27,6 +27,7 @@
validate_call,
ConfigDict,
ValidationError,
Field,
)
from typing_extensions import Self

Expand Down Expand Up @@ -74,6 +75,13 @@ class PropertyType(BaseModel):
)


def default_additional_item(key: str) -> Callable[[dict[str, Any]], bool]:
def wrapper(validated_data: dict[str, Any]) -> bool:
return len(validated_data.get(key, [])) == 0

return wrapper


class NodeType(BaseModel):
"""
Represents a possible node in the graph.
Expand All @@ -82,7 +90,9 @@ class NodeType(BaseModel):
label: str
description: str = ""
properties: list[PropertyType] = []
additional_properties: bool = True
additional_properties: bool = Field(
default_factory=default_additional_item("properties")
)

@model_validator(mode="before")
@classmethod
Expand All @@ -96,7 +106,8 @@ def validate_additional_properties(self) -> Self:
if len(self.properties) == 0 and not self.additional_properties:
raise ValueError(
"Using `additional_properties=False` with no defined "
"properties will cause the model to be pruned during graph cleaning.",
"properties will cause the model to be pruned during graph cleaning. "
f"Define some properties or remove this NodeType: {self}"
)
return self

Expand All @@ -109,7 +120,9 @@ class RelationshipType(BaseModel):
label: str
description: str = ""
properties: list[PropertyType] = []
additional_properties: bool = True
additional_properties: bool = Field(
default_factory=default_additional_item("properties")
)

@model_validator(mode="before")
@classmethod
Expand All @@ -123,7 +136,8 @@ def validate_additional_properties(self) -> Self:
if len(self.properties) == 0 and not self.additional_properties:
raise ValueError(
"Using `additional_properties=False` with no defined "
"properties will cause the model to be pruned during graph cleaning.",
"properties will cause the model to be pruned during graph cleaning. "
f"Define some properties or remove this RelationshipType: {self}"
)
return self

Expand All @@ -145,9 +159,15 @@ class GraphSchema(DataModel):
relationship_types: Tuple[RelationshipType, ...] = tuple()
patterns: Tuple[Tuple[str, str, str], ...] = tuple()

additional_node_types: bool = True
additional_relationship_types: bool = True
additional_patterns: bool = True
additional_node_types: bool = Field(
default_factory=default_additional_item("node_types")
)
additional_relationship_types: bool = Field(
default_factory=default_additional_item("relationship_types")
)
additional_patterns: bool = Field(
default_factory=default_additional_item("patterns")
)

_node_type_index: dict[str, NodeType] = PrivateAttr()
_relationship_type_index: dict[str, RelationshipType] = PrivateAttr()
Expand Down
16 changes: 14 additions & 2 deletions tests/e2e/experimental/test_graph_pruning_component_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ async def test_graph_pruning_loose(extracted_graph: Neo4jGraph) -> None:
{"name": "name", "type": "STRING"},
{"name": "height", "type": "INTEGER"},
],
"additional_properties": True,
}
],
"relationship_types": [
Expand All @@ -129,6 +130,9 @@ async def test_graph_pruning_loose(extracted_graph: Neo4jGraph) -> None:
"patterns": [
("Person", "KNOWS", "Person"),
],
"additional_node_types": True,
"additional_relationship_types": True,
"additional_patterns": True,
}
await _test(extracted_graph, schema_dict, extracted_graph)

Expand All @@ -153,6 +157,7 @@ async def test_graph_pruning_missing_required_property(
},
{"name": "height", "type": "INTEGER"},
],
"additional_properties": True,
}
],
"relationship_types": [
Expand All @@ -163,6 +168,9 @@ async def test_graph_pruning_missing_required_property(
"patterns": [
("Person", "KNOWS", "Person"),
],
"additional_node_types": True,
"additional_relationship_types": True,
"additional_patterns": True,
}
filtered_graph = Neo4jGraph(
nodes=[
Expand Down Expand Up @@ -253,7 +261,7 @@ async def test_graph_pruning_strict_properties_and_node_types(
},
{"name": "height", "type": "INTEGER"},
],
"additional_properties": False,
# "additional_properties": False, # default value
}
],
"relationship_types": [
Expand All @@ -264,7 +272,9 @@ async def test_graph_pruning_strict_properties_and_node_types(
"patterns": [
("Person", "KNOWS", "Person"),
],
"additional_node_types": False,
# "additional_node_types": False, # default value
"additional_relationship_types": True,
"additional_patterns": True,
}
filtered_graph = Neo4jGraph(
nodes=[
Expand Down Expand Up @@ -354,6 +364,7 @@ async def test_graph_pruning_strict_patterns(extracted_graph: Neo4jGraph) -> Non
},
{"name": "height", "type": "INTEGER"},
],
"additional_properties": True,
},
{
"label": "Organization",
Expand All @@ -371,6 +382,7 @@ async def test_graph_pruning_strict_patterns(extracted_graph: Neo4jGraph) -> Non
("Person", "KNOWS", "Person"),
("Person", "KNOWS", "Organization"),
),
"additional_node_types": True,
"additional_relationship_types": False,
"additional_patterns": False,
}
Expand Down
10 changes: 5 additions & 5 deletions tests/e2e/experimental/test_simplekgpipeline_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,13 @@ async def test_pipeline_builder_happy_path_legacy_schema(
]

# Instantiate Entity and Relation objects
entities = ["PERSON", "ORGANIZATION", "HORCRUX", "LOCATION"]
entities = ["Person", "Organization", "Horcrux", "Location"]
relations = ["SITUATED_AT", "INTERACTS", "OWNS", "LED_BY"]
potential_schema = [
("PERSON", "SITUATED_AT", "LOCATION"),
("PERSON", "INTERACTS", "PERSON"),
("PERSON", "OWNS", "HORCRUX"),
("ORGANIZATION", "LED_BY", "PERSON"),
("Person", "SITUATED_AT", "Location"),
("Person", "INTERACTS", "Person"),
("Person", "OWNS", "Horcrux"),
("Organization", "LED_BY", "Person"),
]

# Additional arguments
Expand Down
Loading