From cdcd2fb725e0e5998387d9165dfefacde4920601 Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 11 Mar 2025 14:31:12 +0100 Subject: [PATCH 01/12] Expose context from the orchestrator to the components --- .../simple_kg_builder_from_text.py | 2 +- .../pipeline_with_component_notifications.py | 108 ++++++++++++++++ .../pipeline/pipeline_with_notifications.py | 2 +- .../experimental/components/schema.py | 2 +- .../experimental/pipeline/component.py | 64 +++++---- .../pipeline/config/pipeline_config.py | 2 +- .../experimental/pipeline/config/runner.py | 2 +- .../pipeline/config/template_pipeline/base.py | 2 +- .../template_pipeline/simple_kg_builder.py | 4 +- .../experimental/pipeline/kg_builder.py | 2 +- .../experimental/pipeline/notification.py | 79 ++++++++++-- .../experimental/pipeline/orchestrator.py | 20 ++- .../experimental/pipeline/pipeline.py | 22 +++- .../experimental/pipeline/types.py | 122 ------------------ .../experimental/pipeline/types/__init__.py | 0 .../experimental/pipeline/types/context.py | 37 ++++++ .../pipeline/types/definitions.py | 46 +++++++ .../pipeline/types/orchestration.py | 46 +++++++ .../experimental/pipeline/types/schema.py | 26 ++++ .../pipeline/config/test_pipeline_config.py | 2 +- .../pipeline/config/test_runner.py | 2 +- .../pipeline/test_orchestrator.py | 2 +- .../experimental/pipeline/test_pipeline.py | 16 +-- 23 files changed, 422 insertions(+), 188 deletions(-) create mode 100644 examples/customize/build_graph/pipeline/pipeline_with_component_notifications.py delete mode 100644 src/neo4j_graphrag/experimental/pipeline/types.py create mode 100644 src/neo4j_graphrag/experimental/pipeline/types/__init__.py create mode 100644 src/neo4j_graphrag/experimental/pipeline/types/context.py create mode 100644 src/neo4j_graphrag/experimental/pipeline/types/definitions.py create mode 100644 src/neo4j_graphrag/experimental/pipeline/types/orchestration.py create mode 100644 src/neo4j_graphrag/experimental/pipeline/types/schema.py diff --git a/examples/build_graph/simple_kg_builder_from_text.py b/examples/build_graph/simple_kg_builder_from_text.py index 29a5cfc5b..79b8c8791 100644 --- a/examples/build_graph/simple_kg_builder_from_text.py +++ b/examples/build_graph/simple_kg_builder_from_text.py @@ -14,7 +14,7 @@ from neo4j_graphrag.embeddings import OpenAIEmbeddings from neo4j_graphrag.experimental.pipeline.kg_builder import SimpleKGPipeline from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult -from neo4j_graphrag.experimental.pipeline.types import ( +from neo4j_graphrag.experimental.pipeline.types.schema import ( EntityInputType, RelationInputType, ) diff --git a/examples/customize/build_graph/pipeline/pipeline_with_component_notifications.py b/examples/customize/build_graph/pipeline/pipeline_with_component_notifications.py new file mode 100644 index 000000000..8b75d45fc --- /dev/null +++ b/examples/customize/build_graph/pipeline/pipeline_with_component_notifications.py @@ -0,0 +1,108 @@ +"""This example demonstrates how to use event callback to receive notifications +about the component progress. +""" + +from __future__ import annotations + +import asyncio +import logging +from typing import Any + +from neo4j_graphrag.experimental.pipeline import Pipeline, Component, DataModel +from neo4j_graphrag.experimental.pipeline.notification import Event, EventType +from neo4j_graphrag.experimental.pipeline.types.context import RunContext + +logger = logging.getLogger(__name__) +logging.basicConfig() +logger.setLevel(logging.INFO) + + +class BatchComponentResult(DataModel): + result: list[int] + + +class MultiplicationComponent(Component): + def __init__(self, f: int) -> None: + self.f = f + + async def run(self, numbers: list[int]) -> BatchComponentResult: + return BatchComponentResult(result=[]) + + async def multiply_number( + self, context_: RunContext, number: int, + ) -> int: + await context_.notify( + message=f"Processing number {number}", + data={"number_processed": number}, + ) + return self.f * number + + async def run_with_context( + self, + context_: RunContext, + numbers: list[int], + **kwargs: Any, + ) -> BatchComponentResult: + result = await asyncio.gather( + *[ + self.multiply_number( + context_, + number, + ) + for number in numbers + ] + ) + return BatchComponentResult(result=result) + + +async def event_handler(event: Event) -> None: + """Function can do anything about the event, + here we're just logging it if it's a pipeline-level event. + """ + if event.event_type == EventType.TASK_PROGRESS: + logger.warning(event) + else: + logger.info(event) + + +async def main() -> None: + """ """ + pipe = Pipeline( + callback=event_handler, + ) + # define the components + pipe.add_component( + MultiplicationComponent(f=2), + "multiply_by_2", + ) + pipe.add_component( + MultiplicationComponent(f=10), + "multiply_by_10", + ) + # define the execution order of component + # and how the output of previous components must be used + pipe.connect( + "multiply_by_2", + "multiply_by_10", + input_config={"numbers": "multiply_by_2.result"}, + ) + # user input: + pipe_inputs_1 = { + "multiply_by_2": { + "numbers": [1, 2, 5, 4], + }, + } + pipe_inputs_2 = { + "multiply_by_2": { + "numbers": [3, 10, 1], + } + } + # run the pipeline + await asyncio.gather( + pipe.run(pipe_inputs_1), + pipe.run(pipe_inputs_2), + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/customize/build_graph/pipeline/pipeline_with_notifications.py b/examples/customize/build_graph/pipeline/pipeline_with_notifications.py index db30ddd61..88939d2ab 100644 --- a/examples/customize/build_graph/pipeline/pipeline_with_notifications.py +++ b/examples/customize/build_graph/pipeline/pipeline_with_notifications.py @@ -15,7 +15,7 @@ ) from neo4j_graphrag.experimental.pipeline import Pipeline from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult -from neo4j_graphrag.experimental.pipeline.types import Event +from neo4j_graphrag.experimental.pipeline.notification import Event logger = logging.getLogger(__name__) logging.basicConfig() diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index 0f1185457..96d7d466b 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -21,7 +21,7 @@ from neo4j_graphrag.exceptions import SchemaValidationError from neo4j_graphrag.experimental.pipeline.component import Component, DataModel -from neo4j_graphrag.experimental.pipeline.types import ( +from neo4j_graphrag.experimental.pipeline.types.schema import ( EntityInputType, RelationInputType, ) diff --git a/src/neo4j_graphrag/experimental/pipeline/component.py b/src/neo4j_graphrag/experimental/pipeline/component.py index 90877738f..ff3953c3a 100644 --- a/src/neo4j_graphrag/experimental/pipeline/component.py +++ b/src/neo4j_graphrag/experimental/pipeline/component.py @@ -20,6 +20,7 @@ from pydantic import BaseModel +from neo4j_graphrag.experimental.pipeline.types.context import RunContext from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError from neo4j_graphrag.utils.validation import issubclass_safe @@ -36,34 +37,39 @@ def __new__( ) -> type: # extract required inputs and outputs from the run method signature run_method = attrs.get("run") - if run_method is not None: - sig = inspect.signature(run_method) - attrs["component_inputs"] = { - param.name: { - "has_default": param.default != inspect.Parameter.empty, - "annotation": param.annotation, - } - for param in sig.parameters.values() - if param.name not in ("self", "kwargs") + run_context_method = attrs.get("run_with_context") + run = run_context_method or run_method + if run is None: + raise RuntimeError( + f"Either 'run' or 'run_with_context' must be implemented in component: '{name}'" + ) + sig = inspect.signature(run) + attrs["component_inputs"] = { + param.name: { + "has_default": param.default != inspect.Parameter.empty, + "annotation": param.annotation, } - # extract returned fields from the run method return type hint - return_model = get_type_hints(run_method).get("return") - if return_model is None: - raise PipelineDefinitionError( - f"The run method return type must be annotated in {name}" - ) - # the type hint must be a subclass of DataModel - if not issubclass_safe(return_model, DataModel): - raise PipelineDefinitionError( - f"The run method must return a subclass of DataModel in {name}" - ) - attrs["component_outputs"] = { - f: { - "has_default": field.is_required(), - "annotation": field.annotation, - } - for f, field in return_model.model_fields.items() + for param in sig.parameters.values() + if param.name not in ("self", "kwargs", "context_") + } + # extract returned fields from the run method return type hint + return_model = get_type_hints(run).get("return") + if return_model is None: + raise PipelineDefinitionError( + f"The run method return type must be annotated in {name}" + ) + # the type hint must be a subclass of DataModel + if not issubclass_safe(return_model, DataModel): + raise PipelineDefinitionError( + f"The run method must return a subclass of DataModel in {name}" + ) + attrs["component_outputs"] = { + f: { + "has_default": field.is_required(), + "annotation": field.annotation, } + for f, field in return_model.model_fields.items() + } return type.__new__(meta, name, bases, attrs) @@ -81,3 +87,9 @@ class Component(abc.ABC, metaclass=ComponentMeta): @abc.abstractmethod async def run(self, *args: Any, **kwargs: Any) -> DataModel: pass + + async def run_with_context( + self, context_: RunContext, *args: Any, **kwargs: Any + ) -> DataModel: + # default behavior to prevent a breaking change + return await self.run(*args, **kwargs) diff --git a/src/neo4j_graphrag/experimental/pipeline/config/pipeline_config.py b/src/neo4j_graphrag/experimental/pipeline/config/pipeline_config.py index 2f393f6e9..92f9968f1 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/pipeline_config.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/pipeline_config.py @@ -31,7 +31,7 @@ ParamConfig, ) from neo4j_graphrag.experimental.pipeline.config.types import PipelineType -from neo4j_graphrag.experimental.pipeline.types import ( +from neo4j_graphrag.experimental.pipeline.types.definitions import ( ComponentDefinition, ConnectionDefinition, PipelineDefinition, diff --git a/src/neo4j_graphrag/experimental/pipeline/config/runner.py b/src/neo4j_graphrag/experimental/pipeline/config/runner.py index c1bef9fee..82e8175a9 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/runner.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/runner.py @@ -47,7 +47,7 @@ ) from neo4j_graphrag.experimental.pipeline.config.types import PipelineType from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult -from neo4j_graphrag.experimental.pipeline.types import PipelineDefinition +from neo4j_graphrag.experimental.pipeline.types.definitions import PipelineDefinition from neo4j_graphrag.utils.logging import prettify logger = logging.getLogger(__name__) diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/base.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/base.py index 69fbc7511..fa0c5f632 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/base.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/base.py @@ -18,7 +18,7 @@ from neo4j_graphrag.experimental.pipeline.config.pipeline_config import ( AbstractPipelineConfig, ) -from neo4j_graphrag.experimental.pipeline.types import ComponentDefinition +from neo4j_graphrag.experimental.pipeline.types.definitions import ComponentDefinition logger = logging.getLogger(__name__) diff --git a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py index 02266eded..306c4eb32 100644 --- a/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py @@ -47,8 +47,8 @@ ) from neo4j_graphrag.experimental.pipeline.config.types import PipelineType from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError -from neo4j_graphrag.experimental.pipeline.types import ( - ConnectionDefinition, +from neo4j_graphrag.experimental.pipeline.types.definitions import ConnectionDefinition +from neo4j_graphrag.experimental.pipeline.types.schema import ( EntityInputType, RelationInputType, ) diff --git a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py index 8a3e15041..6a809b766 100644 --- a/src/neo4j_graphrag/experimental/pipeline/kg_builder.py +++ b/src/neo4j_graphrag/experimental/pipeline/kg_builder.py @@ -33,7 +33,7 @@ ) from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError from neo4j_graphrag.experimental.pipeline.pipeline import PipelineResult -from neo4j_graphrag.experimental.pipeline.types import ( +from neo4j_graphrag.experimental.pipeline.types.schema import ( EntityInputType, RelationInputType, ) diff --git a/src/neo4j_graphrag/experimental/pipeline/notification.py b/src/neo4j_graphrag/experimental/pipeline/notification.py index f0693cabd..e9cb63cc6 100644 --- a/src/neo4j_graphrag/experimental/pipeline/notification.py +++ b/src/neo4j_graphrag/experimental/pipeline/notification.py @@ -14,16 +14,61 @@ # limitations under the License. from __future__ import annotations -from typing import Any, Optional +import datetime +import enum +from collections.abc import Awaitable +from pydantic import BaseModel, Field -from neo4j_graphrag.experimental.pipeline.types import ( - Event, - EventCallbackProtocol, - EventType, - PipelineEvent, - RunResult, - TaskEvent, -) +from typing import Any, Optional, Protocol, TYPE_CHECKING + +if TYPE_CHECKING: + from neo4j_graphrag.experimental.pipeline.types.orchestration import RunResult + + +class EventType(enum.Enum): + PIPELINE_STARTED = "PIPELINE_STARTED" + TASK_STARTED = "TASK_STARTED" + TASK_PROGRESS = "TASK_PROGRESS" + TASK_FINISHED = "TASK_FINISHED" + PIPELINE_FINISHED = "PIPELINE_FINISHED" + + @property + def is_pipeline_event(self) -> bool: + return self in [EventType.PIPELINE_STARTED, EventType.PIPELINE_FINISHED] + + @property + def is_task_event(self) -> bool: + return self in [ + EventType.TASK_STARTED, + EventType.TASK_PROGRESS, + EventType.TASK_FINISHED, + ] + + +class Event(BaseModel): + event_type: EventType + run_id: str + """Pipeline unique run_id, same as the one returned in PipelineResult after pipeline.run""" + timestamp: datetime.datetime = Field( + default_factory=lambda: datetime.datetime.now(datetime.timezone.utc) + ) + message: Optional[str] = None + """Optional information about the status""" + payload: Optional[dict[str, Any]] = None + """Input or output data depending on the type of event""" + + +class PipelineEvent(Event): + pass + + +class TaskEvent(Event): + task_name: str + """Name of the task as defined in pipeline.add_component""" + + +class EventCallbackProtocol(Protocol): + def __call__(self, event: Event) -> Awaitable[None]: ... class EventNotifier: @@ -87,3 +132,19 @@ async def notify_task_finished( else None, ) await self.notify(event) + + async def notify_task_progress( + self, + run_id: str, + task_name: str, + message: str, + data: dict[str, Any], + ) -> None: + event = TaskEvent( + event_type=EventType.TASK_PROGRESS, + run_id=run_id, + task_name=task_name, + message=message, + payload=data, + ) + await self.notify(event) diff --git a/src/neo4j_graphrag/experimental/pipeline/orchestrator.py b/src/neo4j_graphrag/experimental/pipeline/orchestrator.py index 6c218784e..fd933470a 100644 --- a/src/neo4j_graphrag/experimental/pipeline/orchestrator.py +++ b/src/neo4j_graphrag/experimental/pipeline/orchestrator.py @@ -18,20 +18,24 @@ import logging import uuid import warnings +from functools import partial from typing import TYPE_CHECKING, Any, AsyncGenerator +from neo4j_graphrag.experimental.pipeline.types.context import RunContext from neo4j_graphrag.experimental.pipeline.exceptions import ( PipelineDefinitionError, PipelineMissingDependencyError, PipelineStatusUpdateError, ) from neo4j_graphrag.experimental.pipeline.notification import EventNotifier -from neo4j_graphrag.experimental.pipeline.types import RunResult, RunStatus +from neo4j_graphrag.experimental.pipeline.types.orchestration import ( + RunResult, + RunStatus, +) if TYPE_CHECKING: from neo4j_graphrag.experimental.pipeline.pipeline import Pipeline, TaskPipelineNode - logger = logging.getLogger(__name__) @@ -48,7 +52,7 @@ class Orchestrator: (checking that all dependencies are met), and run them. """ - def __init__(self, pipeline: "Pipeline"): + def __init__(self, pipeline: Pipeline): self.pipeline = pipeline self.event_notifier = EventNotifier(pipeline.callback) self.run_id = str(uuid.uuid4()) @@ -74,7 +78,15 @@ async def run_task(self, task: TaskPipelineNode, data: dict[str, Any]) -> None: ) return None await self.event_notifier.notify_task_started(self.run_id, task.name, inputs) - res = await task.run(inputs) + # create the notifier function for the component, with fixed + # run_id, task_name and event type: + notifier = partial( + self.event_notifier.notify_task_progress, + run_id=self.run_id, + task_name=task.name, + ) + context = RunContext(run_id=self.run_id, task_name=task.name, notifier=notifier) + res = await task.run(context, inputs) await self.set_task_status(task.name, RunStatus.DONE) await self.event_notifier.notify_task_finished(self.run_id, task.name, res) if res: diff --git a/src/neo4j_graphrag/experimental/pipeline/pipeline.py b/src/neo4j_graphrag/experimental/pipeline/pipeline.py index 77526785f..64b74ddc6 100644 --- a/src/neo4j_graphrag/experimental/pipeline/pipeline.py +++ b/src/neo4j_graphrag/experimental/pipeline/pipeline.py @@ -40,13 +40,15 @@ PipelineNode, ) from neo4j_graphrag.experimental.pipeline.stores import InMemoryStore, ResultStore -from neo4j_graphrag.experimental.pipeline.types import ( +from neo4j_graphrag.experimental.pipeline.types.definitions import ( ComponentDefinition, ConnectionDefinition, - EventCallbackProtocol, PipelineDefinition, - RunResult, ) +from neo4j_graphrag.experimental.pipeline.types.orchestration import RunResult +from neo4j_graphrag.experimental.pipeline.types.context import RunContext +from neo4j_graphrag.experimental.pipeline.notification import EventCallbackProtocol + logger = logging.getLogger(__name__) @@ -67,7 +69,9 @@ def __init__(self, name: str, component: Component): super().__init__(name, {}) self.component = component - async def execute(self, **kwargs: Any) -> RunResult | None: + async def execute( + self, context: RunContext, inputs: dict[str, Any] + ) -> RunResult | None: """Execute the task Returns: @@ -75,17 +79,21 @@ async def execute(self, **kwargs: Any) -> RunResult | None: if the task run successfully, None if the status update was unsuccessful. """ - component_result = await self.component.run(**kwargs) + component_result = await self.component.run_with_context( + context_=context, **inputs + ) run_result = RunResult( result=component_result, ) return run_result - async def run(self, inputs: dict[str, Any]) -> RunResult | None: + async def run( + self, context: RunContext, inputs: dict[str, Any] + ) -> RunResult | None: """Main method to execute the task.""" logger.debug(f"TASK START {self.name=} input={prettify(inputs)}") start_time = default_timer() - res = await self.execute(**inputs) + res = await self.execute(context, inputs) end_time = default_timer() logger.debug( f"TASK FINISHED {self.name} in {end_time - start_time} res={prettify(res)}" diff --git a/src/neo4j_graphrag/experimental/pipeline/types.py b/src/neo4j_graphrag/experimental/pipeline/types.py deleted file mode 100644 index f4f8267c7..000000000 --- a/src/neo4j_graphrag/experimental/pipeline/types.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright (c) "Neo4j" -# Neo4j Sweden AB [https://neo4j.com] -# # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# # -# https://www.apache.org/licenses/LICENSE-2.0 -# # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import annotations - -import datetime -import enum -from collections import defaultdict -from collections.abc import Awaitable -from typing import Any, Optional, Protocol, Union - -from pydantic import BaseModel, ConfigDict, Field - -from neo4j_graphrag.experimental.pipeline.component import Component, DataModel - - -class ComponentDefinition(BaseModel): - name: str - component: Component - run_params: dict[str, Any] = {} - - model_config = ConfigDict(arbitrary_types_allowed=True) - - -class ConnectionDefinition(BaseModel): - start: str - end: str - input_config: dict[str, str] - - -class PipelineDefinition(BaseModel): - components: list[ComponentDefinition] - connections: list[ConnectionDefinition] - - def get_run_params(self) -> defaultdict[str, dict[str, Any]]: - return defaultdict( - dict, {c.name: c.run_params for c in self.components if c.run_params} - ) - - -class RunStatus(enum.Enum): - UNKNOWN = "UNKNOWN" - RUNNING = "RUNNING" - DONE = "DONE" - - def possible_next_status(self) -> list[RunStatus]: - if self == RunStatus.UNKNOWN: - return [RunStatus.RUNNING] - if self == RunStatus.RUNNING: - return [RunStatus.DONE] - if self == RunStatus.DONE: - return [] - return [] - - -class RunResult(BaseModel): - status: RunStatus = RunStatus.DONE - result: Optional[DataModel] = None - timestamp: datetime.datetime = Field( - default_factory=lambda: datetime.datetime.now(datetime.timezone.utc) - ) - - -class EventType(enum.Enum): - PIPELINE_STARTED = "PIPELINE_STARTED" - TASK_STARTED = "TASK_STARTED" - TASK_FINISHED = "TASK_FINISHED" - PIPELINE_FINISHED = "PIPELINE_FINISHED" - - @property - def is_pipeline_event(self) -> bool: - return self in [EventType.PIPELINE_STARTED, EventType.PIPELINE_FINISHED] - - @property - def is_task_event(self) -> bool: - return self in [EventType.TASK_STARTED, EventType.TASK_FINISHED] - - -class Event(BaseModel): - event_type: EventType - run_id: str - """Pipeline unique run_id, same as the one returned in PipelineResult after pipeline.run""" - timestamp: datetime.datetime = Field( - default_factory=lambda: datetime.datetime.now(datetime.timezone.utc) - ) - message: Optional[str] = None - """Optional information about the status""" - payload: Optional[dict[str, Any]] = None - """Input or output data depending on the type of event""" - - -class PipelineEvent(Event): - pass - - -class TaskEvent(Event): - task_name: str - """Name of the task as defined in pipeline.add_component""" - - -class EventCallbackProtocol(Protocol): - def __call__(self, event: Event) -> Awaitable[None]: ... - - -EntityInputType = Union[str, dict[str, Union[str, list[dict[str, str]]]]] -RelationInputType = Union[str, dict[str, Union[str, list[dict[str, str]]]]] -"""Types derived from the SchemaEntity and SchemaRelation types, - so the possible types for dict values are: -- str (for label and description) -- list[dict[str, str]] (for properties) -""" diff --git a/src/neo4j_graphrag/experimental/pipeline/types/__init__.py b/src/neo4j_graphrag/experimental/pipeline/types/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/neo4j_graphrag/experimental/pipeline/types/context.py b/src/neo4j_graphrag/experimental/pipeline/types/context.py new file mode 100644 index 000000000..f0b4caf97 --- /dev/null +++ b/src/neo4j_graphrag/experimental/pipeline/types/context.py @@ -0,0 +1,37 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pydantic import BaseModel, ConfigDict +from collections.abc import Awaitable + +from typing import Any, Optional, Protocol, runtime_checkable + + +@runtime_checkable +class TaskProgressCallbackProtocol(Protocol): + def __call__(self, message: str, data: dict[str, Any]) -> Awaitable[None]: ... + + +class RunContext(BaseModel): + """Context passed to the component""" + + run_id: str + task_name: str + notifier: Optional[TaskProgressCallbackProtocol] = None + + model_config = ConfigDict(arbitrary_types_allowed=True) + + async def notify(self, message: str, data: dict[str, Any]) -> None: + if self.notifier: + await self.notifier(message=message, data=data) diff --git a/src/neo4j_graphrag/experimental/pipeline/types/definitions.py b/src/neo4j_graphrag/experimental/pipeline/types/definitions.py new file mode 100644 index 000000000..68ae403b4 --- /dev/null +++ b/src/neo4j_graphrag/experimental/pipeline/types/definitions.py @@ -0,0 +1,46 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from collections import defaultdict +from typing import Any + +from pydantic import BaseModel, ConfigDict + +from neo4j_graphrag.experimental.pipeline.component import Component + + +class ComponentDefinition(BaseModel): + name: str + component: Component + run_params: dict[str, Any] = {} + + model_config = ConfigDict(arbitrary_types_allowed=True) + + +class ConnectionDefinition(BaseModel): + start: str + end: str + input_config: dict[str, str] + + +class PipelineDefinition(BaseModel): + components: list[ComponentDefinition] + connections: list[ConnectionDefinition] + + def get_run_params(self) -> defaultdict[str, dict[str, Any]]: + return defaultdict( + dict, {c.name: c.run_params for c in self.components if c.run_params} + ) diff --git a/src/neo4j_graphrag/experimental/pipeline/types/orchestration.py b/src/neo4j_graphrag/experimental/pipeline/types/orchestration.py new file mode 100644 index 000000000..1e1497ae3 --- /dev/null +++ b/src/neo4j_graphrag/experimental/pipeline/types/orchestration.py @@ -0,0 +1,46 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import datetime +import enum +from typing import Optional + +from pydantic import BaseModel, Field + +from neo4j_graphrag.experimental.pipeline.component import DataModel + + +class RunStatus(enum.Enum): + UNKNOWN = "UNKNOWN" + RUNNING = "RUNNING" + DONE = "DONE" + + def possible_next_status(self) -> list[RunStatus]: + if self == RunStatus.UNKNOWN: + return [RunStatus.RUNNING] + if self == RunStatus.RUNNING: + return [RunStatus.DONE] + if self == RunStatus.DONE: + return [] + return [] + + +class RunResult(BaseModel): + status: RunStatus = RunStatus.DONE + result: Optional[DataModel] = None + timestamp: datetime.datetime = Field( + default_factory=lambda: datetime.datetime.now(datetime.timezone.utc) + ) diff --git a/src/neo4j_graphrag/experimental/pipeline/types/schema.py b/src/neo4j_graphrag/experimental/pipeline/types/schema.py new file mode 100644 index 000000000..626c99841 --- /dev/null +++ b/src/neo4j_graphrag/experimental/pipeline/types/schema.py @@ -0,0 +1,26 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# # +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# # +# https://www.apache.org/licenses/LICENSE-2.0 +# # +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import Union + + +EntityInputType = Union[str, dict[str, Union[str, list[dict[str, str]]]]] +RelationInputType = Union[str, dict[str, Union[str, list[dict[str, str]]]]] +"""Types derived from the SchemaEntity and SchemaRelation types, + so the possible types for dict values are: +- str (for label and description) +- list[dict[str, str]] (for properties) +""" diff --git a/tests/unit/experimental/pipeline/config/test_pipeline_config.py b/tests/unit/experimental/pipeline/config/test_pipeline_config.py index 4de5874bc..7ec24fc3b 100644 --- a/tests/unit/experimental/pipeline/config/test_pipeline_config.py +++ b/tests/unit/experimental/pipeline/config/test_pipeline_config.py @@ -30,7 +30,7 @@ from neo4j_graphrag.experimental.pipeline.config.pipeline_config import ( AbstractPipelineConfig, ) -from neo4j_graphrag.experimental.pipeline.types import ComponentDefinition +from neo4j_graphrag.experimental.pipeline.types.definitions import ComponentDefinition from neo4j_graphrag.llm import LLMInterface diff --git a/tests/unit/experimental/pipeline/config/test_runner.py b/tests/unit/experimental/pipeline/config/test_runner.py index 327b5221b..620796cb1 100644 --- a/tests/unit/experimental/pipeline/config/test_runner.py +++ b/tests/unit/experimental/pipeline/config/test_runner.py @@ -17,7 +17,7 @@ from neo4j_graphrag.experimental.pipeline import Pipeline from neo4j_graphrag.experimental.pipeline.config.pipeline_config import PipelineConfig from neo4j_graphrag.experimental.pipeline.config.runner import PipelineRunner -from neo4j_graphrag.experimental.pipeline.types import PipelineDefinition +from neo4j_graphrag.experimental.pipeline.types.definitions import PipelineDefinition @patch("neo4j_graphrag.experimental.pipeline.pipeline.Pipeline.from_definition") diff --git a/tests/unit/experimental/pipeline/test_orchestrator.py b/tests/unit/experimental/pipeline/test_orchestrator.py index a253abcdf..bb611da96 100644 --- a/tests/unit/experimental/pipeline/test_orchestrator.py +++ b/tests/unit/experimental/pipeline/test_orchestrator.py @@ -25,7 +25,7 @@ PipelineStatusUpdateError, ) from neo4j_graphrag.experimental.pipeline.orchestrator import Orchestrator -from neo4j_graphrag.experimental.pipeline.types import RunStatus +from neo4j_graphrag.experimental.pipeline.types.orchestration import RunStatus from tests.unit.experimental.pipeline.components import ( ComponentNoParam, diff --git a/tests/unit/experimental/pipeline/test_pipeline.py b/tests/unit/experimental/pipeline/test_pipeline.py index 0e3377547..16d32ba9b 100644 --- a/tests/unit/experimental/pipeline/test_pipeline.py +++ b/tests/unit/experimental/pipeline/test_pipeline.py @@ -24,14 +24,14 @@ import pytest from neo4j_graphrag.experimental.pipeline import Component, Pipeline from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError -from neo4j_graphrag.experimental.pipeline.types import ( +from neo4j_graphrag.experimental.pipeline.notification import ( EventCallbackProtocol, EventType, PipelineEvent, - RunResult, TaskEvent, Event, ) +from neo4j_graphrag.experimental.pipeline.types.orchestration import RunResult from .components import ( ComponentAdd, @@ -251,11 +251,11 @@ def test_pipeline_parameter_validation_full_missing_inputs() -> None: async def test_pipeline_branches() -> None: pipe = Pipeline() component_a = AsyncMock(spec=Component) - component_a.run = AsyncMock(return_value={}) + component_a.run_with_context = AsyncMock(return_value={}) component_b = AsyncMock(spec=Component) - component_b.run = AsyncMock(return_value={}) + component_b.run_with_context = AsyncMock(return_value={}) component_c = AsyncMock(spec=Component) - component_c.run = AsyncMock(return_value={}) + component_c.run_with_context = AsyncMock(return_value={}) pipe.add_component(component_a, "a") pipe.add_component(component_b, "b") @@ -272,11 +272,11 @@ async def test_pipeline_branches() -> None: async def test_pipeline_aggregation() -> None: pipe = Pipeline() component_a = AsyncMock(spec=Component) - component_a.run = AsyncMock(return_value={}) + component_a.run_with_context = AsyncMock(return_value={}) component_b = AsyncMock(spec=Component) - component_b.run = AsyncMock(return_value={}) + component_b.run_with_context = AsyncMock(return_value={}) component_c = AsyncMock(spec=Component) - component_c.run = AsyncMock(return_value={}) + component_c.run_with_context = AsyncMock(return_value={}) pipe.add_component( component_a, From 69b99dc9e15e7ef8321f1fa56e85c8be5c6afd60 Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 11 Mar 2025 14:34:03 +0100 Subject: [PATCH 02/12] Rename to match component name + ruff --- .../pipeline_with_component_notifications.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/examples/customize/build_graph/pipeline/pipeline_with_component_notifications.py b/examples/customize/build_graph/pipeline/pipeline_with_component_notifications.py index 8b75d45fc..6d6fe4d8e 100644 --- a/examples/customize/build_graph/pipeline/pipeline_with_component_notifications.py +++ b/examples/customize/build_graph/pipeline/pipeline_with_component_notifications.py @@ -17,7 +17,7 @@ logger.setLevel(logging.INFO) -class BatchComponentResult(DataModel): +class MultiplyComponentResult(DataModel): result: list[int] @@ -25,11 +25,13 @@ class MultiplicationComponent(Component): def __init__(self, f: int) -> None: self.f = f - async def run(self, numbers: list[int]) -> BatchComponentResult: - return BatchComponentResult(result=[]) + async def run(self, numbers: list[int]) -> MultiplyComponentResult: + return MultiplyComponentResult(result=[]) async def multiply_number( - self, context_: RunContext, number: int, + self, + context_: RunContext, + number: int, ) -> int: await context_.notify( message=f"Processing number {number}", @@ -42,7 +44,7 @@ async def run_with_context( context_: RunContext, numbers: list[int], **kwargs: Any, - ) -> BatchComponentResult: + ) -> MultiplyComponentResult: result = await asyncio.gather( *[ self.multiply_number( @@ -52,7 +54,7 @@ async def run_with_context( for number in numbers ] ) - return BatchComponentResult(result=result) + return MultiplyComponentResult(result=result) async def event_handler(event: Event) -> None: From 7cb54880036bffa0157714793fa6c3df7d67afca Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 11 Mar 2025 16:07:03 +0100 Subject: [PATCH 03/12] Add tests --- .../experimental/pipeline/component.py | 2 +- .../unit/experimental/pipeline/components.py | 11 +++++ .../experimental/pipeline/test_component.py | 45 ++++++++++++++++++- 3 files changed, 56 insertions(+), 2 deletions(-) diff --git a/src/neo4j_graphrag/experimental/pipeline/component.py b/src/neo4j_graphrag/experimental/pipeline/component.py index ff3953c3a..28ea42598 100644 --- a/src/neo4j_graphrag/experimental/pipeline/component.py +++ b/src/neo4j_graphrag/experimental/pipeline/component.py @@ -82,7 +82,7 @@ class Component(abc.ABC, metaclass=ComponentMeta): # added here for the type checker # DO NOT CHANGE component_inputs: dict[str, dict[str, str | bool]] - component_outputs: dict[str, dict[str, str | bool]] + component_outputs: dict[str, dict[str, str | bool | type]] @abc.abstractmethod async def run(self, *args: Any, **kwargs: Any) -> DataModel: diff --git a/tests/unit/experimental/pipeline/components.py b/tests/unit/experimental/pipeline/components.py index 5c3a54276..fc33e840d 100644 --- a/tests/unit/experimental/pipeline/components.py +++ b/tests/unit/experimental/pipeline/components.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from neo4j_graphrag.experimental.pipeline import Component, DataModel +from neo4j_graphrag.experimental.pipeline.types.context import RunContext class StringResultModel(DataModel): @@ -41,3 +42,13 @@ async def run(self, number1: int, number2: int) -> IntResultModel: class ComponentMultiply(Component): async def run(self, number1: int, number2: int = 2) -> IntResultModel: return IntResultModel(result=number1 * number2) + + +class ComponentMultiplyWithContext(Component): + async def run_with_context( + self, context_: RunContext, number1: int, number2: int = 2 + ) -> IntResultModel: + await context_.notify( + message="my message", data={"number1": number1, "number2": number2} + ) + return IntResultModel(result=number1 * number2) diff --git a/tests/unit/experimental/pipeline/test_component.py b/tests/unit/experimental/pipeline/test_component.py index 701554d79..3a178e110 100644 --- a/tests/unit/experimental/pipeline/test_component.py +++ b/tests/unit/experimental/pipeline/test_component.py @@ -12,7 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .components import ComponentMultiply +from unittest.mock import MagicMock, AsyncMock + +import pytest + +from neo4j_graphrag.experimental.pipeline.types.context import RunContext +from .components import ComponentMultiply, ComponentMultiplyWithContext, IntResultModel def test_component_inputs() -> None: @@ -26,3 +31,41 @@ def test_component_inputs() -> None: def test_component_outputs() -> None: outputs = ComponentMultiply.component_outputs assert "result" in outputs + assert outputs["result"]["has_default"] is True + assert outputs["result"]["annotation"] == int + + +@pytest.mark.asyncio +async def test_component_run() -> None: + c = ComponentMultiply() + result = await c.run(number1=1, number2=2) + assert isinstance(result, IntResultModel) + assert isinstance( + result.result, ComponentMultiply.component_outputs["result"]["annotation"] + ) + + +@pytest.mark.asyncio +async def test_component_run_with_context_default_implementation() -> None: + c = ComponentMultiply() + result = await c.run_with_context( + # context can not be null in the function signature, + # but it's ignored in this case + None, # type: ignore + number1=1, + number2=2, + ) + assert result.result == 2 + + +@pytest.mark.asyncio +async def test_component_run_with_context() -> None: + c = ComponentMultiplyWithContext() + notifier_mock = AsyncMock() + result = await c.run_with_context( + RunContext(run_id="run_id", task_name="task_name", notifier=notifier_mock), + number1=1, + number2=2, + ) + assert result.result == 2 + notifier_mock.assert_awaited_once() From 14bff0725df32393a796b1debd05d96329d12c4d Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 11 Mar 2025 16:09:46 +0100 Subject: [PATCH 04/12] Update changelog and example readme --- CHANGELOG.md | 1 + examples/README.md | 1 + 2 files changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 390b32e26..f03dc1e0f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ - Added optional schema enforcement as a validation layer after entity and relation extraction. - Introduced a linear hybrid search ranker for HybridRetriever and HybridCypherRetriever, allowing customizable ranking with an `alpha` parameter. - Introduced SearchQueryParseError for handling invalid Lucene query strings in HybridRetriever and HybridCypherRetriever. +- Components can now be called with the `run_with_context` method that gets an extra `context_` argument containing information about the pipeline it's run from: the `run_id`, `task_name` and a `notify` function that can be used to send `TASK_PROGRESS` events to the same callback as the pipeline events. ### Fixed diff --git a/examples/README.md b/examples/README.md index 3dd6ee65c..3a54589f5 100644 --- a/examples/README.md +++ b/examples/README.md @@ -103,6 +103,7 @@ are listed in [the last section of this file](#customize). - [Export lexical graph creation into another pipeline](./customize/build_graph/pipeline/text_to_lexical_graph_to_entity_graph_two_pipelines.py) - [Build pipeline from config file](customize/build_graph/pipeline/from_config_files/pipeline_from_config_file.py) - [Add event listener to get notification about Pipeline progress](./customize/build_graph/pipeline/pipeline_with_notifications.py) +- [Use component context to send notifications about Component progress](./customize/build_graph/pipeline/pipeline_with_component_notifications.py) #### Components From 056d699660a0a352476c6a242e6d296d605545fd Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 11 Mar 2025 16:15:15 +0100 Subject: [PATCH 05/12] unused import --- tests/unit/experimental/pipeline/test_component.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/experimental/pipeline/test_component.py b/tests/unit/experimental/pipeline/test_component.py index 3a178e110..80ca67166 100644 --- a/tests/unit/experimental/pipeline/test_component.py +++ b/tests/unit/experimental/pipeline/test_component.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import MagicMock, AsyncMock +from unittest.mock import AsyncMock import pytest From 8d9c97ba225884c93beff2584bf723a257a070ab Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 11 Mar 2025 16:21:26 +0100 Subject: [PATCH 06/12] mypy --- tests/unit/experimental/pipeline/components.py | 3 +++ tests/unit/experimental/pipeline/test_component.py | 8 ++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/unit/experimental/pipeline/components.py b/tests/unit/experimental/pipeline/components.py index fc33e840d..146b1fe52 100644 --- a/tests/unit/experimental/pipeline/components.py +++ b/tests/unit/experimental/pipeline/components.py @@ -45,6 +45,9 @@ async def run(self, number1: int, number2: int = 2) -> IntResultModel: class ComponentMultiplyWithContext(Component): + async def run(self, number1: int, number2: int) -> IntResultModel: + return IntResultModel(result=number1 * number2) + async def run_with_context( self, context_: RunContext, number1: int, number2: int = 2 ) -> IntResultModel: diff --git a/tests/unit/experimental/pipeline/test_component.py b/tests/unit/experimental/pipeline/test_component.py index 80ca67166..3abea74ab 100644 --- a/tests/unit/experimental/pipeline/test_component.py +++ b/tests/unit/experimental/pipeline/test_component.py @@ -41,7 +41,9 @@ async def test_component_run() -> None: result = await c.run(number1=1, number2=2) assert isinstance(result, IntResultModel) assert isinstance( - result.result, ComponentMultiply.component_outputs["result"]["annotation"] + result.result, + # we know this is a type and not a bool or str: + ComponentMultiply.component_outputs["result"]["annotation"], # type: ignore ) @@ -55,7 +57,9 @@ async def test_component_run_with_context_default_implementation() -> None: number1=1, number2=2, ) - assert result.result == 2 + # the type checker doesn't know about the type + # because the method is not re-declared + assert result.result == 2 # type: ignore @pytest.mark.asyncio From 6a55af4f2f68122e9605da76c4bbd3b26ccb0379 Mon Sep 17 00:00:00 2001 From: estelle Date: Thu, 13 Mar 2025 14:01:37 +0100 Subject: [PATCH 07/12] Changes so that the `run` method is not required anymore --- .../components/entity_relation_extractor.py | 6 ++-- .../experimental/components/resolver.py | 6 ++-- .../experimental/pipeline/component.py | 28 +++++++++++++------ .../unit/experimental/pipeline/components.py | 3 -- 4 files changed, 23 insertions(+), 20 deletions(-) diff --git a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py index 17d7832fc..f21a22fed 100644 --- a/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py +++ b/src/neo4j_graphrag/experimental/components/entity_relation_extractor.py @@ -14,7 +14,6 @@ # limitations under the License. from __future__ import annotations -import abc import asyncio import enum import json @@ -115,7 +114,7 @@ def fix_invalid_json(raw_json: str) -> str: return repaired_json -class EntityRelationExtractor(Component, abc.ABC): +class EntityRelationExtractor(Component): """Abstract class for entity relation extraction components. Args: @@ -133,7 +132,6 @@ def __init__( self.on_error = on_error self.create_lexical_graph = create_lexical_graph - @abc.abstractmethod async def run( self, chunks: TextChunks, @@ -141,7 +139,7 @@ async def run( lexical_graph_config: Optional[LexicalGraphConfig] = None, **kwargs: Any, ) -> Neo4jGraph: - pass + raise NotImplementedError() def update_ids( self, diff --git a/src/neo4j_graphrag/experimental/components/resolver.py b/src/neo4j_graphrag/experimental/components/resolver.py index a050ea35e..25a7d5e89 100644 --- a/src/neo4j_graphrag/experimental/components/resolver.py +++ b/src/neo4j_graphrag/experimental/components/resolver.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import abc from typing import Any, Optional import neo4j @@ -22,7 +21,7 @@ from neo4j_graphrag.utils import driver_config -class EntityResolver(Component, abc.ABC): +class EntityResolver(Component): """Entity resolution base class Args: @@ -38,9 +37,8 @@ def __init__( self.driver = driver_config.override_user_agent(driver) self.filter_query = filter_query - @abc.abstractmethod async def run(self, *args: Any, **kwargs: Any) -> ResolutionStats: - pass + raise NotImplementedError() class SinglePropertyExactMatchResolver(EntityResolver): diff --git a/src/neo4j_graphrag/experimental/pipeline/component.py b/src/neo4j_graphrag/experimental/pipeline/component.py index 28ea42598..02fdc34f8 100644 --- a/src/neo4j_graphrag/experimental/pipeline/component.py +++ b/src/neo4j_graphrag/experimental/pipeline/component.py @@ -14,7 +14,6 @@ # limitations under the License. from __future__ import annotations -import abc import inspect from typing import Any, get_type_hints @@ -31,7 +30,7 @@ class DataModel(BaseModel): pass -class ComponentMeta(abc.ABCMeta): +class ComponentMeta(type): def __new__( meta, name: str, bases: tuple[type, ...], attrs: dict[str, Any] ) -> type: @@ -39,10 +38,6 @@ def __new__( run_method = attrs.get("run") run_context_method = attrs.get("run_with_context") run = run_context_method or run_method - if run is None: - raise RuntimeError( - f"Either 'run' or 'run_with_context' must be implemented in component: '{name}'" - ) sig = inspect.signature(run) attrs["component_inputs"] = { param.name: { @@ -73,7 +68,7 @@ def __new__( return type.__new__(meta, name, bases, attrs) -class Component(abc.ABC, metaclass=ComponentMeta): +class Component(metaclass=ComponentMeta): """Interface that needs to be implemented by all components. """ @@ -84,12 +79,27 @@ class Component(abc.ABC, metaclass=ComponentMeta): component_inputs: dict[str, dict[str, str | bool]] component_outputs: dict[str, dict[str, str | bool | type]] - @abc.abstractmethod async def run(self, *args: Any, **kwargs: Any) -> DataModel: - pass + """This function is planned for deprecation in a future release. + + Note: if `run_with_context` is implemented, this method will not be used. + """ + raise NotImplementedError( + "You must implement the `run` or `run_with_context` method. " + "`run` method will be marked for deprecation in a future release." + ) async def run_with_context( self, context_: RunContext, *args: Any, **kwargs: Any ) -> DataModel: + """This method is called by the pipeline orchestrator. + The `context_` parameter contains information about + the pipeline run: the `run_id` and a `notify` function + that can be used to send events from the component to + the pipeline callback. + + For now, it defaults to calling the `run` method, but it + is meant to replace the `run` method in a future release. + """ # default behavior to prevent a breaking change return await self.run(*args, **kwargs) diff --git a/tests/unit/experimental/pipeline/components.py b/tests/unit/experimental/pipeline/components.py index 146b1fe52..fc33e840d 100644 --- a/tests/unit/experimental/pipeline/components.py +++ b/tests/unit/experimental/pipeline/components.py @@ -45,9 +45,6 @@ async def run(self, number1: int, number2: int = 2) -> IntResultModel: class ComponentMultiplyWithContext(Component): - async def run(self, number1: int, number2: int) -> IntResultModel: - return IntResultModel(result=number1 * number2) - async def run_with_context( self, context_: RunContext, number1: int, number2: int = 2 ) -> IntResultModel: From b2a750a9a0810e273f0b2bec2e8f02a83fc943ff Mon Sep 17 00:00:00 2001 From: estelle Date: Thu, 13 Mar 2025 14:07:26 +0100 Subject: [PATCH 08/12] Update documentation --- docs/source/api.rst | 8 ++++++-- docs/source/types.rst | 8 ++++---- docs/source/user_guide_pipeline.rst | 32 +++++++++++++++++++++++++++-- 3 files changed, 40 insertions(+), 8 deletions(-) diff --git a/docs/source/api.rst b/docs/source/api.rst index f27bf3af7..2ca19d9b2 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -9,13 +9,18 @@ API Documentation Components ********** +Component +========= + +.. autoclass:: neo4j_graphrag.experimental.pipeline.component.Component + :members: run, run_with_context + DataLoader ========== .. autoclass:: neo4j_graphrag.experimental.components.pdf_loader.DataLoader :members: run, get_document_metadata - PdfLoader ========= @@ -59,7 +64,6 @@ LexicalGraphBuilder :members: :exclude-members: component_inputs, component_outputs - Neo4jChunkReader ================ diff --git a/docs/source/types.rst b/docs/source/types.rst index 192bc4a0a..4b4ae7155 100644 --- a/docs/source/types.rst +++ b/docs/source/types.rst @@ -158,22 +158,22 @@ ParamFromEnvConfig EventType ========= -.. autoenum:: neo4j_graphrag.experimental.pipeline.types.EventType +.. autoenum:: neo4j_graphrag.experimental.pipeline.notification.EventType PipelineEvent ============== -.. autoclass:: neo4j_graphrag.experimental.pipeline.types.PipelineEvent +.. autoclass:: neo4j_graphrag.experimental.pipeline.notification.PipelineEvent TaskEvent ============== -.. autoclass:: neo4j_graphrag.experimental.pipeline.types.TaskEvent +.. autoclass:: neo4j_graphrag.experimental.pipeline.notification.TaskEvent EventCallbackProtocol ===================== -.. autoclass:: neo4j_graphrag.experimental.pipeline.types.EventCallbackProtocol +.. autoclass:: neo4j_graphrag.experimental.pipeline.notification.EventCallbackProtocol :members: __call__ diff --git a/docs/source/user_guide_pipeline.rst b/docs/source/user_guide_pipeline.rst index 5c9538a65..af66fde25 100644 --- a/docs/source/user_guide_pipeline.rst +++ b/docs/source/user_guide_pipeline.rst @@ -22,7 +22,7 @@ their own by following these steps: 1. Create a subclass of the Pydantic `neo4j_graphrag.experimental.pipeline.DataModel` to represent the data being returned by the component 2. Create a subclass of `neo4j_graphrag.experimental.pipeline.Component` -3. Create a run method in this new class and specify the required inputs and output model using the just created `DataModel` +3. Create a `run_with_context` method in this new class and specify the required inputs and output model using the just created `DataModel` 4. Implement the run method: it's an `async` method, allowing tasks to be parallelized and awaited within this method. An example is given below, where a `ComponentAdd` is created to add two numbers together and return @@ -31,12 +31,13 @@ the resulting sum: .. code:: python from neo4j_graphrag.experimental.pipeline import Component, DataModel + from neo4j_graphrag.experimental.pipeline.types.context import RunContext class IntResultModel(DataModel): result: int class ComponentAdd(Component): - async def run(self, number1: int, number2: int = 1) -> IntResultModel: + async def run_with_context(self, context_: RunContext, number1: int, number2: int = 1) -> IntResultModel: return IntResultModel(result = number1 + number2) Read more about :ref:`components-section` in the API Documentation. @@ -141,6 +142,7 @@ It is possible to add a callback to receive notification about pipeline progress - `PIPELINE_STARTED`, when pipeline starts - `PIPELINE_FINISHED`, when pipeline ends - `TASK_STARTED`, when a task starts +- `TASK_PROGRESS`, sent by each component (depends on component's implementation, see below) - `TASK_FINISHED`, when a task ends @@ -172,3 +174,29 @@ See :ref:`pipelineevent` and :ref:`taskevent` to see what is sent in each event # ... add components, connect them as usual await pipeline.run(...) + + +Send Events from Components +=========================== + +Components can send notifications about their progress using the `notify` function from +the `context_`: + +.. code:: python + + from neo4j_graphrag.experimental.pipeline import Component, DataModel + from neo4j_graphrag.experimental.pipeline.types.context import RunContext + + class IntResultModel(DataModel): + result: int + + class ComponentAdd(Component): + async def run_with_context(self, context_: RunContext, number1: int, number2: int = 1) -> IntResultModel: + for fake_iteration in range(10): + await context_.notify( + message=f"Starting iteration {fake_iteration} out of 10", + data={"iteration": fake_iteration, "total": 10} + ) + return IntResultModel(result = number1 + number2) + +This will send an `TASK_PROGRESS` event to the pipeline callback. From 0e349f88ebabeabfaeb83ea548e7e6b00ff4492e Mon Sep 17 00:00:00 2001 From: estelle Date: Thu, 13 Mar 2025 14:21:28 +0100 Subject: [PATCH 09/12] We still need to raise error if components do not have at least one of the two methods implemented --- .../experimental/pipeline/component.py | 6 +++++- tests/unit/experimental/pipeline/test_component.py | 14 ++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/neo4j_graphrag/experimental/pipeline/component.py b/src/neo4j_graphrag/experimental/pipeline/component.py index 02fdc34f8..b59b60b19 100644 --- a/src/neo4j_graphrag/experimental/pipeline/component.py +++ b/src/neo4j_graphrag/experimental/pipeline/component.py @@ -37,7 +37,11 @@ def __new__( # extract required inputs and outputs from the run method signature run_method = attrs.get("run") run_context_method = attrs.get("run_with_context") - run = run_context_method or run_method + run = run_context_method if run_context_method is not None else run_method + if run is None: + raise RuntimeError( + f"You must implement either `run` or `run_with_context` in Component '{name}'" + ) sig = inspect.signature(run) attrs["component_inputs"] = { param.name: { diff --git a/tests/unit/experimental/pipeline/test_component.py b/tests/unit/experimental/pipeline/test_component.py index 3abea74ab..ba39aebaf 100644 --- a/tests/unit/experimental/pipeline/test_component.py +++ b/tests/unit/experimental/pipeline/test_component.py @@ -16,6 +16,7 @@ import pytest +from neo4j_graphrag.experimental.pipeline import Component from neo4j_graphrag.experimental.pipeline.types.context import RunContext from .components import ComponentMultiply, ComponentMultiplyWithContext, IntResultModel @@ -73,3 +74,16 @@ async def test_component_run_with_context() -> None: ) assert result.result == 2 notifier_mock.assert_awaited_once() + + +def test_component_missing_method() -> None: + with pytest.raises(RuntimeError) as e: + + class WrongComponent(Component): + # we must have either run or run_with_context + pass + + assert ( + "You must implement either `run` or `run_with_context` in Component 'WrongComponent'" + in str(e) + ) From e7a0ff004e3e90c5c6a590cc16efa38568541d19 Mon Sep 17 00:00:00 2001 From: estelle Date: Thu, 13 Mar 2025 15:29:34 +0100 Subject: [PATCH 10/12] Update CHANGELOG.md --- CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f03dc1e0f..b27ca4930 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,11 @@ ## Next +### Added + +- Added a `run_with_context` method to `Component`. This method has a `context_` parameter that contains information from the pipeline the component is being run from (e.g. the `run_id`) + + ## 1.6.0 ### Added From e41a738c22f72c532c4e02e8dae3fa720f9d8bf6 Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 18 Mar 2025 10:19:59 +0100 Subject: [PATCH 11/12] Improve documentation of future changes --- CHANGELOG.md | 2 +- docs/source/user_guide_pipeline.rst | 8 ++++++-- .../pipeline/pipeline_with_component_notifications.py | 5 ++--- src/neo4j_graphrag/experimental/pipeline/component.py | 9 +++++---- src/neo4j_graphrag/experimental/pipeline/orchestrator.py | 4 +++- .../experimental/pipeline/types/context.py | 6 +++--- tests/unit/experimental/pipeline/test_component.py | 2 +- 7 files changed, 21 insertions(+), 15 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b27ca4930..65962d2b6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,7 @@ ### Added -- Added a `run_with_context` method to `Component`. This method has a `context_` parameter that contains information from the pipeline the component is being run from (e.g. the `run_id`) +- Added the `run_with_context` method to `Component`. This method includes a `context_` parameter, which provides information about the pipeline from which the component is executed (e.g., the `run_id`). It also enables the component to send events to the pipeline's callback function. ## 1.6.0 diff --git a/docs/source/user_guide_pipeline.rst b/docs/source/user_guide_pipeline.rst index af66fde25..5550d1ea9 100644 --- a/docs/source/user_guide_pipeline.rst +++ b/docs/source/user_guide_pipeline.rst @@ -179,8 +179,8 @@ See :ref:`pipelineevent` and :ref:`taskevent` to see what is sent in each event Send Events from Components =========================== -Components can send notifications about their progress using the `notify` function from -the `context_`: +Components can send progress notifications using the `notify` function from +`context_` by implementing the `run_from_context` method: .. code:: python @@ -200,3 +200,7 @@ the `context_`: return IntResultModel(result = number1 + number2) This will send an `TASK_PROGRESS` event to the pipeline callback. + +.. note:: + + In a future release, the `context_` parameter will be added to the `run` method. diff --git a/examples/customize/build_graph/pipeline/pipeline_with_component_notifications.py b/examples/customize/build_graph/pipeline/pipeline_with_component_notifications.py index 6d6fe4d8e..3fe5bf16e 100644 --- a/examples/customize/build_graph/pipeline/pipeline_with_component_notifications.py +++ b/examples/customize/build_graph/pipeline/pipeline_with_component_notifications.py @@ -25,9 +25,6 @@ class MultiplicationComponent(Component): def __init__(self, f: int) -> None: self.f = f - async def run(self, numbers: list[int]) -> MultiplyComponentResult: - return MultiplyComponentResult(result=[]) - async def multiply_number( self, context_: RunContext, @@ -39,6 +36,8 @@ async def multiply_number( ) return self.f * number + # implementing `run_with_context` to get access to + # the pipeline's RunContext: async def run_with_context( self, context_: RunContext, diff --git a/src/neo4j_graphrag/experimental/pipeline/component.py b/src/neo4j_graphrag/experimental/pipeline/component.py index b59b60b19..39a2816ef 100644 --- a/src/neo4j_graphrag/experimental/pipeline/component.py +++ b/src/neo4j_graphrag/experimental/pipeline/component.py @@ -84,13 +84,12 @@ class Component(metaclass=ComponentMeta): component_outputs: dict[str, dict[str, str | bool | type]] async def run(self, *args: Any, **kwargs: Any) -> DataModel: - """This function is planned for deprecation in a future release. + """Run the component and return its result. Note: if `run_with_context` is implemented, this method will not be used. """ raise NotImplementedError( "You must implement the `run` or `run_with_context` method. " - "`run` method will be marked for deprecation in a future release." ) async def run_with_context( @@ -102,8 +101,10 @@ async def run_with_context( that can be used to send events from the component to the pipeline callback. - For now, it defaults to calling the `run` method, but it - is meant to replace the `run` method in a future release. + This feature will be moved to the `run` method in a future + release. + + It defaults to calling the `run` method to prevent any breaking change. """ # default behavior to prevent a breaking change return await self.run(*args, **kwargs) diff --git a/src/neo4j_graphrag/experimental/pipeline/orchestrator.py b/src/neo4j_graphrag/experimental/pipeline/orchestrator.py index fd933470a..d2426db39 100644 --- a/src/neo4j_graphrag/experimental/pipeline/orchestrator.py +++ b/src/neo4j_graphrag/experimental/pipeline/orchestrator.py @@ -85,7 +85,9 @@ async def run_task(self, task: TaskPipelineNode, data: dict[str, Any]) -> None: run_id=self.run_id, task_name=task.name, ) - context = RunContext(run_id=self.run_id, task_name=task.name, notifier=notifier) + context = RunContext( + run_id=self.run_id, task_name=task.name, _notifier=notifier + ) res = await task.run(context, inputs) await self.set_task_status(task.name, RunStatus.DONE) await self.event_notifier.notify_task_finished(self.run_id, task.name, res) diff --git a/src/neo4j_graphrag/experimental/pipeline/types/context.py b/src/neo4j_graphrag/experimental/pipeline/types/context.py index f0b4caf97..0e51d9036 100644 --- a/src/neo4j_graphrag/experimental/pipeline/types/context.py +++ b/src/neo4j_graphrag/experimental/pipeline/types/context.py @@ -28,10 +28,10 @@ class RunContext(BaseModel): run_id: str task_name: str - notifier: Optional[TaskProgressCallbackProtocol] = None + _notifier: Optional[TaskProgressCallbackProtocol] = None model_config = ConfigDict(arbitrary_types_allowed=True) async def notify(self, message: str, data: dict[str, Any]) -> None: - if self.notifier: - await self.notifier(message=message, data=data) + if self._notifier: + await self._notifier(message=message, data=data) diff --git a/tests/unit/experimental/pipeline/test_component.py b/tests/unit/experimental/pipeline/test_component.py index ba39aebaf..867cad68a 100644 --- a/tests/unit/experimental/pipeline/test_component.py +++ b/tests/unit/experimental/pipeline/test_component.py @@ -68,7 +68,7 @@ async def test_component_run_with_context() -> None: c = ComponentMultiplyWithContext() notifier_mock = AsyncMock() result = await c.run_with_context( - RunContext(run_id="run_id", task_name="task_name", notifier=notifier_mock), + RunContext(run_id="run_id", task_name="task_name", _notifier=notifier_mock), number1=1, number2=2, ) From fe0fe02e271e006560d764fe1bf2d0c6d38a16df Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 18 Mar 2025 13:59:14 +0100 Subject: [PATCH 12/12] Undo make notifier private, not needed --- src/neo4j_graphrag/experimental/pipeline/orchestrator.py | 4 +--- src/neo4j_graphrag/experimental/pipeline/types/context.py | 6 +++--- tests/unit/experimental/pipeline/test_component.py | 2 +- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/neo4j_graphrag/experimental/pipeline/orchestrator.py b/src/neo4j_graphrag/experimental/pipeline/orchestrator.py index d2426db39..fd933470a 100644 --- a/src/neo4j_graphrag/experimental/pipeline/orchestrator.py +++ b/src/neo4j_graphrag/experimental/pipeline/orchestrator.py @@ -85,9 +85,7 @@ async def run_task(self, task: TaskPipelineNode, data: dict[str, Any]) -> None: run_id=self.run_id, task_name=task.name, ) - context = RunContext( - run_id=self.run_id, task_name=task.name, _notifier=notifier - ) + context = RunContext(run_id=self.run_id, task_name=task.name, notifier=notifier) res = await task.run(context, inputs) await self.set_task_status(task.name, RunStatus.DONE) await self.event_notifier.notify_task_finished(self.run_id, task.name, res) diff --git a/src/neo4j_graphrag/experimental/pipeline/types/context.py b/src/neo4j_graphrag/experimental/pipeline/types/context.py index 0e51d9036..f0b4caf97 100644 --- a/src/neo4j_graphrag/experimental/pipeline/types/context.py +++ b/src/neo4j_graphrag/experimental/pipeline/types/context.py @@ -28,10 +28,10 @@ class RunContext(BaseModel): run_id: str task_name: str - _notifier: Optional[TaskProgressCallbackProtocol] = None + notifier: Optional[TaskProgressCallbackProtocol] = None model_config = ConfigDict(arbitrary_types_allowed=True) async def notify(self, message: str, data: dict[str, Any]) -> None: - if self._notifier: - await self._notifier(message=message, data=data) + if self.notifier: + await self.notifier(message=message, data=data) diff --git a/tests/unit/experimental/pipeline/test_component.py b/tests/unit/experimental/pipeline/test_component.py index 867cad68a..ba39aebaf 100644 --- a/tests/unit/experimental/pipeline/test_component.py +++ b/tests/unit/experimental/pipeline/test_component.py @@ -68,7 +68,7 @@ async def test_component_run_with_context() -> None: c = ComponentMultiplyWithContext() notifier_mock = AsyncMock() result = await c.run_with_context( - RunContext(run_id="run_id", task_name="task_name", _notifier=notifier_mock), + RunContext(run_id="run_id", task_name="task_name", notifier=notifier_mock), number1=1, number2=2, )