Skip to content

Commit 6a7c168

Browse files
committed
Add tests
1 parent 6e4261b commit 6a7c168

File tree

6 files changed

+179
-64
lines changed

6 files changed

+179
-64
lines changed

examples/customize/build_graph/pipeline/pipeline_streaming.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from neo4j_graphrag.experimental.pipeline.component import Component, DataModel
44
from neo4j_graphrag.experimental.pipeline.pipeline import Pipeline
5-
from neo4j_graphrag.experimental.pipeline.notification import EventType
5+
from neo4j_graphrag.experimental.pipeline.notification import EventType, Event
66
from neo4j_graphrag.experimental.pipeline.types.context import RunContext
77

88

@@ -13,6 +13,7 @@ class OutputModel(DataModel):
1313

1414
class SlowAdder(Component):
1515
"""A component that slowly adds numbers and reports progress"""
16+
1617
def __init__(self, number: int) -> None:
1718
self.number = number
1819

@@ -21,14 +22,14 @@ async def run_with_context(self, context_: RunContext, value: int) -> OutputMode
2122
for i in range(value):
2223
await asyncio.sleep(0.5) # Simulate work
2324
await context_.notify(
24-
message=f"Added {i+1}/{value}",
25-
data={"current": i+1, "total": value}
25+
message=f"Added {i+1}/{value}", data={"current": i + 1, "total": value}
2626
)
2727
return OutputModel(result=value + self.number)
2828

2929

3030
class SlowMultiplier(Component):
3131
"""A component that slowly multiplies numbers and reports progress"""
32+
3233
def __init__(self, multiplier: int) -> None:
3334
self.multiplier = multiplier
3435

@@ -37,26 +38,25 @@ async def run_with_context(self, context_: RunContext, value: int) -> OutputMode
3738
for i in range(3): # Always do 3 steps
3839
await asyncio.sleep(0.7) # Simulate work
3940
await context_.notify(
40-
message=f"Multiplication step {i+1}/3",
41-
data={"step": i+1, "total": 3}
41+
message=f"Multiplication step {i+1}/3", data={"step": i + 1, "total": 3}
4242
)
4343
return OutputModel(result=value * self.multiplier)
4444

4545

46-
async def main():
46+
async def callback(event: Event) -> None:
47+
await asyncio.sleep(0.1)
48+
49+
50+
async def main() -> None:
4751
# Create pipeline
48-
pipeline = Pipeline()
49-
52+
pipeline = Pipeline(callback=callback)
53+
5054
# Add components
5155
pipeline.add_component(SlowAdder(number=3), "adder")
5256
pipeline.add_component(SlowMultiplier(multiplier=2), "multiplier")
53-
57+
5458
# Connect components
55-
pipeline.connect(
56-
"adder",
57-
"multiplier",
58-
{"value": "adder.result"}
59-
)
59+
pipeline.connect("adder", "multiplier", {"value": "adder.result"})
6060

6161
print("\n=== Running pipeline with streaming ===")
6262
# Run pipeline with streaming - see events as they happen
@@ -66,12 +66,16 @@ async def main():
6666
elif event.event_type == EventType.PIPELINE_FINISHED:
6767
print(f"Stream: Pipeline finished! Final results: {event.payload}")
6868
elif event.event_type == EventType.TASK_STARTED:
69-
print(f"Stream: Task {event.task_name} started with inputs: {event.payload}")
69+
print(
70+
f"Stream: Task {event.task_name} started with inputs: {event.payload}" # type: ignore
71+
)
7072
elif event.event_type == EventType.TASK_PROGRESS:
71-
print(f"Stream: Task {event.task_name} progress - {event.message}")
73+
print(f"Stream: Task {event.task_name} progress - {event.message}") # type: ignore
7274
elif event.event_type == EventType.TASK_FINISHED:
73-
print(f"Stream: Task {event.task_name} finished with result: {event.payload}")
75+
print(
76+
f"Stream: Task {event.task_name} finished with result: {event.payload}" # type: ignore
77+
)
7478

7579

7680
if __name__ == "__main__":
77-
asyncio.run(main())
81+
asyncio.run(main())

src/neo4j_graphrag/experimental/pipeline/notification.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17+
import asyncio
1718
import datetime
1819
import enum
1920
from collections.abc import Awaitable
@@ -72,12 +73,14 @@ def __call__(self, event: Event) -> Awaitable[None]: ...
7273

7374

7475
class EventNotifier:
75-
def __init__(self, callback: EventCallbackProtocol | None) -> None:
76-
self.callback = callback
76+
def __init__(self, callbacks: list[EventCallbackProtocol]) -> None:
77+
self.callbacks = callbacks
7778

7879
async def notify(self, event: Event) -> None:
79-
if self.callback:
80-
await self.callback(event)
80+
await asyncio.gather(
81+
*[c(event) for c in self.callbacks],
82+
return_exceptions=True,
83+
)
8184

8285
async def notify_pipeline_started(
8386
self, run_id: str, input_data: Optional[dict[str, Any]] = None

src/neo4j_graphrag/experimental/pipeline/orchestrator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class Orchestrator:
5454

5555
def __init__(self, pipeline: Pipeline):
5656
self.pipeline = pipeline
57-
self.event_notifier = EventNotifier(pipeline.callback)
57+
self.event_notifier = EventNotifier(pipeline.callbacks)
5858
self.run_id = str(uuid.uuid4())
5959

6060
async def run_task(self, task: TaskPipelineNode, data: dict[str, Any]) -> None:

src/neo4j_graphrag/experimental/pipeline/pipeline.py

Lines changed: 43 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import warnings
1919
from collections import defaultdict
2020
from timeit import default_timer
21-
from typing import Any, Optional, AsyncGenerator, Callable, List
21+
from typing import Any, Optional, AsyncGenerator
2222
import asyncio
2323

2424
from neo4j_graphrag.utils.logging import prettify
@@ -48,7 +48,10 @@
4848
)
4949
from neo4j_graphrag.experimental.pipeline.types.orchestration import RunResult
5050
from neo4j_graphrag.experimental.pipeline.types.context import RunContext
51-
from neo4j_graphrag.experimental.pipeline.notification import EventCallbackProtocol, Event
51+
from neo4j_graphrag.experimental.pipeline.notification import (
52+
EventCallbackProtocol,
53+
Event,
54+
)
5255

5356

5457
logger = logging.getLogger(__name__)
@@ -118,7 +121,7 @@ def __init__(
118121
) -> None:
119122
super().__init__()
120123
self.store = store or InMemoryStore()
121-
self.callback = callback
124+
self.callbacks = [callback] if callback else []
122125
self.final_results = InMemoryStore()
123126
self.is_validated = False
124127
self.param_mapping: dict[str, dict[str, dict[str, str]]] = defaultdict(dict)
@@ -415,61 +418,60 @@ async def get_final_results(self, run_id: str) -> dict[str, Any]:
415418

416419
async def stream(self, data: dict[str, Any]) -> AsyncGenerator[Event, None]:
417420
"""Run the pipeline and stream events for task progress.
418-
421+
419422
Args:
420423
data: Input data for the pipeline components
421-
424+
422425
Yields:
423426
Event: Pipeline and task events including start, progress, and completion
424427
"""
425428
# Create queue for events
426429
event_queue: asyncio.Queue[Event] = asyncio.Queue()
427-
428-
# Store original callback
429-
original_callback = self.callback
430-
431-
async def callback_and_event_stream(event: Event) -> None:
430+
431+
async def event_stream(event: Event) -> None:
432432
# Put event in queue for streaming
433433
await event_queue.put(event)
434-
# Call original callback if it exists
435-
if original_callback:
436-
await original_callback(event)
437-
438-
# Set up event callback
439-
self.callback = callback_and_event_stream
440-
434+
435+
# Add event streaming callback
436+
self.callbacks.append(event_stream)
437+
441438
try:
442439
# Start pipeline execution in background task
443440
run_task = asyncio.create_task(self.run(data))
444-
445-
while True:
441+
442+
# loop until the run task is done, and we do not have
443+
# any more pending tasks in queue
444+
is_run_task_running = True
445+
is_queue_empty = False
446+
while is_run_task_running or not is_queue_empty:
446447
# Wait for next event or pipeline completion
448+
event_queue_getter_task = asyncio.create_task(event_queue.get())
447449
done, pending = await asyncio.wait(
448-
[run_task, event_queue.get()],
449-
return_when=asyncio.FIRST_COMPLETED
450+
[run_task, event_queue_getter_task],
451+
return_when=asyncio.FIRST_COMPLETED,
450452
)
451-
452-
# Pipeline finished
453-
if run_task in done:
454-
if run_task.exception():
455-
raise run_task.exception()
456-
# Drain any remaining events
457-
while not event_queue.empty():
458-
yield await event_queue.get()
459-
break
460-
461-
# Got an event from queue
462-
event_future = next(f for f in done if f != run_task)
463-
try:
464-
event = event_future.result()
465-
yield event
466-
except Exception as e:
467-
logger.error(f"Error processing event: {e}")
468-
raise
469-
453+
454+
is_run_task_running = run_task not in done
455+
is_queue_empty = event_queue.empty()
456+
457+
for event_future in done:
458+
if event_future == run_task:
459+
continue
460+
yield event_future.result() # type: ignore
461+
462+
# cancel remaining task
463+
event_queue_getter_task.cancel()
464+
465+
# # Drain any remaining events
466+
# while not event_queue.empty():
467+
# yield await event_queue.get()
468+
# Pipeline finished
469+
if run_task.exception():
470+
raise run_task.exception() # type: ignore
471+
470472
finally:
471473
# Restore original callback
472-
self.callback = original_callback
474+
self.callbacks.remove(event_stream)
473475

474476
async def run(self, data: dict[str, Any]) -> PipelineResult:
475477
logger.debug("PIPELINE START")

tests/unit/experimental/pipeline/components.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
import asyncio
16+
1517
from neo4j_graphrag.experimental.pipeline import Component, DataModel
1618
from neo4j_graphrag.experimental.pipeline.types.context import RunContext
1719

@@ -52,3 +54,12 @@ async def run_with_context(
5254
message="my message", data={"number1": number1, "number2": number2}
5355
)
5456
return IntResultModel(result=number1 * number2)
57+
58+
59+
class SlowComponentMultiply(Component):
60+
def __init__(self, sleep: float = 1.0) -> None:
61+
self.sleep = sleep
62+
63+
async def run(self, number1: int, number2: int = 2) -> IntResultModel:
64+
await asyncio.sleep(self.sleep)
65+
return IntResultModel(result=number1 * number2)

tests/unit/experimental/pipeline/test_pipeline.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
ComponentNoParam,
4040
ComponentPassThrough,
4141
StringResultModel,
42+
SlowComponentMultiply,
4243
)
4344

4445

@@ -491,3 +492,97 @@ def test_event_model_no_warning(recwarn: Sized) -> None:
491492
)
492493
assert event.timestamp is not None
493494
assert len(recwarn) == 0
495+
496+
497+
@pytest.mark.asyncio
498+
async def test_pipeline_streaming_no_user_callback_happy_path() -> None:
499+
pipe = Pipeline()
500+
events = []
501+
async for e in pipe.stream({}):
502+
events.append(e)
503+
assert len(events) == 2
504+
assert events[0].event_type == EventType.PIPELINE_STARTED
505+
assert events[1].event_type == EventType.PIPELINE_FINISHED
506+
assert len(pipe.callbacks) == 0
507+
508+
509+
@pytest.mark.asyncio
510+
async def test_pipeline_streaming_with_user_callback_happy_path() -> None:
511+
callback = AsyncMock()
512+
pipe = Pipeline(callback=callback)
513+
events = []
514+
async for e in pipe.stream({}):
515+
events.append(e)
516+
assert len(events) == 2
517+
assert len(callback.call_args_list) == 2
518+
assert len(pipe.callbacks) == 1
519+
520+
521+
@pytest.mark.asyncio
522+
async def test_pipeline_streaming_very_long_running_user_callback() -> None:
523+
async def callback(event: Event) -> None:
524+
await asyncio.sleep(2)
525+
526+
pipe = Pipeline(callback=callback)
527+
events = []
528+
async for e in pipe.stream({}):
529+
events.append(e)
530+
assert len(events) == 2
531+
assert len(pipe.callbacks) == 1
532+
533+
534+
@pytest.mark.asyncio
535+
async def test_pipeline_streaming_very_long_running_pipeline() -> None:
536+
slow_component = SlowComponentMultiply()
537+
pipe = Pipeline()
538+
pipe.add_component(slow_component, "slow_component")
539+
events = []
540+
async for e in pipe.stream({"slow_component": {"number1": 1, "number2": 2}}):
541+
events.append(e)
542+
assert len(events) == 4
543+
last_event = events[-1]
544+
assert last_event.event_type == EventType.PIPELINE_FINISHED
545+
assert last_event.payload == {"slow_component": {"result": 2}}
546+
547+
548+
@pytest.mark.asyncio
549+
async def test_pipeline_streaming_error_in_pipeline_definition() -> None:
550+
pipe = Pipeline()
551+
component_a = ComponentAdd()
552+
component_b = ComponentAdd()
553+
pipe.add_component(component_a, "a")
554+
pipe.add_component(component_b, "b")
555+
pipe.connect("a", "b", {"number1": "a.result"})
556+
events = []
557+
with pytest.raises(PipelineDefinitionError):
558+
async for e in pipe.stream({"a": {"number1": 1, "number2": 2}}):
559+
events.append(e)
560+
# validation happens before pipeline run actually starts
561+
assert len(events) == 0
562+
563+
564+
@pytest.mark.asyncio
565+
async def test_pipeline_streaming_error_in_component() -> None:
566+
component = ComponentMultiply()
567+
pipe = Pipeline()
568+
pipe.add_component(component, "component")
569+
events = []
570+
with pytest.raises(TypeError):
571+
async for e in pipe.stream({"component": {"number1": None, "number2": 2}}):
572+
events.append(e)
573+
assert len(events) == 2
574+
assert events[0].event_type == EventType.PIPELINE_STARTED
575+
assert events[1].event_type == EventType.TASK_STARTED
576+
577+
578+
@pytest.mark.asyncio
579+
async def test_pipeline_streaming_error_in_user_callback() -> None:
580+
async def callback(event: Event) -> None:
581+
raise Exception("error in callback")
582+
583+
pipe = Pipeline(callback=callback)
584+
events = []
585+
async for e in pipe.stream({}):
586+
events.append(e)
587+
assert len(events) == 2
588+
assert len(pipe.callbacks) == 1

0 commit comments

Comments
 (0)