@@ -430,6 +430,118 @@ def __init__(
430
430
)
431
431
self ._llm_params : dict [str , Any ] = llm_params or {}
432
432
433
+ def _filter_invalid_patterns (
434
+ self ,
435
+ patterns : List [Tuple [str , str , str ]],
436
+ node_types : List [Dict [str , Any ]],
437
+ relationship_types : Optional [List [Dict [str , Any ]]] = None ,
438
+ ) -> List [Tuple [str , str , str ]]:
439
+ """
440
+ Filter out patterns that reference undefined node types or relationship types.
441
+
442
+ Args:
443
+ patterns: List of patterns to filter.
444
+ node_types: List of node type definitions.
445
+ relationship_types: Optional list of relationship type definitions.
446
+
447
+ Returns:
448
+ Filtered list of patterns containing only valid references.
449
+ """
450
+ # Early returns for missing required types
451
+ if not node_types :
452
+ logging .info (
453
+ "Filtering out all patterns because no node types are defined. "
454
+ "Patterns reference node types that must be defined."
455
+ )
456
+ return []
457
+
458
+ if not relationship_types :
459
+ logging .info (
460
+ "Filtering out all patterns because no relationship types are defined. "
461
+ "GraphSchema validation requires relationship_types when patterns are provided."
462
+ )
463
+ return []
464
+
465
+ # Create sets of valid labels
466
+ valid_node_labels = {node_type ["label" ] for node_type in node_types }
467
+ valid_relationship_labels = {
468
+ rel_type ["label" ] for rel_type in relationship_types
469
+ }
470
+
471
+ # Filter patterns
472
+ filtered_patterns = []
473
+ for pattern in patterns :
474
+ if not (isinstance (pattern , (list , tuple )) and len (pattern ) == 3 ):
475
+ continue
476
+
477
+ entity1 , relation , entity2 = pattern
478
+
479
+ # Check if all components are valid
480
+ if (
481
+ entity1 in valid_node_labels
482
+ and entity2 in valid_node_labels
483
+ and relation in valid_relationship_labels
484
+ ):
485
+ filtered_patterns .append (pattern )
486
+ else :
487
+ # Log invalid pattern with validation details
488
+ entity1_valid = entity1 in valid_node_labels
489
+ entity2_valid = entity2 in valid_node_labels
490
+ relation_valid = relation in valid_relationship_labels
491
+
492
+ logging .info (
493
+ f"Filtering out invalid pattern: { pattern } . "
494
+ f"Entity1 '{ entity1 } ' valid: { entity1_valid } , "
495
+ f"Entity2 '{ entity2 } ' valid: { entity2_valid } , "
496
+ f"Relation '{ relation } ' valid: { relation_valid } "
497
+ )
498
+
499
+ return filtered_patterns
500
+
501
+ def _filter_nodes_without_labels (
502
+ self , node_types : List [Dict [str , Any ]]
503
+ ) -> List [Dict [str , Any ]]:
504
+ """
505
+ Filter out node types that have no labels.
506
+
507
+ Args:
508
+ node_types: List of node type definitions.
509
+
510
+ Returns:
511
+ Filtered list of node types containing only those with valid labels.
512
+ """
513
+ filtered_nodes = []
514
+ for node_type in node_types :
515
+ if node_type .get ("label" ):
516
+ filtered_nodes .append (node_type )
517
+ else :
518
+ logging .info (f"Filtering out node type with missing label: { node_type } " )
519
+
520
+ return filtered_nodes
521
+
522
+ def _filter_relationships_without_labels (
523
+ self , relationship_types : List [Dict [str , Any ]]
524
+ ) -> List [Dict [str , Any ]]:
525
+ """
526
+ Filter out relationship types that have no labels.
527
+
528
+ Args:
529
+ relationship_types: List of relationship type definitions.
530
+
531
+ Returns:
532
+ Filtered list of relationship types containing only those with valid labels.
533
+ """
534
+ filtered_relationships = []
535
+ for rel_type in relationship_types :
536
+ if rel_type .get ("label" ):
537
+ filtered_relationships .append (rel_type )
538
+ else :
539
+ logging .info (
540
+ f"Filtering out relationship type with missing label: { rel_type } "
541
+ )
542
+
543
+ return filtered_relationships
544
+
433
545
@validate_call
434
546
async def run (self , text : str , examples : str = "" , ** kwargs : Any ) -> GraphSchema :
435
547
"""
@@ -459,13 +571,13 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema
459
571
pass # Keep as is
460
572
# handle list
461
573
elif isinstance (extracted_schema , list ):
462
- if len (extracted_schema ) > 0 and isinstance (extracted_schema [0 ], dict ):
463
- extracted_schema = extracted_schema [0 ]
464
- elif len (extracted_schema ) == 0 :
465
- logging .warning (
574
+ if len (extracted_schema ) == 0 :
575
+ logging .info (
466
576
"LLM returned an empty list for schema. Falling back to empty schema."
467
577
)
468
578
extracted_schema = {}
579
+ elif isinstance (extracted_schema [0 ], dict ):
580
+ extracted_schema = extracted_schema [0 ]
469
581
else :
470
582
raise SchemaExtractionError (
471
583
f"Expected a dictionary or list of dictionaries, but got list containing: { type (extracted_schema [0 ])} "
@@ -488,6 +600,19 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema
488
600
"patterns"
489
601
)
490
602
603
+ # Filter out nodes and relationships without labels
604
+ extracted_node_types = self ._filter_nodes_without_labels (extracted_node_types )
605
+ if extracted_relationship_types :
606
+ extracted_relationship_types = self ._filter_relationships_without_labels (
607
+ extracted_relationship_types
608
+ )
609
+
610
+ # Filter out invalid patterns before validation
611
+ if extracted_patterns :
612
+ extracted_patterns = self ._filter_invalid_patterns (
613
+ extracted_patterns , extracted_node_types , extracted_relationship_types
614
+ )
615
+
491
616
return GraphSchema .model_validate (
492
617
{
493
618
"node_types" : extracted_node_types ,
0 commit comments