Skip to content

Commit 42ab19a

Browse files
authored
Pipeline streaming (#304)
* Expose context from the orchestrator to the components * mypy * Changes so that the `run` method is not required anymore * Update documentation * Improve documentation of future changes * Undo make notifier private, not needed * Pipeline streaming * Add tests * Fix "Task destroyed but is pending" warning * Fix rebase * Update CHANGELOG * Add PIPELINE_FAILED event and option not to raise exception * Mypy
1 parent bb8fe0a commit 42ab19a

File tree

8 files changed

+292
-12
lines changed

8 files changed

+292
-12
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
### Added
66

77
- Added support for multi-vector collection in Qdrant driver.
8+
- Added a `Pipeline.stream` method to stream pipeline progress.
89

910
### Changed
1011

docs/source/user_guide_pipeline.rst

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ their own by following these steps:
2222

2323
1. Create a subclass of the Pydantic `neo4j_graphrag.experimental.pipeline.DataModel` to represent the data being returned by the component
2424
2. Create a subclass of `neo4j_graphrag.experimental.pipeline.Component`
25-
3. Create a `run_with_context` method in this new class and specify the required inputs and output model using the just created `DataModel`
25+
3. Create a `run` method in this new class and specify the required inputs and output model using the just created `DataModel`
2626
4. Implement the run method: it's an `async` method, allowing tasks to be parallelized and awaited within this method.
2727

2828
An example is given below, where a `ComponentAdd` is created to add two numbers together and return
@@ -31,13 +31,12 @@ the resulting sum:
3131
.. code:: python
3232
3333
from neo4j_graphrag.experimental.pipeline import Component, DataModel
34-
from neo4j_graphrag.experimental.pipeline.types.context import RunContext
3534
3635
class IntResultModel(DataModel):
3736
result: int
3837
3938
class ComponentAdd(Component):
40-
async def run_with_context(self, context_: RunContext, number1: int, number2: int = 1) -> IntResultModel:
39+
async def run(self, number1: int, number2: int = 1) -> IntResultModel:
4140
return IntResultModel(result = number1 + number2)
4241
4342
Read more about :ref:`components-section` in the API Documentation.
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import asyncio
2+
3+
from neo4j_graphrag.experimental.pipeline.component import Component, DataModel
4+
from neo4j_graphrag.experimental.pipeline.pipeline import Pipeline
5+
from neo4j_graphrag.experimental.pipeline.notification import EventType, Event
6+
from neo4j_graphrag.experimental.pipeline.types.context import RunContext
7+
8+
9+
# Define some example components with progress notifications
10+
class OutputModel(DataModel):
11+
result: int
12+
13+
14+
class SlowAdder(Component):
15+
"""A component that slowly adds numbers and reports progress"""
16+
17+
def __init__(self, number: int) -> None:
18+
self.number = number
19+
20+
async def run_with_context(self, context_: RunContext, value: int) -> OutputModel:
21+
# Simulate work with progress updates
22+
for i in range(value):
23+
await asyncio.sleep(0.5) # Simulate work
24+
await context_.notify(
25+
message=f"Added {i+1}/{value}", data={"current": i + 1, "total": value}
26+
)
27+
return OutputModel(result=value + self.number)
28+
29+
30+
class SlowMultiplier(Component):
31+
"""A component that slowly multiplies numbers and reports progress"""
32+
33+
def __init__(self, multiplier: int) -> None:
34+
self.multiplier = multiplier
35+
36+
async def run_with_context(self, context_: RunContext, value: int) -> OutputModel:
37+
# Simulate work with progress updates
38+
for i in range(3): # Always do 3 steps
39+
await asyncio.sleep(0.7) # Simulate work
40+
await context_.notify(
41+
message=f"Multiplication step {i+1}/3", data={"step": i + 1, "total": 3}
42+
)
43+
return OutputModel(result=value * self.multiplier)
44+
45+
46+
async def callback(event: Event) -> None:
47+
await asyncio.sleep(0.1)
48+
49+
50+
async def main() -> None:
51+
# Create pipeline
52+
pipeline = Pipeline(callback=callback)
53+
54+
# Add components
55+
pipeline.add_component(SlowAdder(number=3), "adder")
56+
pipeline.add_component(SlowMultiplier(multiplier=2), "multiplier")
57+
58+
# Connect components
59+
pipeline.connect("adder", "multiplier", {"value": "adder.result"})
60+
61+
print("\n=== Running pipeline with streaming ===")
62+
# Run pipeline with streaming - see events as they happen
63+
async for event in pipeline.stream(
64+
{"adder": {"value": 2}},
65+
raise_exception=False, # default is True
66+
):
67+
if event.event_type == EventType.PIPELINE_STARTED:
68+
print("Stream: Pipeline started!")
69+
elif event.event_type == EventType.PIPELINE_FINISHED:
70+
print(f"Stream: Pipeline finished! Final results: {event.payload}")
71+
elif event.event_type == EventType.PIPELINE_FAILED:
72+
print(f"Stream: Pipeline failed with message: {event.message}")
73+
elif event.event_type == EventType.TASK_STARTED:
74+
print(
75+
f"Stream: Task {event.task_name} started with inputs: {event.payload}" # type: ignore
76+
)
77+
elif event.event_type == EventType.TASK_PROGRESS:
78+
print(f"Stream: Task {event.task_name} progress - {event.message}") # type: ignore
79+
elif event.event_type == EventType.TASK_FINISHED:
80+
print(
81+
f"Stream: Task {event.task_name} finished with result: {event.payload}" # type: ignore
82+
)
83+
84+
85+
if __name__ == "__main__":
86+
asyncio.run(main())

src/neo4j_graphrag/experimental/pipeline/notification.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17+
import asyncio
1718
import datetime
1819
import enum
1920
from collections.abc import Awaitable
@@ -31,10 +32,15 @@ class EventType(enum.Enum):
3132
TASK_PROGRESS = "TASK_PROGRESS"
3233
TASK_FINISHED = "TASK_FINISHED"
3334
PIPELINE_FINISHED = "PIPELINE_FINISHED"
35+
PIPELINE_FAILED = "PIPELINE_FAILED"
3436

3537
@property
3638
def is_pipeline_event(self) -> bool:
37-
return self in [EventType.PIPELINE_STARTED, EventType.PIPELINE_FINISHED]
39+
return self in [
40+
EventType.PIPELINE_STARTED,
41+
EventType.PIPELINE_FINISHED,
42+
EventType.PIPELINE_FAILED,
43+
]
3844

3945
@property
4046
def is_task_event(self) -> bool:
@@ -72,12 +78,14 @@ def __call__(self, event: Event) -> Awaitable[None]: ...
7278

7379

7480
class EventNotifier:
75-
def __init__(self, callback: EventCallbackProtocol | None) -> None:
76-
self.callback = callback
81+
def __init__(self, callbacks: list[EventCallbackProtocol]) -> None:
82+
self.callbacks = callbacks
7783

7884
async def notify(self, event: Event) -> None:
79-
if self.callback:
80-
await self.callback(event)
85+
await asyncio.gather(
86+
*[c(event) for c in self.callbacks],
87+
return_exceptions=True,
88+
)
8189

8290
async def notify_pipeline_started(
8391
self, run_id: str, input_data: Optional[dict[str, Any]] = None

src/neo4j_graphrag/experimental/pipeline/orchestrator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class Orchestrator:
5454

5555
def __init__(self, pipeline: Pipeline):
5656
self.pipeline = pipeline
57-
self.event_notifier = EventNotifier(pipeline.callback)
57+
self.event_notifier = EventNotifier(pipeline.callbacks)
5858
self.run_id = str(uuid.uuid4())
5959

6060
async def run_task(self, task: TaskPipelineNode, data: dict[str, Any]) -> None:

src/neo4j_graphrag/experimental/pipeline/pipeline.py

Lines changed: 79 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
import warnings
1919
from collections import defaultdict
2020
from timeit import default_timer
21-
from typing import Any, Optional
21+
from typing import Any, Optional, AsyncGenerator
22+
import asyncio
2223

2324
from neo4j_graphrag.utils.logging import prettify
2425

@@ -47,7 +48,12 @@
4748
)
4849
from neo4j_graphrag.experimental.pipeline.types.orchestration import RunResult
4950
from neo4j_graphrag.experimental.pipeline.types.context import RunContext
50-
from neo4j_graphrag.experimental.pipeline.notification import EventCallbackProtocol
51+
from neo4j_graphrag.experimental.pipeline.notification import (
52+
EventCallbackProtocol,
53+
Event,
54+
PipelineEvent,
55+
EventType,
56+
)
5157

5258

5359
logger = logging.getLogger(__name__)
@@ -117,7 +123,7 @@ def __init__(
117123
) -> None:
118124
super().__init__()
119125
self.store = store or InMemoryStore()
120-
self.callback = callback
126+
self.callbacks = [callback] if callback else []
121127
self.final_results = InMemoryStore()
122128
self.is_validated = False
123129
self.param_mapping: dict[str, dict[str, dict[str, str]]] = defaultdict(dict)
@@ -412,6 +418,76 @@ def validate_parameter_mapping_for_task(self, task: TaskPipelineNode) -> bool:
412418
async def get_final_results(self, run_id: str) -> dict[str, Any]:
413419
return await self.final_results.get(run_id) # type: ignore[no-any-return]
414420

421+
async def stream(
422+
self, data: dict[str, Any], raise_exception: bool = True
423+
) -> AsyncGenerator[Event, None]:
424+
"""Run the pipeline and stream events for task progress.
425+
426+
Args:
427+
data (dict): Input data for the pipeline components
428+
raise_exception (bool): set to False to prevent this task from propagating
429+
Pipeline exceptions.
430+
431+
Yields:
432+
Event: Pipeline and task events including start, progress, and completion
433+
"""
434+
# Create queue for events
435+
event_queue: asyncio.Queue[Event] = asyncio.Queue()
436+
run_id = None
437+
438+
async def event_stream(event: Event) -> None:
439+
# Put event in queue for streaming
440+
await event_queue.put(event)
441+
442+
# Add event streaming callback
443+
self.callbacks.append(event_stream)
444+
445+
event_queue_getter_task = None
446+
try:
447+
# Start pipeline execution in background task
448+
run_task = asyncio.create_task(self.run(data))
449+
450+
# loop until the run task is done, and we do not have
451+
# any more pending tasks in queue
452+
is_run_task_running = True
453+
is_queue_empty = False
454+
while is_run_task_running or not is_queue_empty:
455+
# Wait for next event or pipeline completion
456+
event_queue_getter_task = asyncio.create_task(event_queue.get())
457+
done, pending = await asyncio.wait(
458+
[run_task, event_queue_getter_task],
459+
return_when=asyncio.FIRST_COMPLETED,
460+
)
461+
462+
is_run_task_running = run_task not in done
463+
is_queue_empty = event_queue.empty()
464+
465+
for event_future in done:
466+
if event_future == run_task:
467+
continue
468+
# we are sure to get an Event here, since this is the only
469+
# thing we put in the queue, but mypy still complains
470+
event = event_future.result()
471+
run_id = getattr(event, "run_id", None)
472+
yield event # type: ignore
473+
474+
if exc := run_task.exception():
475+
yield PipelineEvent(
476+
event_type=EventType.PIPELINE_FAILED,
477+
# run_id is null if pipeline fails before even starting
478+
# ie during pipeline validation
479+
run_id=run_id or "",
480+
message=str(exc),
481+
)
482+
if raise_exception:
483+
raise exc
484+
485+
finally:
486+
# Restore original callback
487+
self.callbacks.remove(event_stream)
488+
if event_queue_getter_task and not event_queue_getter_task.done():
489+
event_queue_getter_task.cancel()
490+
415491
async def run(self, data: dict[str, Any]) -> PipelineResult:
416492
logger.debug("PIPELINE START")
417493
start_time = default_timer()

tests/unit/experimental/pipeline/components.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
import asyncio
16+
1517
from neo4j_graphrag.experimental.pipeline import Component, DataModel
1618
from neo4j_graphrag.experimental.pipeline.types.context import RunContext
1719

@@ -52,3 +54,12 @@ async def run_with_context(
5254
message="my message", data={"number1": number1, "number2": number2}
5355
)
5456
return IntResultModel(result=number1 * number2)
57+
58+
59+
class SlowComponentMultiply(Component):
60+
def __init__(self, sleep: float = 1.0) -> None:
61+
self.sleep = sleep
62+
63+
async def run(self, number1: int, number2: int = 2) -> IntResultModel:
64+
await asyncio.sleep(self.sleep)
65+
return IntResultModel(result=number1 * number2)

0 commit comments

Comments
 (0)