Skip to content

Commit 5b868aa

Browse files
Add optional schema enforcement for KG builder as a validation layer after entity and relation extraction (neo4j#296)
* Add schema enforcement modes and strict mode behaviour * Add unit tests for schema enforcement modes * Update change log and docs * Fix documentation * Add warning when schema enforcement is on but schema not provided * Code cleanups * Improve code for more clarity * Apply changes requested by the PR review * Invert rel direction * Adapt SimpleKGPipelineConfig
1 parent c96af5c commit 5b868aa

File tree

6 files changed

+491
-8
lines changed

6 files changed

+491
-8
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
## Next
44

55
### Added
6+
7+
- Added optional schema enforcement as a validation layer after entity and relation extraction.
68
- Introduced SearchQueryParseError for handling invalid Lucene query strings in HybridRetriever and HybridCypherRetriever.
79

810
## 1.5.0

docs/source/user_guide_kg_builder.rst

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,8 @@ This schema information can be provided to the `SimpleKGBuilder` as demonstrated
125125
# ...
126126
)
127127
128-
Prompt Template, Lexical Graph Config and Error Behavior
129-
--------------------------------------------------------
128+
Extra configurations
129+
--------------------
130130

131131
These parameters are part of the `EntityAndRelationExtractor` component.
132132
For detailed information, refer to the section on :ref:`Entity and Relation Extractor`.
@@ -138,6 +138,7 @@ They are also accessible via the `SimpleKGPipeline` interface.
138138
# ...
139139
prompt_template="",
140140
lexical_graph_config=my_config,
141+
enforce_schema="STRICT"
141142
on_error="RAISE",
142143
# ...
143144
)
@@ -829,6 +830,30 @@ It can be used in this way:
829830

830831
The LLM to use can be customized, the only constraint is that it obeys the :ref:`LLMInterface <llminterface>`.
831832

833+
Schema Enforcement Behaviour
834+
----------------------------
835+
By default, even if a schema is provided to guide the LLM in the entity and relation extraction, the LLM response is not validated against that schema.
836+
This behaviour can be changed by using the `enforce_schema` flag in the `LLMEntityRelationExtractor` constructor:
837+
838+
.. code:: python
839+
840+
from neo4j_graphrag.experimental.components.entity_relation_extractor import LLMEntityRelationExtractor
841+
from neo4j_graphrag.experimental.components.types import SchemaEnforcementMode
842+
843+
extractor = LLMEntityRelationExtractor(
844+
# ...
845+
enforce_schema=SchemaEnforcementMode.STRICT,
846+
)
847+
848+
In this scenario, any extracted node/relation/property that is not part of the provided schema will be pruned.
849+
Any relation whose start node or end node does not conform to the provided tuple in `potential_schema` will be pruned.
850+
If a relation start/end nodes are valid but the direction is incorrect, the latter will be inverted.
851+
If a node is left with no properties, it will be also pruned.
852+
853+
.. warning::
854+
855+
Note that if the schema enforcement mode is on but the schema is not provided, no schema enforcement will be applied.
856+
832857
Error Behaviour
833858
---------------
834859

src/neo4j_graphrag/experimental/components/entity_relation_extractor.py

Lines changed: 166 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import enum
2020
import json
2121
import logging
22-
from typing import Any, List, Optional, Union, cast
22+
from typing import Any, List, Optional, Union, cast, Dict
2323

2424
import json_repair
2525
from pydantic import ValidationError, validate_call
@@ -31,8 +31,11 @@
3131
DocumentInfo,
3232
LexicalGraphConfig,
3333
Neo4jGraph,
34+
Neo4jNode,
35+
Neo4jRelationship,
3436
TextChunk,
3537
TextChunks,
38+
SchemaEnforcementMode,
3639
)
3740
from neo4j_graphrag.experimental.pipeline.component import Component
3841
from neo4j_graphrag.experimental.pipeline.exceptions import InvalidJSONError
@@ -168,6 +171,7 @@ class LLMEntityRelationExtractor(EntityRelationExtractor):
168171
llm (LLMInterface): The language model to use for extraction.
169172
prompt_template (ERExtractionTemplate | str): A custom prompt template to use for extraction.
170173
create_lexical_graph (bool): Whether to include the text chunks in the graph in addition to the extracted entities and relations. Defaults to True.
174+
enforce_schema (SchemaEnforcementMode): Whether to validate or not the extracted entities/rels against the provided schema. Defaults to None.
171175
on_error (OnError): What to do when an error occurs during extraction. Defaults to raising an error.
172176
max_concurrency (int): The maximum number of concurrent tasks which can be used to make requests to the LLM.
173177
@@ -192,11 +196,13 @@ def __init__(
192196
llm: LLMInterface,
193197
prompt_template: ERExtractionTemplate | str = ERExtractionTemplate(),
194198
create_lexical_graph: bool = True,
199+
enforce_schema: SchemaEnforcementMode = SchemaEnforcementMode.NONE,
195200
on_error: OnError = OnError.RAISE,
196201
max_concurrency: int = 5,
197202
) -> None:
198203
super().__init__(on_error=on_error, create_lexical_graph=create_lexical_graph)
199204
self.llm = llm # with response_format={ "type": "json_object" },
205+
self.enforce_schema = enforce_schema
200206
self.max_concurrency = max_concurrency
201207
if isinstance(prompt_template, str):
202208
template = PromptTemplate(prompt_template, expected_inputs=[])
@@ -275,15 +281,16 @@ async def run_for_chunk(
275281
examples: str,
276282
lexical_graph_builder: Optional[LexicalGraphBuilder] = None,
277283
) -> Neo4jGraph:
278-
"""Run extraction and post processing for a single chunk"""
284+
"""Run extraction, validation and post processing for a single chunk"""
279285
async with sem:
280286
chunk_graph = await self.extract_for_chunk(schema, examples, chunk)
287+
final_chunk_graph = self.validate_chunk(chunk_graph, schema)
281288
await self.post_process_chunk(
282-
chunk_graph,
289+
final_chunk_graph,
283290
chunk,
284291
lexical_graph_builder,
285292
)
286-
return chunk_graph
293+
return final_chunk_graph
287294

288295
@validate_call
289296
async def run(
@@ -306,7 +313,7 @@ async def run(
306313
chunks (TextChunks): List of text chunks to extract entities and relations from.
307314
document_info (Optional[DocumentInfo], optional): Document the chunks are coming from. Used in the lexical graph creation step.
308315
lexical_graph_config (Optional[LexicalGraphConfig], optional): Lexical graph configuration to customize node labels and relationship types in the lexical graph.
309-
schema (SchemaConfig | None): Definition of the schema to guide the LLM in its extraction. Caution: at the moment, there is no guarantee that the extracted entities and relations will strictly obey the schema.
316+
schema (SchemaConfig | None): Definition of the schema to guide the LLM in its extraction.
310317
examples (str): Examples for few-shot learning in the prompt.
311318
"""
312319
lexical_graph_builder = None
@@ -337,3 +344,157 @@ async def run(
337344
graph = self.combine_chunk_graphs(lexical_graph, chunk_graphs)
338345
logger.debug(f"Extracted graph: {prettify(graph)}")
339346
return graph
347+
348+
def validate_chunk(
349+
self,
350+
chunk_graph: Neo4jGraph,
351+
schema: SchemaConfig
352+
) -> Neo4jGraph:
353+
"""
354+
Perform validation after entity and relation extraction:
355+
- Enforce schema if schema enforcement mode is on and schema is provided
356+
"""
357+
if self.enforce_schema != SchemaEnforcementMode.NONE:
358+
if not schema or not schema.entities: # schema is not provided
359+
logger.warning(
360+
"Schema enforcement is ON but the guiding schema is not provided."
361+
)
362+
else:
363+
# if enforcing_schema is on and schema is provided, clean the graph
364+
return self._clean_graph(chunk_graph, schema)
365+
return chunk_graph
366+
367+
def _clean_graph(
368+
self,
369+
graph: Neo4jGraph,
370+
schema: SchemaConfig,
371+
) -> Neo4jGraph:
372+
"""
373+
Verify that the graph conforms to the provided schema.
374+
375+
Remove invalid entities,relationships, and properties.
376+
If an entity is removed, all of its relationships are also removed.
377+
If no valid properties remain for an entity, remove that entity.
378+
"""
379+
# enforce nodes (remove invalid labels, strip invalid properties)
380+
filtered_nodes = self._enforce_nodes(graph.nodes, schema)
381+
382+
# enforce relationships (remove those referencing invalid nodes or with invalid
383+
# types or with start/end nodes not conforming to the schema, and strip invalid
384+
# properties)
385+
filtered_rels = self._enforce_relationships(
386+
graph.relationships, filtered_nodes, schema
387+
)
388+
389+
return Neo4jGraph(nodes=filtered_nodes, relationships=filtered_rels)
390+
391+
def _enforce_nodes(
392+
self,
393+
extracted_nodes: List[Neo4jNode],
394+
schema: SchemaConfig
395+
) -> List[Neo4jNode]:
396+
"""
397+
Filter extracted nodes to be conformant to the schema.
398+
399+
Keep only those whose label is in schema.
400+
For each valid node, filter out properties not present in the schema.
401+
Remove a node if it ends up with no valid properties.
402+
"""
403+
if self.enforce_schema != SchemaEnforcementMode.STRICT:
404+
return extracted_nodes
405+
406+
valid_nodes = []
407+
408+
for node in extracted_nodes:
409+
schema_entity = schema.entities.get(node.label)
410+
if not schema_entity:
411+
continue
412+
allowed_props = schema_entity.get("properties", [])
413+
filtered_props = self._enforce_properties(node.properties, allowed_props)
414+
if filtered_props:
415+
valid_nodes.append(
416+
Neo4jNode(
417+
id=node.id,
418+
label=node.label,
419+
properties=filtered_props,
420+
embedding_properties=node.embedding_properties,
421+
)
422+
)
423+
424+
return valid_nodes
425+
426+
def _enforce_relationships(
427+
self,
428+
extracted_relationships: List[Neo4jRelationship],
429+
filtered_nodes: List[Neo4jNode],
430+
schema: SchemaConfig
431+
) -> List[Neo4jRelationship]:
432+
"""
433+
Filter extracted nodes to be conformant to the schema.
434+
435+
Keep only those whose types are in schema, start/end node conform to schema,
436+
and start/end nodes are in filtered nodes (i.e., kept after node enforcement).
437+
For each valid relationship, filter out properties not present in the schema.
438+
If a relationship direct is incorrect, invert it.
439+
"""
440+
if self.enforce_schema != SchemaEnforcementMode.STRICT:
441+
return extracted_relationships
442+
443+
valid_rels = []
444+
445+
valid_nodes = {node.id: node.label for node in filtered_nodes}
446+
447+
potential_schema = schema.potential_schema
448+
449+
for rel in extracted_relationships:
450+
schema_relation = schema.relations.get(rel.type)
451+
if not schema_relation:
452+
continue
453+
454+
if (rel.start_node_id not in valid_nodes or
455+
rel.end_node_id not in valid_nodes):
456+
continue
457+
458+
start_label = valid_nodes[rel.start_node_id]
459+
end_label = valid_nodes[rel.end_node_id]
460+
461+
tuple_valid = True
462+
if potential_schema:
463+
tuple_valid = (start_label, rel.type, end_label) in potential_schema
464+
reverse_tuple_valid = ((end_label, rel.type, start_label) in
465+
potential_schema)
466+
467+
if not tuple_valid and not reverse_tuple_valid:
468+
continue
469+
470+
allowed_props = schema_relation.get("properties", [])
471+
filtered_props = self._enforce_properties(rel.properties, allowed_props)
472+
473+
valid_rels.append(
474+
Neo4jRelationship(
475+
start_node_id=rel.start_node_id if tuple_valid else rel.end_node_id,
476+
end_node_id=rel.end_node_id if tuple_valid else rel.start_node_id,
477+
type=rel.type,
478+
properties=filtered_props,
479+
embedding_properties=rel.embedding_properties,
480+
)
481+
)
482+
483+
return valid_rels
484+
485+
def _enforce_properties(
486+
self,
487+
properties: Dict[str, Any],
488+
valid_properties: List[Dict[str, Any]]
489+
) -> Dict[str, Any]:
490+
"""
491+
Filter properties.
492+
Keep only those that exist in schema (i.e., valid properties).
493+
"""
494+
valid_prop_names = {prop["name"] for prop in valid_properties}
495+
return {
496+
key: value
497+
for key, value in properties.items()
498+
if key in valid_prop_names
499+
}
500+

src/neo4j_graphrag/experimental/components/types.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
import uuid
18+
from enum import Enum
1819
from typing import Any, Dict, Optional
1920

2021
from pydantic import BaseModel, Field, field_validator
@@ -170,3 +171,8 @@ def lexical_graph_node_labels(self) -> tuple[str, ...]:
170171
class GraphResult(DataModel):
171172
graph: Neo4jGraph
172173
config: LexicalGraphConfig
174+
175+
176+
class SchemaEnforcementMode(str, Enum):
177+
NONE = "NONE"
178+
STRICT = "STRICT"

src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,10 @@
3737
from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import (
3838
FixedSizeSplitter,
3939
)
40-
from neo4j_graphrag.experimental.components.types import LexicalGraphConfig
40+
from neo4j_graphrag.experimental.components.types import (
41+
LexicalGraphConfig,
42+
SchemaEnforcementMode
43+
)
4144
from neo4j_graphrag.experimental.pipeline.config.object_config import ComponentType
4245
from neo4j_graphrag.experimental.pipeline.config.template_pipeline.base import (
4346
TemplatePipelineConfig,
@@ -71,6 +74,7 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig):
7174
entities: Sequence[EntityInputType] = []
7275
relations: Sequence[RelationInputType] = []
7376
potential_schema: Optional[list[tuple[str, str, str]]] = None
77+
enforce_schema: SchemaEnforcementMode = SchemaEnforcementMode.NONE
7478
on_error: OnError = OnError.IGNORE
7579
prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate()
7680
perform_entity_resolution: bool = True
@@ -124,6 +128,7 @@ def _get_extractor(self) -> EntityRelationExtractor:
124128
return LLMEntityRelationExtractor(
125129
llm=self.get_default_llm(),
126130
prompt_template=self.prompt_template,
131+
enforce_schema=self.enforce_schema,
127132
on_error=self.on_error,
128133
)
129134

0 commit comments

Comments
 (0)