Skip to content

Enhance tests around Orchestrator #293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/neo4j_graphrag/experimental/pipeline/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,11 @@ async def set_task_status(self, task_name: str, status: RunStatus) -> None:
async with asyncio.Lock():
current_status = await self.get_status_for_component(task_name)
if status == current_status:
raise PipelineStatusUpdateError(f"Status is already '{status}'")
if status == RunStatus.RUNNING and current_status == RunStatus.DONE:
raise PipelineStatusUpdateError("Can't go from DONE to RUNNING")
raise PipelineStatusUpdateError(f"Status is already {status}")
if status not in current_status.possible_next_status():
raise PipelineStatusUpdateError(
f"Can't go from {current_status} to {status}"
)
return await self.pipeline.store.add_status_for_component(
self.run_id, task_name, status.value
)
Expand Down
9 changes: 9 additions & 0 deletions src/neo4j_graphrag/experimental/pipeline/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,15 @@ class RunStatus(enum.Enum):
RUNNING = "RUNNING"
DONE = "DONE"

def possible_next_status(self) -> list[RunStatus]:
if self == RunStatus.UNKNOWN:
return [RunStatus.RUNNING]
if self == RunStatus.RUNNING:
return [RunStatus.DONE]
if self == RunStatus.DONE:
return []
return []


class RunResult(BaseModel):
status: RunStatus = RunStatus.DONE
Expand Down
195 changes: 162 additions & 33 deletions tests/unit/experimental/pipeline/test_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
Component,
Pipeline,
)
from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError
from neo4j_graphrag.experimental.pipeline.exceptions import (
PipelineDefinitionError,
PipelineMissingDependencyError,
PipelineStatusUpdateError,
)
from neo4j_graphrag.experimental.pipeline.orchestrator import Orchestrator
from neo4j_graphrag.experimental.pipeline.types import RunStatus

Expand All @@ -34,8 +38,9 @@ def test_orchestrator_get_input_config_for_task_pipeline_not_validated() -> None
pipe.add_component(ComponentPassThrough(), "a")
pipe.add_component(ComponentPassThrough(), "b")
orchestrator = Orchestrator(pipe)
with pytest.raises(PipelineDefinitionError):
with pytest.raises(PipelineDefinitionError) as exc:
orchestrator.get_input_config_for_task(pipe.get_node_by_name("a"))
assert "You must validate the pipeline input config first" in str(exc.value)


@pytest.mark.asyncio
Expand All @@ -59,10 +64,10 @@ async def test_orchestrator_get_component_inputs_from_user_only() -> None:
"neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_results_for_component"
)
@pytest.mark.asyncio
async def test_pipeline_get_component_inputs_from_parent_specific(
async def test_orchestrator_get_component_inputs_from_parent_specific(
mock_result: Mock,
) -> None:
"""Propagate one specific output field from 'a' to the next component."""
"""Propagate one specific output field from parent to a child component."""
pipe = Pipeline()
pipe.add_component(ComponentPassThrough(), "a")
pipe.add_component(ComponentPassThrough(), "b")
Expand Down Expand Up @@ -164,6 +169,56 @@ async def test_orchestrator_get_component_inputs_ignore_user_input_if_input_def_
)


@pytest.mark.asyncio
@patch(
"neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component"
)
@pytest.mark.parametrize(
"old_status, new_status, result",
[
# Normal path: from UNKNOWN to RUNNING to DONE
(RunStatus.UNKNOWN, RunStatus.RUNNING, "ok"),
(RunStatus.RUNNING, RunStatus.DONE, "ok"),
# Error: status is already set to this value
(RunStatus.RUNNING, RunStatus.RUNNING, "Status is already RunStatus.RUNNING"),
(RunStatus.DONE, RunStatus.DONE, "Status is already RunStatus.DONE"),
# Error: can't go back in time
(
RunStatus.DONE,
RunStatus.RUNNING,
"Can't go from RunStatus.DONE to RunStatus.RUNNING",
),
(
RunStatus.RUNNING,
RunStatus.UNKNOWN,
"Can't go from RunStatus.RUNNING to RunStatus.UNKNOWN",
),
(
RunStatus.DONE,
RunStatus.UNKNOWN,
"Can't go from RunStatus.DONE to RunStatus.UNKNOWN",
),
],
)
async def test_orchestrator_set_component_status(
mock_status: Mock,
old_status: RunStatus,
new_status: RunStatus,
result: str,
) -> None:
pipe = Pipeline()
orchestrator = Orchestrator(pipeline=pipe)
mock_status.side_effect = [
old_status,
]
if result == "ok":
await orchestrator.set_task_status("task_name", new_status)
else:
with pytest.raises(PipelineStatusUpdateError) as exc:
await orchestrator.set_task_status("task_name", new_status)
assert result in str(exc)


@pytest.fixture(scope="function")
def pipeline_branch() -> Pipeline:
pipe = Pipeline()
Expand All @@ -190,21 +245,45 @@ def pipeline_aggregation() -> Pipeline:
@patch(
"neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component"
)
async def test_orchestrator_branch(
async def test_orchestrator_check_dependency_complete(
mock_status: Mock, pipeline_branch: Pipeline
) -> None:
"""a -> b, c"""
orchestrator = Orchestrator(pipeline=pipeline_branch)
node_a = pipeline_branch.get_node_by_name("a")
await orchestrator.check_dependencies_complete(node_a)
node_b = pipeline_branch.get_node_by_name("b")
# dependency is DONE:
mock_status.side_effect = [RunStatus.DONE]
await orchestrator.check_dependencies_complete(node_b)
# dependency is not DONE:
mock_status.side_effect = [RunStatus.RUNNING]
with pytest.raises(PipelineMissingDependencyError):
await orchestrator.check_dependencies_complete(node_b)


@pytest.mark.asyncio
@patch(
"neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component"
)
@patch(
"neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.check_dependencies_complete",
)
async def test_orchestrator_next_task_branch_no_missing_dependencies(
mock_dep: Mock, mock_status: Mock, pipeline_branch: Pipeline
) -> None:
"""a -> b, c"""
orchestrator = Orchestrator(pipeline=pipeline_branch)
node_a = pipeline_branch.get_node_by_name("a")
mock_status.side_effect = [
# next b
# next "b"
RunStatus.UNKNOWN,
# dep of b = a
RunStatus.DONE,
# next c
# next "c"
RunStatus.UNKNOWN,
# dep of c = a
RunStatus.DONE,
]
mock_dep.side_effect = [
None, # "b" has no missing dependencies
None, # "c" has no missing dependencies
]
next_tasks = [n async for n in orchestrator.next(node_a)]
next_task_names = [n.name for n in next_tasks]
Expand All @@ -215,31 +294,48 @@ async def test_orchestrator_branch(
@patch(
"neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component"
)
async def test_orchestrator_aggregation(
mock_status: Mock, pipeline_aggregation: Pipeline
@patch(
"neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.check_dependencies_complete",
)
async def test_orchestrator_next_task_branch_missing_dependencies(
mock_dep: Mock, mock_status: Mock, pipeline_branch: Pipeline
) -> None:
"""a, b -> c"""
orchestrator = Orchestrator(pipeline=pipeline_aggregation)
node_a = pipeline_aggregation.get_node_by_name("a")
"""a -> b, c"""
orchestrator = Orchestrator(pipeline=pipeline_branch)
node_a = pipeline_branch.get_node_by_name("a")
mock_status.side_effect = [
# next c:
# next "b"
RunStatus.UNKNOWN,
# dep of c = a
RunStatus.DONE,
# dep of c = b
# next "c"
RunStatus.UNKNOWN,
]
next_task_names = [n.name async for n in orchestrator.next(node_a)]
# "c" dependencies not ready yet
assert next_task_names == []
# set "b" to DONE
mock_dep.side_effect = [
PipelineMissingDependencyError, # "b" has missing dependencies
None, # "c" has no missing dependencies
]
next_tasks = [n async for n in orchestrator.next(node_a)]
next_task_names = [n.name for n in next_tasks]
assert next_task_names == ["c"]


@pytest.mark.asyncio
@patch(
"neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component"
)
@patch(
"neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.check_dependencies_complete",
)
async def test_orchestrator_next_task_aggregation_no_missing_dependencies(
mock_dep: Mock, mock_status: Mock, pipeline_aggregation: Pipeline
) -> None:
"""a, b -> c"""
orchestrator = Orchestrator(pipeline=pipeline_aggregation)
node_a = pipeline_aggregation.get_node_by_name("a")
mock_status.side_effect = [
# next c:
RunStatus.UNKNOWN,
# dep of c = a
RunStatus.DONE,
# dep of c = b
RunStatus.DONE,
RunStatus.UNKNOWN, # status for "c", not started
]
mock_dep.side_effect = [
None, # no missing deps
]
# then "c" can start
next_tasks = [n async for n in orchestrator.next(node_a)]
Expand All @@ -248,8 +344,41 @@ async def test_orchestrator_aggregation(


@pytest.mark.asyncio
async def test_orchestrator_aggregation_waiting(pipeline_aggregation: Pipeline) -> None:
@patch(
"neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component"
)
@patch(
"neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.check_dependencies_complete",
)
async def test_orchestrator_next_task_aggregation_missing_dependency(
mock_dep: Mock, mock_status: Mock, pipeline_aggregation: Pipeline
) -> None:
"""a, b -> c"""
orchestrator = Orchestrator(pipeline=pipeline_aggregation)
node_a = pipeline_aggregation.get_node_by_name("a")
next_tasks = [n async for n in orchestrator.next(node_a)]
assert next_tasks == []
mock_status.side_effect = [
RunStatus.UNKNOWN, # status for "c" is unknown, it's a possible next
]
mock_dep.side_effect = [
PipelineMissingDependencyError, # some dependencies are not done yet
]
next_task_names = [n.name async for n in orchestrator.next(node_a)]
# "c" dependencies not ready yet
assert next_task_names == []


@pytest.mark.asyncio
@patch(
"neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component"
)
async def test_orchestrator_next_task_aggregation_next_already_started(
mock_status: Mock, pipeline_aggregation: Pipeline
) -> None:
"""a, b -> c"""
orchestrator = Orchestrator(pipeline=pipeline_aggregation)
node_a = pipeline_aggregation.get_node_by_name("a")
mock_status.side_effect = [
RunStatus.RUNNING, # status for "c" is already running, do not start it again
]
next_task_names = [n.name async for n in orchestrator.next(node_a)]
assert next_task_names == []