Skip to content

Add optional schema enforcement for KG builder as a validation layer after entity and relation extraction #296

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 12, 2025
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
## Next

### Added

- Added optional schema enforcement as a validation layer after entity and relation extraction.
- Introduced SearchQueryParseError for handling invalid Lucene query strings in HybridRetriever and HybridCypherRetriever.

## 1.5.0
Expand Down
29 changes: 27 additions & 2 deletions docs/source/user_guide_kg_builder.rst
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ This schema information can be provided to the `SimpleKGBuilder` as demonstrated
# ...
)
Prompt Template, Lexical Graph Config and Error Behavior
--------------------------------------------------------
Extra configurations
--------------------

These parameters are part of the `EntityAndRelationExtractor` component.
For detailed information, refer to the section on :ref:`Entity and Relation Extractor`.
Expand All @@ -138,6 +138,7 @@ They are also accessible via the `SimpleKGPipeline` interface.
# ...
prompt_template="",
lexical_graph_config=my_config,
enforce_schema="STRICT"
on_error="RAISE",
# ...
)
Expand Down Expand Up @@ -829,6 +830,30 @@ It can be used in this way:

The LLM to use can be customized, the only constraint is that it obeys the :ref:`LLMInterface <llminterface>`.

Schema Enforcement Behaviour
----------------------------
By default, even if a schema is provided to guide the LLM in the entity and relation extraction, the LLM response is not validated against that schema.
This behaviour can be changed by using the `enforce_schema` flag in the `LLMEntityRelationExtractor` constructor:

.. code:: python
from neo4j_graphrag.experimental.components.entity_relation_extractor import LLMEntityRelationExtractor
from neo4j_graphrag.experimental.components.types import SchemaEnforcementMode
extractor = LLMEntityRelationExtractor(
# ...
enforce_schema=SchemaEnforcementMode.STRICT,
)
In this scenario, any extracted node/relation/property that is not part of the provided schema will be pruned.
Any relation whose start node or end node does not conform to the provided tuple in `potential_schema` will be pruned.
If a relation start/end nodes are valid but the direction is incorrect, the latter will be inverted.
If a node is left with no properties, it will be also pruned.

.. warning::

Note that if the schema enforcement mode is on but the schema is not provided, no schema enforcement will be applied.

Error Behaviour
---------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import enum
import json
import logging
from typing import Any, List, Optional, Union, cast
from typing import Any, List, Optional, Union, cast, Dict

import json_repair
from pydantic import ValidationError, validate_call
Expand All @@ -31,8 +31,11 @@
DocumentInfo,
LexicalGraphConfig,
Neo4jGraph,
Neo4jNode,
Neo4jRelationship,
TextChunk,
TextChunks,
SchemaEnforcementMode,
)
from neo4j_graphrag.experimental.pipeline.component import Component
from neo4j_graphrag.experimental.pipeline.exceptions import InvalidJSONError
Expand Down Expand Up @@ -168,6 +171,7 @@ class LLMEntityRelationExtractor(EntityRelationExtractor):
llm (LLMInterface): The language model to use for extraction.
prompt_template (ERExtractionTemplate | str): A custom prompt template to use for extraction.
create_lexical_graph (bool): Whether to include the text chunks in the graph in addition to the extracted entities and relations. Defaults to True.
enforce_schema (SchemaEnforcementMode): Whether to validate or not the extracted entities/rels against the provided schema. Defaults to None.
on_error (OnError): What to do when an error occurs during extraction. Defaults to raising an error.
max_concurrency (int): The maximum number of concurrent tasks which can be used to make requests to the LLM.

Expand All @@ -192,11 +196,13 @@ def __init__(
llm: LLMInterface,
prompt_template: ERExtractionTemplate | str = ERExtractionTemplate(),
create_lexical_graph: bool = True,
enforce_schema: SchemaEnforcementMode = SchemaEnforcementMode.NONE,
on_error: OnError = OnError.RAISE,
max_concurrency: int = 5,
) -> None:
super().__init__(on_error=on_error, create_lexical_graph=create_lexical_graph)
self.llm = llm # with response_format={ "type": "json_object" },
self.enforce_schema = enforce_schema
self.max_concurrency = max_concurrency
if isinstance(prompt_template, str):
template = PromptTemplate(prompt_template, expected_inputs=[])
Expand Down Expand Up @@ -275,15 +281,16 @@ async def run_for_chunk(
examples: str,
lexical_graph_builder: Optional[LexicalGraphBuilder] = None,
) -> Neo4jGraph:
"""Run extraction and post processing for a single chunk"""
"""Run extraction, validation and post processing for a single chunk"""
async with sem:
chunk_graph = await self.extract_for_chunk(schema, examples, chunk)
final_chunk_graph = self.validate_chunk(chunk_graph, schema)
await self.post_process_chunk(
chunk_graph,
final_chunk_graph,
chunk,
lexical_graph_builder,
)
return chunk_graph
return final_chunk_graph

@validate_call
async def run(
Expand All @@ -306,7 +313,7 @@ async def run(
chunks (TextChunks): List of text chunks to extract entities and relations from.
document_info (Optional[DocumentInfo], optional): Document the chunks are coming from. Used in the lexical graph creation step.
lexical_graph_config (Optional[LexicalGraphConfig], optional): Lexical graph configuration to customize node labels and relationship types in the lexical graph.
schema (SchemaConfig | None): Definition of the schema to guide the LLM in its extraction. Caution: at the moment, there is no guarantee that the extracted entities and relations will strictly obey the schema.
schema (SchemaConfig | None): Definition of the schema to guide the LLM in its extraction.
examples (str): Examples for few-shot learning in the prompt.
"""
lexical_graph_builder = None
Expand Down Expand Up @@ -337,3 +344,157 @@ async def run(
graph = self.combine_chunk_graphs(lexical_graph, chunk_graphs)
logger.debug(f"Extracted graph: {prettify(graph)}")
return graph

def validate_chunk(
self,
chunk_graph: Neo4jGraph,
schema: SchemaConfig
) -> Neo4jGraph:
"""
Perform validation after entity and relation extraction:
- Enforce schema if schema enforcement mode is on and schema is provided
"""
if self.enforce_schema != SchemaEnforcementMode.NONE:
if not schema or not schema.entities: # schema is not provided
logger.warning(
"Schema enforcement is ON but the guiding schema is not provided."
)
else:
# if enforcing_schema is on and schema is provided, clean the graph
return self._clean_graph(chunk_graph, schema)
return chunk_graph

def _clean_graph(
self,
graph: Neo4jGraph,
schema: SchemaConfig,
) -> Neo4jGraph:
"""
Verify that the graph conforms to the provided schema.

Remove invalid entities,relationships, and properties.
If an entity is removed, all of its relationships are also removed.
If no valid properties remain for an entity, remove that entity.
"""
# enforce nodes (remove invalid labels, strip invalid properties)
filtered_nodes = self._enforce_nodes(graph.nodes, schema)

# enforce relationships (remove those referencing invalid nodes or with invalid
# types or with start/end nodes not conforming to the schema, and strip invalid
# properties)
filtered_rels = self._enforce_relationships(
graph.relationships, filtered_nodes, schema
)

return Neo4jGraph(nodes=filtered_nodes, relationships=filtered_rels)

def _enforce_nodes(
self,
extracted_nodes: List[Neo4jNode],
schema: SchemaConfig
) -> List[Neo4jNode]:
"""
Filter extracted nodes to be conformant to the schema.

Keep only those whose label is in schema.
For each valid node, filter out properties not present in the schema.
Remove a node if it ends up with no valid properties.
"""
if self.enforce_schema != SchemaEnforcementMode.STRICT:
return extracted_nodes

valid_nodes = []

for node in extracted_nodes:
schema_entity = schema.entities.get(node.label)
if not schema_entity:
continue
allowed_props = schema_entity.get("properties", [])
filtered_props = self._enforce_properties(node.properties, allowed_props)
if filtered_props:
valid_nodes.append(
Neo4jNode(
id=node.id,
label=node.label,
properties=filtered_props,
embedding_properties=node.embedding_properties,
)
)

return valid_nodes

def _enforce_relationships(
self,
extracted_relationships: List[Neo4jRelationship],
filtered_nodes: List[Neo4jNode],
schema: SchemaConfig
) -> List[Neo4jRelationship]:
"""
Filter extracted nodes to be conformant to the schema.

Keep only those whose types are in schema, start/end node conform to schema,
and start/end nodes are in filtered nodes (i.e., kept after node enforcement).
For each valid relationship, filter out properties not present in the schema.
If a relationship direct is incorrect, invert it.
"""
if self.enforce_schema != SchemaEnforcementMode.STRICT:
return extracted_relationships

valid_rels = []

valid_nodes = {node.id: node.label for node in filtered_nodes}

potential_schema = schema.potential_schema

for rel in extracted_relationships:
schema_relation = schema.relations.get(rel.type)
if not schema_relation:
continue

if (rel.start_node_id not in valid_nodes or
rel.end_node_id not in valid_nodes):
continue

start_label = valid_nodes[rel.start_node_id]
end_label = valid_nodes[rel.end_node_id]

tuple_valid = True
if potential_schema:
tuple_valid = (start_label, rel.type, end_label) in potential_schema
reverse_tuple_valid = ((end_label, rel.type, start_label) in
potential_schema)

if not tuple_valid and not reverse_tuple_valid:
continue

allowed_props = schema_relation.get("properties", [])
filtered_props = self._enforce_properties(rel.properties, allowed_props)

valid_rels.append(
Neo4jRelationship(
start_node_id=rel.start_node_id if tuple_valid else rel.end_node_id,
end_node_id=rel.end_node_id if tuple_valid else rel.start_node_id,
type=rel.type,
properties=filtered_props,
embedding_properties=rel.embedding_properties,
)
)

return valid_rels

def _enforce_properties(
self,
properties: Dict[str, Any],
valid_properties: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""
Filter properties.
Keep only those that exist in schema (i.e., valid properties).
"""
valid_prop_names = {prop["name"] for prop in valid_properties}
return {
key: value
for key, value in properties.items()
if key in valid_prop_names
}

6 changes: 6 additions & 0 deletions src/neo4j_graphrag/experimental/components/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

import uuid
from enum import Enum
from typing import Any, Dict, Optional

from pydantic import BaseModel, Field, field_validator
Expand Down Expand Up @@ -170,3 +171,8 @@ def lexical_graph_node_labels(self) -> tuple[str, ...]:
class GraphResult(DataModel):
graph: Neo4jGraph
config: LexicalGraphConfig


class SchemaEnforcementMode(str, Enum):
NONE = "NONE"
STRICT = "STRICT"
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@
from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import (
FixedSizeSplitter,
)
from neo4j_graphrag.experimental.components.types import LexicalGraphConfig
from neo4j_graphrag.experimental.components.types import (
LexicalGraphConfig,
SchemaEnforcementMode
)
from neo4j_graphrag.experimental.pipeline.config.object_config import ComponentType
from neo4j_graphrag.experimental.pipeline.config.template_pipeline.base import (
TemplatePipelineConfig,
Expand Down Expand Up @@ -71,6 +74,7 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig):
entities: Sequence[EntityInputType] = []
relations: Sequence[RelationInputType] = []
potential_schema: Optional[list[tuple[str, str, str]]] = None
enforce_schema: SchemaEnforcementMode = SchemaEnforcementMode.NONE
on_error: OnError = OnError.IGNORE
prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate()
perform_entity_resolution: bool = True
Expand Down Expand Up @@ -124,6 +128,7 @@ def _get_extractor(self) -> EntityRelationExtractor:
return LLMEntityRelationExtractor(
llm=self.get_default_llm(),
prompt_template=self.prompt_template,
enforce_schema=self.enforce_schema,
on_error=self.on_error,
)

Expand Down
Loading
Loading