Skip to content

Commit d5d230e

Browse files
committed
Extract required properties from existing constraints
1 parent f4b94ff commit d5d230e

File tree

1 file changed

+75
-4
lines changed
  • src/neo4j_graphrag/experimental/components

1 file changed

+75
-4
lines changed

src/neo4j_graphrag/experimental/components/schema.py

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -473,13 +473,78 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema
473473

474474

475475
class SchemaFromExistingGraphExtractor(BaseSchemaBuilder):
476-
"""A class to build a GraphSchema object from an existing graph."""
476+
"""A class to build a GraphSchema object from an existing graph.
477477
478-
def __init__(self, driver: neo4j.Driver) -> None:
478+
Uses the get_structured_schema function to extract existing node labels,
479+
relationship types, properties and existence constraints.
480+
481+
By default, the built schema does not allow any additional item (property,
482+
node label, relationship type or pattern).
483+
484+
Args:
485+
driver (neo4j.Driver): connection to the neo4j database.
486+
additional_properties (bool, default False): see GraphSchema
487+
additional_node_types (bool, default False): see GraphSchema
488+
additional_relationship_types (bool, default False): see GraphSchema:
489+
additional_patterns (bool, default False): see GraphSchema:
490+
neo4j_database (Optional | str): name of the neo4j database to use
491+
"""
492+
493+
def __init__(
494+
self,
495+
driver: neo4j.Driver,
496+
additional_properties: bool = False,
497+
additional_node_types: bool = False,
498+
additional_relationship_types: bool = False,
499+
additional_patterns: bool = False,
500+
neo4j_database: Optional[str] = None,
501+
) -> None:
479502
self.driver = driver
503+
self.database = neo4j_database
504+
505+
self.additional_properties = additional_properties
506+
self.additional_node_types = additional_node_types
507+
self.additional_relationship_types = additional_relationship_types
508+
self.additional_patterns = additional_patterns
509+
510+
@staticmethod
511+
def _extract_required_properties(
512+
structured_schema: dict[str, Any],
513+
) -> list[tuple[str, str]]:
514+
"""Extract a list of (node label (or rel type), property name) for which
515+
an "EXISTENCE" or "KEY" constraint is defined in the DB.
516+
517+
Args:
518+
519+
structured_schema (dict[str, Any]): the result of the `get_structured_schema()` function.
520+
521+
Returns:
522+
523+
list of tuples of (node label (or rel type), property name)
524+
525+
"""
526+
schema_metadata = structured_schema.get("metadata", {})
527+
existence_constraint = [] # list of (node label, property name)
528+
for constraint in schema_metadata.get("constraints", []):
529+
if constraint["type"] in (
530+
"NODE_PROPERTY_EXISTENCE",
531+
"NODE_KEY",
532+
"RELATIONSHIP_PROPERTY_EXISTENCE",
533+
"RELATIONSHIP_KEY",
534+
):
535+
properties = constraint["properties"]
536+
labels = constraint["labelsOrTypes"]
537+
# note: existence constraint only apply to a single property
538+
# and a single label
539+
prop = properties[0]
540+
lab = labels[0]
541+
existence_constraint.append((lab, prop))
542+
return existence_constraint
543+
544+
async def run(self) -> GraphSchema:
545+
structured_schema = get_structured_schema(self.driver, database=self.database)
546+
existence_constraint = self._extract_required_properties(structured_schema)
480547

481-
async def run(self, **kwargs: Any) -> GraphSchema:
482-
structured_schema = get_structured_schema(self.driver)
483548
node_labels = set(structured_schema["node_props"].keys())
484549
node_types = [
485550
{
@@ -488,9 +553,11 @@ async def run(self, **kwargs: Any) -> GraphSchema:
488553
{
489554
"name": p["property"],
490555
"type": p["type"],
556+
"required": (key, p["property"]) in existence_constraint,
491557
}
492558
for p in properties
493559
],
560+
"additional_properties": self.additional_properties,
494561
}
495562
for key, properties in structured_schema["node_props"].items()
496563
]
@@ -502,6 +569,7 @@ async def run(self, **kwargs: Any) -> GraphSchema:
502569
{
503570
"name": p["property"],
504571
"type": p["type"],
572+
"required": (key, p["property"]) in existence_constraint,
505573
}
506574
for p in properties
507575
],
@@ -540,5 +608,8 @@ async def run(self, **kwargs: Any) -> GraphSchema:
540608
"node_types": node_types,
541609
"relationship_types": relationship_types,
542610
"patterns": patterns,
611+
"additional_node_types": self.additional_node_types,
612+
"additional_relationship_types": self.additional_relationship_types,
613+
"additional_patterns": self.additional_patterns,
543614
}
544615
)

0 commit comments

Comments
 (0)