diff --git a/CHANGELOG.md b/CHANGELOG.md index 65962d2b6..a58d76e6c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,9 @@ - Added the `run_with_context` method to `Component`. This method includes a `context_` parameter, which provides information about the pipeline from which the component is executed (e.g., the `run_id`). It also enables the component to send events to the pipeline's callback function. +### Fixed + +- Added `enforce_schema` parameter to `SimpleKGPipeline` for optional schema enforcement. ## 1.6.0 diff --git a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py index 6a809b766..ec5ed5218 100644 --- a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py @@ -25,7 +25,10 @@ from neo4j_graphrag.experimental.components.kg_writer import KGWriter from neo4j_graphrag.experimental.components.pdf_loader import DataLoader from neo4j_graphrag.experimental.components.text_splitters.base import TextSplitter -from neo4j_graphrag.experimental.components.types import LexicalGraphConfig +from neo4j_graphrag.experimental.components.types import ( + LexicalGraphConfig, + SchemaEnforcementMode, +) from neo4j_graphrag.experimental.pipeline.config.object_config import ComponentType from neo4j_graphrag.experimental.pipeline.config.runner import PipelineRunner from neo4j_graphrag.experimental.pipeline.config.template_pipeline import ( @@ -61,6 +64,7 @@ class SimpleKGPipeline: - dict: following the SchemaRelation schema, ie with label, description and properties keys potential_schema (Optional[List[tuple]]): A list of potential schema relationships. + enforce_schema (str): Validation of the extracted entities/rels against the provided schema. Defaults to "NONE", where schema enforcement will be ignored even if the schema is provided. Possible values "None" or "STRICT". from_pdf (bool): Determines whether to include the PdfLoader in the pipeline. If True, expects `file_path` input in `run` methods. If False, expects `text` input in `run` methods. @@ -81,6 +85,7 @@ def __init__( entities: Optional[Sequence[EntityInputType]] = None, relations: Optional[Sequence[RelationInputType]] = None, potential_schema: Optional[List[tuple[str, str, str]]] = None, + enforce_schema: str = "NONE", from_pdf: bool = True, text_splitter: Optional[TextSplitter] = None, pdf_loader: Optional[DataLoader] = None, @@ -100,6 +105,7 @@ def __init__( entities=entities or [], relations=relations or [], potential_schema=potential_schema, + enforce_schema=SchemaEnforcementMode(enforce_schema), from_pdf=from_pdf, pdf_loader=ComponentType(pdf_loader) if pdf_loader else None, kg_writer=ComponentType(kg_writer) if kg_writer else None, diff --git a/tests/unit/experimental/pipeline/test_kg_builder.py b/tests/unit/experimental/pipeline/test_kg_builder.py index d95abeed3..a6d3d4c42 100644 --- a/tests/unit/experimental/pipeline/test_kg_builder.py +++ b/tests/unit/experimental/pipeline/test_kg_builder.py @@ -151,6 +151,20 @@ def test_simple_kg_pipeline_on_error_invalid_value() -> None: ) +def test_simple_kg_pipeline_enforce_schema_invalid_value() -> None: + llm = MagicMock(spec=LLMInterface) + driver = MagicMock(spec=neo4j.Driver) + embedder = MagicMock(spec=Embedder) + + with pytest.raises(PipelineDefinitionError): + SimpleKGPipeline( + llm=llm, + driver=driver, + embedder=embedder, + enforce_schema="INVALID_VALUE", + ) + + @mock.patch( "neo4j_graphrag.experimental.components.kg_writer.get_version", return_value=((5, 23, 0), False, False),