19
19
import enum
20
20
import json
21
21
import logging
22
- from typing import Any , List , Optional , Union , cast
22
+ from typing import Any , List , Optional , Union , cast , Dict
23
23
24
24
import json_repair
25
25
from pydantic import ValidationError , validate_call
31
31
DocumentInfo ,
32
32
LexicalGraphConfig ,
33
33
Neo4jGraph ,
34
+ Neo4jNode ,
35
+ Neo4jRelationship ,
34
36
TextChunk ,
35
37
TextChunks ,
38
+ SchemaEnforcementMode ,
36
39
)
37
40
from neo4j_graphrag .experimental .pipeline .component import Component
38
41
from neo4j_graphrag .experimental .pipeline .exceptions import InvalidJSONError
@@ -168,6 +171,7 @@ class LLMEntityRelationExtractor(EntityRelationExtractor):
168
171
llm (LLMInterface): The language model to use for extraction.
169
172
prompt_template (ERExtractionTemplate | str): A custom prompt template to use for extraction.
170
173
create_lexical_graph (bool): Whether to include the text chunks in the graph in addition to the extracted entities and relations. Defaults to True.
174
+ enforce_schema (SchemaEnforcementMode): Whether to validate or not the extracted entities/rels against the provided schema. Defaults to None.
171
175
on_error (OnError): What to do when an error occurs during extraction. Defaults to raising an error.
172
176
max_concurrency (int): The maximum number of concurrent tasks which can be used to make requests to the LLM.
173
177
@@ -192,11 +196,13 @@ def __init__(
192
196
llm : LLMInterface ,
193
197
prompt_template : ERExtractionTemplate | str = ERExtractionTemplate (),
194
198
create_lexical_graph : bool = True ,
199
+ enforce_schema : SchemaEnforcementMode = SchemaEnforcementMode .NONE ,
195
200
on_error : OnError = OnError .RAISE ,
196
201
max_concurrency : int = 5 ,
197
202
) -> None :
198
203
super ().__init__ (on_error = on_error , create_lexical_graph = create_lexical_graph )
199
204
self .llm = llm # with response_format={ "type": "json_object" },
205
+ self .enforce_schema = enforce_schema
200
206
self .max_concurrency = max_concurrency
201
207
if isinstance (prompt_template , str ):
202
208
template = PromptTemplate (prompt_template , expected_inputs = [])
@@ -275,15 +281,16 @@ async def run_for_chunk(
275
281
examples : str ,
276
282
lexical_graph_builder : Optional [LexicalGraphBuilder ] = None ,
277
283
) -> Neo4jGraph :
278
- """Run extraction and post processing for a single chunk"""
284
+ """Run extraction, validation and post processing for a single chunk"""
279
285
async with sem :
280
286
chunk_graph = await self .extract_for_chunk (schema , examples , chunk )
287
+ final_chunk_graph = self .validate_chunk (chunk_graph , schema )
281
288
await self .post_process_chunk (
282
- chunk_graph ,
289
+ final_chunk_graph ,
283
290
chunk ,
284
291
lexical_graph_builder ,
285
292
)
286
- return chunk_graph
293
+ return final_chunk_graph
287
294
288
295
@validate_call
289
296
async def run (
@@ -306,7 +313,7 @@ async def run(
306
313
chunks (TextChunks): List of text chunks to extract entities and relations from.
307
314
document_info (Optional[DocumentInfo], optional): Document the chunks are coming from. Used in the lexical graph creation step.
308
315
lexical_graph_config (Optional[LexicalGraphConfig], optional): Lexical graph configuration to customize node labels and relationship types in the lexical graph.
309
- schema (SchemaConfig | None): Definition of the schema to guide the LLM in its extraction. Caution: at the moment, there is no guarantee that the extracted entities and relations will strictly obey the schema.
316
+ schema (SchemaConfig | None): Definition of the schema to guide the LLM in its extraction.
310
317
examples (str): Examples for few-shot learning in the prompt.
311
318
"""
312
319
lexical_graph_builder = None
@@ -337,3 +344,147 @@ async def run(
337
344
graph = self .combine_chunk_graphs (lexical_graph , chunk_graphs )
338
345
logger .debug (f"Extracted graph: { prettify (graph )} " )
339
346
return graph
347
+
348
+ def validate_chunk (
349
+ self ,
350
+ chunk_graph : Neo4jGraph ,
351
+ schema : SchemaConfig
352
+ ) -> Neo4jGraph :
353
+ """
354
+ Perform validation after entity and relation extraction:
355
+ - Enforce schema if schema enforcement mode is on and schema is provided
356
+ """
357
+ # if enforcing_schema is on and schema is provided, clean the graph
358
+ return (
359
+ self ._clean_graph (chunk_graph , schema )
360
+ if self .enforce_schema != SchemaEnforcementMode .NONE and schema .entities
361
+ else chunk_graph
362
+ )
363
+
364
+ def _clean_graph (
365
+ self ,
366
+ graph : Neo4jGraph ,
367
+ schema : SchemaConfig ,
368
+ ) -> Neo4jGraph :
369
+ """
370
+ Verify that the graph conforms to the provided schema.
371
+
372
+ Remove invalid entities,relationships, and properties.
373
+ If an entity is removed, all of its relationships are also removed.
374
+ If no valid properties remain for an entity, remove that entity.
375
+ """
376
+ # enforce nodes (remove invalid labels, strip invalid properties)
377
+ filtered_nodes = self ._enforce_nodes (graph .nodes , schema )
378
+
379
+ # enforce relationships (remove those referencing invalid nodes or with invalid
380
+ # types or with start/end nodes not conforming to the schema, and strip invalid
381
+ # properties)
382
+ filtered_rels = self ._enforce_relationships (
383
+ graph .relationships , filtered_nodes , schema
384
+ )
385
+
386
+ return Neo4jGraph (nodes = filtered_nodes , relationships = filtered_rels )
387
+
388
+ def _enforce_nodes (
389
+ self ,
390
+ extracted_nodes : List [Neo4jNode ],
391
+ schema : SchemaConfig
392
+ ) -> List [Neo4jNode ]:
393
+ """
394
+ Filter extracted nodes to be conformant to the schema.
395
+
396
+ Keep only those whose label is in schema.
397
+ For each valid node, filter out properties not present in the schema.
398
+ Remove a node if it ends up with no valid properties.
399
+ """
400
+ valid_nodes = []
401
+ if self .enforce_schema == SchemaEnforcementMode .STRICT :
402
+ for node in extracted_nodes :
403
+ if node .label in schema .entities :
404
+ schema_entity = schema .entities [node .label ]
405
+ filtered_props = self ._enforce_properties (node .properties ,
406
+ schema_entity )
407
+ if filtered_props :
408
+ # keep node only if it has at least one valid property
409
+ new_node = Neo4jNode (
410
+ id = node .id ,
411
+ label = node .label ,
412
+ properties = filtered_props ,
413
+ embedding_properties = node .embedding_properties ,
414
+ )
415
+ valid_nodes .append (new_node )
416
+ # elif self.enforce_schema == SchemaEnforcementMode.OPEN:
417
+ # future logic
418
+ return valid_nodes
419
+
420
+ def _enforce_relationships (
421
+ self ,
422
+ extracted_relationships : List [Neo4jRelationship ],
423
+ filtered_nodes : List [Neo4jNode ],
424
+ schema : SchemaConfig
425
+ ) -> List [Neo4jRelationship ]:
426
+ """
427
+ Filter extracted nodes to be conformant to the schema.
428
+
429
+ Keep only those whose types are in schema, start/end node conform to schema,
430
+ and start/end nodes are in filtered nodes (i.e., kept after node enforcement).
431
+ For each valid relationship, filter out properties not present in the schema.
432
+ """
433
+ valid_rels = []
434
+ if self .enforce_schema == SchemaEnforcementMode .STRICT :
435
+ valid_node_ids = {node .id for node in filtered_nodes }
436
+ for rel in extracted_relationships :
437
+ # keep relationship if it conforms with the schema
438
+ if rel .type in schema .relations :
439
+ if (rel .start_node_id in valid_node_ids and
440
+ rel .end_node_id in valid_node_ids ):
441
+ start_node_label = self ._get_node_label (rel .start_node_id ,
442
+ filtered_nodes )
443
+ end_node_label = self ._get_node_label (rel .end_node_id ,
444
+ filtered_nodes )
445
+ if (not schema .potential_schema or
446
+ (start_node_label , rel .type , end_node_label ) in
447
+ schema .potential_schema ):
448
+ schema_relation = schema .relations [rel .type ]
449
+ filtered_props = self ._enforce_properties (rel .properties ,
450
+ schema_relation )
451
+ new_rel = Neo4jRelationship (
452
+ start_node_id = rel .start_node_id ,
453
+ end_node_id = rel .end_node_id ,
454
+ type = rel .type ,
455
+ properties = filtered_props ,
456
+ embedding_properties = rel .embedding_properties ,
457
+ )
458
+ valid_rels .append (new_rel )
459
+ # elif self.enforce_schema == SchemaEnforcementMode.OPEN:
460
+ # future logic
461
+ return valid_rels
462
+
463
+ def _enforce_properties (
464
+ self ,
465
+ properties : Dict [str , Any ],
466
+ valid_properties : Dict [str , Any ]
467
+ ) -> Dict [str , Any ]:
468
+ """
469
+ Filter properties.
470
+ Keep only those that exist in schema (i.e., valid properties).
471
+ """
472
+ return {
473
+ key : value
474
+ for key , value in properties .items ()
475
+ if key in valid_properties
476
+ }
477
+
478
+ def _get_node_label (
479
+ self ,
480
+ node_id : str ,
481
+ nodes : List [Neo4jNode ]
482
+ ) -> str :
483
+ """
484
+ Given a list of nodes, get the label of the node whose id matches the provided
485
+ node id.
486
+ """
487
+ for node in nodes :
488
+ if node .id == node_id :
489
+ return node .label
490
+ return ""
0 commit comments