From 545f1dd24bac4ed1ef71c628f70a9321b4fc1010 Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 11 Mar 2025 14:31:12 +0100 Subject: [PATCH 01/13] Expose context from the orchestrator to the components --- src/neo4j_graphrag/experimental/pipeline/component.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/neo4j_graphrag/experimental/pipeline/component.py b/src/neo4j_graphrag/experimental/pipeline/component.py index 39a2816ef..ece8cfc58 100644 --- a/src/neo4j_graphrag/experimental/pipeline/component.py +++ b/src/neo4j_graphrag/experimental/pipeline/component.py @@ -40,7 +40,7 @@ def __new__( run = run_context_method if run_context_method is not None else run_method if run is None: raise RuntimeError( - f"You must implement either `run` or `run_with_context` in Component '{name}'" + f"Either 'run' or 'run_with_context' must be implemented in component: '{name}'" ) sig = inspect.signature(run) attrs["component_inputs"] = { From 89968892f7ee3fc185de422153f695f6fa4198d2 Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 11 Mar 2025 16:21:26 +0100 Subject: [PATCH 02/13] mypy --- tests/unit/experimental/pipeline/components.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/unit/experimental/pipeline/components.py b/tests/unit/experimental/pipeline/components.py index fc33e840d..146b1fe52 100644 --- a/tests/unit/experimental/pipeline/components.py +++ b/tests/unit/experimental/pipeline/components.py @@ -45,6 +45,9 @@ async def run(self, number1: int, number2: int = 2) -> IntResultModel: class ComponentMultiplyWithContext(Component): + async def run(self, number1: int, number2: int) -> IntResultModel: + return IntResultModel(result=number1 * number2) + async def run_with_context( self, context_: RunContext, number1: int, number2: int = 2 ) -> IntResultModel: From d146adf89fac3ceeddcd0e2d3726c69310d38c34 Mon Sep 17 00:00:00 2001 From: estelle Date: Thu, 13 Mar 2025 14:01:37 +0100 Subject: [PATCH 03/13] Changes so that the `run` method is not required anymore --- tests/unit/experimental/pipeline/components.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/unit/experimental/pipeline/components.py b/tests/unit/experimental/pipeline/components.py index 146b1fe52..fc33e840d 100644 --- a/tests/unit/experimental/pipeline/components.py +++ b/tests/unit/experimental/pipeline/components.py @@ -45,9 +45,6 @@ async def run(self, number1: int, number2: int = 2) -> IntResultModel: class ComponentMultiplyWithContext(Component): - async def run(self, number1: int, number2: int) -> IntResultModel: - return IntResultModel(result=number1 * number2) - async def run_with_context( self, context_: RunContext, number1: int, number2: int = 2 ) -> IntResultModel: From 567a387be5fdc699ca64f8ea244421ef884cbc19 Mon Sep 17 00:00:00 2001 From: estelle Date: Thu, 13 Mar 2025 14:07:26 +0100 Subject: [PATCH 04/13] Update documentation --- docs/source/user_guide_pipeline.rst | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/docs/source/user_guide_pipeline.rst b/docs/source/user_guide_pipeline.rst index 5550d1ea9..703e48aba 100644 --- a/docs/source/user_guide_pipeline.rst +++ b/docs/source/user_guide_pipeline.rst @@ -22,7 +22,7 @@ their own by following these steps: 1. Create a subclass of the Pydantic `neo4j_graphrag.experimental.pipeline.DataModel` to represent the data being returned by the component 2. Create a subclass of `neo4j_graphrag.experimental.pipeline.Component` -3. Create a `run_with_context` method in this new class and specify the required inputs and output model using the just created `DataModel` +3. Create a `run` method in this new class and specify the required inputs and output model using the just created `DataModel` 4. Implement the run method: it's an `async` method, allowing tasks to be parallelized and awaited within this method. An example is given below, where a `ComponentAdd` is created to add two numbers together and return @@ -31,13 +31,12 @@ the resulting sum: .. code:: python from neo4j_graphrag.experimental.pipeline import Component, DataModel - from neo4j_graphrag.experimental.pipeline.types.context import RunContext class IntResultModel(DataModel): result: int class ComponentAdd(Component): - async def run_with_context(self, context_: RunContext, number1: int, number2: int = 1) -> IntResultModel: + async def run(self, number1: int, number2: int = 1) -> IntResultModel: return IntResultModel(result = number1 + number2) Read more about :ref:`components-section` in the API Documentation. From a4b043fa2ac558c8bfd66ac23623461f2302865a Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 18 Mar 2025 10:19:59 +0100 Subject: [PATCH 05/13] Improve documentation of future changes --- src/neo4j_graphrag/experimental/pipeline/orchestrator.py | 4 +++- src/neo4j_graphrag/experimental/pipeline/types/context.py | 6 +++--- tests/unit/experimental/pipeline/test_component.py | 2 +- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/neo4j_graphrag/experimental/pipeline/orchestrator.py b/src/neo4j_graphrag/experimental/pipeline/orchestrator.py index fd933470a..d2426db39 100644 --- a/src/neo4j_graphrag/experimental/pipeline/orchestrator.py +++ b/src/neo4j_graphrag/experimental/pipeline/orchestrator.py @@ -85,7 +85,9 @@ async def run_task(self, task: TaskPipelineNode, data: dict[str, Any]) -> None: run_id=self.run_id, task_name=task.name, ) - context = RunContext(run_id=self.run_id, task_name=task.name, notifier=notifier) + context = RunContext( + run_id=self.run_id, task_name=task.name, _notifier=notifier + ) res = await task.run(context, inputs) await self.set_task_status(task.name, RunStatus.DONE) await self.event_notifier.notify_task_finished(self.run_id, task.name, res) diff --git a/src/neo4j_graphrag/experimental/pipeline/types/context.py b/src/neo4j_graphrag/experimental/pipeline/types/context.py index f0b4caf97..0e51d9036 100644 --- a/src/neo4j_graphrag/experimental/pipeline/types/context.py +++ b/src/neo4j_graphrag/experimental/pipeline/types/context.py @@ -28,10 +28,10 @@ class RunContext(BaseModel): run_id: str task_name: str - notifier: Optional[TaskProgressCallbackProtocol] = None + _notifier: Optional[TaskProgressCallbackProtocol] = None model_config = ConfigDict(arbitrary_types_allowed=True) async def notify(self, message: str, data: dict[str, Any]) -> None: - if self.notifier: - await self.notifier(message=message, data=data) + if self._notifier: + await self._notifier(message=message, data=data) diff --git a/tests/unit/experimental/pipeline/test_component.py b/tests/unit/experimental/pipeline/test_component.py index ba39aebaf..867cad68a 100644 --- a/tests/unit/experimental/pipeline/test_component.py +++ b/tests/unit/experimental/pipeline/test_component.py @@ -68,7 +68,7 @@ async def test_component_run_with_context() -> None: c = ComponentMultiplyWithContext() notifier_mock = AsyncMock() result = await c.run_with_context( - RunContext(run_id="run_id", task_name="task_name", notifier=notifier_mock), + RunContext(run_id="run_id", task_name="task_name", _notifier=notifier_mock), number1=1, number2=2, ) From 19d032d0def27dd45b1a59bf6292b8128d03cadb Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 18 Mar 2025 13:59:14 +0100 Subject: [PATCH 06/13] Undo make notifier private, not needed --- src/neo4j_graphrag/experimental/pipeline/orchestrator.py | 4 +--- src/neo4j_graphrag/experimental/pipeline/types/context.py | 6 +++--- tests/unit/experimental/pipeline/test_component.py | 2 +- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/neo4j_graphrag/experimental/pipeline/orchestrator.py b/src/neo4j_graphrag/experimental/pipeline/orchestrator.py index d2426db39..fd933470a 100644 --- a/src/neo4j_graphrag/experimental/pipeline/orchestrator.py +++ b/src/neo4j_graphrag/experimental/pipeline/orchestrator.py @@ -85,9 +85,7 @@ async def run_task(self, task: TaskPipelineNode, data: dict[str, Any]) -> None: run_id=self.run_id, task_name=task.name, ) - context = RunContext( - run_id=self.run_id, task_name=task.name, _notifier=notifier - ) + context = RunContext(run_id=self.run_id, task_name=task.name, notifier=notifier) res = await task.run(context, inputs) await self.set_task_status(task.name, RunStatus.DONE) await self.event_notifier.notify_task_finished(self.run_id, task.name, res) diff --git a/src/neo4j_graphrag/experimental/pipeline/types/context.py b/src/neo4j_graphrag/experimental/pipeline/types/context.py index 0e51d9036..f0b4caf97 100644 --- a/src/neo4j_graphrag/experimental/pipeline/types/context.py +++ b/src/neo4j_graphrag/experimental/pipeline/types/context.py @@ -28,10 +28,10 @@ class RunContext(BaseModel): run_id: str task_name: str - _notifier: Optional[TaskProgressCallbackProtocol] = None + notifier: Optional[TaskProgressCallbackProtocol] = None model_config = ConfigDict(arbitrary_types_allowed=True) async def notify(self, message: str, data: dict[str, Any]) -> None: - if self._notifier: - await self._notifier(message=message, data=data) + if self.notifier: + await self.notifier(message=message, data=data) diff --git a/tests/unit/experimental/pipeline/test_component.py b/tests/unit/experimental/pipeline/test_component.py index 867cad68a..ba39aebaf 100644 --- a/tests/unit/experimental/pipeline/test_component.py +++ b/tests/unit/experimental/pipeline/test_component.py @@ -68,7 +68,7 @@ async def test_component_run_with_context() -> None: c = ComponentMultiplyWithContext() notifier_mock = AsyncMock() result = await c.run_with_context( - RunContext(run_id="run_id", task_name="task_name", _notifier=notifier_mock), + RunContext(run_id="run_id", task_name="task_name", notifier=notifier_mock), number1=1, number2=2, ) From 6e4261bba315168a24acbd36b8ed8c4a588e7aaf Mon Sep 17 00:00:00 2001 From: estelle Date: Thu, 13 Mar 2025 11:31:36 +0100 Subject: [PATCH 07/13] Pipeline streaming --- .../pipeline/pipeline_streaming.py | 77 +++++++++++++++++++ .../experimental/pipeline/pipeline.py | 63 ++++++++++++++- 2 files changed, 138 insertions(+), 2 deletions(-) create mode 100644 examples/customize/build_graph/pipeline/pipeline_streaming.py diff --git a/examples/customize/build_graph/pipeline/pipeline_streaming.py b/examples/customize/build_graph/pipeline/pipeline_streaming.py new file mode 100644 index 000000000..15e70d4a9 --- /dev/null +++ b/examples/customize/build_graph/pipeline/pipeline_streaming.py @@ -0,0 +1,77 @@ +import asyncio + +from neo4j_graphrag.experimental.pipeline.component import Component, DataModel +from neo4j_graphrag.experimental.pipeline.pipeline import Pipeline +from neo4j_graphrag.experimental.pipeline.notification import EventType +from neo4j_graphrag.experimental.pipeline.types.context import RunContext + + +# Define some example components with progress notifications +class OutputModel(DataModel): + result: int + + +class SlowAdder(Component): + """A component that slowly adds numbers and reports progress""" + def __init__(self, number: int) -> None: + self.number = number + + async def run_with_context(self, context_: RunContext, value: int) -> OutputModel: + # Simulate work with progress updates + for i in range(value): + await asyncio.sleep(0.5) # Simulate work + await context_.notify( + message=f"Added {i+1}/{value}", + data={"current": i+1, "total": value} + ) + return OutputModel(result=value + self.number) + + +class SlowMultiplier(Component): + """A component that slowly multiplies numbers and reports progress""" + def __init__(self, multiplier: int) -> None: + self.multiplier = multiplier + + async def run_with_context(self, context_: RunContext, value: int) -> OutputModel: + # Simulate work with progress updates + for i in range(3): # Always do 3 steps + await asyncio.sleep(0.7) # Simulate work + await context_.notify( + message=f"Multiplication step {i+1}/3", + data={"step": i+1, "total": 3} + ) + return OutputModel(result=value * self.multiplier) + + +async def main(): + # Create pipeline + pipeline = Pipeline() + + # Add components + pipeline.add_component(SlowAdder(number=3), "adder") + pipeline.add_component(SlowMultiplier(multiplier=2), "multiplier") + + # Connect components + pipeline.connect( + "adder", + "multiplier", + {"value": "adder.result"} + ) + + print("\n=== Running pipeline with streaming ===") + # Run pipeline with streaming - see events as they happen + async for event in pipeline.stream({"adder": {"value": 2}}): + if event.event_type == EventType.PIPELINE_STARTED: + print("Stream: Pipeline started!") + elif event.event_type == EventType.PIPELINE_FINISHED: + print(f"Stream: Pipeline finished! Final results: {event.payload}") + elif event.event_type == EventType.TASK_STARTED: + print(f"Stream: Task {event.task_name} started with inputs: {event.payload}") + elif event.event_type == EventType.TASK_PROGRESS: + print(f"Stream: Task {event.task_name} progress - {event.message}") + elif event.event_type == EventType.TASK_FINISHED: + print(f"Stream: Task {event.task_name} finished with result: {event.payload}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/neo4j_graphrag/experimental/pipeline/pipeline.py b/src/neo4j_graphrag/experimental/pipeline/pipeline.py index 64b74ddc6..5d10e0ecb 100644 --- a/src/neo4j_graphrag/experimental/pipeline/pipeline.py +++ b/src/neo4j_graphrag/experimental/pipeline/pipeline.py @@ -18,7 +18,8 @@ import warnings from collections import defaultdict from timeit import default_timer -from typing import Any, Optional +from typing import Any, Optional, AsyncGenerator, Callable, List +import asyncio from neo4j_graphrag.utils.logging import prettify @@ -47,7 +48,7 @@ ) from neo4j_graphrag.experimental.pipeline.types.orchestration import RunResult from neo4j_graphrag.experimental.pipeline.types.context import RunContext -from neo4j_graphrag.experimental.pipeline.notification import EventCallbackProtocol +from neo4j_graphrag.experimental.pipeline.notification import EventCallbackProtocol, Event logger = logging.getLogger(__name__) @@ -412,6 +413,64 @@ def validate_parameter_mapping_for_task(self, task: TaskPipelineNode) -> bool: async def get_final_results(self, run_id: str) -> dict[str, Any]: return await self.final_results.get(run_id) # type: ignore[no-any-return] + async def stream(self, data: dict[str, Any]) -> AsyncGenerator[Event, None]: + """Run the pipeline and stream events for task progress. + + Args: + data: Input data for the pipeline components + + Yields: + Event: Pipeline and task events including start, progress, and completion + """ + # Create queue for events + event_queue: asyncio.Queue[Event] = asyncio.Queue() + + # Store original callback + original_callback = self.callback + + async def callback_and_event_stream(event: Event) -> None: + # Put event in queue for streaming + await event_queue.put(event) + # Call original callback if it exists + if original_callback: + await original_callback(event) + + # Set up event callback + self.callback = callback_and_event_stream + + try: + # Start pipeline execution in background task + run_task = asyncio.create_task(self.run(data)) + + while True: + # Wait for next event or pipeline completion + done, pending = await asyncio.wait( + [run_task, event_queue.get()], + return_when=asyncio.FIRST_COMPLETED + ) + + # Pipeline finished + if run_task in done: + if run_task.exception(): + raise run_task.exception() + # Drain any remaining events + while not event_queue.empty(): + yield await event_queue.get() + break + + # Got an event from queue + event_future = next(f for f in done if f != run_task) + try: + event = event_future.result() + yield event + except Exception as e: + logger.error(f"Error processing event: {e}") + raise + + finally: + # Restore original callback + self.callback = original_callback + async def run(self, data: dict[str, Any]) -> PipelineResult: logger.debug("PIPELINE START") start_time = default_timer() From 6a7c1685ed17f491557d0ff6d0f6d80d37f3d4f4 Mon Sep 17 00:00:00 2001 From: estelle Date: Mon, 17 Mar 2025 15:46:04 +0100 Subject: [PATCH 08/13] Add tests --- .../pipeline/pipeline_streaming.py | 40 ++++---- .../experimental/pipeline/notification.py | 11 ++- .../experimental/pipeline/orchestrator.py | 2 +- .../experimental/pipeline/pipeline.py | 84 ++++++++-------- .../unit/experimental/pipeline/components.py | 11 +++ .../experimental/pipeline/test_pipeline.py | 95 +++++++++++++++++++ 6 files changed, 179 insertions(+), 64 deletions(-) diff --git a/examples/customize/build_graph/pipeline/pipeline_streaming.py b/examples/customize/build_graph/pipeline/pipeline_streaming.py index 15e70d4a9..824e9a361 100644 --- a/examples/customize/build_graph/pipeline/pipeline_streaming.py +++ b/examples/customize/build_graph/pipeline/pipeline_streaming.py @@ -2,7 +2,7 @@ from neo4j_graphrag.experimental.pipeline.component import Component, DataModel from neo4j_graphrag.experimental.pipeline.pipeline import Pipeline -from neo4j_graphrag.experimental.pipeline.notification import EventType +from neo4j_graphrag.experimental.pipeline.notification import EventType, Event from neo4j_graphrag.experimental.pipeline.types.context import RunContext @@ -13,6 +13,7 @@ class OutputModel(DataModel): class SlowAdder(Component): """A component that slowly adds numbers and reports progress""" + def __init__(self, number: int) -> None: self.number = number @@ -21,14 +22,14 @@ async def run_with_context(self, context_: RunContext, value: int) -> OutputMode for i in range(value): await asyncio.sleep(0.5) # Simulate work await context_.notify( - message=f"Added {i+1}/{value}", - data={"current": i+1, "total": value} + message=f"Added {i+1}/{value}", data={"current": i + 1, "total": value} ) return OutputModel(result=value + self.number) class SlowMultiplier(Component): """A component that slowly multiplies numbers and reports progress""" + def __init__(self, multiplier: int) -> None: self.multiplier = multiplier @@ -37,26 +38,25 @@ async def run_with_context(self, context_: RunContext, value: int) -> OutputMode for i in range(3): # Always do 3 steps await asyncio.sleep(0.7) # Simulate work await context_.notify( - message=f"Multiplication step {i+1}/3", - data={"step": i+1, "total": 3} + message=f"Multiplication step {i+1}/3", data={"step": i + 1, "total": 3} ) return OutputModel(result=value * self.multiplier) -async def main(): +async def callback(event: Event) -> None: + await asyncio.sleep(0.1) + + +async def main() -> None: # Create pipeline - pipeline = Pipeline() - + pipeline = Pipeline(callback=callback) + # Add components pipeline.add_component(SlowAdder(number=3), "adder") pipeline.add_component(SlowMultiplier(multiplier=2), "multiplier") - + # Connect components - pipeline.connect( - "adder", - "multiplier", - {"value": "adder.result"} - ) + pipeline.connect("adder", "multiplier", {"value": "adder.result"}) print("\n=== Running pipeline with streaming ===") # Run pipeline with streaming - see events as they happen @@ -66,12 +66,16 @@ async def main(): elif event.event_type == EventType.PIPELINE_FINISHED: print(f"Stream: Pipeline finished! Final results: {event.payload}") elif event.event_type == EventType.TASK_STARTED: - print(f"Stream: Task {event.task_name} started with inputs: {event.payload}") + print( + f"Stream: Task {event.task_name} started with inputs: {event.payload}" # type: ignore + ) elif event.event_type == EventType.TASK_PROGRESS: - print(f"Stream: Task {event.task_name} progress - {event.message}") + print(f"Stream: Task {event.task_name} progress - {event.message}") # type: ignore elif event.event_type == EventType.TASK_FINISHED: - print(f"Stream: Task {event.task_name} finished with result: {event.payload}") + print( + f"Stream: Task {event.task_name} finished with result: {event.payload}" # type: ignore + ) if __name__ == "__main__": - asyncio.run(main()) + asyncio.run(main()) diff --git a/src/neo4j_graphrag/experimental/pipeline/notification.py b/src/neo4j_graphrag/experimental/pipeline/notification.py index e9cb63cc6..0062e4690 100644 --- a/src/neo4j_graphrag/experimental/pipeline/notification.py +++ b/src/neo4j_graphrag/experimental/pipeline/notification.py @@ -14,6 +14,7 @@ # limitations under the License. from __future__ import annotations +import asyncio import datetime import enum from collections.abc import Awaitable @@ -72,12 +73,14 @@ def __call__(self, event: Event) -> Awaitable[None]: ... class EventNotifier: - def __init__(self, callback: EventCallbackProtocol | None) -> None: - self.callback = callback + def __init__(self, callbacks: list[EventCallbackProtocol]) -> None: + self.callbacks = callbacks async def notify(self, event: Event) -> None: - if self.callback: - await self.callback(event) + await asyncio.gather( + *[c(event) for c in self.callbacks], + return_exceptions=True, + ) async def notify_pipeline_started( self, run_id: str, input_data: Optional[dict[str, Any]] = None diff --git a/src/neo4j_graphrag/experimental/pipeline/orchestrator.py b/src/neo4j_graphrag/experimental/pipeline/orchestrator.py index fd933470a..de9468bf7 100644 --- a/src/neo4j_graphrag/experimental/pipeline/orchestrator.py +++ b/src/neo4j_graphrag/experimental/pipeline/orchestrator.py @@ -54,7 +54,7 @@ class Orchestrator: def __init__(self, pipeline: Pipeline): self.pipeline = pipeline - self.event_notifier = EventNotifier(pipeline.callback) + self.event_notifier = EventNotifier(pipeline.callbacks) self.run_id = str(uuid.uuid4()) async def run_task(self, task: TaskPipelineNode, data: dict[str, Any]) -> None: diff --git a/src/neo4j_graphrag/experimental/pipeline/pipeline.py b/src/neo4j_graphrag/experimental/pipeline/pipeline.py index 5d10e0ecb..0e8d0f6e5 100644 --- a/src/neo4j_graphrag/experimental/pipeline/pipeline.py +++ b/src/neo4j_graphrag/experimental/pipeline/pipeline.py @@ -18,7 +18,7 @@ import warnings from collections import defaultdict from timeit import default_timer -from typing import Any, Optional, AsyncGenerator, Callable, List +from typing import Any, Optional, AsyncGenerator import asyncio from neo4j_graphrag.utils.logging import prettify @@ -48,7 +48,10 @@ ) from neo4j_graphrag.experimental.pipeline.types.orchestration import RunResult from neo4j_graphrag.experimental.pipeline.types.context import RunContext -from neo4j_graphrag.experimental.pipeline.notification import EventCallbackProtocol, Event +from neo4j_graphrag.experimental.pipeline.notification import ( + EventCallbackProtocol, + Event, +) logger = logging.getLogger(__name__) @@ -118,7 +121,7 @@ def __init__( ) -> None: super().__init__() self.store = store or InMemoryStore() - self.callback = callback + self.callbacks = [callback] if callback else [] self.final_results = InMemoryStore() self.is_validated = False self.param_mapping: dict[str, dict[str, dict[str, str]]] = defaultdict(dict) @@ -415,61 +418,60 @@ async def get_final_results(self, run_id: str) -> dict[str, Any]: async def stream(self, data: dict[str, Any]) -> AsyncGenerator[Event, None]: """Run the pipeline and stream events for task progress. - + Args: data: Input data for the pipeline components - + Yields: Event: Pipeline and task events including start, progress, and completion """ # Create queue for events event_queue: asyncio.Queue[Event] = asyncio.Queue() - - # Store original callback - original_callback = self.callback - - async def callback_and_event_stream(event: Event) -> None: + + async def event_stream(event: Event) -> None: # Put event in queue for streaming await event_queue.put(event) - # Call original callback if it exists - if original_callback: - await original_callback(event) - - # Set up event callback - self.callback = callback_and_event_stream - + + # Add event streaming callback + self.callbacks.append(event_stream) + try: # Start pipeline execution in background task run_task = asyncio.create_task(self.run(data)) - - while True: + + # loop until the run task is done, and we do not have + # any more pending tasks in queue + is_run_task_running = True + is_queue_empty = False + while is_run_task_running or not is_queue_empty: # Wait for next event or pipeline completion + event_queue_getter_task = asyncio.create_task(event_queue.get()) done, pending = await asyncio.wait( - [run_task, event_queue.get()], - return_when=asyncio.FIRST_COMPLETED + [run_task, event_queue_getter_task], + return_when=asyncio.FIRST_COMPLETED, ) - - # Pipeline finished - if run_task in done: - if run_task.exception(): - raise run_task.exception() - # Drain any remaining events - while not event_queue.empty(): - yield await event_queue.get() - break - - # Got an event from queue - event_future = next(f for f in done if f != run_task) - try: - event = event_future.result() - yield event - except Exception as e: - logger.error(f"Error processing event: {e}") - raise - + + is_run_task_running = run_task not in done + is_queue_empty = event_queue.empty() + + for event_future in done: + if event_future == run_task: + continue + yield event_future.result() # type: ignore + + # cancel remaining task + event_queue_getter_task.cancel() + + # # Drain any remaining events + # while not event_queue.empty(): + # yield await event_queue.get() + # Pipeline finished + if run_task.exception(): + raise run_task.exception() # type: ignore + finally: # Restore original callback - self.callback = original_callback + self.callbacks.remove(event_stream) async def run(self, data: dict[str, Any]) -> PipelineResult: logger.debug("PIPELINE START") diff --git a/tests/unit/experimental/pipeline/components.py b/tests/unit/experimental/pipeline/components.py index fc33e840d..e1b2e1845 100644 --- a/tests/unit/experimental/pipeline/components.py +++ b/tests/unit/experimental/pipeline/components.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import asyncio + from neo4j_graphrag.experimental.pipeline import Component, DataModel from neo4j_graphrag.experimental.pipeline.types.context import RunContext @@ -52,3 +54,12 @@ async def run_with_context( message="my message", data={"number1": number1, "number2": number2} ) return IntResultModel(result=number1 * number2) + + +class SlowComponentMultiply(Component): + def __init__(self, sleep: float = 1.0) -> None: + self.sleep = sleep + + async def run(self, number1: int, number2: int = 2) -> IntResultModel: + await asyncio.sleep(self.sleep) + return IntResultModel(result=number1 * number2) diff --git a/tests/unit/experimental/pipeline/test_pipeline.py b/tests/unit/experimental/pipeline/test_pipeline.py index 16d32ba9b..ef7e81227 100644 --- a/tests/unit/experimental/pipeline/test_pipeline.py +++ b/tests/unit/experimental/pipeline/test_pipeline.py @@ -39,6 +39,7 @@ ComponentNoParam, ComponentPassThrough, StringResultModel, + SlowComponentMultiply, ) @@ -491,3 +492,97 @@ def test_event_model_no_warning(recwarn: Sized) -> None: ) assert event.timestamp is not None assert len(recwarn) == 0 + + +@pytest.mark.asyncio +async def test_pipeline_streaming_no_user_callback_happy_path() -> None: + pipe = Pipeline() + events = [] + async for e in pipe.stream({}): + events.append(e) + assert len(events) == 2 + assert events[0].event_type == EventType.PIPELINE_STARTED + assert events[1].event_type == EventType.PIPELINE_FINISHED + assert len(pipe.callbacks) == 0 + + +@pytest.mark.asyncio +async def test_pipeline_streaming_with_user_callback_happy_path() -> None: + callback = AsyncMock() + pipe = Pipeline(callback=callback) + events = [] + async for e in pipe.stream({}): + events.append(e) + assert len(events) == 2 + assert len(callback.call_args_list) == 2 + assert len(pipe.callbacks) == 1 + + +@pytest.mark.asyncio +async def test_pipeline_streaming_very_long_running_user_callback() -> None: + async def callback(event: Event) -> None: + await asyncio.sleep(2) + + pipe = Pipeline(callback=callback) + events = [] + async for e in pipe.stream({}): + events.append(e) + assert len(events) == 2 + assert len(pipe.callbacks) == 1 + + +@pytest.mark.asyncio +async def test_pipeline_streaming_very_long_running_pipeline() -> None: + slow_component = SlowComponentMultiply() + pipe = Pipeline() + pipe.add_component(slow_component, "slow_component") + events = [] + async for e in pipe.stream({"slow_component": {"number1": 1, "number2": 2}}): + events.append(e) + assert len(events) == 4 + last_event = events[-1] + assert last_event.event_type == EventType.PIPELINE_FINISHED + assert last_event.payload == {"slow_component": {"result": 2}} + + +@pytest.mark.asyncio +async def test_pipeline_streaming_error_in_pipeline_definition() -> None: + pipe = Pipeline() + component_a = ComponentAdd() + component_b = ComponentAdd() + pipe.add_component(component_a, "a") + pipe.add_component(component_b, "b") + pipe.connect("a", "b", {"number1": "a.result"}) + events = [] + with pytest.raises(PipelineDefinitionError): + async for e in pipe.stream({"a": {"number1": 1, "number2": 2}}): + events.append(e) + # validation happens before pipeline run actually starts + assert len(events) == 0 + + +@pytest.mark.asyncio +async def test_pipeline_streaming_error_in_component() -> None: + component = ComponentMultiply() + pipe = Pipeline() + pipe.add_component(component, "component") + events = [] + with pytest.raises(TypeError): + async for e in pipe.stream({"component": {"number1": None, "number2": 2}}): + events.append(e) + assert len(events) == 2 + assert events[0].event_type == EventType.PIPELINE_STARTED + assert events[1].event_type == EventType.TASK_STARTED + + +@pytest.mark.asyncio +async def test_pipeline_streaming_error_in_user_callback() -> None: + async def callback(event: Event) -> None: + raise Exception("error in callback") + + pipe = Pipeline(callback=callback) + events = [] + async for e in pipe.stream({}): + events.append(e) + assert len(events) == 2 + assert len(pipe.callbacks) == 1 From bc70e56cc79c6ddebf998f5ffc2d3431c64a3e08 Mon Sep 17 00:00:00 2001 From: estelle Date: Thu, 20 Mar 2025 12:08:33 +0100 Subject: [PATCH 09/13] Fix "Task destroyed but is pending" warning --- src/neo4j_graphrag/experimental/pipeline/pipeline.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/neo4j_graphrag/experimental/pipeline/pipeline.py b/src/neo4j_graphrag/experimental/pipeline/pipeline.py index 0e8d0f6e5..2a83aac94 100644 --- a/src/neo4j_graphrag/experimental/pipeline/pipeline.py +++ b/src/neo4j_graphrag/experimental/pipeline/pipeline.py @@ -435,6 +435,7 @@ async def event_stream(event: Event) -> None: # Add event streaming callback self.callbacks.append(event_stream) + event_queue_getter_task = None try: # Start pipeline execution in background task run_task = asyncio.create_task(self.run(data)) @@ -459,19 +460,14 @@ async def event_stream(event: Event) -> None: continue yield event_future.result() # type: ignore - # cancel remaining task - event_queue_getter_task.cancel() - - # # Drain any remaining events - # while not event_queue.empty(): - # yield await event_queue.get() - # Pipeline finished if run_task.exception(): raise run_task.exception() # type: ignore finally: # Restore original callback self.callbacks.remove(event_stream) + if event_queue_getter_task and not event_queue_getter_task.done(): + event_queue_getter_task.cancel() async def run(self, data: dict[str, Any]) -> PipelineResult: logger.debug("PIPELINE START") From 8cafb31950be2cea9d2cfa143a05d3c59bba1d3c Mon Sep 17 00:00:00 2001 From: estelle Date: Thu, 20 Mar 2025 12:15:38 +0100 Subject: [PATCH 10/13] Fix rebase --- src/neo4j_graphrag/experimental/pipeline/component.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/neo4j_graphrag/experimental/pipeline/component.py b/src/neo4j_graphrag/experimental/pipeline/component.py index ece8cfc58..39a2816ef 100644 --- a/src/neo4j_graphrag/experimental/pipeline/component.py +++ b/src/neo4j_graphrag/experimental/pipeline/component.py @@ -40,7 +40,7 @@ def __new__( run = run_context_method if run_context_method is not None else run_method if run is None: raise RuntimeError( - f"Either 'run' or 'run_with_context' must be implemented in component: '{name}'" + f"You must implement either `run` or `run_with_context` in Component '{name}'" ) sig = inspect.signature(run) attrs["component_inputs"] = { From 2645692188da7c83b6ac69a408d33927976c6e4c Mon Sep 17 00:00:00 2001 From: estelle Date: Fri, 28 Mar 2025 13:30:55 +0100 Subject: [PATCH 11/13] Update CHANGELOG --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c3e0ac312..603a9a10f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ ### Added - Added support for multi-vector collection in Qdrant driver. +- Added a `Pipeline.stream` method to stream pipeline progress. ### Changed From 6eebfab9b93b4333c495d7eb2b13d550db4a128d Mon Sep 17 00:00:00 2001 From: estelle Date: Wed, 2 Apr 2025 12:17:48 +0200 Subject: [PATCH 12/13] Add PIPELINE_FAILED event and option not to raise exception --- .../pipeline/pipeline_streaming.py | 7 ++++- .../experimental/pipeline/notification.py | 7 ++++- .../experimental/pipeline/pipeline.py | 29 +++++++++++++++---- .../experimental/pipeline/test_pipeline.py | 8 +++-- 4 files changed, 41 insertions(+), 10 deletions(-) diff --git a/examples/customize/build_graph/pipeline/pipeline_streaming.py b/examples/customize/build_graph/pipeline/pipeline_streaming.py index 824e9a361..13a262a9c 100644 --- a/examples/customize/build_graph/pipeline/pipeline_streaming.py +++ b/examples/customize/build_graph/pipeline/pipeline_streaming.py @@ -60,11 +60,16 @@ async def main() -> None: print("\n=== Running pipeline with streaming ===") # Run pipeline with streaming - see events as they happen - async for event in pipeline.stream({"adder": {"value": 2}}): + async for event in pipeline.stream( + {"adder": {"value": 2}}, + raise_exception=False, # default is True + ): if event.event_type == EventType.PIPELINE_STARTED: print("Stream: Pipeline started!") elif event.event_type == EventType.PIPELINE_FINISHED: print(f"Stream: Pipeline finished! Final results: {event.payload}") + elif event.event_type == EventType.PIPELINE_FAILED: + print(f"Stream: Pipeline failed with message: {event.message}") elif event.event_type == EventType.TASK_STARTED: print( f"Stream: Task {event.task_name} started with inputs: {event.payload}" # type: ignore diff --git a/src/neo4j_graphrag/experimental/pipeline/notification.py b/src/neo4j_graphrag/experimental/pipeline/notification.py index 0062e4690..cc69a69e2 100644 --- a/src/neo4j_graphrag/experimental/pipeline/notification.py +++ b/src/neo4j_graphrag/experimental/pipeline/notification.py @@ -32,10 +32,15 @@ class EventType(enum.Enum): TASK_PROGRESS = "TASK_PROGRESS" TASK_FINISHED = "TASK_FINISHED" PIPELINE_FINISHED = "PIPELINE_FINISHED" + PIPELINE_FAILED = "PIPELINE_FAILED" @property def is_pipeline_event(self) -> bool: - return self in [EventType.PIPELINE_STARTED, EventType.PIPELINE_FINISHED] + return self in [ + EventType.PIPELINE_STARTED, + EventType.PIPELINE_FINISHED, + EventType.PIPELINE_FAILED, + ] @property def is_task_event(self) -> bool: diff --git a/src/neo4j_graphrag/experimental/pipeline/pipeline.py b/src/neo4j_graphrag/experimental/pipeline/pipeline.py index 2a83aac94..bfec89290 100644 --- a/src/neo4j_graphrag/experimental/pipeline/pipeline.py +++ b/src/neo4j_graphrag/experimental/pipeline/pipeline.py @@ -51,6 +51,8 @@ from neo4j_graphrag.experimental.pipeline.notification import ( EventCallbackProtocol, Event, + PipelineEvent, + EventType, ) @@ -416,17 +418,22 @@ def validate_parameter_mapping_for_task(self, task: TaskPipelineNode) -> bool: async def get_final_results(self, run_id: str) -> dict[str, Any]: return await self.final_results.get(run_id) # type: ignore[no-any-return] - async def stream(self, data: dict[str, Any]) -> AsyncGenerator[Event, None]: + async def stream( + self, data: dict[str, Any], raise_exception: bool = True + ) -> AsyncGenerator[Event, None]: """Run the pipeline and stream events for task progress. Args: - data: Input data for the pipeline components + data (dict): Input data for the pipeline components + raise_exception (bool): set to False to prevent this task from propagating + Pipeline exceptions. Yields: Event: Pipeline and task events including start, progress, and completion """ # Create queue for events event_queue: asyncio.Queue[Event] = asyncio.Queue() + run_id = None async def event_stream(event: Event) -> None: # Put event in queue for streaming @@ -458,10 +465,20 @@ async def event_stream(event: Event) -> None: for event_future in done: if event_future == run_task: continue - yield event_future.result() # type: ignore - - if run_task.exception(): - raise run_task.exception() # type: ignore + event = event_future.result() + run_id = event.run_id + yield event # type: ignore + + if exc := run_task.exception(): + yield PipelineEvent( + event_type=EventType.PIPELINE_FAILED, + # run_id is null if pipeline fails before even starting + # ie during pipeline validation + run_id=run_id or "", + message=str(exc), + ) + if raise_exception: + raise exc # type: ignore finally: # Restore original callback diff --git a/tests/unit/experimental/pipeline/test_pipeline.py b/tests/unit/experimental/pipeline/test_pipeline.py index ef7e81227..868aaae86 100644 --- a/tests/unit/experimental/pipeline/test_pipeline.py +++ b/tests/unit/experimental/pipeline/test_pipeline.py @@ -558,7 +558,10 @@ async def test_pipeline_streaming_error_in_pipeline_definition() -> None: async for e in pipe.stream({"a": {"number1": 1, "number2": 2}}): events.append(e) # validation happens before pipeline run actually starts - assert len(events) == 0 + # but we have the PIPELINE_FAILED event + assert len(events) == 1 + assert events[0].event_type == EventType.PIPELINE_FAILED + assert events[0].run_id == "" @pytest.mark.asyncio @@ -570,9 +573,10 @@ async def test_pipeline_streaming_error_in_component() -> None: with pytest.raises(TypeError): async for e in pipe.stream({"component": {"number1": None, "number2": 2}}): events.append(e) - assert len(events) == 2 + assert len(events) == 3 assert events[0].event_type == EventType.PIPELINE_STARTED assert events[1].event_type == EventType.TASK_STARTED + assert events[2].event_type == EventType.PIPELINE_FAILED @pytest.mark.asyncio From 81aebf7008f636fa4ae11ce2ca0e08744315c6e6 Mon Sep 17 00:00:00 2001 From: estelle Date: Wed, 2 Apr 2025 13:53:14 +0200 Subject: [PATCH 13/13] Mypy --- src/neo4j_graphrag/experimental/pipeline/pipeline.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/neo4j_graphrag/experimental/pipeline/pipeline.py b/src/neo4j_graphrag/experimental/pipeline/pipeline.py index bfec89290..4d12d8d81 100644 --- a/src/neo4j_graphrag/experimental/pipeline/pipeline.py +++ b/src/neo4j_graphrag/experimental/pipeline/pipeline.py @@ -465,8 +465,10 @@ async def event_stream(event: Event) -> None: for event_future in done: if event_future == run_task: continue + # we are sure to get an Event here, since this is the only + # thing we put in the queue, but mypy still complains event = event_future.result() - run_id = event.run_id + run_id = getattr(event, "run_id", None) yield event # type: ignore if exc := run_task.exception(): @@ -478,7 +480,7 @@ async def event_stream(event: Event) -> None: message=str(exc), ) if raise_exception: - raise exc # type: ignore + raise exc finally: # Restore original callback