Skip to content

Fix invalid extracted schema #375

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/user_guide_kg_builder.rst
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,7 @@ Instead of manually defining the schema, you can use the `SchemaFromTextExtracto
# Extract the schema from the text
extracted_schema = await schema_extractor.run(text="Some text")

The `SchemaFromTextExtractor` component analyzes the text and identifies entity types, relationship types, and their property types. It creates a complete `GraphSchema` object that can be used in the same way as a manually defined schema.
The `SchemaFromTextExtractor` component analyzes the text and identifies node types, relationship types, their property types, and the patterns connecting them. It creates a complete `GraphSchema` object that can be used in the same way as a manually defined schema.

You can also save and reload the extracted schema:

Expand Down
85 changes: 82 additions & 3 deletions src/neo4j_graphrag/experimental/components/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,79 @@ def __init__(
)
self._llm_params: dict[str, Any] = llm_params or {}

def _filter_invalid_patterns(
self,
patterns: List[Tuple[str, str, str]],
node_types: List[Dict[str, Any]],
relationship_types: Optional[List[Dict[str, Any]]] = None,
) -> List[Tuple[str, str, str]]:
"""
Filter out patterns that reference undefined node types or relationship types.

Args:
patterns: List of patterns to filter.
node_types: List of node type definitions.
relationship_types: Optional list of relationship type definitions.

Returns:
Filtered list of patterns containing only valid references.
"""
# Early returns for missing required types
if not node_types:
logging.warning(
"Filtering out all patterns because no node types are defined. "
"Patterns reference node types that must be defined."
)
return []

if not relationship_types:
logging.warning(
"Filtering out all patterns because no relationship types are defined. "
"GraphSchema validation requires relationship_types when patterns are provided."
)
return []

# Create sets of valid labels
valid_node_labels = {
node_type["label"] for node_type in node_types if node_type.get("label")
}

valid_relationship_labels = {
rel_type["label"]
for rel_type in relationship_types
if rel_type.get("label")
}

# Filter patterns
filtered_patterns = []
for pattern in patterns:
if not (isinstance(pattern, (list, tuple)) and len(pattern) == 3):
continue

entity1, relation, entity2 = pattern

# Check if all components are valid
if (
entity1 in valid_node_labels
and entity2 in valid_node_labels
and relation in valid_relationship_labels
):
filtered_patterns.append(pattern)
else:
# Log invalid pattern with validation details
entity1_valid = entity1 in valid_node_labels
entity2_valid = entity2 in valid_node_labels
relation_valid = relation in valid_relationship_labels

logging.warning(
f"Filtering out invalid pattern: {pattern}. "
f"Entity1 '{entity1}' valid: {entity1_valid}, "
f"Entity2 '{entity2}' valid: {entity2_valid}, "
f"Relation '{relation}' valid: {relation_valid}"
)

return filtered_patterns

@validate_call
async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema:
"""
Expand Down Expand Up @@ -459,13 +532,13 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema
pass # Keep as is
# handle list
elif isinstance(extracted_schema, list):
if len(extracted_schema) > 0 and isinstance(extracted_schema[0], dict):
extracted_schema = extracted_schema[0]
elif len(extracted_schema) == 0:
if len(extracted_schema) == 0:
logging.warning(
"LLM returned an empty list for schema. Falling back to empty schema."
)
extracted_schema = {}
elif isinstance(extracted_schema[0], dict):
extracted_schema = extracted_schema[0]
else:
raise SchemaExtractionError(
f"Expected a dictionary or list of dictionaries, but got list containing: {type(extracted_schema[0])}"
Expand All @@ -488,6 +561,12 @@ async def run(self, text: str, examples: str = "", **kwargs: Any) -> GraphSchema
"patterns"
)

# Filter out invalid patterns before validation
if extracted_patterns:
extracted_patterns = self._filter_invalid_patterns(
extracted_patterns, extracted_node_types, extracted_relationship_types
)

return GraphSchema.model_validate(
{
"node_types": extracted_node_types,
Expand Down
10 changes: 5 additions & 5 deletions src/neo4j_graphrag/generation/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,16 +207,16 @@ class SchemaExtractionTemplate(PromptTemplate):
You are a top-tier algorithm designed for extracting a labeled property graph schema in
structured formats.

Generate a generalized graph schema based on the input text. Identify key entity types,
Generate a generalized graph schema based on the input text. Identify key node types,
their relationship types, and property types.

IMPORTANT RULES:
1. Return only abstract schema information, not concrete instances.
2. Use singular PascalCase labels for entity types (e.g., Person, Company, Product).
3. Use UPPER_SNAKE_CASE for relationship types (e.g., WORKS_FOR, MANAGES).
2. Use singular PascalCase labels for node types (e.g., Person, Company, Product).
3. Use UPPER_SNAKE_CASE labels for relationship types (e.g., WORKS_FOR, MANAGES).
4. Include property definitions only when the type can be confidently inferred, otherwise omit them.
5. When defining potential_schema, ensure that every entity and relation mentioned exists in your entities and relations lists.
6. Do not create entity types that aren't clearly mentioned in the text.
5. When defining patterns, ensure that every node label and relationship label mentioned exists in your lists of node types and relationship types.
6. Do not create node types that aren't clearly mentioned in the text.
7. Keep your schema minimal and focused on clearly identifiable patterns in the text.

Accepted property types are: BOOLEAN, DATE, DURATION, FLOAT, INTEGER, LIST,
Expand Down
110 changes: 110 additions & 0 deletions tests/unit/experimental/components/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,3 +700,113 @@ async def test_schema_from_text_run_valid_json_array(
assert schema.patterns is not None
assert len(schema.patterns) == 1
assert schema.patterns[0] == ("Person", "WORKS_FOR", "Organization")


@pytest.fixture
def schema_json_with_invalid_node_patterns() -> str:
return """
{
"node_types": [
{
"label": "Person",
"properties": [
{"name": "name", "type": "STRING"}
]
},
{
"label": "Organization",
"properties": [
{"name": "name", "type": "STRING"}
]
}
],
"relationship_types": [
{
"label": "WORKS_FOR",
"properties": [
{"name": "since", "type": "DATE"}
]
}
],
"patterns": [
["Person", "WORKS_FOR", "Organization"],
["Person", "WORKS_FOR", "UndefinedNode"],
["UndefinedNode", "WORKS_FOR", "Organization"]
]
}
"""


@pytest.fixture
def schema_json_with_invalid_relationship_patterns() -> str:
return """
{
"node_types": [
{
"label": "Person",
"properties": [
{"name": "name", "type": "STRING"}
]
},
{
"label": "Organization",
"properties": [
{"name": "name", "type": "STRING"}
]
}
],
"relationship_types": [
{
"label": "WORKS_FOR",
"properties": [
{"name": "since", "type": "DATE"}
]
}
],
"patterns": [
["Person", "WORKS_FOR", "Organization"],
["Person", "UNDEFINED_RELATION", "Organization"],
["Organization", "ANOTHER_UNDEFINED_RELATION", "Person"]
]
}
"""


@pytest.mark.asyncio
async def test_schema_from_text_filters_invalid_node_patterns(
schema_from_text: SchemaFromTextExtractor,
mock_llm: AsyncMock,
schema_json_with_invalid_node_patterns: str,
) -> None:
# configure the mock LLM to return schema with invalid node patterns
mock_llm.ainvoke.return_value = LLMResponse(
content=schema_json_with_invalid_node_patterns
)

# run the schema extraction
schema = await schema_from_text.run(text="Sample text for extraction")

# verify that invalid node patterns were filtered out (2 out of 3 patterns should be removed)
assert schema.patterns is not None
assert len(schema.patterns) == 1
assert schema.patterns[0] == ("Person", "WORKS_FOR", "Organization")


@pytest.mark.asyncio
async def test_schema_from_text_filters_invalid_relationship_patterns(
schema_from_text: SchemaFromTextExtractor,
mock_llm: AsyncMock,
schema_json_with_invalid_relationship_patterns: str,
) -> None:
# configure the mock LLM to return schema with invalid relationship patterns
mock_llm.ainvoke.return_value = LLMResponse(
content=schema_json_with_invalid_relationship_patterns
)

# run the schema extraction
schema = await schema_from_text.run(text="Sample text for extraction")

# verify that invalid relationship patterns were filtered out (2 out of 3 patterns should be removed)
assert schema.patterns is not None
assert len(schema.patterns) == 1
assert schema.patterns[0] == ("Person", "WORKS_FOR", "Organization")