From 417e50e403f12bd245584fcb81292bbdae0eda06 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Thu, 6 Mar 2025 12:39:51 +0100 Subject: [PATCH 01/10] Add schema enforcement modes and strict mode behaviour --- .../components/entity_relation_extractor.py | 161 +++++++++++++++++- .../experimental/components/types.py | 8 + 2 files changed, 164 insertions(+), 5 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py index 1d6861232..e7439f598 100644 --- a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py +++ b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py @@ -19,7 +19,7 @@ import enum import json import logging -from typing import Any, List, Optional, Union, cast +from typing import Any, List, Optional, Union, cast, Dict import json_repair from pydantic import ValidationError, validate_call @@ -31,8 +31,11 @@ DocumentInfo, LexicalGraphConfig, Neo4jGraph, + Neo4jNode, + Neo4jRelationship, TextChunk, TextChunks, + SchemaEnforcementMode, ) from neo4j_graphrag.experimental.pipeline.component import Component from neo4j_graphrag.experimental.pipeline.exceptions import InvalidJSONError @@ -168,6 +171,7 @@ class LLMEntityRelationExtractor(EntityRelationExtractor): llm (LLMInterface): The language model to use for extraction. prompt_template (ERExtractionTemplate | str): A custom prompt template to use for extraction. create_lexical_graph (bool): Whether to include the text chunks in the graph in addition to the extracted entities and relations. Defaults to True. + enforce_schema (SchemaEnforcementMode): Whether to validate or not the extracted entities/rels against the provided schema. Defaults to None. on_error (OnError): What to do when an error occurs during extraction. Defaults to raising an error. max_concurrency (int): The maximum number of concurrent tasks which can be used to make requests to the LLM. @@ -192,11 +196,13 @@ def __init__( llm: LLMInterface, prompt_template: ERExtractionTemplate | str = ERExtractionTemplate(), create_lexical_graph: bool = True, + enforce_schema: SchemaEnforcementMode = SchemaEnforcementMode.NONE, on_error: OnError = OnError.RAISE, max_concurrency: int = 5, ) -> None: super().__init__(on_error=on_error, create_lexical_graph=create_lexical_graph) self.llm = llm # with response_format={ "type": "json_object" }, + self.enforce_schema = enforce_schema self.max_concurrency = max_concurrency if isinstance(prompt_template, str): template = PromptTemplate(prompt_template, expected_inputs=[]) @@ -275,15 +281,16 @@ async def run_for_chunk( examples: str, lexical_graph_builder: Optional[LexicalGraphBuilder] = None, ) -> Neo4jGraph: - """Run extraction and post processing for a single chunk""" + """Run extraction, validation and post processing for a single chunk""" async with sem: chunk_graph = await self.extract_for_chunk(schema, examples, chunk) + final_chunk_graph = self.validate_chunk(chunk_graph, schema) await self.post_process_chunk( - chunk_graph, + final_chunk_graph, chunk, lexical_graph_builder, ) - return chunk_graph + return final_chunk_graph @validate_call async def run( @@ -306,7 +313,7 @@ async def run( chunks (TextChunks): List of text chunks to extract entities and relations from. document_info (Optional[DocumentInfo], optional): Document the chunks are coming from. Used in the lexical graph creation step. lexical_graph_config (Optional[LexicalGraphConfig], optional): Lexical graph configuration to customize node labels and relationship types in the lexical graph. - schema (SchemaConfig | None): Definition of the schema to guide the LLM in its extraction. Caution: at the moment, there is no guarantee that the extracted entities and relations will strictly obey the schema. + schema (SchemaConfig | None): Definition of the schema to guide the LLM in its extraction. examples (str): Examples for few-shot learning in the prompt. """ lexical_graph_builder = None @@ -337,3 +344,147 @@ async def run( graph = self.combine_chunk_graphs(lexical_graph, chunk_graphs) logger.debug(f"Extracted graph: {prettify(graph)}") return graph + + def validate_chunk( + self, + chunk_graph: Neo4jGraph, + schema: SchemaConfig + ) -> Neo4jGraph: + """ + Perform validation after entity and relation extraction: + - Enforce schema if schema enforcement mode is on and schema is provided + """ + # if enforcing_schema is on and schema is provided, clean the graph + return ( + self._clean_graph(chunk_graph, schema) + if self.enforce_schema != SchemaEnforcementMode.NONE and schema.entities + else chunk_graph + ) + + def _clean_graph( + self, + graph: Neo4jGraph, + schema: SchemaConfig, + ) -> Neo4jGraph: + """ + Verify that the graph conforms to the provided schema. + + Remove invalid entities,relationships, and properties. + If an entity is removed, all of its relationships are also removed. + If no valid properties remain for an entity, remove that entity. + """ + # enforce nodes (remove invalid labels, strip invalid properties) + filtered_nodes = self._enforce_nodes(graph.nodes, schema) + + # enforce relationships (remove those referencing invalid nodes or with invalid + # types or with start/end nodes not conforming to the schema, and strip invalid + # properties) + filtered_rels = self._enforce_relationships( + graph.relationships, filtered_nodes, schema + ) + + return Neo4jGraph(nodes=filtered_nodes, relationships=filtered_rels) + + def _enforce_nodes( + self, + extracted_nodes: List[Neo4jNode], + schema: SchemaConfig + ) -> List[Neo4jNode]: + """ + Filter extracted nodes to be conformant to the schema. + + Keep only those whose label is in schema. + For each valid node, filter out properties not present in the schema. + Remove a node if it ends up with no valid properties. + """ + valid_nodes = [] + if self.enforce_schema == SchemaEnforcementMode.STRICT: + for node in extracted_nodes: + if node.label in schema.entities: + schema_entity = schema.entities[node.label] + filtered_props = self._enforce_properties(node.properties, + schema_entity) + if filtered_props: + # keep node only if it has at least one valid property + new_node = Neo4jNode( + id=node.id, + label=node.label, + properties=filtered_props, + embedding_properties=node.embedding_properties, + ) + valid_nodes.append(new_node) + # elif self.enforce_schema == SchemaEnforcementMode.OPEN: + # future logic + return valid_nodes + + def _enforce_relationships( + self, + extracted_relationships: List[Neo4jRelationship], + filtered_nodes: List[Neo4jNode], + schema: SchemaConfig + ) -> List[Neo4jRelationship]: + """ + Filter extracted nodes to be conformant to the schema. + + Keep only those whose types are in schema, start/end node conform to schema, + and start/end nodes are in filtered nodes (i.e., kept after node enforcement). + For each valid relationship, filter out properties not present in the schema. + """ + valid_rels = [] + if self.enforce_schema == SchemaEnforcementMode.STRICT: + valid_node_ids = {node.id for node in filtered_nodes} + for rel in extracted_relationships: + # keep relationship if it conforms with the schema + if rel.type in schema.relations: + if (rel.start_node_id in valid_node_ids and + rel.end_node_id in valid_node_ids): + start_node_label = self._get_node_label(rel.start_node_id, + filtered_nodes) + end_node_label = self._get_node_label(rel.end_node_id, + filtered_nodes) + if (not schema.potential_schema or + (start_node_label, rel.type, end_node_label) in + schema.potential_schema): + schema_relation = schema.relations[rel.type] + filtered_props = self._enforce_properties(rel.properties, + schema_relation) + new_rel = Neo4jRelationship( + start_node_id=rel.start_node_id, + end_node_id=rel.end_node_id, + type=rel.type, + properties=filtered_props, + embedding_properties=rel.embedding_properties, + ) + valid_rels.append(new_rel) + # elif self.enforce_schema == SchemaEnforcementMode.OPEN: + # future logic + return valid_rels + + def _enforce_properties( + self, + properties: Dict[str, Any], + valid_properties: Dict[str, Any] + ) -> Dict[str, Any]: + """ + Filter properties. + Keep only those that exist in schema (i.e., valid properties). + """ + return { + key: value + for key, value in properties.items() + if key in valid_properties + } + + def _get_node_label( + self, + node_id: str, + nodes: List[Neo4jNode] + ) -> str: + """ + Given a list of nodes, get the label of the node whose id matches the provided + node id. + """ + for node in nodes: + if node.id == node_id: + return node.label + return "" diff --git a/src/neo4j_graphrag/experimental/components/types.py b/src/neo4j_graphrag/experimental/components/types.py index 689e6b6ce..247ac24aa 100644 --- a/src/neo4j_graphrag/experimental/components/types.py +++ b/src/neo4j_graphrag/experimental/components/types.py @@ -15,6 +15,7 @@ from __future__ import annotations import uuid +from enum import Enum from typing import Any, Dict, Optional from pydantic import BaseModel, Field, field_validator @@ -170,3 +171,10 @@ def lexical_graph_node_labels(self) -> tuple[str, ...]: class GraphResult(DataModel): graph: Neo4jGraph config: LexicalGraphConfig + + +class SchemaEnforcementMode(str, Enum): + NONE = "none" + STRICT = "strict" + # future possibility: OPEN = "open" -> ensure conformance of nodes/props/rels that + # were listed in the schema but leave room for extras \ No newline at end of file From d1399ad2c35c5b8832e2fbcb40b9cbc95d07d526 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Thu, 6 Mar 2025 12:40:08 +0100 Subject: [PATCH 02/10] Add unit tests for schema enforcement modes --- .../test_entity_relation_extractor.py | 235 ++++++++++++++++++ 1 file changed, 235 insertions(+) diff --git a/tests/unit/experimental/components/test_entity_relation_extractor.py b/tests/unit/experimental/components/test_entity_relation_extractor.py index f76ab5c9c..a92390a7b 100644 --- a/tests/unit/experimental/components/test_entity_relation_extractor.py +++ b/tests/unit/experimental/components/test_entity_relation_extractor.py @@ -25,11 +25,13 @@ balance_curly_braces, fix_invalid_json, ) +from neo4j_graphrag.experimental.components.schema import SchemaConfig from neo4j_graphrag.experimental.components.types import ( DocumentInfo, Neo4jGraph, TextChunk, TextChunks, + SchemaEnforcementMode, ) from neo4j_graphrag.experimental.pipeline.exceptions import InvalidJSONError from neo4j_graphrag.llm import LLMInterface, LLMResponse @@ -229,6 +231,239 @@ async def test_extractor_custom_prompt() -> None: llm.ainvoke.assert_called_once_with("this is my prompt") +@pytest.mark.asyncio +async def test_extractor_no_schema_enforcement() -> None: + llm = MagicMock(spec=LLMInterface) + llm.ainvoke.return_value = LLMResponse( + content='{"nodes":[{"id":"0","label":"Alien","properties":{"foo":"bar"}}],' + '"relationships":[]}' + ) + + extractor = LLMEntityRelationExtractor(llm=llm, + create_lexical_graph=False, + enforce_schema=SchemaEnforcementMode.NONE) + + schema = SchemaConfig(entities={"Person": {"name": "STRING"}}, + relations={}, + potential_schema=[]) + + chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) + + result: Neo4jGraph = await extractor.run(chunks=chunks, schema=schema) + + assert len(result.nodes) == 1 + assert result.nodes[0].label == "Alien" + assert result.nodes[0].properties == {"chunk_index": 0, "foo": "bar"} + + +@pytest.mark.asyncio +async def test_extractor_schema_enforcement_when_no_schema_provided(): + llm = MagicMock(spec=LLMInterface) + llm.ainvoke.return_value = LLMResponse( + content='{"nodes":[{"id":"0","label":"Alien","properties":{"foo":"bar"}}],' + '"relationships":[]}' + ) + + extractor = LLMEntityRelationExtractor(llm=llm, + create_lexical_graph=False, + enforce_schema=SchemaEnforcementMode.STRICT) + + chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) + + result: Neo4jGraph = await extractor.run(chunks=chunks) + + assert len(result.nodes) == 1 + assert result.nodes[0].label == "Alien" + assert result.nodes[0].properties == {"chunk_index": 0, "foo": "bar"} + + +@pytest.mark.asyncio +async def test_extractor_schema_enforcement_invalid_nodes(): + llm = MagicMock(spec=LLMInterface) + llm.ainvoke.return_value = LLMResponse( + content='{"nodes":[{"id":"0","label":"Alien","properties":{"foo":"bar"}},' + '{"id":"1","label":"Person","properties":{"name":"Alice"}}],' + '"relationships":[]}' + ) + + extractor = LLMEntityRelationExtractor(llm=llm, + create_lexical_graph=False, + enforce_schema=SchemaEnforcementMode.STRICT) + + schema = SchemaConfig(entities={"Person": {"name": "STRING"}}, + relations={}, + potential_schema=[]) + + chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) + + result: Neo4jGraph = await extractor.run(chunks=chunks, schema=schema) + + assert len(result.nodes) == 1 + assert result.nodes[0].label == "Person" + assert result.nodes[0].properties == {"chunk_index": 0, "name": "Alice"} + + +@pytest.mark.asyncio +async def test_extraction_schema_enforcement_invalid_node_properties(): + llm = MagicMock(spec=LLMInterface) + llm.ainvoke.return_value = LLMResponse( + content='{"nodes":[{"id":"1","label":"Person","properties":' + '{"name":"Alice","age":30,"foo":"bar"}}],' + '"relationships":[]}' + ) + + extractor = LLMEntityRelationExtractor(llm=llm, + create_lexical_graph=False, + enforce_schema=SchemaEnforcementMode.STRICT) + + schema = SchemaConfig(entities={"Person": {"name": str, "age": int}}, + relations={}, + potential_schema=[]) + + chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) + + result: Neo4jGraph = await extractor.run(chunks, schema=schema) + + # "foo" is removed + assert len(result.nodes) == 1 + assert len(result.nodes[0].properties) == 3 + assert "foo" not in result.nodes[0].properties + + +@pytest.mark.asyncio +async def test_extractor_schema_enforcement_valid_nodes_with_empty_props(): + llm = MagicMock(spec=LLMInterface) + llm.ainvoke.return_value = LLMResponse( + content='{"nodes":[{"id":"1","label":"Person","properties":{"foo":"bar"}}],' + '"relationships":[]}' + ) + + extractor = LLMEntityRelationExtractor(llm=llm, + create_lexical_graph=False, + enforce_schema=SchemaEnforcementMode.STRICT) + + schema = SchemaConfig(entities={"Person": {}}, + relations={}, + potential_schema=[]) + + chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) + + result: Neo4jGraph = await extractor.run(chunks, schema=schema) + + assert len(result.nodes) == 0 + + +@pytest.mark.asyncio +async def test_extractor_schema_enforcement_invalid_relations_wrong_types(): + llm = MagicMock(spec=LLMInterface) + llm.ainvoke.return_value = LLMResponse( + content='{"nodes":[{"id":"1","label":"Person","properties":' + '{"name":"Alice"}},{"id":"2","label":"Person","properties":' + '{"name":"Bob"}}],' + '"relationships":[{"start_node_id":"1","end_node_id":"2",' + '"type":"FRIENDS_WITH","properties":{}}]}' + ) + + extractor = LLMEntityRelationExtractor(llm=llm, + create_lexical_graph=False, + enforce_schema=SchemaEnforcementMode.STRICT) + + schema = SchemaConfig(entities={"Person": {"name": str}}, + relations={"LIKES": {}}, + potential_schema=[]) + + chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) + + result: Neo4jGraph = await extractor.run(chunks, schema=schema) + + assert len(result.nodes) == 2 + assert len(result.relationships) == 0 + + +@pytest.mark.asyncio +async def test_extractor_schema_enforcement_invalid_relations_wrong_start_node(): + llm = MagicMock(spec=LLMInterface) + llm.ainvoke.return_value = LLMResponse( + content='{"nodes":[{"id":"1","label":"Person","properties":{"name":"Alice"}},' + '{"id":"2","label":"Person","properties":{"name":"Bob"}}, ' + '{"id":"3","label":"City","properties":{"name":"London"}}],' + '"relationships":[{"start_node_id":"1","end_node_id":"2",' + '"type":"LIVES_IN","properties":{}}]}' + ) + + extractor = LLMEntityRelationExtractor(llm=llm, + create_lexical_graph=False, + enforce_schema=SchemaEnforcementMode.STRICT) + + schema = SchemaConfig(entities={"Person": {"name": str}, "City": {"name": str}}, + relations={"LIVES_IN": {}}, + potential_schema=[("Person", "LIVES_IN", "City")]) + + chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) + + result: Neo4jGraph = await extractor.run(chunks, schema=schema) + + assert len(result.nodes) == 3 + assert len(result.relationships) == 0 + + +@pytest.mark.asyncio +async def test_extractor_schema_enforcement_invalid_relation_properties(): + llm = MagicMock(spec=LLMInterface) + llm.ainvoke.return_value = LLMResponse( + content='{"nodes":[{"id":"1","label":"Person","properties":{"name":"Alice"}},' + '{"id":"2","label":"Person","properties":{"name":"Bob"}}],' + '"relationships":[{"start_node_id":"1","end_node_id":"2",' + '"type":"LIKES","properties":{"strength":"high","foo":"bar"}}]}' + ) + + extractor = LLMEntityRelationExtractor(llm=llm, + create_lexical_graph=False, + enforce_schema=SchemaEnforcementMode.STRICT) + + schema = SchemaConfig( + entities={"Person": {"name": str}}, + relations={"LIKES": {"strength": str}}, + potential_schema=[] + ) + + chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) + + result: Neo4jGraph = await extractor.run(chunks, schema=schema) + + assert len(result.nodes) == 2 + assert len(result.relationships) == 1 + rel = result.relationships[0] + assert "foo" not in rel.properties + assert rel.properties["strength"] == "high" + + +@pytest.mark.asyncio +async def test_extractor_schema_enforcement_removed_relation_start_end_nodes(): + llm = MagicMock(spec=LLMInterface) + llm.ainvoke.return_value = LLMResponse( + content='{"nodes":[{"id":"1","label":"Alien","properties":{}},' + '{"id":"2","label":"Robot","properties":{}}],' + '"relationships":[{"start_node_id":"1","end_node_id":"2",' + '"type":"LIKES","properties":{}}]}' + ) + + extractor = LLMEntityRelationExtractor(llm=llm, + create_lexical_graph=False, + enforce_schema=SchemaEnforcementMode.STRICT) + + schema = SchemaConfig(entities={"Person": {"name": str}}, + relations={"LIKES": {}}, + potential_schema=[("Person", "LIKES", "Person")]) + + chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) + + result: Neo4jGraph = await extractor.run(chunks, schema=schema) + + assert len(result.nodes) == 0 + assert len(result.relationships) == 0 + + def test_fix_invalid_json_empty_result() -> None: json_string = "invalid json" From d2ba8626ff7e1d7fa4464bcaf04aea47dd7534d5 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Thu, 6 Mar 2025 12:40:41 +0100 Subject: [PATCH 03/10] Update change log and docs --- CHANGELOG.md | 2 ++ docs/source/user_guide_kg_builder.rst | 35 ++++++++++++++++++++++++++- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 09d264afc..88ee296d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,8 @@ ## Next ### Added + +- Added optional schema enforcement as a validation layer after entity and relation extraction. - Introduced SearchQueryParseError for handling invalid Lucene query strings in HybridRetriever and HybridCypherRetriever. ## 1.5.0 diff --git a/docs/source/user_guide_kg_builder.rst b/docs/source/user_guide_kg_builder.rst index edc89f609..8d9224548 100644 --- a/docs/source/user_guide_kg_builder.rst +++ b/docs/source/user_guide_kg_builder.rst @@ -125,7 +125,7 @@ This schema information can be provided to the `SimpleKGBuilder` as demonstrated # ... ) -Prompt Template, Lexical Graph Config and Error Behavior +Prompt Template, Lexical Graph Config, Schema Enforcement, and Error Behavior -------------------------------------------------------- These parameters are part of the `EntityAndRelationExtractor` component. @@ -138,6 +138,7 @@ They are also accessible via the `SimpleKGPipeline` interface. # ... prompt_template="", lexical_graph_config=my_config, + enforce_schema=SchemaEnforcementMode.Strict on_error="RAISE", # ... ) @@ -829,6 +830,38 @@ It can be used in this way: The LLM to use can be customized, the only constraint is that it obeys the :ref:`LLMInterface `. +Schema Enforcement Behaviour +--------------- +By default, even if a schema is provided to guide the LLM in the entity and relation extraction, the LLM response is not validated against that schema. +This behaviour can be changed by using the `enforce_schema` flag in the `LLMEntityRelationExtractor` constructor: + +.. code:: python + + from neo4j_graphrag.experimental.components.entity_relation_extractor import LLMEntityRelationExtractor + from neo4j_graphrag.experimental.components.types import SchemaEnforcementMode + + extractor = LLMEntityRelationExtractor( + # ... + enforce_schema=SchemaEnforcementMode.STRICT, + ) + schema = SchemaConfig( + entities={"Label": {"name": str}}, + relations={"REL_TYPE": {}}, + potential_schema=[("Label", "REL_TYPE", "Label")] + ) + + #.... + result = await extractor.run( + #... + schema=schema + ) + +In this scenario, any extracted node/relation/property that is not part of the provided schema will be pruned. +Any relation whose start node or end node does not conform to the provided tuple in `potential_schema` will be pruned. +If a node is left with no properties, it will be also pruned. + +Note that if the schema enforcement mode is on but the schema is not provided, no schema enforcement will be applied. + Error Behaviour --------------- From ca049fd591613e307bb4adb928d72a14f4ef0aca Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Thu, 6 Mar 2025 16:44:51 +0100 Subject: [PATCH 04/10] Fix documentation --- docs/source/user_guide_kg_builder.rst | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/docs/source/user_guide_kg_builder.rst b/docs/source/user_guide_kg_builder.rst index 8d9224548..e5fca03ac 100644 --- a/docs/source/user_guide_kg_builder.rst +++ b/docs/source/user_guide_kg_builder.rst @@ -831,7 +831,7 @@ It can be used in this way: The LLM to use can be customized, the only constraint is that it obeys the :ref:`LLMInterface `. Schema Enforcement Behaviour ---------------- +---------------------------- By default, even if a schema is provided to guide the LLM in the entity and relation extraction, the LLM response is not validated against that schema. This behaviour can be changed by using the `enforce_schema` flag in the `LLMEntityRelationExtractor` constructor: @@ -844,23 +844,14 @@ This behaviour can be changed by using the `enforce_schema` flag in the `LLMEnti # ... enforce_schema=SchemaEnforcementMode.STRICT, ) - schema = SchemaConfig( - entities={"Label": {"name": str}}, - relations={"REL_TYPE": {}}, - potential_schema=[("Label", "REL_TYPE", "Label")] - ) - - #.... - result = await extractor.run( - #... - schema=schema - ) In this scenario, any extracted node/relation/property that is not part of the provided schema will be pruned. Any relation whose start node or end node does not conform to the provided tuple in `potential_schema` will be pruned. If a node is left with no properties, it will be also pruned. -Note that if the schema enforcement mode is on but the schema is not provided, no schema enforcement will be applied. +.. warning:: + + Note that if the schema enforcement mode is on but the schema is not provided, no schema enforcement will be applied. Error Behaviour --------------- From 5a852c4bd475a56623380604773997872010740b Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Thu, 6 Mar 2025 17:25:55 +0100 Subject: [PATCH 05/10] Add warning when schema enforcement is on but schema not provided --- .../components/entity_relation_extractor.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py index e7439f598..e914daf88 100644 --- a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py +++ b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py @@ -354,12 +354,15 @@ def validate_chunk( Perform validation after entity and relation extraction: - Enforce schema if schema enforcement mode is on and schema is provided """ - # if enforcing_schema is on and schema is provided, clean the graph - return ( - self._clean_graph(chunk_graph, schema) - if self.enforce_schema != SchemaEnforcementMode.NONE and schema.entities - else chunk_graph - ) + if self.enforce_schema != SchemaEnforcementMode.NONE: + if not schema or not schema.entities: # schema is not provided + logger.warning( + "Schema enforcement is ON but the guiding schema is not provided." + ) + else: + # if enforcing_schema is on and schema is provided, clean the graph + return self._clean_graph(chunk_graph, schema) + return chunk_graph def _clean_graph( self, From ea5e5a981ae07155ff220ee014bde411fc586405 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Fri, 7 Mar 2025 12:52:15 +0100 Subject: [PATCH 06/10] Code cleanups --- .../components/entity_relation_extractor.py | 46 +++++---------- .../experimental/components/types.py | 2 - .../test_entity_relation_extractor.py | 59 ++++++++++++------- 3 files changed, 54 insertions(+), 53 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py index e914daf88..d3883ec5f 100644 --- a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py +++ b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py @@ -405,8 +405,9 @@ def _enforce_nodes( for node in extracted_nodes: if node.label in schema.entities: schema_entity = schema.entities[node.label] - filtered_props = self._enforce_properties(node.properties, - schema_entity) + filtered_props = self._enforce_properties( + node.properties, + schema_entity["properties"]) if filtered_props: # keep node only if it has at least one valid property new_node = Neo4jNode( @@ -416,8 +417,7 @@ def _enforce_nodes( embedding_properties=node.embedding_properties, ) valid_nodes.append(new_node) - # elif self.enforce_schema == SchemaEnforcementMode.OPEN: - # future logic + return valid_nodes def _enforce_relationships( @@ -435,22 +435,21 @@ def _enforce_relationships( """ valid_rels = [] if self.enforce_schema == SchemaEnforcementMode.STRICT: - valid_node_ids = {node.id for node in filtered_nodes} + valid_nodes = {node.id: node.label for node in filtered_nodes} for rel in extracted_relationships: # keep relationship if it conforms with the schema if rel.type in schema.relations: - if (rel.start_node_id in valid_node_ids and - rel.end_node_id in valid_node_ids): - start_node_label = self._get_node_label(rel.start_node_id, - filtered_nodes) - end_node_label = self._get_node_label(rel.end_node_id, - filtered_nodes) + if (rel.start_node_id in valid_nodes and + rel.end_node_id in valid_nodes): + start_node_label = valid_nodes[rel.start_node_id] + end_node_label = valid_nodes[rel.end_node_id] if (not schema.potential_schema or (start_node_label, rel.type, end_node_label) in schema.potential_schema): schema_relation = schema.relations[rel.type] - filtered_props = self._enforce_properties(rel.properties, - schema_relation) + filtered_props = self._enforce_properties( + rel.properties, + schema_relation["properties"]) new_rel = Neo4jRelationship( start_node_id=rel.start_node_id, end_node_id=rel.end_node_id, @@ -459,35 +458,22 @@ def _enforce_relationships( embedding_properties=rel.embedding_properties, ) valid_rels.append(new_rel) - # elif self.enforce_schema == SchemaEnforcementMode.OPEN: - # future logic + return valid_rels def _enforce_properties( self, properties: Dict[str, Any], - valid_properties: Dict[str, Any] + valid_properties: List[Dict[str, Any]] ) -> Dict[str, Any]: """ Filter properties. Keep only those that exist in schema (i.e., valid properties). """ + valid_prop_names = {prop["name"] for prop in valid_properties} return { key: value for key, value in properties.items() - if key in valid_properties + if key in valid_prop_names } - def _get_node_label( - self, - node_id: str, - nodes: List[Neo4jNode] - ) -> str: - """ - Given a list of nodes, get the label of the node whose id matches the provided - node id. - """ - for node in nodes: - if node.id == node_id: - return node.label - return "" diff --git a/src/neo4j_graphrag/experimental/components/types.py b/src/neo4j_graphrag/experimental/components/types.py index 247ac24aa..93075a7c7 100644 --- a/src/neo4j_graphrag/experimental/components/types.py +++ b/src/neo4j_graphrag/experimental/components/types.py @@ -176,5 +176,3 @@ class GraphResult(DataModel): class SchemaEnforcementMode(str, Enum): NONE = "none" STRICT = "strict" - # future possibility: OPEN = "open" -> ensure conformance of nodes/props/rels that - # were listed in the schema but leave room for extras \ No newline at end of file diff --git a/tests/unit/experimental/components/test_entity_relation_extractor.py b/tests/unit/experimental/components/test_entity_relation_extractor.py index a92390a7b..52f27b50b 100644 --- a/tests/unit/experimental/components/test_entity_relation_extractor.py +++ b/tests/unit/experimental/components/test_entity_relation_extractor.py @@ -243,9 +243,11 @@ async def test_extractor_no_schema_enforcement() -> None: create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.NONE) - schema = SchemaConfig(entities={"Person": {"name": "STRING"}}, - relations={}, - potential_schema=[]) + schema = SchemaConfig( + entities={"Person": {"label": "Person", + "properties": [{"name": "name", "type": "STRING"}]}}, + relations={}, + potential_schema=[]) chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) @@ -290,9 +292,11 @@ async def test_extractor_schema_enforcement_invalid_nodes(): create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT) - schema = SchemaConfig(entities={"Person": {"name": "STRING"}}, - relations={}, - potential_schema=[]) + schema = SchemaConfig( + entities={"Person": {"label": "Person", + "properties": [{"name": "name", "type": "STRING"}]}}, + relations={}, + potential_schema=[]) chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) @@ -316,9 +320,12 @@ async def test_extraction_schema_enforcement_invalid_node_properties(): create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT) - schema = SchemaConfig(entities={"Person": {"name": str, "age": int}}, - relations={}, - potential_schema=[]) + schema = SchemaConfig( + entities={"Person": {"label": "Person", + "properties": [{"name": "name", "type": "STRING"}, + {"name": "age", "type": "INTEGER"}]}}, + relations={}, + potential_schema=[]) chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) @@ -342,7 +349,7 @@ async def test_extractor_schema_enforcement_valid_nodes_with_empty_props(): create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT) - schema = SchemaConfig(entities={"Person": {}}, + schema = SchemaConfig(entities={"Person": {"label": "Person", "properties": []}}, relations={}, potential_schema=[]) @@ -368,9 +375,11 @@ async def test_extractor_schema_enforcement_invalid_relations_wrong_types(): create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT) - schema = SchemaConfig(entities={"Person": {"name": str}}, - relations={"LIKES": {}}, - potential_schema=[]) + schema = SchemaConfig( + entities={"Person": {"label": "Person", + "properties": [{"name": "name", "type": "STRING"}]}}, + relations={"LIKES": {"label": "LIKES", "properties": []}}, + potential_schema=[]) chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) @@ -395,9 +404,13 @@ async def test_extractor_schema_enforcement_invalid_relations_wrong_start_node() create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT) - schema = SchemaConfig(entities={"Person": {"name": str}, "City": {"name": str}}, - relations={"LIVES_IN": {}}, - potential_schema=[("Person", "LIVES_IN", "City")]) + schema = SchemaConfig( + entities={"Person": {"label": "Person", + "properties": [{"name": "name", "type": "STRING"}]}, + "City": {"label": "City", + "properties": [{"name": "name", "type": "STRING"}]}}, + relations={"LIVES_IN": {"label": "LIVES_IN", "properties": []}}, + potential_schema=[("Person", "LIVES_IN", "City")]) chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) @@ -422,8 +435,10 @@ async def test_extractor_schema_enforcement_invalid_relation_properties(): enforce_schema=SchemaEnforcementMode.STRICT) schema = SchemaConfig( - entities={"Person": {"name": str}}, - relations={"LIKES": {"strength": str}}, + entities={"Person": {"label": "Person", + "properties": [{"name": "name", "type": "STRING"}]}}, + relations={"LIKES": {"label": "LIKES", + "properties": [{"name": "strength", "type": "STRING"}]}}, potential_schema=[] ) @@ -452,9 +467,11 @@ async def test_extractor_schema_enforcement_removed_relation_start_end_nodes(): create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT) - schema = SchemaConfig(entities={"Person": {"name": str}}, - relations={"LIKES": {}}, - potential_schema=[("Person", "LIKES", "Person")]) + schema = SchemaConfig( + entities={"Person": {"label": "Person", + "properties": [{"name": "name", "type": "STRING"}]}}, + relations={"LIKES": {"label": "LIKES", "properties": []}}, + potential_schema=[("Person", "LIKES", "Person")]) chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) From d3dee7579fd26701c4e540a7835434a8c9085f6c Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Mon, 10 Mar 2025 12:10:32 +0100 Subject: [PATCH 07/10] Improve code for more clarity --- .../components/entity_relation_extractor.py | 95 +++++++++++-------- .../test_entity_relation_extractor.py | 8 +- 2 files changed, 59 insertions(+), 44 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py index d3883ec5f..3ae51238e 100644 --- a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py +++ b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py @@ -400,23 +400,26 @@ def _enforce_nodes( For each valid node, filter out properties not present in the schema. Remove a node if it ends up with no valid properties. """ + if self.enforce_schema != SchemaEnforcementMode.STRICT: + return extracted_nodes + valid_nodes = [] - if self.enforce_schema == SchemaEnforcementMode.STRICT: - for node in extracted_nodes: - if node.label in schema.entities: - schema_entity = schema.entities[node.label] - filtered_props = self._enforce_properties( - node.properties, - schema_entity["properties"]) - if filtered_props: - # keep node only if it has at least one valid property - new_node = Neo4jNode( - id=node.id, - label=node.label, - properties=filtered_props, - embedding_properties=node.embedding_properties, - ) - valid_nodes.append(new_node) + + for node in extracted_nodes: + schema_entity = schema.entities.get(node.label) + if not schema_entity: + continue + allowed_props = schema_entity.get("properties", {}) + filtered_props = self._enforce_properties(node.properties, allowed_props) + if filtered_props: + valid_nodes.append( + Neo4jNode( + id=node.id, + label=node.label, + properties=filtered_props, + embedding_properties=node.embedding_properties, + ) + ) return valid_nodes @@ -433,31 +436,43 @@ def _enforce_relationships( and start/end nodes are in filtered nodes (i.e., kept after node enforcement). For each valid relationship, filter out properties not present in the schema. """ + if self.enforce_schema != SchemaEnforcementMode.STRICT: + return extracted_relationships + valid_rels = [] - if self.enforce_schema == SchemaEnforcementMode.STRICT: - valid_nodes = {node.id: node.label for node in filtered_nodes} - for rel in extracted_relationships: - # keep relationship if it conforms with the schema - if rel.type in schema.relations: - if (rel.start_node_id in valid_nodes and - rel.end_node_id in valid_nodes): - start_node_label = valid_nodes[rel.start_node_id] - end_node_label = valid_nodes[rel.end_node_id] - if (not schema.potential_schema or - (start_node_label, rel.type, end_node_label) in - schema.potential_schema): - schema_relation = schema.relations[rel.type] - filtered_props = self._enforce_properties( - rel.properties, - schema_relation["properties"]) - new_rel = Neo4jRelationship( - start_node_id=rel.start_node_id, - end_node_id=rel.end_node_id, - type=rel.type, - properties=filtered_props, - embedding_properties=rel.embedding_properties, - ) - valid_rels.append(new_rel) + + valid_nodes = {node.id: node.label for node in filtered_nodes} + + potential_schema = schema.potential_schema + + for rel in extracted_relationships: + schema_relation = schema.relations.get(rel.type) + if not schema_relation: + continue + + if (rel.start_node_id not in valid_nodes or + rel.end_node_id not in valid_nodes): + continue + + start_label = valid_nodes[rel.start_node_id] + end_label = valid_nodes[rel.end_node_id] + + if (potential_schema and + (start_label, rel.type, end_label) not in potential_schema): + continue + + allowed_props = schema_relation.get("properties", {}) + filtered_props = self._enforce_properties(rel.properties, allowed_props) + + valid_rels.append( + Neo4jRelationship( + start_node_id=rel.start_node_id, + end_node_id=rel.end_node_id, + type=rel.type, + properties=filtered_props, + embedding_properties=rel.embedding_properties, + ) + ) return valid_rels diff --git a/tests/unit/experimental/components/test_entity_relation_extractor.py b/tests/unit/experimental/components/test_entity_relation_extractor.py index 52f27b50b..407a5d6b8 100644 --- a/tests/unit/experimental/components/test_entity_relation_extractor.py +++ b/tests/unit/experimental/components/test_entity_relation_extractor.py @@ -349,7 +349,7 @@ async def test_extractor_schema_enforcement_valid_nodes_with_empty_props(): create_lexical_graph=False, enforce_schema=SchemaEnforcementMode.STRICT) - schema = SchemaConfig(entities={"Person": {"label": "Person", "properties": []}}, + schema = SchemaConfig(entities={"Person": {"label": "Person"}}, relations={}, potential_schema=[]) @@ -378,7 +378,7 @@ async def test_extractor_schema_enforcement_invalid_relations_wrong_types(): schema = SchemaConfig( entities={"Person": {"label": "Person", "properties": [{"name": "name", "type": "STRING"}]}}, - relations={"LIKES": {"label": "LIKES", "properties": []}}, + relations={"LIKES": {"label": "LIKES"}}, potential_schema=[]) chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) @@ -409,7 +409,7 @@ async def test_extractor_schema_enforcement_invalid_relations_wrong_start_node() "properties": [{"name": "name", "type": "STRING"}]}, "City": {"label": "City", "properties": [{"name": "name", "type": "STRING"}]}}, - relations={"LIVES_IN": {"label": "LIVES_IN", "properties": []}}, + relations={"LIVES_IN": {"label": "LIVES_IN"}}, potential_schema=[("Person", "LIVES_IN", "City")]) chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) @@ -470,7 +470,7 @@ async def test_extractor_schema_enforcement_removed_relation_start_end_nodes(): schema = SchemaConfig( entities={"Person": {"label": "Person", "properties": [{"name": "name", "type": "STRING"}]}}, - relations={"LIKES": {"label": "LIKES", "properties": []}}, + relations={"LIKES": {"label": "LIKES"}}, potential_schema=[("Person", "LIKES", "Person")]) chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) From c891f3bc6374190123303304b62415113e0b6b05 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Tue, 11 Mar 2025 18:03:05 +0100 Subject: [PATCH 08/10] Apply changes requested by the PR review --- docs/source/user_guide_kg_builder.rst | 6 +++--- .../experimental/components/entity_relation_extractor.py | 4 ++-- src/neo4j_graphrag/experimental/components/types.py | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/source/user_guide_kg_builder.rst b/docs/source/user_guide_kg_builder.rst index e5fca03ac..3a6715bd8 100644 --- a/docs/source/user_guide_kg_builder.rst +++ b/docs/source/user_guide_kg_builder.rst @@ -125,8 +125,8 @@ This schema information can be provided to the `SimpleKGBuilder` as demonstrated # ... ) -Prompt Template, Lexical Graph Config, Schema Enforcement, and Error Behavior --------------------------------------------------------- +Extra configurations +-------------------- These parameters are part of the `EntityAndRelationExtractor` component. For detailed information, refer to the section on :ref:`Entity and Relation Extractor`. @@ -138,7 +138,7 @@ They are also accessible via the `SimpleKGPipeline` interface. # ... prompt_template="", lexical_graph_config=my_config, - enforce_schema=SchemaEnforcementMode.Strict + enforce_schema="STRICT" on_error="RAISE", # ... ) diff --git a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py index 3ae51238e..be2c3996a 100644 --- a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py +++ b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py @@ -409,7 +409,7 @@ def _enforce_nodes( schema_entity = schema.entities.get(node.label) if not schema_entity: continue - allowed_props = schema_entity.get("properties", {}) + allowed_props = schema_entity.get("properties", []) filtered_props = self._enforce_properties(node.properties, allowed_props) if filtered_props: valid_nodes.append( @@ -461,7 +461,7 @@ def _enforce_relationships( (start_label, rel.type, end_label) not in potential_schema): continue - allowed_props = schema_relation.get("properties", {}) + allowed_props = schema_relation.get("properties", []) filtered_props = self._enforce_properties(rel.properties, allowed_props) valid_rels.append( diff --git a/src/neo4j_graphrag/experimental/components/types.py b/src/neo4j_graphrag/experimental/components/types.py index 93075a7c7..b1ed46569 100644 --- a/src/neo4j_graphrag/experimental/components/types.py +++ b/src/neo4j_graphrag/experimental/components/types.py @@ -174,5 +174,5 @@ class GraphResult(DataModel): class SchemaEnforcementMode(str, Enum): - NONE = "none" - STRICT = "strict" + NONE = "NONE" + STRICT = "STRICT" From 48a2fbb41596563653f034ed826a37314c7bdf2b Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Tue, 11 Mar 2025 19:26:44 +0100 Subject: [PATCH 09/10] Invert rel direction --- docs/source/user_guide_kg_builder.rst | 1 + .../components/entity_relation_extractor.py | 16 +++++++--- .../test_entity_relation_extractor.py | 32 +++++++++++++++++++ 3 files changed, 44 insertions(+), 5 deletions(-) diff --git a/docs/source/user_guide_kg_builder.rst b/docs/source/user_guide_kg_builder.rst index 3a6715bd8..20a7db63f 100644 --- a/docs/source/user_guide_kg_builder.rst +++ b/docs/source/user_guide_kg_builder.rst @@ -847,6 +847,7 @@ This behaviour can be changed by using the `enforce_schema` flag in the `LLMEnti In this scenario, any extracted node/relation/property that is not part of the provided schema will be pruned. Any relation whose start node or end node does not conform to the provided tuple in `potential_schema` will be pruned. +If a relation start/end nodes are valid but the direction is incorrect, the latter will be inverted. If a node is left with no properties, it will be also pruned. .. warning:: diff --git a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py index be2c3996a..2137cfa27 100644 --- a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py +++ b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py @@ -435,6 +435,7 @@ def _enforce_relationships( Keep only those whose types are in schema, start/end node conform to schema, and start/end nodes are in filtered nodes (i.e., kept after node enforcement). For each valid relationship, filter out properties not present in the schema. + If a relationship direct is incorrect, invert it. """ if self.enforce_schema != SchemaEnforcementMode.STRICT: return extracted_relationships @@ -457,17 +458,22 @@ def _enforce_relationships( start_label = valid_nodes[rel.start_node_id] end_label = valid_nodes[rel.end_node_id] - if (potential_schema and - (start_label, rel.type, end_label) not in potential_schema): - continue + tuple_valid = True + if potential_schema: + tuple_valid = (start_label, rel.type, end_label) in potential_schema + reverse_tuple_valid = ((end_label, rel.type, start_label) in + potential_schema) + + if not tuple_valid and not reverse_tuple_valid: + continue allowed_props = schema_relation.get("properties", []) filtered_props = self._enforce_properties(rel.properties, allowed_props) valid_rels.append( Neo4jRelationship( - start_node_id=rel.start_node_id, - end_node_id=rel.end_node_id, + start_node_id=rel.start_node_id if tuple_valid else rel.end_node_id, + end_node_id=rel.end_node_id if tuple_valid else rel.start_node_id, type=rel.type, properties=filtered_props, embedding_properties=rel.embedding_properties, diff --git a/tests/unit/experimental/components/test_entity_relation_extractor.py b/tests/unit/experimental/components/test_entity_relation_extractor.py index 407a5d6b8..f117c1893 100644 --- a/tests/unit/experimental/components/test_entity_relation_extractor.py +++ b/tests/unit/experimental/components/test_entity_relation_extractor.py @@ -481,6 +481,38 @@ async def test_extractor_schema_enforcement_removed_relation_start_end_nodes(): assert len(result.relationships) == 0 +@pytest.mark.asyncio +async def test_extractor_schema_enforcement_inverted_relation_direction(): + llm = MagicMock(spec=LLMInterface) + llm.ainvoke.return_value = LLMResponse( + content='{"nodes":[{"id":"1","label":"Person","properties":{"name":"Alice"}},' + '{"id":"2","label":"City","properties":{"name":"London"}}],' + '"relationships":[{"start_node_id":"2","end_node_id":"1",' + '"type":"LIVES_IN","properties":{}}]}' + ) + + extractor = LLMEntityRelationExtractor(llm=llm, + create_lexical_graph=False, + enforce_schema=SchemaEnforcementMode.STRICT) + + schema = SchemaConfig( + entities={"Person": {"label": "Person", + "properties": [{"name": "name", "type": "STRING"}]}, + "City": {"label": "City", + "properties": [{"name": "name", "type": "STRING"}]}}, + relations={"LIVES_IN": {"label": "LIVES_IN"}}, + potential_schema=[("Person", "LIVES_IN", "City")]) + + chunks = TextChunks(chunks=[TextChunk(text="some text", index=0)]) + + result: Neo4jGraph = await extractor.run(chunks, schema=schema) + + assert len(result.nodes) == 2 + assert len(result.relationships) == 1 + assert result.relationships[0].start_node_id.split(":")[1] == "1" + assert result.relationships[0].end_node_id.split(":")[1] == "2" + + def test_fix_invalid_json_empty_result() -> None: json_string = "invalid json" From cf9fe8651aa7faf2786596319d1725da8b403bf7 Mon Sep 17 00:00:00 2001 From: nathaliecharbel Date: Wed, 12 Mar 2025 10:03:36 +0100 Subject: [PATCH 10/10] Adapt SimpleKGPipelineConfig --- .../pipeline/config/template_pipeline/simple_kg_builder.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py index 14ee112ab..2e45a2db1 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py @@ -37,7 +37,10 @@ from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import ( FixedSizeSplitter, ) -from neo4j_graphrag.experimental.components.types import LexicalGraphConfig +from neo4j_graphrag.experimental.components.types import ( + LexicalGraphConfig, + SchemaEnforcementMode +) from neo4j_graphrag.experimental.pipeline.config.object_config import ComponentType from neo4j_graphrag.experimental.pipeline.config.template_pipeline.base import ( TemplatePipelineConfig, @@ -71,6 +74,7 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig): entities: Sequence[EntityInputType] = [] relations: Sequence[RelationInputType] = [] potential_schema: Optional[list[tuple[str, str, str]]] = None + enforce_schema: SchemaEnforcementMode = SchemaEnforcementMode.NONE on_error: OnError = OnError.IGNORE prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate() perform_entity_resolution: bool = True @@ -124,6 +128,7 @@ def _get_extractor(self) -> EntityRelationExtractor: return LLMEntityRelationExtractor( llm=self.get_default_llm(), prompt_template=self.prompt_template, + enforce_schema=self.enforce_schema, on_error=self.on_error, )