Skip to content

Pipeline streaming #304

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Apr 2, 2025
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### Added

- Added support for multi-vector collection in Qdrant driver.
- Added a `Pipeline.stream` method to stream pipeline progress.

### Changed

Expand Down
5 changes: 2 additions & 3 deletions docs/source/user_guide_pipeline.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ their own by following these steps:

1. Create a subclass of the Pydantic `neo4j_graphrag.experimental.pipeline.DataModel` to represent the data being returned by the component
2. Create a subclass of `neo4j_graphrag.experimental.pipeline.Component`
3. Create a `run_with_context` method in this new class and specify the required inputs and output model using the just created `DataModel`
3. Create a `run` method in this new class and specify the required inputs and output model using the just created `DataModel`
4. Implement the run method: it's an `async` method, allowing tasks to be parallelized and awaited within this method.

An example is given below, where a `ComponentAdd` is created to add two numbers together and return
Expand All @@ -31,13 +31,12 @@ the resulting sum:
.. code:: python

from neo4j_graphrag.experimental.pipeline import Component, DataModel
from neo4j_graphrag.experimental.pipeline.types.context import RunContext

class IntResultModel(DataModel):
result: int

class ComponentAdd(Component):
async def run_with_context(self, context_: RunContext, number1: int, number2: int = 1) -> IntResultModel:
async def run(self, number1: int, number2: int = 1) -> IntResultModel:
return IntResultModel(result = number1 + number2)

Read more about :ref:`components-section` in the API Documentation.
Expand Down
81 changes: 81 additions & 0 deletions examples/customize/build_graph/pipeline/pipeline_streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import asyncio

from neo4j_graphrag.experimental.pipeline.component import Component, DataModel
from neo4j_graphrag.experimental.pipeline.pipeline import Pipeline
from neo4j_graphrag.experimental.pipeline.notification import EventType, Event
from neo4j_graphrag.experimental.pipeline.types.context import RunContext


# Define some example components with progress notifications
class OutputModel(DataModel):
result: int


class SlowAdder(Component):
"""A component that slowly adds numbers and reports progress"""

def __init__(self, number: int) -> None:
self.number = number

async def run_with_context(self, context_: RunContext, value: int) -> OutputModel:
# Simulate work with progress updates
for i in range(value):
await asyncio.sleep(0.5) # Simulate work
await context_.notify(
message=f"Added {i+1}/{value}", data={"current": i + 1, "total": value}
)
return OutputModel(result=value + self.number)


class SlowMultiplier(Component):
"""A component that slowly multiplies numbers and reports progress"""

def __init__(self, multiplier: int) -> None:
self.multiplier = multiplier

async def run_with_context(self, context_: RunContext, value: int) -> OutputModel:
# Simulate work with progress updates
for i in range(3): # Always do 3 steps
await asyncio.sleep(0.7) # Simulate work
await context_.notify(
message=f"Multiplication step {i+1}/3", data={"step": i + 1, "total": 3}
)
return OutputModel(result=value * self.multiplier)


async def callback(event: Event) -> None:
await asyncio.sleep(0.1)


async def main() -> None:
# Create pipeline
pipeline = Pipeline(callback=callback)

# Add components
pipeline.add_component(SlowAdder(number=3), "adder")
pipeline.add_component(SlowMultiplier(multiplier=2), "multiplier")

# Connect components
pipeline.connect("adder", "multiplier", {"value": "adder.result"})

print("\n=== Running pipeline with streaming ===")
# Run pipeline with streaming - see events as they happen
async for event in pipeline.stream({"adder": {"value": 2}}):
if event.event_type == EventType.PIPELINE_STARTED:
print("Stream: Pipeline started!")
elif event.event_type == EventType.PIPELINE_FINISHED:
print(f"Stream: Pipeline finished! Final results: {event.payload}")
elif event.event_type == EventType.TASK_STARTED:
print(
f"Stream: Task {event.task_name} started with inputs: {event.payload}" # type: ignore
)
elif event.event_type == EventType.TASK_PROGRESS:
print(f"Stream: Task {event.task_name} progress - {event.message}") # type: ignore
elif event.event_type == EventType.TASK_FINISHED:
print(
f"Stream: Task {event.task_name} finished with result: {event.payload}" # type: ignore
)


if __name__ == "__main__":
asyncio.run(main())
11 changes: 7 additions & 4 deletions src/neo4j_graphrag/experimental/pipeline/notification.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
from __future__ import annotations

import asyncio
import datetime
import enum
from collections.abc import Awaitable
Expand Down Expand Up @@ -72,12 +73,14 @@ def __call__(self, event: Event) -> Awaitable[None]: ...


class EventNotifier:
def __init__(self, callback: EventCallbackProtocol | None) -> None:
self.callback = callback
def __init__(self, callbacks: list[EventCallbackProtocol]) -> None:
self.callbacks = callbacks

async def notify(self, event: Event) -> None:
if self.callback:
await self.callback(event)
await asyncio.gather(
*[c(event) for c in self.callbacks],
return_exceptions=True,
)

async def notify_pipeline_started(
self, run_id: str, input_data: Optional[dict[str, Any]] = None
Expand Down
2 changes: 1 addition & 1 deletion src/neo4j_graphrag/experimental/pipeline/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class Orchestrator:

def __init__(self, pipeline: Pipeline):
self.pipeline = pipeline
self.event_notifier = EventNotifier(pipeline.callback)
self.event_notifier = EventNotifier(pipeline.callbacks)
self.run_id = str(uuid.uuid4())

async def run_task(self, task: TaskPipelineNode, data: dict[str, Any]) -> None:
Expand Down
63 changes: 60 additions & 3 deletions src/neo4j_graphrag/experimental/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import warnings
from collections import defaultdict
from timeit import default_timer
from typing import Any, Optional
from typing import Any, Optional, AsyncGenerator
import asyncio

from neo4j_graphrag.utils.logging import prettify

Expand Down Expand Up @@ -47,7 +48,10 @@
)
from neo4j_graphrag.experimental.pipeline.types.orchestration import RunResult
from neo4j_graphrag.experimental.pipeline.types.context import RunContext
from neo4j_graphrag.experimental.pipeline.notification import EventCallbackProtocol
from neo4j_graphrag.experimental.pipeline.notification import (
EventCallbackProtocol,
Event,
)


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -117,7 +121,7 @@ def __init__(
) -> None:
super().__init__()
self.store = store or InMemoryStore()
self.callback = callback
self.callbacks = [callback] if callback else []
self.final_results = InMemoryStore()
self.is_validated = False
self.param_mapping: dict[str, dict[str, dict[str, str]]] = defaultdict(dict)
Expand Down Expand Up @@ -412,6 +416,59 @@ def validate_parameter_mapping_for_task(self, task: TaskPipelineNode) -> bool:
async def get_final_results(self, run_id: str) -> dict[str, Any]:
return await self.final_results.get(run_id) # type: ignore[no-any-return]

async def stream(self, data: dict[str, Any]) -> AsyncGenerator[Event, None]:
"""Run the pipeline and stream events for task progress.

Args:
data: Input data for the pipeline components

Yields:
Event: Pipeline and task events including start, progress, and completion
"""
# Create queue for events
event_queue: asyncio.Queue[Event] = asyncio.Queue()

async def event_stream(event: Event) -> None:
# Put event in queue for streaming
await event_queue.put(event)

# Add event streaming callback
self.callbacks.append(event_stream)

event_queue_getter_task = None
try:
# Start pipeline execution in background task
run_task = asyncio.create_task(self.run(data))

# loop until the run task is done, and we do not have
# any more pending tasks in queue
is_run_task_running = True
is_queue_empty = False
while is_run_task_running or not is_queue_empty:
# Wait for next event or pipeline completion
event_queue_getter_task = asyncio.create_task(event_queue.get())
done, pending = await asyncio.wait(
[run_task, event_queue_getter_task],
return_when=asyncio.FIRST_COMPLETED,
)

is_run_task_running = run_task not in done
is_queue_empty = event_queue.empty()

for event_future in done:
if event_future == run_task:
continue
yield event_future.result() # type: ignore

if run_task.exception():
raise run_task.exception() # type: ignore

finally:
# Restore original callback
self.callbacks.remove(event_stream)
if event_queue_getter_task and not event_queue_getter_task.done():
event_queue_getter_task.cancel()

async def run(self, data: dict[str, Any]) -> PipelineResult:
logger.debug("PIPELINE START")
start_time = default_timer()
Expand Down
11 changes: 11 additions & 0 deletions tests/unit/experimental/pipeline/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio

from neo4j_graphrag.experimental.pipeline import Component, DataModel
from neo4j_graphrag.experimental.pipeline.types.context import RunContext

Expand Down Expand Up @@ -52,3 +54,12 @@ async def run_with_context(
message="my message", data={"number1": number1, "number2": number2}
)
return IntResultModel(result=number1 * number2)


class SlowComponentMultiply(Component):
def __init__(self, sleep: float = 1.0) -> None:
self.sleep = sleep

async def run(self, number1: int, number2: int = 2) -> IntResultModel:
await asyncio.sleep(self.sleep)
return IntResultModel(result=number1 * number2)
95 changes: 95 additions & 0 deletions tests/unit/experimental/pipeline/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
ComponentNoParam,
ComponentPassThrough,
StringResultModel,
SlowComponentMultiply,
)


Expand Down Expand Up @@ -491,3 +492,97 @@ def test_event_model_no_warning(recwarn: Sized) -> None:
)
assert event.timestamp is not None
assert len(recwarn) == 0


@pytest.mark.asyncio
async def test_pipeline_streaming_no_user_callback_happy_path() -> None:
pipe = Pipeline()
events = []
async for e in pipe.stream({}):
events.append(e)
assert len(events) == 2
assert events[0].event_type == EventType.PIPELINE_STARTED
assert events[1].event_type == EventType.PIPELINE_FINISHED
assert len(pipe.callbacks) == 0


@pytest.mark.asyncio
async def test_pipeline_streaming_with_user_callback_happy_path() -> None:
callback = AsyncMock()
pipe = Pipeline(callback=callback)
events = []
async for e in pipe.stream({}):
events.append(e)
assert len(events) == 2
assert len(callback.call_args_list) == 2
assert len(pipe.callbacks) == 1


@pytest.mark.asyncio
async def test_pipeline_streaming_very_long_running_user_callback() -> None:
async def callback(event: Event) -> None:
await asyncio.sleep(2)

pipe = Pipeline(callback=callback)
events = []
async for e in pipe.stream({}):
events.append(e)
assert len(events) == 2
assert len(pipe.callbacks) == 1


@pytest.mark.asyncio
async def test_pipeline_streaming_very_long_running_pipeline() -> None:
slow_component = SlowComponentMultiply()
pipe = Pipeline()
pipe.add_component(slow_component, "slow_component")
events = []
async for e in pipe.stream({"slow_component": {"number1": 1, "number2": 2}}):
events.append(e)
assert len(events) == 4
last_event = events[-1]
assert last_event.event_type == EventType.PIPELINE_FINISHED
assert last_event.payload == {"slow_component": {"result": 2}}


@pytest.mark.asyncio
async def test_pipeline_streaming_error_in_pipeline_definition() -> None:
pipe = Pipeline()
component_a = ComponentAdd()
component_b = ComponentAdd()
pipe.add_component(component_a, "a")
pipe.add_component(component_b, "b")
pipe.connect("a", "b", {"number1": "a.result"})
events = []
with pytest.raises(PipelineDefinitionError):
async for e in pipe.stream({"a": {"number1": 1, "number2": 2}}):
events.append(e)
# validation happens before pipeline run actually starts
assert len(events) == 0


@pytest.mark.asyncio
async def test_pipeline_streaming_error_in_component() -> None:
component = ComponentMultiply()
pipe = Pipeline()
pipe.add_component(component, "component")
events = []
with pytest.raises(TypeError):
async for e in pipe.stream({"component": {"number1": None, "number2": 2}}):
events.append(e)
assert len(events) == 2
assert events[0].event_type == EventType.PIPELINE_STARTED
assert events[1].event_type == EventType.TASK_STARTED


@pytest.mark.asyncio
async def test_pipeline_streaming_error_in_user_callback() -> None:
async def callback(event: Event) -> None:
raise Exception("error in callback")

pipe = Pipeline(callback=callback)
events = []
async for e in pipe.stream({}):
events.append(e)
assert len(events) == 2
assert len(pipe.callbacks) == 1