Skip to content

Commit 522d232

Browse files
Fix invalid extracted schema (neo4j#375)
* Update default prompt for schema extraction * Filter invalid patterns from extracted schema * Add unit tests * Improve documentation * Change warning to info * Correct other logging level * Filter out nodes and rels without labels * Add more unit tests
1 parent a07cd9c commit 522d232

File tree

4 files changed

+392
-10
lines changed

4 files changed

+392
-10
lines changed

docs/source/user_guide_kg_builder.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -837,7 +837,7 @@ Instead of manually defining the schema, you can use the `SchemaFromTextExtracto
837837
# Extract the schema from the text
838838
extracted_schema = await schema_extractor.run(text="Some text")
839839
840-
The `SchemaFromTextExtractor` component analyzes the text and identifies entity types, relationship types, and their property types. It creates a complete `GraphSchema` object that can be used in the same way as a manually defined schema.
840+
The `SchemaFromTextExtractor` component analyzes the text and identifies node types, relationship types, their property types, and the patterns connecting them. It creates a complete `GraphSchema` object that can be used in the same way as a manually defined schema.
841841

842842
You can also save and reload the extracted schema:
843843

src/neo4j_graphrag/experimental/components/schema.py

Lines changed: 129 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,118 @@ def __init__(
430430
)
431431
self._llm_params: dict[str, Any] = llm_params or {}
432432

433+
def _filter_invalid_patterns(
434+
self,
435+
patterns: List[Tuple[str, str, str]],
436+
node_types: List[Dict[str, Any]],
437+
relationship_types: Optional[List[Dict[str, Any]]] = None,
438+
) -> List[Tuple[str, str, str]]:
439+
"""
440+
Filter out patterns that reference undefined node types or relationship types.
441+
442+
Args:
443+
patterns: List of patterns to filter.
444+
node_types: List of node type definitions.
445+
relationship_types: Optional list of relationship type definitions.
446+
447+
Returns:
448+
Filtered list of patterns containing only valid references.
449+
"""
450+
# Early returns for missing required types
451+
if not node_types:
452+
logging.info(
453+
"Filtering out all patterns because no node types are defined. "
454+
"Patterns reference node types that must be defined."
455+
)
456+
return []
457+
458+
if not relationship_types:
459+
logging.info(
460+
"Filtering out all patterns because no relationship types are defined. "
461+
"GraphSchema validation requires relationship_types when patterns are provided."
462+
)
463+
return []
464+
465+
# Create sets of valid labels
466+
valid_node_labels = {node_type["label"] for node_type in node_types}
467+
valid_relationship_labels = {
468+
rel_type["label"] for rel_type in relationship_types
469+
}
470+
471+
# Filter patterns
472+
filtered_patterns = []
473+
for pattern in patterns:
474+
if not (isinstance(pattern, (list, tuple)) and len(pattern) == 3):
475+
continue
476+
477+
entity1, relation, entity2 = pattern
478+
479+
# Check if all components are valid
480+
if (
481+
entity1 in valid_node_labels
482+
and entity2 in valid_node_labels
483+
and relation in valid_relationship_labels
484+
):
485+
filtered_patterns.append(pattern)
486+
else:
487+
# Log invalid pattern with validation details
488+
entity1_valid = entity1 in valid_node_labels
489+
entity2_valid = entity2 in valid_node_labels
490+
relation_valid = relation in valid_relationship_labels
491+
492+
logging.info(
493+
f"Filtering out invalid pattern: {pattern}. "
494+
f"Entity1 '{entity1}' valid: {entity1_valid}, "
495+
f"Entity2 '{entity2}' valid: {entity2_valid}, "
496+
f"Relation '{relation}' valid: {relation_valid}"
497+
)
498+
499+
return filtered_patterns
500+
501+
def _filter_nodes_without_labels(
502+
self, node_types: List[Dict[str, Any]]
503+
) -> List[Dict[str, Any]]:
504+
"""
505+
Filter out node types that have no labels.
506+
507+
Args:
508+
node_types: List of node type definitions.
509+
510+
Returns:
511+
Filtered list of node types containing only those with valid labels.
512+
"""
513+
filtered_nodes = []
514+
for node_type in node_types:
515+
if node_type.get("label"):
516+
filtered_nodes.append(node_type)
517+
else:
518+
logging.info(f"Filtering out node type with missing label: {node_type}")
519+
520+
return filtered_nodes
521+
522+
def _filter_relationships_without_labels(
523+
self, relationship_types: List[Dict[str, Any]]
524+
) -> List[Dict[str, Any]]:
525+
"""
526+
Filter out relationship types that have no labels.
527+
528+
Args:
529+
relationship_types: List of relationship type definitions.
530+
531+
Returns:
532+
Filtered list of relationship types containing only those with valid labels.
533+
"""
534+
filtered_relationships = []
535+
for rel_type in relationship_types:
536+
if rel_type.get("label"):
537+
filtered_relationships.append(rel_type)
538+
else:
539+
logging.info(
540+
f"Filtering out relationship type with missing label: {rel_type}"
541+
)
542+
543+
return filtered_relationships
544+
433545
@validate_call
434546
async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema:
435547
"""
@@ -459,13 +571,13 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema
459571
pass # Keep as is
460572
# handle list
461573
elif isinstance(extracted_schema, list):
462-
if len(extracted_schema) > 0 and isinstance(extracted_schema[0], dict):
463-
extracted_schema = extracted_schema[0]
464-
elif len(extracted_schema) == 0:
465-
logging.warning(
574+
if len(extracted_schema) == 0:
575+
logging.info(
466576
"LLM returned an empty list for schema. Falling back to empty schema."
467577
)
468578
extracted_schema = {}
579+
elif isinstance(extracted_schema[0], dict):
580+
extracted_schema = extracted_schema[0]
469581
else:
470582
raise SchemaExtractionError(
471583
f"Expected a dictionary or list of dictionaries, but got list containing: {type(extracted_schema[0])}"
@@ -488,6 +600,19 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema
488600
"patterns"
489601
)
490602

603+
# Filter out nodes and relationships without labels
604+
extracted_node_types = self._filter_nodes_without_labels(extracted_node_types)
605+
if extracted_relationship_types:
606+
extracted_relationship_types = self._filter_relationships_without_labels(
607+
extracted_relationship_types
608+
)
609+
610+
# Filter out invalid patterns before validation
611+
if extracted_patterns:
612+
extracted_patterns = self._filter_invalid_patterns(
613+
extracted_patterns, extracted_node_types, extracted_relationship_types
614+
)
615+
491616
return GraphSchema.model_validate(
492617
{
493618
"node_types": extracted_node_types,

src/neo4j_graphrag/generation/prompts.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -207,16 +207,16 @@ class SchemaExtractionTemplate(PromptTemplate):
207207
You are a top-tier algorithm designed for extracting a labeled property graph schema in
208208
structured formats.
209209
210-
Generate a generalized graph schema based on the input text. Identify key entity types,
210+
Generate a generalized graph schema based on the input text. Identify key node types,
211211
their relationship types, and property types.
212212
213213
IMPORTANT RULES:
214214
1. Return only abstract schema information, not concrete instances.
215-
2. Use singular PascalCase labels for entity types (e.g., Person, Company, Product).
216-
3. Use UPPER_SNAKE_CASE for relationship types (e.g., WORKS_FOR, MANAGES).
215+
2. Use singular PascalCase labels for node types (e.g., Person, Company, Product).
216+
3. Use UPPER_SNAKE_CASE labels for relationship types (e.g., WORKS_FOR, MANAGES).
217217
4. Include property definitions only when the type can be confidently inferred, otherwise omit them.
218-
5. When defining potential_schema, ensure that every entity and relation mentioned exists in your entities and relations lists.
219-
6. Do not create entity types that aren't clearly mentioned in the text.
218+
5. When defining patterns, ensure that every node label and relationship label mentioned exists in your lists of node types and relationship types.
219+
6. Do not create node types that aren't clearly mentioned in the text.
220220
7. Keep your schema minimal and focused on clearly identifiable patterns in the text.
221221
222222
Accepted property types are: BOOLEAN, DATE, DURATION, FLOAT, INTEGER, LIST,

0 commit comments

Comments
 (0)