Skip to content

Commit ad22b8d

Browse files
committed
Pipeline streaming
1 parent fe0fe02 commit ad22b8d

File tree

2 files changed

+138
-2
lines changed

2 files changed

+138
-2
lines changed
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import asyncio
2+
3+
from neo4j_graphrag.experimental.pipeline.component import Component, DataModel
4+
from neo4j_graphrag.experimental.pipeline.pipeline import Pipeline
5+
from neo4j_graphrag.experimental.pipeline.notification import EventType
6+
from neo4j_graphrag.experimental.pipeline.types.context import RunContext
7+
8+
9+
# Define some example components with progress notifications
10+
class OutputModel(DataModel):
11+
result: int
12+
13+
14+
class SlowAdder(Component):
15+
"""A component that slowly adds numbers and reports progress"""
16+
def __init__(self, number: int) -> None:
17+
self.number = number
18+
19+
async def run_with_context(self, context_: RunContext, value: int) -> OutputModel:
20+
# Simulate work with progress updates
21+
for i in range(value):
22+
await asyncio.sleep(0.5) # Simulate work
23+
await context_.notify(
24+
message=f"Added {i+1}/{value}",
25+
data={"current": i+1, "total": value}
26+
)
27+
return OutputModel(result=value + self.number)
28+
29+
30+
class SlowMultiplier(Component):
31+
"""A component that slowly multiplies numbers and reports progress"""
32+
def __init__(self, multiplier: int) -> None:
33+
self.multiplier = multiplier
34+
35+
async def run_with_context(self, context_: RunContext, value: int) -> OutputModel:
36+
# Simulate work with progress updates
37+
for i in range(3): # Always do 3 steps
38+
await asyncio.sleep(0.7) # Simulate work
39+
await context_.notify(
40+
message=f"Multiplication step {i+1}/3",
41+
data={"step": i+1, "total": 3}
42+
)
43+
return OutputModel(result=value * self.multiplier)
44+
45+
46+
async def main():
47+
# Create pipeline
48+
pipeline = Pipeline()
49+
50+
# Add components
51+
pipeline.add_component(SlowAdder(number=3), "adder")
52+
pipeline.add_component(SlowMultiplier(multiplier=2), "multiplier")
53+
54+
# Connect components
55+
pipeline.connect(
56+
"adder",
57+
"multiplier",
58+
{"value": "adder.result"}
59+
)
60+
61+
print("\n=== Running pipeline with streaming ===")
62+
# Run pipeline with streaming - see events as they happen
63+
async for event in pipeline.stream({"adder": {"value": 2}}):
64+
if event.event_type == EventType.PIPELINE_STARTED:
65+
print("Stream: Pipeline started!")
66+
elif event.event_type == EventType.PIPELINE_FINISHED:
67+
print(f"Stream: Pipeline finished! Final results: {event.payload}")
68+
elif event.event_type == EventType.TASK_STARTED:
69+
print(f"Stream: Task {event.task_name} started with inputs: {event.payload}")
70+
elif event.event_type == EventType.TASK_PROGRESS:
71+
print(f"Stream: Task {event.task_name} progress - {event.message}")
72+
elif event.event_type == EventType.TASK_FINISHED:
73+
print(f"Stream: Task {event.task_name} finished with result: {event.payload}")
74+
75+
76+
if __name__ == "__main__":
77+
asyncio.run(main())

src/neo4j_graphrag/experimental/pipeline/pipeline.py

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
import warnings
1919
from collections import defaultdict
2020
from timeit import default_timer
21-
from typing import Any, Optional
21+
from typing import Any, Optional, AsyncGenerator, Callable, List
22+
import asyncio
2223

2324
from neo4j_graphrag.utils.logging import prettify
2425

@@ -47,7 +48,7 @@
4748
)
4849
from neo4j_graphrag.experimental.pipeline.types.orchestration import RunResult
4950
from neo4j_graphrag.experimental.pipeline.types.context import RunContext
50-
from neo4j_graphrag.experimental.pipeline.notification import EventCallbackProtocol
51+
from neo4j_graphrag.experimental.pipeline.notification import EventCallbackProtocol, Event
5152

5253

5354
logger = logging.getLogger(__name__)
@@ -412,6 +413,64 @@ def validate_parameter_mapping_for_task(self, task: TaskPipelineNode) -> bool:
412413
async def get_final_results(self, run_id: str) -> dict[str, Any]:
413414
return await self.final_results.get(run_id) # type: ignore[no-any-return]
414415

416+
async def stream(self, data: dict[str, Any]) -> AsyncGenerator[Event, None]:
417+
"""Run the pipeline and stream events for task progress.
418+
419+
Args:
420+
data: Input data for the pipeline components
421+
422+
Yields:
423+
Event: Pipeline and task events including start, progress, and completion
424+
"""
425+
# Create queue for events
426+
event_queue: asyncio.Queue[Event] = asyncio.Queue()
427+
428+
# Store original callback
429+
original_callback = self.callback
430+
431+
async def callback_and_event_stream(event: Event) -> None:
432+
# Put event in queue for streaming
433+
await event_queue.put(event)
434+
# Call original callback if it exists
435+
if original_callback:
436+
await original_callback(event)
437+
438+
# Set up event callback
439+
self.callback = callback_and_event_stream
440+
441+
try:
442+
# Start pipeline execution in background task
443+
run_task = asyncio.create_task(self.run(data))
444+
445+
while True:
446+
# Wait for next event or pipeline completion
447+
done, pending = await asyncio.wait(
448+
[run_task, event_queue.get()],
449+
return_when=asyncio.FIRST_COMPLETED
450+
)
451+
452+
# Pipeline finished
453+
if run_task in done:
454+
if run_task.exception():
455+
raise run_task.exception()
456+
# Drain any remaining events
457+
while not event_queue.empty():
458+
yield await event_queue.get()
459+
break
460+
461+
# Got an event from queue
462+
event_future = next(f for f in done if f != run_task)
463+
try:
464+
event = event_future.result()
465+
yield event
466+
except Exception as e:
467+
logger.error(f"Error processing event: {e}")
468+
raise
469+
470+
finally:
471+
# Restore original callback
472+
self.callback = original_callback
473+
415474
async def run(self, data: dict[str, Any]) -> PipelineResult:
416475
logger.debug("PIPELINE START")
417476
start_time = default_timer()

0 commit comments

Comments
 (0)