diff --git a/CHANGELOG.md b/CHANGELOG.md index 65962d2b6..acd1919f9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ ### Added - Added the `run_with_context` method to `Component`. This method includes a `context_` parameter, which provides information about the pipeline from which the component is executed (e.g., the `run_id`). It also enables the component to send events to the pipeline's callback function. +- Exposed `schema_builder`, `chunk_embedder`, `extractor` and `resolver` in the `SimpleKGPipeline` constructor so that they can be customized. ## 1.6.0 diff --git a/docs/source/api.rst b/docs/source/api.rst index 2ca19d9b2..e70614139 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -102,6 +102,14 @@ Neo4jWriter .. autoclass:: neo4j_graphrag.experimental.components.kg_writer.Neo4jWriter :members: run + +EntityResolver +============== + +.. autoclass:: neo4j_graphrag.experimental.components.resolver.EntityResolver + :members: run + + SinglePropertyExactMatchResolver ================================ diff --git a/docs/source/user_guide_kg_builder.rst b/docs/source/user_guide_kg_builder.rst index 20a7db63f..49b178234 100644 --- a/docs/source/user_guide_kg_builder.rst +++ b/docs/source/user_guide_kg_builder.rst @@ -178,9 +178,13 @@ For advanced customization or when using a custom implementation, you can pass instances of specific components to the `SimpleKGPipeline`. The components that can customized at the moment are: -- `text_splitter`: must be an instance of :ref:`TextSplitter` - `pdf_loader`: must be an instance of :ref:`PdfLoader` +- `schema_builder`: must be an instance of :ref:`SchemaBuilder` +- `text_splitter`: must be an instance of :ref:`TextSplitter` +- `chunk_embedder`: must be an instance of :ref:`TextChunkEmbedder` +- `extractor`: must be an instance of :ref:`EntityRelationExtractor` - `kg_writer`: must be an instance of :ref:`KGWriter` +- `resolver`: must be an instance of :ref:`EntityResolver` For instance, the following code can be used to customize the chunk size and chunk overlap in the text splitter component: @@ -200,6 +204,24 @@ chunk overlap in the text splitter component: ) +.. warning:: + + When providing a custom component, all other related parameters in the SimpleKGPipeline constructor are ignored. For instance, in the following example: + + .. code:: python + + kg_builder = SimpleKGPipeline( + # ... + writer=Neo4jKGWriter(neo4j_database="db_1"), + neo4j_database="db_2", + # ... + ) + + + The graph will be saved to the **db_1** database. + + + Using a Config file =================== diff --git a/src/neo4j_graphrag/experimental/pipeline/config/object_config.py b/src/neo4j_graphrag/experimental/pipeline/config/object_config.py index 95d69888d..792af7440 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/object_config.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/object_config.py @@ -244,7 +244,10 @@ def parse(self, resolved_data: dict[str, Any] | None = None) -> Embedder: return self.root.parse(resolved_data) -class ComponentConfig(ObjectConfig[Component]): +ComponentGeneric = TypeVar("ComponentGeneric") + + +class ComponentConfig(ObjectConfig[ComponentGeneric], Generic[ComponentGeneric]): """A config model for all components. In addition to the object config, components can have pre-defined parameters @@ -256,22 +259,27 @@ class ComponentConfig(ObjectConfig[Component]): DEFAULT_MODULE = "neo4j_graphrag.experimental.components" INTERFACE = Component + model_config = ConfigDict(arbitrary_types_allowed=True) + def get_run_params(self, resolved_data: dict[str, Any]) -> dict[str, Any]: self._global_data = resolved_data return self.resolve_params(self.run_params_) -class ComponentType(RootModel): # type: ignore[type-arg] - root: Union[Component, ComponentConfig] +class ComponentType( + RootModel[Union[ComponentGeneric, ComponentConfig[ComponentGeneric]]], + Generic[ComponentGeneric], +): + root: Union[ComponentGeneric, ComponentConfig[ComponentGeneric]] model_config = ConfigDict(arbitrary_types_allowed=True) - def parse(self, resolved_data: dict[str, Any] | None = None) -> Component: - if isinstance(self.root, Component): - return self.root - return self.root.parse(resolved_data) + def parse(self, resolved_data: dict[str, Any] | None = None) -> ComponentGeneric: + if isinstance(self.root, ComponentConfig): + return self.root.parse(resolved_data) + return self.root def get_run_params(self, resolved_data: dict[str, Any]) -> dict[str, Any]: - if isinstance(self.root, Component): - return {} - return self.root.get_run_params(resolved_data) + if isinstance(self.root, ComponentConfig): + return self.root.get_run_params(resolved_data) + return {} diff --git a/src/neo4j_graphrag/experimental/pipeline/config/pipeline_config.py b/src/neo4j_graphrag/experimental/pipeline/config/pipeline_config.py index 92f9968f1..a75598743 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/pipeline_config.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/pipeline_config.py @@ -20,6 +20,7 @@ from pydantic import field_validator from neo4j_graphrag.embeddings import Embedder +from neo4j_graphrag.experimental.pipeline import Component from neo4j_graphrag.experimental.pipeline.config.base import AbstractConfig from neo4j_graphrag.experimental.pipeline.config.object_config import ( ComponentType, @@ -83,7 +84,7 @@ def validate_embedders( return embedders def _resolve_component_definition( - self, name: str, config: ComponentType + self, name: str, config: ComponentType[Component] ) -> ComponentDefinition: component = config.parse(self._global_data) if hasattr(config.root, "run_params_"): @@ -188,7 +189,7 @@ class PipelineConfig(AbstractPipelineConfig): """Configuration class for raw pipelines. Config must contain all components and connections.""" - component_config: dict[str, ComponentType] + component_config: dict[str, ComponentType[Component]] connection_config: list[ConnectionDefinition] template_: Literal[PipelineType.NONE] = PipelineType.NONE diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py index 306c4eb32..ac34c2403 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py @@ -23,7 +23,7 @@ OnError, ) from neo4j_graphrag.experimental.components.kg_writer import KGWriter, Neo4jWriter -from neo4j_graphrag.experimental.components.pdf_loader import PdfLoader +from neo4j_graphrag.experimental.components.pdf_loader import PdfLoader, DataLoader from neo4j_graphrag.experimental.components.resolver import ( EntityResolver, SinglePropertyExactMatchResolver, @@ -81,17 +81,21 @@ class SimpleKGPipelineConfig(TemplatePipelineConfig): lexical_graph_config: Optional[LexicalGraphConfig] = None neo4j_database: Optional[str] = None - pdf_loader: Optional[ComponentType] = None - kg_writer: Optional[ComponentType] = None - text_splitter: Optional[ComponentType] = None + pdf_loader: Optional[ComponentType[DataLoader]] = None + schema_builder: Optional[ComponentType[SchemaBuilder]] = None + text_splitter: Optional[ComponentType[TextSplitter]] = None + chunk_embedder: Optional[ComponentType[TextChunkEmbedder]] = None + extractor: Optional[ComponentType[EntityRelationExtractor]] = None + kg_writer: Optional[ComponentType[KGWriter]] = None + resolver: Optional[ComponentType[EntityResolver]] = None model_config = ConfigDict(arbitrary_types_allowed=True) - def _get_pdf_loader(self) -> Optional[PdfLoader]: + def _get_pdf_loader(self) -> Optional[DataLoader]: if not self.from_pdf: return None if self.pdf_loader: - return self.pdf_loader.parse(self._global_data) # type: ignore + return self.pdf_loader.parse(self._global_data) return PdfLoader() def _get_run_params_for_pdf_loader(self) -> dict[str, Any]: @@ -103,7 +107,7 @@ def _get_run_params_for_pdf_loader(self) -> dict[str, Any]: def _get_splitter(self) -> TextSplitter: if self.text_splitter: - return self.text_splitter.parse(self._global_data) # type: ignore + return self.text_splitter.parse(self._global_data) return FixedSizeSplitter() def _get_run_params_for_splitter(self) -> dict[str, Any]: @@ -112,9 +116,13 @@ def _get_run_params_for_splitter(self) -> dict[str, Any]: return {} def _get_chunk_embedder(self) -> TextChunkEmbedder: + if self.chunk_embedder: + return self.chunk_embedder.parse(self._global_data) return TextChunkEmbedder(embedder=self.get_default_embedder()) def _get_schema(self) -> SchemaBuilder: + if self.schema_builder: + return self.schema_builder.parse(self._global_data) return SchemaBuilder() def _get_run_params_for_schema(self) -> dict[str, Any]: @@ -125,6 +133,8 @@ def _get_run_params_for_schema(self) -> dict[str, Any]: } def _get_extractor(self) -> EntityRelationExtractor: + if self.extractor: + return self.extractor.parse(self._global_data) return LLMEntityRelationExtractor( llm=self.get_default_llm(), prompt_template=self.prompt_template, @@ -134,7 +144,7 @@ def _get_extractor(self) -> EntityRelationExtractor: def _get_writer(self) -> KGWriter: if self.kg_writer: - return self.kg_writer.parse(self._global_data) # type: ignore + return self.kg_writer.parse(self._global_data) return Neo4jWriter( driver=self.get_default_neo4j_driver(), neo4j_database=self.neo4j_database, @@ -148,6 +158,8 @@ def _get_run_params_for_writer(self) -> dict[str, Any]: def _get_resolver(self) -> Optional[EntityResolver]: if not self.perform_entity_resolution: return None + if self.resolver: + return self.resolver.parse(self._global_data) return SinglePropertyExactMatchResolver( driver=self.get_default_neo4j_driver(), neo4j_database=self.neo4j_database, diff --git a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py index 6a809b766..a75f5314e 100644 --- a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py @@ -21,9 +21,15 @@ from pydantic import ValidationError from neo4j_graphrag.embeddings import Embedder -from neo4j_graphrag.experimental.components.entity_relation_extractor import OnError +from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder +from neo4j_graphrag.experimental.components.entity_relation_extractor import ( + OnError, + EntityRelationExtractor, +) from neo4j_graphrag.experimental.components.kg_writer import KGWriter from neo4j_graphrag.experimental.components.pdf_loader import DataLoader +from neo4j_graphrag.experimental.components.resolver import EntityResolver +from neo4j_graphrag.experimental.components.schema import SchemaBuilder from neo4j_graphrag.experimental.components.text_splitters.base import TextSplitter from neo4j_graphrag.experimental.components.types import LexicalGraphConfig from neo4j_graphrag.experimental.pipeline.config.object_config import ComponentType @@ -82,9 +88,13 @@ def __init__( relations: Optional[Sequence[RelationInputType]] = None, potential_schema: Optional[List[tuple[str, str, str]]] = None, from_pdf: bool = True, - text_splitter: Optional[TextSplitter] = None, pdf_loader: Optional[DataLoader] = None, + schema_builder: Optional[SchemaBuilder] = None, + text_splitter: Optional[TextSplitter] = None, + chunk_embedder: Optional[TextChunkEmbedder] = None, + extractor: Optional[EntityRelationExtractor] = None, kg_writer: Optional[KGWriter] = None, + resolver: Optional[EntityResolver] = None, on_error: str = "IGNORE", prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate(), perform_entity_resolution: bool = True, @@ -102,8 +112,16 @@ def __init__( potential_schema=potential_schema, from_pdf=from_pdf, pdf_loader=ComponentType(pdf_loader) if pdf_loader else None, - kg_writer=ComponentType(kg_writer) if kg_writer else None, + schema_builder=ComponentType(schema_builder) + if schema_builder + else None, text_splitter=ComponentType(text_splitter) if text_splitter else None, + chunk_embedder=ComponentType(chunk_embedder) + if chunk_embedder + else None, + extractor=ComponentType(extractor) if extractor else None, + kg_writer=ComponentType(kg_writer) if kg_writer else None, + resolver=ComponentType(resolver) if resolver else None, on_error=OnError(on_error), prompt_template=prompt_template, perform_entity_resolution=perform_entity_resolution, diff --git a/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py b/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_pipeline_config.py similarity index 78% rename from tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py rename to tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_pipeline_config.py index ef0365849..3cfc8e113 100644 --- a/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py +++ b/tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_pipeline_config.py @@ -24,6 +24,9 @@ ) from neo4j_graphrag.experimental.components.kg_writer import Neo4jWriter from neo4j_graphrag.experimental.components.pdf_loader import PdfLoader +from neo4j_graphrag.experimental.components.resolver import ( + SinglePropertyExactMatchResolver, +) from neo4j_graphrag.experimental.components.schema import ( SchemaBuilder, SchemaEntity, @@ -71,7 +74,7 @@ def test_simple_kg_pipeline_config_pdf_loader_class_overwrite_but_from_pdf_is_fa def test_simple_kg_pipeline_config_pdf_loader_from_pdf_is_true_class_overwrite_from_config( mock_component_parse: Mock, ) -> None: - my_pdf_loader_config = ComponentConfig( + my_pdf_loader_config: ComponentConfig[PdfLoader] = ComponentConfig( class_="", ) my_pdf_loader = PdfLoader() @@ -92,7 +95,7 @@ def test_simple_kg_pipeline_config_text_splitter() -> None: def test_simple_kg_pipeline_config_text_splitter_overwrite( mock_component_parse: Mock, ) -> None: - my_text_splitter_config = ComponentConfig( + my_text_splitter_config: ComponentConfig[FixedSizeSplitter] = ComponentConfig( class_="", ) my_text_splitter = FixedSizeSplitter() @@ -152,6 +155,28 @@ def test_simple_kg_pipeline_config_extractor(mock_llm: Mock, llm: LLMInterface) assert extractor.prompt_template.template == "my template {text}" +@patch( + "neo4j_graphrag.experimental.pipeline.config.template_pipeline.simple_kg_builder.SimpleKGPipelineConfig.get_default_llm" +) +@patch("neo4j_graphrag.experimental.pipeline.config.object_config.ComponentType.parse") +def test_simple_kg_pipeline_config_extractor_overwrite( + mock_component_parse: Mock, mock_llm: Mock +) -> None: + my_extractor = LLMEntityRelationExtractor(llm=mock_llm) + mock_component_parse.return_value = my_extractor + config = SimpleKGPipelineConfig( + on_error="IGNORE", # type: ignore + prompt_template=ERExtractionTemplate(template="my template {text}"), + extractor={}, # type: ignore + ) + extractor = config._get_extractor() + assert isinstance(extractor, LLMEntityRelationExtractor) + assert extractor.llm == mock_llm + # default values are not overwritten by the parameters: + assert extractor.on_error == OnError.RAISE + assert extractor.prompt_template.template == ERExtractionTemplate.DEFAULT_TEMPLATE + + @patch( "neo4j_graphrag.experimental.components.kg_writer.get_version", return_value=((5, 23, 0), False, False), @@ -184,13 +209,10 @@ def test_simple_kg_pipeline_config_writer_overwrite( _: Mock, driver: neo4j.Driver, ) -> None: - my_writer_config = ComponentConfig( - class_="", - ) my_writer = Neo4jWriter(driver, neo4j_database="my_db") mock_component_parse.return_value = my_writer config = SimpleKGPipelineConfig( - kg_writer=my_writer_config, # type: ignore + kg_writer={}, # type: ignore neo4j_database="my_other_db", ) writer: Neo4jWriter = config._get_writer() # type: ignore @@ -199,6 +221,42 @@ def test_simple_kg_pipeline_config_writer_overwrite( assert writer.neo4j_database == "my_db" +@patch( + "neo4j_graphrag.experimental.pipeline.config.template_pipeline.simple_kg_builder.SimpleKGPipelineConfig.get_default_llm" +) +@patch("neo4j_graphrag.experimental.pipeline.config.object_config.ComponentType.parse") +def test_simple_kg_pipeline_config_resolver_overwrite( + mock_component_parse: Mock, driver: neo4j.Driver +) -> None: + my_resolver = SinglePropertyExactMatchResolver(driver, resolve_property="full_name") + mock_component_parse.return_value = my_resolver + config = SimpleKGPipelineConfig( + perform_entity_resolution=True, + resolver={}, # type: ignore + ) + resolver = config._get_resolver() + assert isinstance(resolver, SinglePropertyExactMatchResolver) + assert resolver.driver == driver + assert resolver.resolve_property == "full_name" + + +@patch( + "neo4j_graphrag.experimental.pipeline.config.template_pipeline.simple_kg_builder.SimpleKGPipelineConfig.get_default_llm" +) +@patch("neo4j_graphrag.experimental.pipeline.config.object_config.ComponentType.parse") +def test_simple_kg_pipeline_config_resolver_overwrite_but_disabled( + mock_component_parse: Mock, driver: neo4j.Driver +) -> None: + my_resolver = SinglePropertyExactMatchResolver(driver, resolve_property="full_name") + mock_component_parse.return_value = my_resolver + config = SimpleKGPipelineConfig( + perform_entity_resolution=False, + resolver={}, # type: ignore + ) + resolver = config._get_resolver() + assert resolver is None + + def test_simple_kg_pipeline_config_connections_from_pdf() -> None: config = SimpleKGPipelineConfig( from_pdf=True, @@ -234,7 +292,7 @@ def test_simple_kg_pipeline_config_connections_from_text() -> None: assert (actual.start, actual.end) == expected -def test_simple_kg_pipeline_config_connections_with_er() -> None: +def test_simple_kg_pipeline_config_connections_with_entity_resolution() -> None: config = SimpleKGPipelineConfig( from_pdf=True, perform_entity_resolution=True, diff --git a/tests/unit/experimental/pipeline/config/test_pipeline_config.py b/tests/unit/experimental/pipeline/config/test_pipeline_config.py index 7ec24fc3b..6867a15cc 100644 --- a/tests/unit/experimental/pipeline/config/test_pipeline_config.py +++ b/tests/unit/experimental/pipeline/config/test_pipeline_config.py @@ -345,7 +345,7 @@ def test_abstract_pipeline_config_resolve_component_definition_no_run_params( ) -> None: mock_component_parse.return_value = component config = AbstractPipelineConfig() - component_type = ComponentType(component) + component_type: ComponentType[Component] = ComponentType(component) component_definition = config._resolve_component_definition("name", component_type) assert isinstance(component_definition, ComponentDefinition) mock_component_parse.assert_called_once_with({}) @@ -366,7 +366,7 @@ def test_abstract_pipeline_config_resolve_component_definition_with_run_params( mock_component_parse.return_value = component mock_resolve_params.return_value = {"param": "resolver param result"} config = AbstractPipelineConfig() - component_type = ComponentType( + component_type: ComponentType[Component] = ComponentType( ComponentConfig(class_="", params_={}, run_params_={"param1": "value1"}) ) component_definition = config._resolve_component_definition("name", component_type) diff --git a/tests/unit/experimental/pipeline/test_kg_builder.py b/tests/unit/experimental/pipeline/test_simple_kg_pipeline.py similarity index 94% rename from tests/unit/experimental/pipeline/test_kg_builder.py rename to tests/unit/experimental/pipeline/test_simple_kg_pipeline.py index d95abeed3..87ae59308 100644 --- a/tests/unit/experimental/pipeline/test_kg_builder.py +++ b/tests/unit/experimental/pipeline/test_simple_kg_pipeline.py @@ -34,7 +34,7 @@ return_value=((5, 23, 0), False, False), ) @pytest.mark.asyncio -async def test_knowledge_graph_builder_document_info_with_file(_: Mock) -> None: +async def test_simple_kg_pipeline_document_info_with_file(_: Mock) -> None: llm = MagicMock(spec=LLMInterface) driver = MagicMock(spec=neo4j.Driver) embedder = MagicMock(spec=Embedder) @@ -66,7 +66,7 @@ async def test_knowledge_graph_builder_document_info_with_file(_: Mock) -> None: return_value=((5, 23, 0), False, False), ) @pytest.mark.asyncio -async def test_knowledge_graph_builder_document_info_with_text(_: Mock) -> None: +async def test_simple_kg_pipeline_document_info_with_text(_: Mock) -> None: llm = MagicMock(spec=LLMInterface) driver = MagicMock(spec=neo4j.Driver) embedder = MagicMock(spec=Embedder) @@ -97,7 +97,7 @@ async def test_knowledge_graph_builder_document_info_with_text(_: Mock) -> None: return_value=((5, 23, 0), False, False), ) @pytest.mark.asyncio -async def test_knowledge_graph_builder_with_entities_and_file(_: Mock) -> None: +async def test_simple_kg_pipeline_with_entities_and_file(_: Mock) -> None: llm = MagicMock(spec=LLMInterface) driver = MagicMock(spec=neo4j.Driver) embedder = MagicMock(spec=Embedder) @@ -156,7 +156,7 @@ def test_simple_kg_pipeline_on_error_invalid_value() -> None: return_value=((5, 23, 0), False, False), ) @pytest.mark.asyncio -async def test_knowledge_graph_builder_with_lexical_graph_config(_: Mock) -> None: +async def test_simple_kg_pipeline_with_lexical_graph_config(_: Mock) -> None: llm = MagicMock(spec=LLMInterface) driver = MagicMock(spec=neo4j.Driver) embedder = MagicMock(spec=Embedder)