Skip to content

Enhance tests around Orchestrator #293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 4, 2025
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 149 additions & 33 deletions tests/unit/experimental/pipeline/test_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
Component,
Pipeline,
)
from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError
from neo4j_graphrag.experimental.pipeline.exceptions import (
PipelineDefinitionError,
PipelineMissingDependencyError,
PipelineStatusUpdateError,
)
from neo4j_graphrag.experimental.pipeline.orchestrator import Orchestrator
from neo4j_graphrag.experimental.pipeline.types import RunStatus

Expand All @@ -34,8 +38,9 @@ def test_orchestrator_get_input_config_for_task_pipeline_not_validated() -> None
pipe.add_component(ComponentPassThrough(), "a")
pipe.add_component(ComponentPassThrough(), "b")
orchestrator = Orchestrator(pipe)
with pytest.raises(PipelineDefinitionError):
with pytest.raises(PipelineDefinitionError) as exc:
orchestrator.get_input_config_for_task(pipe.get_node_by_name("a"))
assert "You must validate the pipeline input config first" in str(exc.value)


@pytest.mark.asyncio
Expand All @@ -59,10 +64,10 @@ async def test_orchestrator_get_component_inputs_from_user_only() -> None:
"neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_results_for_component"
)
@pytest.mark.asyncio
async def test_pipeline_get_component_inputs_from_parent_specific(
async def test_orchestrator_get_component_inputs_from_parent_specific(
mock_result: Mock,
) -> None:
"""Propagate one specific output field from 'a' to the next component."""
"""Propagate one specific output field from parent to a child component."""
pipe = Pipeline()
pipe.add_component(ComponentPassThrough(), "a")
pipe.add_component(ComponentPassThrough(), "b")
Expand Down Expand Up @@ -164,6 +169,43 @@ async def test_orchestrator_get_component_inputs_ignore_user_input_if_input_def_
)


@pytest.mark.asyncio
@patch(
"neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component"
)
async def test_orchestrator_set_component_status(mock_status: Mock) -> None:
pipe = Pipeline()
orchestrator = Orchestrator(pipeline=pipe)
# Normal status update: UNKNOWN -> RUNNING -> DONE
# UNKNOWN -> RUNNING
mock_status.side_effect = [
RunStatus.UNKNOWN,
]
await orchestrator.set_task_status("task_name", RunStatus.RUNNING)
# RUNNING -> DONE
mock_status.side_effect = [
RunStatus.RUNNING,
]
await orchestrator.set_task_status("task_name", RunStatus.DONE)
# Error path, raising PipelineStatusUpdateError
# Same status
# RUNNING -> RUNNING
mock_status.side_effect = [
RunStatus.RUNNING,
]
with pytest.raises(PipelineStatusUpdateError) as exc:
await orchestrator.set_task_status("task_name", RunStatus.RUNNING)
assert "Status is already" in str(exc)
# Going back to RUNNING after the task is DONE
# DONE -> RUNNING
mock_status.side_effect = [
RunStatus.DONE,
]
with pytest.raises(PipelineStatusUpdateError) as exc:
await orchestrator.set_task_status("task_name", RunStatus.RUNNING)
assert "Can't go from DONE to RUNNING" in str(exc)


@pytest.fixture(scope="function")
def pipeline_branch() -> Pipeline:
pipe = Pipeline()
Expand All @@ -190,21 +232,45 @@ def pipeline_aggregation() -> Pipeline:
@patch(
"neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component"
)
async def test_orchestrator_branch(
async def test_orchestrator_check_dependency_complete(
mock_status: Mock, pipeline_branch: Pipeline
) -> None:
"""a -> b, c"""
orchestrator = Orchestrator(pipeline=pipeline_branch)
node_a = pipeline_branch.get_node_by_name("a")
await orchestrator.check_dependencies_complete(node_a)
node_b = pipeline_branch.get_node_by_name("b")
# dependency is DONE:
mock_status.side_effect = [RunStatus.DONE]
await orchestrator.check_dependencies_complete(node_b)
# dependency is not DONE:
mock_status.side_effect = [RunStatus.RUNNING]
with pytest.raises(PipelineMissingDependencyError):
await orchestrator.check_dependencies_complete(node_b)


@pytest.mark.asyncio
@patch(
"neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component"
)
@patch(
"neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.check_dependencies_complete",
)
async def test_orchestrator_next_task_branch_no_missing_dependencies(
mock_dep: Mock, mock_status: Mock, pipeline_branch: Pipeline
) -> None:
"""a -> b, c"""
orchestrator = Orchestrator(pipeline=pipeline_branch)
node_a = pipeline_branch.get_node_by_name("a")
mock_status.side_effect = [
# next b
# next "b"
RunStatus.UNKNOWN,
# dep of b = a
RunStatus.DONE,
# next c
# next "c"
RunStatus.UNKNOWN,
# dep of c = a
RunStatus.DONE,
]
mock_dep.side_effect = [
None, # "b" has no missing dependencies
None, # "c" has no missing dependencies
]
next_tasks = [n async for n in orchestrator.next(node_a)]
next_task_names = [n.name for n in next_tasks]
Expand All @@ -215,31 +281,48 @@ async def test_orchestrator_branch(
@patch(
"neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component"
)
async def test_orchestrator_aggregation(
mock_status: Mock, pipeline_aggregation: Pipeline
@patch(
"neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.check_dependencies_complete",
)
async def test_orchestrator_next_task_branch_missing_dependencies(
mock_dep: Mock, mock_status: Mock, pipeline_branch: Pipeline
) -> None:
"""a, b -> c"""
orchestrator = Orchestrator(pipeline=pipeline_aggregation)
node_a = pipeline_aggregation.get_node_by_name("a")
"""a -> b, c"""
orchestrator = Orchestrator(pipeline=pipeline_branch)
node_a = pipeline_branch.get_node_by_name("a")
mock_status.side_effect = [
# next c:
# next "b"
RunStatus.UNKNOWN,
# dep of c = a
RunStatus.DONE,
# dep of c = b
# next "c"
RunStatus.UNKNOWN,
]
next_task_names = [n.name async for n in orchestrator.next(node_a)]
# "c" dependencies not ready yet
assert next_task_names == []
# set "b" to DONE
mock_dep.side_effect = [
PipelineMissingDependencyError, # "b" has missing dependencies
None, # "c" has no missing dependencies
]
next_tasks = [n async for n in orchestrator.next(node_a)]
next_task_names = [n.name for n in next_tasks]
assert next_task_names == ["c"]


@pytest.mark.asyncio
@patch(
"neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component"
)
@patch(
"neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.check_dependencies_complete",
)
async def test_orchestrator_next_task_aggregation_no_missing_dependencies(
mock_dep: Mock, mock_status: Mock, pipeline_aggregation: Pipeline
) -> None:
"""a, b -> c"""
orchestrator = Orchestrator(pipeline=pipeline_aggregation)
node_a = pipeline_aggregation.get_node_by_name("a")
mock_status.side_effect = [
# next c:
RunStatus.UNKNOWN,
# dep of c = a
RunStatus.DONE,
# dep of c = b
RunStatus.DONE,
RunStatus.UNKNOWN, # status for "c", not started
]
mock_dep.side_effect = [
None, # no missing deps
]
# then "c" can start
next_tasks = [n async for n in orchestrator.next(node_a)]
Expand All @@ -248,8 +331,41 @@ async def test_orchestrator_aggregation(


@pytest.mark.asyncio
async def test_orchestrator_aggregation_waiting(pipeline_aggregation: Pipeline) -> None:
@patch(
"neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component"
)
@patch(
"neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.check_dependencies_complete",
)
async def test_orchestrator_next_task_aggregation_missing_dependency(
mock_dep: Mock, mock_status: Mock, pipeline_aggregation: Pipeline
) -> None:
"""a, b -> c"""
orchestrator = Orchestrator(pipeline=pipeline_aggregation)
node_a = pipeline_aggregation.get_node_by_name("a")
next_tasks = [n async for n in orchestrator.next(node_a)]
assert next_tasks == []
mock_status.side_effect = [
RunStatus.UNKNOWN, # status for "c" is unknown, it's a possible next
]
mock_dep.side_effect = [
PipelineMissingDependencyError, # some dependencies are not done yet
]
next_task_names = [n.name async for n in orchestrator.next(node_a)]
# "c" dependencies not ready yet
assert next_task_names == []


@pytest.mark.asyncio
@patch(
"neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component"
)
async def test_orchestrator_next_task_aggregation_next_already_started(
mock_status: Mock, pipeline_aggregation: Pipeline
) -> None:
"""a, b -> c"""
orchestrator = Orchestrator(pipeline=pipeline_aggregation)
node_a = pipeline_aggregation.get_node_by_name("a")
mock_status.side_effect = [
RunStatus.RUNNING, # status for "c" is already running, do not start it again
]
next_task_names = [n.name async for n in orchestrator.next(node_a)]
assert next_task_names == []