Skip to content

Fix invalid extracted schema #375

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/user_guide_kg_builder.rst
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,7 @@ Instead of manually defining the schema, you can use the `SchemaFromTextExtracto
# Extract the schema from the text
extracted_schema = await schema_extractor.run(text="Some text")

The `SchemaFromTextExtractor` component analyzes the text and identifies entity types, relationship types, and their property types. It creates a complete `GraphSchema` object that can be used in the same way as a manually defined schema.
The `SchemaFromTextExtractor` component analyzes the text and identifies node types, relationship types, their property types, and the patterns connecting them. It creates a complete `GraphSchema` object that can be used in the same way as a manually defined schema.

You can also save and reload the extracted schema:

Expand Down
133 changes: 129 additions & 4 deletions src/neo4j_graphrag/experimental/components/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,118 @@ def __init__(
)
self._llm_params: dict[str, Any] = llm_params or {}

def _filter_invalid_patterns(
self,
patterns: List[Tuple[str, str, str]],
node_types: List[Dict[str, Any]],
relationship_types: Optional[List[Dict[str, Any]]] = None,
) -> List[Tuple[str, str, str]]:
"""
Filter out patterns that reference undefined node types or relationship types.

Args:
patterns: List of patterns to filter.
node_types: List of node type definitions.
relationship_types: Optional list of relationship type definitions.

Returns:
Filtered list of patterns containing only valid references.
"""
# Early returns for missing required types
if not node_types:
logging.info(
"Filtering out all patterns because no node types are defined. "
"Patterns reference node types that must be defined."
)
return []

if not relationship_types:
logging.info(
"Filtering out all patterns because no relationship types are defined. "
"GraphSchema validation requires relationship_types when patterns are provided."
)
return []

# Create sets of valid labels
valid_node_labels = {node_type["label"] for node_type in node_types}
valid_relationship_labels = {
rel_type["label"] for rel_type in relationship_types
}

# Filter patterns
filtered_patterns = []
for pattern in patterns:
if not (isinstance(pattern, (list, tuple)) and len(pattern) == 3):
continue

entity1, relation, entity2 = pattern

# Check if all components are valid
if (
entity1 in valid_node_labels
and entity2 in valid_node_labels
and relation in valid_relationship_labels
):
filtered_patterns.append(pattern)
else:
# Log invalid pattern with validation details
entity1_valid = entity1 in valid_node_labels
entity2_valid = entity2 in valid_node_labels
relation_valid = relation in valid_relationship_labels

logging.info(
f"Filtering out invalid pattern: {pattern}. "
f"Entity1 '{entity1}' valid: {entity1_valid}, "
f"Entity2 '{entity2}' valid: {entity2_valid}, "
f"Relation '{relation}' valid: {relation_valid}"
)

return filtered_patterns

def _filter_nodes_without_labels(
self, node_types: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""
Filter out node types that have no labels.

Args:
node_types: List of node type definitions.

Returns:
Filtered list of node types containing only those with valid labels.
"""
filtered_nodes = []
for node_type in node_types:
if node_type.get("label"):
filtered_nodes.append(node_type)
else:
logging.info(f"Filtering out node type with missing label: {node_type}")

return filtered_nodes

def _filter_relationships_without_labels(
self, relationship_types: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""
Filter out relationship types that have no labels.

Args:
relationship_types: List of relationship type definitions.

Returns:
Filtered list of relationship types containing only those with valid labels.
"""
filtered_relationships = []
for rel_type in relationship_types:
if rel_type.get("label"):
filtered_relationships.append(rel_type)
else:
logging.info(
f"Filtering out relationship type with missing label: {rel_type}"
)

return filtered_relationships

@validate_call
async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema:
"""
Expand Down Expand Up @@ -459,13 +571,13 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema
pass # Keep as is
# handle list
elif isinstance(extracted_schema, list):
if len(extracted_schema) > 0 and isinstance(extracted_schema[0], dict):
extracted_schema = extracted_schema[0]
elif len(extracted_schema) == 0:
logging.warning(
if len(extracted_schema) == 0:
logging.info(
"LLM returned an empty list for schema. Falling back to empty schema."
)
extracted_schema = {}
elif isinstance(extracted_schema[0], dict):
extracted_schema = extracted_schema[0]
else:
raise SchemaExtractionError(
f"Expected a dictionary or list of dictionaries, but got list containing: {type(extracted_schema[0])}"
Expand All @@ -488,6 +600,19 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema
"patterns"
)

# Filter out nodes and relationships without labels
extracted_node_types = self._filter_nodes_without_labels(extracted_node_types)
if extracted_relationship_types:
extracted_relationship_types = self._filter_relationships_without_labels(
extracted_relationship_types
)

# Filter out invalid patterns before validation
if extracted_patterns:
extracted_patterns = self._filter_invalid_patterns(
extracted_patterns, extracted_node_types, extracted_relationship_types
)

return GraphSchema.model_validate(
{
"node_types": extracted_node_types,
Expand Down
10 changes: 5 additions & 5 deletions src/neo4j_graphrag/generation/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,16 +207,16 @@ class SchemaExtractionTemplate(PromptTemplate):
You are a top-tier algorithm designed for extracting a labeled property graph schema in
structured formats.

Generate a generalized graph schema based on the input text. Identify key entity types,
Generate a generalized graph schema based on the input text. Identify key node types,
their relationship types, and property types.

IMPORTANT RULES:
1. Return only abstract schema information, not concrete instances.
2. Use singular PascalCase labels for entity types (e.g., Person, Company, Product).
3. Use UPPER_SNAKE_CASE for relationship types (e.g., WORKS_FOR, MANAGES).
2. Use singular PascalCase labels for node types (e.g., Person, Company, Product).
3. Use UPPER_SNAKE_CASE labels for relationship types (e.g., WORKS_FOR, MANAGES).
4. Include property definitions only when the type can be confidently inferred, otherwise omit them.
5. When defining potential_schema, ensure that every entity and relation mentioned exists in your entities and relations lists.
6. Do not create entity types that aren't clearly mentioned in the text.
5. When defining patterns, ensure that every node label and relationship label mentioned exists in your lists of node types and relationship types.
6. Do not create node types that aren't clearly mentioned in the text.
7. Keep your schema minimal and focused on clearly identifiable patterns in the text.

Accepted property types are: BOOLEAN, DATE, DURATION, FLOAT, INTEGER, LIST,
Expand Down
Loading