Skip to content

Pipeline streaming #304

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Apr 2, 2025
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### Added

- Added support for multi-vector collection in Qdrant driver.
- Added a `Pipeline.stream` method to stream pipeline progress.

### Changed

Expand Down
5 changes: 2 additions & 3 deletions docs/source/user_guide_pipeline.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ their own by following these steps:

1. Create a subclass of the Pydantic `neo4j_graphrag.experimental.pipeline.DataModel` to represent the data being returned by the component
2. Create a subclass of `neo4j_graphrag.experimental.pipeline.Component`
3. Create a `run_with_context` method in this new class and specify the required inputs and output model using the just created `DataModel`
3. Create a `run` method in this new class and specify the required inputs and output model using the just created `DataModel`
4. Implement the run method: it's an `async` method, allowing tasks to be parallelized and awaited within this method.

An example is given below, where a `ComponentAdd` is created to add two numbers together and return
Expand All @@ -31,13 +31,12 @@ the resulting sum:
.. code:: python

from neo4j_graphrag.experimental.pipeline import Component, DataModel
from neo4j_graphrag.experimental.pipeline.types.context import RunContext

class IntResultModel(DataModel):
result: int

class ComponentAdd(Component):
async def run_with_context(self, context_: RunContext, number1: int, number2: int = 1) -> IntResultModel:
async def run(self, number1: int, number2: int = 1) -> IntResultModel:
return IntResultModel(result = number1 + number2)

Read more about :ref:`components-section` in the API Documentation.
Expand Down
86 changes: 86 additions & 0 deletions examples/customize/build_graph/pipeline/pipeline_streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import asyncio

from neo4j_graphrag.experimental.pipeline.component import Component, DataModel
from neo4j_graphrag.experimental.pipeline.pipeline import Pipeline
from neo4j_graphrag.experimental.pipeline.notification import EventType, Event
from neo4j_graphrag.experimental.pipeline.types.context import RunContext


# Define some example components with progress notifications
class OutputModel(DataModel):
result: int


class SlowAdder(Component):
"""A component that slowly adds numbers and reports progress"""

def __init__(self, number: int) -> None:
self.number = number

async def run_with_context(self, context_: RunContext, value: int) -> OutputModel:
# Simulate work with progress updates
for i in range(value):
await asyncio.sleep(0.5) # Simulate work
await context_.notify(
message=f"Added {i+1}/{value}", data={"current": i + 1, "total": value}
)
return OutputModel(result=value + self.number)


class SlowMultiplier(Component):
"""A component that slowly multiplies numbers and reports progress"""

def __init__(self, multiplier: int) -> None:
self.multiplier = multiplier

async def run_with_context(self, context_: RunContext, value: int) -> OutputModel:
# Simulate work with progress updates
for i in range(3): # Always do 3 steps
await asyncio.sleep(0.7) # Simulate work
await context_.notify(
message=f"Multiplication step {i+1}/3", data={"step": i + 1, "total": 3}
)
return OutputModel(result=value * self.multiplier)


async def callback(event: Event) -> None:
await asyncio.sleep(0.1)


async def main() -> None:
# Create pipeline
pipeline = Pipeline(callback=callback)

# Add components
pipeline.add_component(SlowAdder(number=3), "adder")
pipeline.add_component(SlowMultiplier(multiplier=2), "multiplier")

# Connect components
pipeline.connect("adder", "multiplier", {"value": "adder.result"})

print("\n=== Running pipeline with streaming ===")
# Run pipeline with streaming - see events as they happen
async for event in pipeline.stream(
{"adder": {"value": 2}},
raise_exception=False, # default is True
):
if event.event_type == EventType.PIPELINE_STARTED:
print("Stream: Pipeline started!")
elif event.event_type == EventType.PIPELINE_FINISHED:
print(f"Stream: Pipeline finished! Final results: {event.payload}")
elif event.event_type == EventType.PIPELINE_FAILED:
print(f"Stream: Pipeline failed with message: {event.message}")
elif event.event_type == EventType.TASK_STARTED:
print(
f"Stream: Task {event.task_name} started with inputs: {event.payload}" # type: ignore
)
elif event.event_type == EventType.TASK_PROGRESS:
print(f"Stream: Task {event.task_name} progress - {event.message}") # type: ignore
elif event.event_type == EventType.TASK_FINISHED:
print(
f"Stream: Task {event.task_name} finished with result: {event.payload}" # type: ignore
)


if __name__ == "__main__":
asyncio.run(main())
18 changes: 13 additions & 5 deletions src/neo4j_graphrag/experimental/pipeline/notification.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
from __future__ import annotations

import asyncio
import datetime
import enum
from collections.abc import Awaitable
Expand All @@ -31,10 +32,15 @@ class EventType(enum.Enum):
TASK_PROGRESS = "TASK_PROGRESS"
TASK_FINISHED = "TASK_FINISHED"
PIPELINE_FINISHED = "PIPELINE_FINISHED"
PIPELINE_FAILED = "PIPELINE_FAILED"

@property
def is_pipeline_event(self) -> bool:
return self in [EventType.PIPELINE_STARTED, EventType.PIPELINE_FINISHED]
return self in [
EventType.PIPELINE_STARTED,
EventType.PIPELINE_FINISHED,
EventType.PIPELINE_FAILED,
]

@property
def is_task_event(self) -> bool:
Expand Down Expand Up @@ -72,12 +78,14 @@ def __call__(self, event: Event) -> Awaitable[None]: ...


class EventNotifier:
def __init__(self, callback: EventCallbackProtocol | None) -> None:
self.callback = callback
def __init__(self, callbacks: list[EventCallbackProtocol]) -> None:
self.callbacks = callbacks

async def notify(self, event: Event) -> None:
if self.callback:
await self.callback(event)
await asyncio.gather(
*[c(event) for c in self.callbacks],
return_exceptions=True,
)

async def notify_pipeline_started(
self, run_id: str, input_data: Optional[dict[str, Any]] = None
Expand Down
2 changes: 1 addition & 1 deletion src/neo4j_graphrag/experimental/pipeline/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class Orchestrator:

def __init__(self, pipeline: Pipeline):
self.pipeline = pipeline
self.event_notifier = EventNotifier(pipeline.callback)
self.event_notifier = EventNotifier(pipeline.callbacks)
self.run_id = str(uuid.uuid4())

async def run_task(self, task: TaskPipelineNode, data: dict[str, Any]) -> None:
Expand Down
82 changes: 79 additions & 3 deletions src/neo4j_graphrag/experimental/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import warnings
from collections import defaultdict
from timeit import default_timer
from typing import Any, Optional
from typing import Any, Optional, AsyncGenerator
import asyncio

from neo4j_graphrag.utils.logging import prettify

Expand Down Expand Up @@ -47,7 +48,12 @@
)
from neo4j_graphrag.experimental.pipeline.types.orchestration import RunResult
from neo4j_graphrag.experimental.pipeline.types.context import RunContext
from neo4j_graphrag.experimental.pipeline.notification import EventCallbackProtocol
from neo4j_graphrag.experimental.pipeline.notification import (
EventCallbackProtocol,
Event,
PipelineEvent,
EventType,
)


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -117,7 +123,7 @@ def __init__(
) -> None:
super().__init__()
self.store = store or InMemoryStore()
self.callback = callback
self.callbacks = [callback] if callback else []
self.final_results = InMemoryStore()
self.is_validated = False
self.param_mapping: dict[str, dict[str, dict[str, str]]] = defaultdict(dict)
Expand Down Expand Up @@ -412,6 +418,76 @@ def validate_parameter_mapping_for_task(self, task: TaskPipelineNode) -> bool:
async def get_final_results(self, run_id: str) -> dict[str, Any]:
return await self.final_results.get(run_id) # type: ignore[no-any-return]

async def stream(
self, data: dict[str, Any], raise_exception: bool = True
) -> AsyncGenerator[Event, None]:
"""Run the pipeline and stream events for task progress.

Args:
data (dict): Input data for the pipeline components
raise_exception (bool): set to False to prevent this task from propagating
Pipeline exceptions.

Yields:
Event: Pipeline and task events including start, progress, and completion
"""
# Create queue for events
event_queue: asyncio.Queue[Event] = asyncio.Queue()
run_id = None

async def event_stream(event: Event) -> None:
# Put event in queue for streaming
await event_queue.put(event)

# Add event streaming callback
self.callbacks.append(event_stream)

event_queue_getter_task = None
try:
# Start pipeline execution in background task
run_task = asyncio.create_task(self.run(data))

# loop until the run task is done, and we do not have
# any more pending tasks in queue
is_run_task_running = True
is_queue_empty = False
while is_run_task_running or not is_queue_empty:
# Wait for next event or pipeline completion
event_queue_getter_task = asyncio.create_task(event_queue.get())
done, pending = await asyncio.wait(
[run_task, event_queue_getter_task],
return_when=asyncio.FIRST_COMPLETED,
)

is_run_task_running = run_task not in done
is_queue_empty = event_queue.empty()

for event_future in done:
if event_future == run_task:
continue
# we are sure to get an Event here, since this is the only
# thing we put in the queue, but mypy still complains
event = event_future.result()
run_id = getattr(event, "run_id", None)
yield event # type: ignore

if exc := run_task.exception():
yield PipelineEvent(
event_type=EventType.PIPELINE_FAILED,
# run_id is null if pipeline fails before even starting
# ie during pipeline validation
run_id=run_id or "",
message=str(exc),
)
if raise_exception:
raise exc

finally:
# Restore original callback
self.callbacks.remove(event_stream)
if event_queue_getter_task and not event_queue_getter_task.done():
event_queue_getter_task.cancel()

async def run(self, data: dict[str, Any]) -> PipelineResult:
logger.debug("PIPELINE START")
start_time = default_timer()
Expand Down
11 changes: 11 additions & 0 deletions tests/unit/experimental/pipeline/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio

from neo4j_graphrag.experimental.pipeline import Component, DataModel
from neo4j_graphrag.experimental.pipeline.types.context import RunContext

Expand Down Expand Up @@ -52,3 +54,12 @@ async def run_with_context(
message="my message", data={"number1": number1, "number2": number2}
)
return IntResultModel(result=number1 * number2)


class SlowComponentMultiply(Component):
def __init__(self, sleep: float = 1.0) -> None:
self.sleep = sleep

async def run(self, number1: int, number2: int = 2) -> IntResultModel:
await asyncio.sleep(self.sleep)
return IntResultModel(result=number1 * number2)
Loading