diff --git a/CHANGELOG.md b/CHANGELOG.md index c3e0ac312..603a9a10f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ ### Added - Added support for multi-vector collection in Qdrant driver. +- Added a `Pipeline.stream` method to stream pipeline progress. ### Changed diff --git a/docs/source/user_guide_pipeline.rst b/docs/source/user_guide_pipeline.rst index 5550d1ea9..703e48aba 100644 --- a/docs/source/user_guide_pipeline.rst +++ b/docs/source/user_guide_pipeline.rst @@ -22,7 +22,7 @@ their own by following these steps: 1. Create a subclass of the Pydantic `neo4j_graphrag.experimental.pipeline.DataModel` to represent the data being returned by the component 2. Create a subclass of `neo4j_graphrag.experimental.pipeline.Component` -3. Create a `run_with_context` method in this new class and specify the required inputs and output model using the just created `DataModel` +3. Create a `run` method in this new class and specify the required inputs and output model using the just created `DataModel` 4. Implement the run method: it's an `async` method, allowing tasks to be parallelized and awaited within this method. An example is given below, where a `ComponentAdd` is created to add two numbers together and return @@ -31,13 +31,12 @@ the resulting sum: .. code:: python from neo4j_graphrag.experimental.pipeline import Component, DataModel - from neo4j_graphrag.experimental.pipeline.types.context import RunContext class IntResultModel(DataModel): result: int class ComponentAdd(Component): - async def run_with_context(self, context_: RunContext, number1: int, number2: int = 1) -> IntResultModel: + async def run(self, number1: int, number2: int = 1) -> IntResultModel: return IntResultModel(result = number1 + number2) Read more about :ref:`components-section` in the API Documentation. diff --git a/examples/customize/build_graph/pipeline/pipeline_streaming.py b/examples/customize/build_graph/pipeline/pipeline_streaming.py new file mode 100644 index 000000000..13a262a9c --- /dev/null +++ b/examples/customize/build_graph/pipeline/pipeline_streaming.py @@ -0,0 +1,86 @@ +import asyncio + +from neo4j_graphrag.experimental.pipeline.component import Component, DataModel +from neo4j_graphrag.experimental.pipeline.pipeline import Pipeline +from neo4j_graphrag.experimental.pipeline.notification import EventType, Event +from neo4j_graphrag.experimental.pipeline.types.context import RunContext + + +# Define some example components with progress notifications +class OutputModel(DataModel): + result: int + + +class SlowAdder(Component): + """A component that slowly adds numbers and reports progress""" + + def __init__(self, number: int) -> None: + self.number = number + + async def run_with_context(self, context_: RunContext, value: int) -> OutputModel: + # Simulate work with progress updates + for i in range(value): + await asyncio.sleep(0.5) # Simulate work + await context_.notify( + message=f"Added {i+1}/{value}", data={"current": i + 1, "total": value} + ) + return OutputModel(result=value + self.number) + + +class SlowMultiplier(Component): + """A component that slowly multiplies numbers and reports progress""" + + def __init__(self, multiplier: int) -> None: + self.multiplier = multiplier + + async def run_with_context(self, context_: RunContext, value: int) -> OutputModel: + # Simulate work with progress updates + for i in range(3): # Always do 3 steps + await asyncio.sleep(0.7) # Simulate work + await context_.notify( + message=f"Multiplication step {i+1}/3", data={"step": i + 1, "total": 3} + ) + return OutputModel(result=value * self.multiplier) + + +async def callback(event: Event) -> None: + await asyncio.sleep(0.1) + + +async def main() -> None: + # Create pipeline + pipeline = Pipeline(callback=callback) + + # Add components + pipeline.add_component(SlowAdder(number=3), "adder") + pipeline.add_component(SlowMultiplier(multiplier=2), "multiplier") + + # Connect components + pipeline.connect("adder", "multiplier", {"value": "adder.result"}) + + print("\n=== Running pipeline with streaming ===") + # Run pipeline with streaming - see events as they happen + async for event in pipeline.stream( + {"adder": {"value": 2}}, + raise_exception=False, # default is True + ): + if event.event_type == EventType.PIPELINE_STARTED: + print("Stream: Pipeline started!") + elif event.event_type == EventType.PIPELINE_FINISHED: + print(f"Stream: Pipeline finished! Final results: {event.payload}") + elif event.event_type == EventType.PIPELINE_FAILED: + print(f"Stream: Pipeline failed with message: {event.message}") + elif event.event_type == EventType.TASK_STARTED: + print( + f"Stream: Task {event.task_name} started with inputs: {event.payload}" # type: ignore + ) + elif event.event_type == EventType.TASK_PROGRESS: + print(f"Stream: Task {event.task_name} progress - {event.message}") # type: ignore + elif event.event_type == EventType.TASK_FINISHED: + print( + f"Stream: Task {event.task_name} finished with result: {event.payload}" # type: ignore + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/neo4j_graphrag/experimental/pipeline/notification.py b/src/neo4j_graphrag/experimental/pipeline/notification.py index e9cb63cc6..cc69a69e2 100644 --- a/src/neo4j_graphrag/experimental/pipeline/notification.py +++ b/src/neo4j_graphrag/experimental/pipeline/notification.py @@ -14,6 +14,7 @@ # limitations under the License. from __future__ import annotations +import asyncio import datetime import enum from collections.abc import Awaitable @@ -31,10 +32,15 @@ class EventType(enum.Enum): TASK_PROGRESS = "TASK_PROGRESS" TASK_FINISHED = "TASK_FINISHED" PIPELINE_FINISHED = "PIPELINE_FINISHED" + PIPELINE_FAILED = "PIPELINE_FAILED" @property def is_pipeline_event(self) -> bool: - return self in [EventType.PIPELINE_STARTED, EventType.PIPELINE_FINISHED] + return self in [ + EventType.PIPELINE_STARTED, + EventType.PIPELINE_FINISHED, + EventType.PIPELINE_FAILED, + ] @property def is_task_event(self) -> bool: @@ -72,12 +78,14 @@ def __call__(self, event: Event) -> Awaitable[None]: ... class EventNotifier: - def __init__(self, callback: EventCallbackProtocol | None) -> None: - self.callback = callback + def __init__(self, callbacks: list[EventCallbackProtocol]) -> None: + self.callbacks = callbacks async def notify(self, event: Event) -> None: - if self.callback: - await self.callback(event) + await asyncio.gather( + *[c(event) for c in self.callbacks], + return_exceptions=True, + ) async def notify_pipeline_started( self, run_id: str, input_data: Optional[dict[str, Any]] = None diff --git a/src/neo4j_graphrag/experimental/pipeline/orchestrator.py b/src/neo4j_graphrag/experimental/pipeline/orchestrator.py index fd933470a..de9468bf7 100644 --- a/src/neo4j_graphrag/experimental/pipeline/orchestrator.py +++ b/src/neo4j_graphrag/experimental/pipeline/orchestrator.py @@ -54,7 +54,7 @@ class Orchestrator: def __init__(self, pipeline: Pipeline): self.pipeline = pipeline - self.event_notifier = EventNotifier(pipeline.callback) + self.event_notifier = EventNotifier(pipeline.callbacks) self.run_id = str(uuid.uuid4()) async def run_task(self, task: TaskPipelineNode, data: dict[str, Any]) -> None: diff --git a/src/neo4j_graphrag/experimental/pipeline/pipeline.py b/src/neo4j_graphrag/experimental/pipeline/pipeline.py index 64b74ddc6..4d12d8d81 100644 --- a/src/neo4j_graphrag/experimental/pipeline/pipeline.py +++ b/src/neo4j_graphrag/experimental/pipeline/pipeline.py @@ -18,7 +18,8 @@ import warnings from collections import defaultdict from timeit import default_timer -from typing import Any, Optional +from typing import Any, Optional, AsyncGenerator +import asyncio from neo4j_graphrag.utils.logging import prettify @@ -47,7 +48,12 @@ ) from neo4j_graphrag.experimental.pipeline.types.orchestration import RunResult from neo4j_graphrag.experimental.pipeline.types.context import RunContext -from neo4j_graphrag.experimental.pipeline.notification import EventCallbackProtocol +from neo4j_graphrag.experimental.pipeline.notification import ( + EventCallbackProtocol, + Event, + PipelineEvent, + EventType, +) logger = logging.getLogger(__name__) @@ -117,7 +123,7 @@ def __init__( ) -> None: super().__init__() self.store = store or InMemoryStore() - self.callback = callback + self.callbacks = [callback] if callback else [] self.final_results = InMemoryStore() self.is_validated = False self.param_mapping: dict[str, dict[str, dict[str, str]]] = defaultdict(dict) @@ -412,6 +418,76 @@ def validate_parameter_mapping_for_task(self, task: TaskPipelineNode) -> bool: async def get_final_results(self, run_id: str) -> dict[str, Any]: return await self.final_results.get(run_id) # type: ignore[no-any-return] + async def stream( + self, data: dict[str, Any], raise_exception: bool = True + ) -> AsyncGenerator[Event, None]: + """Run the pipeline and stream events for task progress. + + Args: + data (dict): Input data for the pipeline components + raise_exception (bool): set to False to prevent this task from propagating + Pipeline exceptions. + + Yields: + Event: Pipeline and task events including start, progress, and completion + """ + # Create queue for events + event_queue: asyncio.Queue[Event] = asyncio.Queue() + run_id = None + + async def event_stream(event: Event) -> None: + # Put event in queue for streaming + await event_queue.put(event) + + # Add event streaming callback + self.callbacks.append(event_stream) + + event_queue_getter_task = None + try: + # Start pipeline execution in background task + run_task = asyncio.create_task(self.run(data)) + + # loop until the run task is done, and we do not have + # any more pending tasks in queue + is_run_task_running = True + is_queue_empty = False + while is_run_task_running or not is_queue_empty: + # Wait for next event or pipeline completion + event_queue_getter_task = asyncio.create_task(event_queue.get()) + done, pending = await asyncio.wait( + [run_task, event_queue_getter_task], + return_when=asyncio.FIRST_COMPLETED, + ) + + is_run_task_running = run_task not in done + is_queue_empty = event_queue.empty() + + for event_future in done: + if event_future == run_task: + continue + # we are sure to get an Event here, since this is the only + # thing we put in the queue, but mypy still complains + event = event_future.result() + run_id = getattr(event, "run_id", None) + yield event # type: ignore + + if exc := run_task.exception(): + yield PipelineEvent( + event_type=EventType.PIPELINE_FAILED, + # run_id is null if pipeline fails before even starting + # ie during pipeline validation + run_id=run_id or "", + message=str(exc), + ) + if raise_exception: + raise exc + + finally: + # Restore original callback + self.callbacks.remove(event_stream) + if event_queue_getter_task and not event_queue_getter_task.done(): + event_queue_getter_task.cancel() + async def run(self, data: dict[str, Any]) -> PipelineResult: logger.debug("PIPELINE START") start_time = default_timer() diff --git a/tests/unit/experimental/pipeline/components.py b/tests/unit/experimental/pipeline/components.py index fc33e840d..e1b2e1845 100644 --- a/tests/unit/experimental/pipeline/components.py +++ b/tests/unit/experimental/pipeline/components.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import asyncio + from neo4j_graphrag.experimental.pipeline import Component, DataModel from neo4j_graphrag.experimental.pipeline.types.context import RunContext @@ -52,3 +54,12 @@ async def run_with_context( message="my message", data={"number1": number1, "number2": number2} ) return IntResultModel(result=number1 * number2) + + +class SlowComponentMultiply(Component): + def __init__(self, sleep: float = 1.0) -> None: + self.sleep = sleep + + async def run(self, number1: int, number2: int = 2) -> IntResultModel: + await asyncio.sleep(self.sleep) + return IntResultModel(result=number1 * number2) diff --git a/tests/unit/experimental/pipeline/test_pipeline.py b/tests/unit/experimental/pipeline/test_pipeline.py index 16d32ba9b..868aaae86 100644 --- a/tests/unit/experimental/pipeline/test_pipeline.py +++ b/tests/unit/experimental/pipeline/test_pipeline.py @@ -39,6 +39,7 @@ ComponentNoParam, ComponentPassThrough, StringResultModel, + SlowComponentMultiply, ) @@ -491,3 +492,101 @@ def test_event_model_no_warning(recwarn: Sized) -> None: ) assert event.timestamp is not None assert len(recwarn) == 0 + + +@pytest.mark.asyncio +async def test_pipeline_streaming_no_user_callback_happy_path() -> None: + pipe = Pipeline() + events = [] + async for e in pipe.stream({}): + events.append(e) + assert len(events) == 2 + assert events[0].event_type == EventType.PIPELINE_STARTED + assert events[1].event_type == EventType.PIPELINE_FINISHED + assert len(pipe.callbacks) == 0 + + +@pytest.mark.asyncio +async def test_pipeline_streaming_with_user_callback_happy_path() -> None: + callback = AsyncMock() + pipe = Pipeline(callback=callback) + events = [] + async for e in pipe.stream({}): + events.append(e) + assert len(events) == 2 + assert len(callback.call_args_list) == 2 + assert len(pipe.callbacks) == 1 + + +@pytest.mark.asyncio +async def test_pipeline_streaming_very_long_running_user_callback() -> None: + async def callback(event: Event) -> None: + await asyncio.sleep(2) + + pipe = Pipeline(callback=callback) + events = [] + async for e in pipe.stream({}): + events.append(e) + assert len(events) == 2 + assert len(pipe.callbacks) == 1 + + +@pytest.mark.asyncio +async def test_pipeline_streaming_very_long_running_pipeline() -> None: + slow_component = SlowComponentMultiply() + pipe = Pipeline() + pipe.add_component(slow_component, "slow_component") + events = [] + async for e in pipe.stream({"slow_component": {"number1": 1, "number2": 2}}): + events.append(e) + assert len(events) == 4 + last_event = events[-1] + assert last_event.event_type == EventType.PIPELINE_FINISHED + assert last_event.payload == {"slow_component": {"result": 2}} + + +@pytest.mark.asyncio +async def test_pipeline_streaming_error_in_pipeline_definition() -> None: + pipe = Pipeline() + component_a = ComponentAdd() + component_b = ComponentAdd() + pipe.add_component(component_a, "a") + pipe.add_component(component_b, "b") + pipe.connect("a", "b", {"number1": "a.result"}) + events = [] + with pytest.raises(PipelineDefinitionError): + async for e in pipe.stream({"a": {"number1": 1, "number2": 2}}): + events.append(e) + # validation happens before pipeline run actually starts + # but we have the PIPELINE_FAILED event + assert len(events) == 1 + assert events[0].event_type == EventType.PIPELINE_FAILED + assert events[0].run_id == "" + + +@pytest.mark.asyncio +async def test_pipeline_streaming_error_in_component() -> None: + component = ComponentMultiply() + pipe = Pipeline() + pipe.add_component(component, "component") + events = [] + with pytest.raises(TypeError): + async for e in pipe.stream({"component": {"number1": None, "number2": 2}}): + events.append(e) + assert len(events) == 3 + assert events[0].event_type == EventType.PIPELINE_STARTED + assert events[1].event_type == EventType.TASK_STARTED + assert events[2].event_type == EventType.PIPELINE_FAILED + + +@pytest.mark.asyncio +async def test_pipeline_streaming_error_in_user_callback() -> None: + async def callback(event: Event) -> None: + raise Exception("error in callback") + + pipe = Pipeline(callback=callback) + events = [] + async for e in pipe.stream({}): + events.append(e) + assert len(events) == 2 + assert len(pipe.callbacks) == 1