diff --git a/src/neo4j_graphrag/experimental/pipeline/orchestrator.py b/src/neo4j_graphrag/experimental/pipeline/orchestrator.py index d5a434dde..6c218784e 100644 --- a/src/neo4j_graphrag/experimental/pipeline/orchestrator.py +++ b/src/neo4j_graphrag/experimental/pipeline/orchestrator.py @@ -95,9 +95,11 @@ async def set_task_status(self, task_name: str, status: RunStatus) -> None: async with asyncio.Lock(): current_status = await self.get_status_for_component(task_name) if status == current_status: - raise PipelineStatusUpdateError(f"Status is already '{status}'") - if status == RunStatus.RUNNING and current_status == RunStatus.DONE: - raise PipelineStatusUpdateError("Can't go from DONE to RUNNING") + raise PipelineStatusUpdateError(f"Status is already {status}") + if status not in current_status.possible_next_status(): + raise PipelineStatusUpdateError( + f"Can't go from {current_status} to {status}" + ) return await self.pipeline.store.add_status_for_component( self.run_id, task_name, status.value ) diff --git a/src/neo4j_graphrag/experimental/pipeline/types.py b/src/neo4j_graphrag/experimental/pipeline/types.py index d689ac11c..f4f8267c7 100644 --- a/src/neo4j_graphrag/experimental/pipeline/types.py +++ b/src/neo4j_graphrag/experimental/pipeline/types.py @@ -54,6 +54,15 @@ class RunStatus(enum.Enum): RUNNING = "RUNNING" DONE = "DONE" + def possible_next_status(self) -> list[RunStatus]: + if self == RunStatus.UNKNOWN: + return [RunStatus.RUNNING] + if self == RunStatus.RUNNING: + return [RunStatus.DONE] + if self == RunStatus.DONE: + return [] + return [] + class RunResult(BaseModel): status: RunStatus = RunStatus.DONE diff --git a/tests/unit/experimental/pipeline/test_orchestrator.py b/tests/unit/experimental/pipeline/test_orchestrator.py index 5149077b0..a253abcdf 100644 --- a/tests/unit/experimental/pipeline/test_orchestrator.py +++ b/tests/unit/experimental/pipeline/test_orchestrator.py @@ -19,7 +19,11 @@ Component, Pipeline, ) -from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError +from neo4j_graphrag.experimental.pipeline.exceptions import ( + PipelineDefinitionError, + PipelineMissingDependencyError, + PipelineStatusUpdateError, +) from neo4j_graphrag.experimental.pipeline.orchestrator import Orchestrator from neo4j_graphrag.experimental.pipeline.types import RunStatus @@ -34,8 +38,9 @@ def test_orchestrator_get_input_config_for_task_pipeline_not_validated() -> None pipe.add_component(ComponentPassThrough(), "a") pipe.add_component(ComponentPassThrough(), "b") orchestrator = Orchestrator(pipe) - with pytest.raises(PipelineDefinitionError): + with pytest.raises(PipelineDefinitionError) as exc: orchestrator.get_input_config_for_task(pipe.get_node_by_name("a")) + assert "You must validate the pipeline input config first" in str(exc.value) @pytest.mark.asyncio @@ -59,10 +64,10 @@ async def test_orchestrator_get_component_inputs_from_user_only() -> None: "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_results_for_component" ) @pytest.mark.asyncio -async def test_pipeline_get_component_inputs_from_parent_specific( +async def test_orchestrator_get_component_inputs_from_parent_specific( mock_result: Mock, ) -> None: - """Propagate one specific output field from 'a' to the next component.""" + """Propagate one specific output field from parent to a child component.""" pipe = Pipeline() pipe.add_component(ComponentPassThrough(), "a") pipe.add_component(ComponentPassThrough(), "b") @@ -164,6 +169,56 @@ async def test_orchestrator_get_component_inputs_ignore_user_input_if_input_def_ ) +@pytest.mark.asyncio +@patch( + "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component" +) +@pytest.mark.parametrize( + "old_status, new_status, result", + [ + # Normal path: from UNKNOWN to RUNNING to DONE + (RunStatus.UNKNOWN, RunStatus.RUNNING, "ok"), + (RunStatus.RUNNING, RunStatus.DONE, "ok"), + # Error: status is already set to this value + (RunStatus.RUNNING, RunStatus.RUNNING, "Status is already RunStatus.RUNNING"), + (RunStatus.DONE, RunStatus.DONE, "Status is already RunStatus.DONE"), + # Error: can't go back in time + ( + RunStatus.DONE, + RunStatus.RUNNING, + "Can't go from RunStatus.DONE to RunStatus.RUNNING", + ), + ( + RunStatus.RUNNING, + RunStatus.UNKNOWN, + "Can't go from RunStatus.RUNNING to RunStatus.UNKNOWN", + ), + ( + RunStatus.DONE, + RunStatus.UNKNOWN, + "Can't go from RunStatus.DONE to RunStatus.UNKNOWN", + ), + ], +) +async def test_orchestrator_set_component_status( + mock_status: Mock, + old_status: RunStatus, + new_status: RunStatus, + result: str, +) -> None: + pipe = Pipeline() + orchestrator = Orchestrator(pipeline=pipe) + mock_status.side_effect = [ + old_status, + ] + if result == "ok": + await orchestrator.set_task_status("task_name", new_status) + else: + with pytest.raises(PipelineStatusUpdateError) as exc: + await orchestrator.set_task_status("task_name", new_status) + assert result in str(exc) + + @pytest.fixture(scope="function") def pipeline_branch() -> Pipeline: pipe = Pipeline() @@ -190,21 +245,45 @@ def pipeline_aggregation() -> Pipeline: @patch( "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component" ) -async def test_orchestrator_branch( +async def test_orchestrator_check_dependency_complete( mock_status: Mock, pipeline_branch: Pipeline +) -> None: + """a -> b, c""" + orchestrator = Orchestrator(pipeline=pipeline_branch) + node_a = pipeline_branch.get_node_by_name("a") + await orchestrator.check_dependencies_complete(node_a) + node_b = pipeline_branch.get_node_by_name("b") + # dependency is DONE: + mock_status.side_effect = [RunStatus.DONE] + await orchestrator.check_dependencies_complete(node_b) + # dependency is not DONE: + mock_status.side_effect = [RunStatus.RUNNING] + with pytest.raises(PipelineMissingDependencyError): + await orchestrator.check_dependencies_complete(node_b) + + +@pytest.mark.asyncio +@patch( + "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component" +) +@patch( + "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.check_dependencies_complete", +) +async def test_orchestrator_next_task_branch_no_missing_dependencies( + mock_dep: Mock, mock_status: Mock, pipeline_branch: Pipeline ) -> None: """a -> b, c""" orchestrator = Orchestrator(pipeline=pipeline_branch) node_a = pipeline_branch.get_node_by_name("a") mock_status.side_effect = [ - # next b + # next "b" RunStatus.UNKNOWN, - # dep of b = a - RunStatus.DONE, - # next c + # next "c" RunStatus.UNKNOWN, - # dep of c = a - RunStatus.DONE, + ] + mock_dep.side_effect = [ + None, # "b" has no missing dependencies + None, # "c" has no missing dependencies ] next_tasks = [n async for n in orchestrator.next(node_a)] next_task_names = [n.name for n in next_tasks] @@ -215,31 +294,48 @@ async def test_orchestrator_branch( @patch( "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component" ) -async def test_orchestrator_aggregation( - mock_status: Mock, pipeline_aggregation: Pipeline +@patch( + "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.check_dependencies_complete", +) +async def test_orchestrator_next_task_branch_missing_dependencies( + mock_dep: Mock, mock_status: Mock, pipeline_branch: Pipeline ) -> None: - """a, b -> c""" - orchestrator = Orchestrator(pipeline=pipeline_aggregation) - node_a = pipeline_aggregation.get_node_by_name("a") + """a -> b, c""" + orchestrator = Orchestrator(pipeline=pipeline_branch) + node_a = pipeline_branch.get_node_by_name("a") mock_status.side_effect = [ - # next c: + # next "b" RunStatus.UNKNOWN, - # dep of c = a - RunStatus.DONE, - # dep of c = b + # next "c" RunStatus.UNKNOWN, ] - next_task_names = [n.name async for n in orchestrator.next(node_a)] - # "c" dependencies not ready yet - assert next_task_names == [] - # set "b" to DONE + mock_dep.side_effect = [ + PipelineMissingDependencyError, # "b" has missing dependencies + None, # "c" has no missing dependencies + ] + next_tasks = [n async for n in orchestrator.next(node_a)] + next_task_names = [n.name for n in next_tasks] + assert next_task_names == ["c"] + + +@pytest.mark.asyncio +@patch( + "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component" +) +@patch( + "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.check_dependencies_complete", +) +async def test_orchestrator_next_task_aggregation_no_missing_dependencies( + mock_dep: Mock, mock_status: Mock, pipeline_aggregation: Pipeline +) -> None: + """a, b -> c""" + orchestrator = Orchestrator(pipeline=pipeline_aggregation) + node_a = pipeline_aggregation.get_node_by_name("a") mock_status.side_effect = [ - # next c: - RunStatus.UNKNOWN, - # dep of c = a - RunStatus.DONE, - # dep of c = b - RunStatus.DONE, + RunStatus.UNKNOWN, # status for "c", not started + ] + mock_dep.side_effect = [ + None, # no missing deps ] # then "c" can start next_tasks = [n async for n in orchestrator.next(node_a)] @@ -248,8 +344,41 @@ async def test_orchestrator_aggregation( @pytest.mark.asyncio -async def test_orchestrator_aggregation_waiting(pipeline_aggregation: Pipeline) -> None: +@patch( + "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component" +) +@patch( + "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.check_dependencies_complete", +) +async def test_orchestrator_next_task_aggregation_missing_dependency( + mock_dep: Mock, mock_status: Mock, pipeline_aggregation: Pipeline +) -> None: + """a, b -> c""" orchestrator = Orchestrator(pipeline=pipeline_aggregation) node_a = pipeline_aggregation.get_node_by_name("a") - next_tasks = [n async for n in orchestrator.next(node_a)] - assert next_tasks == [] + mock_status.side_effect = [ + RunStatus.UNKNOWN, # status for "c" is unknown, it's a possible next + ] + mock_dep.side_effect = [ + PipelineMissingDependencyError, # some dependencies are not done yet + ] + next_task_names = [n.name async for n in orchestrator.next(node_a)] + # "c" dependencies not ready yet + assert next_task_names == [] + + +@pytest.mark.asyncio +@patch( + "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component" +) +async def test_orchestrator_next_task_aggregation_next_already_started( + mock_status: Mock, pipeline_aggregation: Pipeline +) -> None: + """a, b -> c""" + orchestrator = Orchestrator(pipeline=pipeline_aggregation) + node_a = pipeline_aggregation.get_node_by_name("a") + mock_status.side_effect = [ + RunStatus.RUNNING, # status for "c" is already running, do not start it again + ] + next_task_names = [n.name async for n in orchestrator.next(node_a)] + assert next_task_names == []