Skip to content

Commit a9f47d7

Browse files
Add schema enforcement modes and strict mode behaviour
1 parent 75ef7e4 commit a9f47d7

File tree

2 files changed

+164
-5
lines changed

2 files changed

+164
-5
lines changed

src/neo4j_graphrag/experimental/components/entity_relation_extractor.py

Lines changed: 156 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import enum
2020
import json
2121
import logging
22-
from typing import Any, List, Optional, Union, cast
22+
from typing import Any, List, Optional, Union, cast, Dict
2323

2424
import json_repair
2525
from pydantic import ValidationError, validate_call
@@ -31,8 +31,11 @@
3131
DocumentInfo,
3232
LexicalGraphConfig,
3333
Neo4jGraph,
34+
Neo4jNode,
35+
Neo4jRelationship,
3436
TextChunk,
3537
TextChunks,
38+
SchemaEnforcementMode,
3639
)
3740
from neo4j_graphrag.experimental.pipeline.component import Component
3841
from neo4j_graphrag.experimental.pipeline.exceptions import InvalidJSONError
@@ -168,6 +171,7 @@ class LLMEntityRelationExtractor(EntityRelationExtractor):
168171
llm (LLMInterface): The language model to use for extraction.
169172
prompt_template (ERExtractionTemplate | str): A custom prompt template to use for extraction.
170173
create_lexical_graph (bool): Whether to include the text chunks in the graph in addition to the extracted entities and relations. Defaults to True.
174+
enforce_schema (SchemaEnforcementMode): Whether to validate or not the extracted entities/rels against the provided schema. Defaults to None.
171175
on_error (OnError): What to do when an error occurs during extraction. Defaults to raising an error.
172176
max_concurrency (int): The maximum number of concurrent tasks which can be used to make requests to the LLM.
173177
@@ -192,11 +196,13 @@ def __init__(
192196
llm: LLMInterface,
193197
prompt_template: ERExtractionTemplate | str = ERExtractionTemplate(),
194198
create_lexical_graph: bool = True,
199+
enforce_schema: SchemaEnforcementMode = SchemaEnforcementMode.NONE,
195200
on_error: OnError = OnError.RAISE,
196201
max_concurrency: int = 5,
197202
) -> None:
198203
super().__init__(on_error=on_error, create_lexical_graph=create_lexical_graph)
199204
self.llm = llm # with response_format={ "type": "json_object" },
205+
self.enforce_schema = enforce_schema
200206
self.max_concurrency = max_concurrency
201207
if isinstance(prompt_template, str):
202208
template = PromptTemplate(prompt_template, expected_inputs=[])
@@ -275,15 +281,16 @@ async def run_for_chunk(
275281
examples: str,
276282
lexical_graph_builder: Optional[LexicalGraphBuilder] = None,
277283
) -> Neo4jGraph:
278-
"""Run extraction and post processing for a single chunk"""
284+
"""Run extraction, validation and post processing for a single chunk"""
279285
async with sem:
280286
chunk_graph = await self.extract_for_chunk(schema, examples, chunk)
287+
final_chunk_graph = self.validate_chunk(chunk_graph, schema)
281288
await self.post_process_chunk(
282-
chunk_graph,
289+
final_chunk_graph,
283290
chunk,
284291
lexical_graph_builder,
285292
)
286-
return chunk_graph
293+
return final_chunk_graph
287294

288295
@validate_call
289296
async def run(
@@ -306,7 +313,7 @@ async def run(
306313
chunks (TextChunks): List of text chunks to extract entities and relations from.
307314
document_info (Optional[DocumentInfo], optional): Document the chunks are coming from. Used in the lexical graph creation step.
308315
lexical_graph_config (Optional[LexicalGraphConfig], optional): Lexical graph configuration to customize node labels and relationship types in the lexical graph.
309-
schema (SchemaConfig | None): Definition of the schema to guide the LLM in its extraction. Caution: at the moment, there is no guarantee that the extracted entities and relations will strictly obey the schema.
316+
schema (SchemaConfig | None): Definition of the schema to guide the LLM in its extraction.
310317
examples (str): Examples for few-shot learning in the prompt.
311318
"""
312319
lexical_graph_builder = None
@@ -337,3 +344,147 @@ async def run(
337344
graph = self.combine_chunk_graphs(lexical_graph, chunk_graphs)
338345
logger.debug(f"Extracted graph: {prettify(graph)}")
339346
return graph
347+
348+
def validate_chunk(
349+
self,
350+
chunk_graph: Neo4jGraph,
351+
schema: SchemaConfig
352+
) -> Neo4jGraph:
353+
"""
354+
Perform validation after entity and relation extraction:
355+
- Enforce schema if schema enforcement mode is on and schema is provided
356+
"""
357+
# if enforcing_schema is on and schema is provided, clean the graph
358+
return (
359+
self._clean_graph(chunk_graph, schema)
360+
if self.enforce_schema != SchemaEnforcementMode.NONE and schema.entities
361+
else chunk_graph
362+
)
363+
364+
def _clean_graph(
365+
self,
366+
graph: Neo4jGraph,
367+
schema: SchemaConfig,
368+
) -> Neo4jGraph:
369+
"""
370+
Verify that the graph conforms to the provided schema.
371+
372+
Remove invalid entities,relationships, and properties.
373+
If an entity is removed, all of its relationships are also removed.
374+
If no valid properties remain for an entity, remove that entity.
375+
"""
376+
# enforce nodes (remove invalid labels, strip invalid properties)
377+
filtered_nodes = self._enforce_nodes(graph.nodes, schema)
378+
379+
# enforce relationships (remove those referencing invalid nodes or with invalid
380+
# types or with start/end nodes not conforming to the schema, and strip invalid
381+
# properties)
382+
filtered_rels = self._enforce_relationships(
383+
graph.relationships, filtered_nodes, schema
384+
)
385+
386+
return Neo4jGraph(nodes=filtered_nodes, relationships=filtered_rels)
387+
388+
def _enforce_nodes(
389+
self,
390+
extracted_nodes: List[Neo4jNode],
391+
schema: SchemaConfig
392+
) -> List[Neo4jNode]:
393+
"""
394+
Filter extracted nodes to be conformant to the schema.
395+
396+
Keep only those whose label is in schema.
397+
For each valid node, filter out properties not present in the schema.
398+
Remove a node if it ends up with no valid properties.
399+
"""
400+
valid_nodes = []
401+
if self.enforce_schema == SchemaEnforcementMode.STRICT:
402+
for node in extracted_nodes:
403+
if node.label in schema.entities:
404+
schema_entity = schema.entities[node.label]
405+
filtered_props = self._enforce_properties(node.properties,
406+
schema_entity)
407+
if filtered_props:
408+
# keep node only if it has at least one valid property
409+
new_node = Neo4jNode(
410+
id=node.id,
411+
label=node.label,
412+
properties=filtered_props,
413+
embedding_properties=node.embedding_properties,
414+
)
415+
valid_nodes.append(new_node)
416+
# elif self.enforce_schema == SchemaEnforcementMode.OPEN:
417+
# future logic
418+
return valid_nodes
419+
420+
def _enforce_relationships(
421+
self,
422+
extracted_relationships: List[Neo4jRelationship],
423+
filtered_nodes: List[Neo4jNode],
424+
schema: SchemaConfig
425+
) -> List[Neo4jRelationship]:
426+
"""
427+
Filter extracted nodes to be conformant to the schema.
428+
429+
Keep only those whose types are in schema, start/end node conform to schema,
430+
and start/end nodes are in filtered nodes (i.e., kept after node enforcement).
431+
For each valid relationship, filter out properties not present in the schema.
432+
"""
433+
valid_rels = []
434+
if self.enforce_schema == SchemaEnforcementMode.STRICT:
435+
valid_node_ids = {node.id for node in filtered_nodes}
436+
for rel in extracted_relationships:
437+
# keep relationship if it conforms with the schema
438+
if rel.type in schema.relations:
439+
if (rel.start_node_id in valid_node_ids and
440+
rel.end_node_id in valid_node_ids):
441+
start_node_label = self._get_node_label(rel.start_node_id,
442+
filtered_nodes)
443+
end_node_label = self._get_node_label(rel.end_node_id,
444+
filtered_nodes)
445+
if (not schema.potential_schema or
446+
(start_node_label, rel.type, end_node_label) in
447+
schema.potential_schema):
448+
schema_relation = schema.relations[rel.type]
449+
filtered_props = self._enforce_properties(rel.properties,
450+
schema_relation)
451+
new_rel = Neo4jRelationship(
452+
start_node_id=rel.start_node_id,
453+
end_node_id=rel.end_node_id,
454+
type=rel.type,
455+
properties=filtered_props,
456+
embedding_properties=rel.embedding_properties,
457+
)
458+
valid_rels.append(new_rel)
459+
# elif self.enforce_schema == SchemaEnforcementMode.OPEN:
460+
# future logic
461+
return valid_rels
462+
463+
def _enforce_properties(
464+
self,
465+
properties: Dict[str, Any],
466+
valid_properties: Dict[str, Any]
467+
) -> Dict[str, Any]:
468+
"""
469+
Filter properties.
470+
Keep only those that exist in schema (i.e., valid properties).
471+
"""
472+
return {
473+
key: value
474+
for key, value in properties.items()
475+
if key in valid_properties
476+
}
477+
478+
def _get_node_label(
479+
self,
480+
node_id: str,
481+
nodes: List[Neo4jNode]
482+
) -> str:
483+
"""
484+
Given a list of nodes, get the label of the node whose id matches the provided
485+
node id.
486+
"""
487+
for node in nodes:
488+
if node.id == node_id:
489+
return node.label
490+
return ""

src/neo4j_graphrag/experimental/components/types.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
import uuid
18+
from enum import Enum
1819
from typing import Any, Dict, Optional
1920

2021
from pydantic import BaseModel, Field, field_validator
@@ -170,3 +171,10 @@ def lexical_graph_node_labels(self) -> tuple[str, ...]:
170171
class GraphResult(DataModel):
171172
graph: Neo4jGraph
172173
config: LexicalGraphConfig
174+
175+
176+
class SchemaEnforcementMode(str, Enum):
177+
NONE = "none"
178+
STRICT = "strict"
179+
# future possibility: OPEN = "open" -> ensure conformance of nodes/props/rels that
180+
# were listed in the schema but leave room for extras

0 commit comments

Comments
 (0)