@@ -473,13 +473,78 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema
473
473
474
474
475
475
class SchemaFromExistingGraphExtractor (BaseSchemaBuilder ):
476
- """A class to build a GraphSchema object from an existing graph."""
476
+ """A class to build a GraphSchema object from an existing graph.
477
477
478
- def __init__ (self , driver : neo4j .Driver ) -> None :
478
+ Uses the get_structured_schema function to extract existing node labels,
479
+ relationship types, properties and existence constraints.
480
+
481
+ By default, the built schema does not allow any additional item (property,
482
+ node label, relationship type or pattern).
483
+
484
+ Args:
485
+ driver (neo4j.Driver): connection to the neo4j database.
486
+ additional_properties (bool, default False): see GraphSchema
487
+ additional_node_types (bool, default False): see GraphSchema
488
+ additional_relationship_types (bool, default False): see GraphSchema:
489
+ additional_patterns (bool, default False): see GraphSchema:
490
+ neo4j_database (Optional | str): name of the neo4j database to use
491
+ """
492
+
493
+ def __init__ (
494
+ self ,
495
+ driver : neo4j .Driver ,
496
+ additional_properties : bool = False ,
497
+ additional_node_types : bool = False ,
498
+ additional_relationship_types : bool = False ,
499
+ additional_patterns : bool = False ,
500
+ neo4j_database : Optional [str ] = None ,
501
+ ) -> None :
479
502
self .driver = driver
503
+ self .database = neo4j_database
504
+
505
+ self .additional_properties = additional_properties
506
+ self .additional_node_types = additional_node_types
507
+ self .additional_relationship_types = additional_relationship_types
508
+ self .additional_patterns = additional_patterns
509
+
510
+ @staticmethod
511
+ def _extract_required_properties (
512
+ structured_schema : dict [str , Any ],
513
+ ) -> list [tuple [str , str ]]:
514
+ """Extract a list of (node label (or rel type), property name) for which
515
+ an "EXISTENCE" or "KEY" constraint is defined in the DB.
516
+
517
+ Args:
518
+
519
+ structured_schema (dict[str, Any]): the result of the `get_structured_schema()` function.
520
+
521
+ Returns:
522
+
523
+ list of tuples of (node label (or rel type), property name)
524
+
525
+ """
526
+ schema_metadata = structured_schema .get ("metadata" , {})
527
+ existence_constraint = [] # list of (node label, property name)
528
+ for constraint in schema_metadata .get ("constraints" , []):
529
+ if constraint ["type" ] in (
530
+ "NODE_PROPERTY_EXISTENCE" ,
531
+ "NODE_KEY" ,
532
+ "RELATIONSHIP_PROPERTY_EXISTENCE" ,
533
+ "RELATIONSHIP_KEY" ,
534
+ ):
535
+ properties = constraint ["properties" ]
536
+ labels = constraint ["labelsOrTypes" ]
537
+ # note: existence constraint only apply to a single property
538
+ # and a single label
539
+ prop = properties [0 ]
540
+ lab = labels [0 ]
541
+ existence_constraint .append ((lab , prop ))
542
+ return existence_constraint
543
+
544
+ async def run (self ) -> GraphSchema :
545
+ structured_schema = get_structured_schema (self .driver , database = self .database )
546
+ existence_constraint = self ._extract_required_properties (structured_schema )
480
547
481
- async def run (self , ** kwargs : Any ) -> GraphSchema :
482
- structured_schema = get_structured_schema (self .driver )
483
548
node_labels = set (structured_schema ["node_props" ].keys ())
484
549
node_types = [
485
550
{
@@ -488,9 +553,11 @@ async def run(self, **kwargs: Any) -> GraphSchema:
488
553
{
489
554
"name" : p ["property" ],
490
555
"type" : p ["type" ],
556
+ "required" : (key , p ["property" ]) in existence_constraint ,
491
557
}
492
558
for p in properties
493
559
],
560
+ "additional_properties" : self .additional_properties ,
494
561
}
495
562
for key , properties in structured_schema ["node_props" ].items ()
496
563
]
@@ -502,6 +569,7 @@ async def run(self, **kwargs: Any) -> GraphSchema:
502
569
{
503
570
"name" : p ["property" ],
504
571
"type" : p ["type" ],
572
+ "required" : (key , p ["property" ]) in existence_constraint ,
505
573
}
506
574
for p in properties
507
575
],
@@ -540,5 +608,8 @@ async def run(self, **kwargs: Any) -> GraphSchema:
540
608
"node_types" : node_types ,
541
609
"relationship_types" : relationship_types ,
542
610
"patterns" : patterns ,
611
+ "additional_node_types" : self .additional_node_types ,
612
+ "additional_relationship_types" : self .additional_relationship_types ,
613
+ "additional_patterns" : self .additional_patterns ,
543
614
}
544
615
)
0 commit comments