From 761d35e7b825bb57a208be7ce9d23020e39b684d Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 4 Mar 2025 11:43:07 +0100 Subject: [PATCH 1/5] Update tests --- .../pipeline/test_orchestrator.py | 190 ++++++++++++++---- 1 file changed, 156 insertions(+), 34 deletions(-) diff --git a/tests/unit/experimental/pipeline/test_orchestrator.py b/tests/unit/experimental/pipeline/test_orchestrator.py index 5149077b0..d98fd1516 100644 --- a/tests/unit/experimental/pipeline/test_orchestrator.py +++ b/tests/unit/experimental/pipeline/test_orchestrator.py @@ -19,7 +19,8 @@ Component, Pipeline, ) -from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError +from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError, \ + PipelineMissingDependencyError, PipelineStatusUpdateError from neo4j_graphrag.experimental.pipeline.orchestrator import Orchestrator from neo4j_graphrag.experimental.pipeline.types import RunStatus @@ -34,8 +35,9 @@ def test_orchestrator_get_input_config_for_task_pipeline_not_validated() -> None pipe.add_component(ComponentPassThrough(), "a") pipe.add_component(ComponentPassThrough(), "b") orchestrator = Orchestrator(pipe) - with pytest.raises(PipelineDefinitionError): + with pytest.raises(PipelineDefinitionError) as exc: orchestrator.get_input_config_for_task(pipe.get_node_by_name("a")) + assert "You must validate the pipeline input config first" in str(exc.value) @pytest.mark.asyncio @@ -59,10 +61,10 @@ async def test_orchestrator_get_component_inputs_from_user_only() -> None: "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_results_for_component" ) @pytest.mark.asyncio -async def test_pipeline_get_component_inputs_from_parent_specific( +async def test_orchestrator_get_component_inputs_from_parent_specific( mock_result: Mock, ) -> None: - """Propagate one specific output field from 'a' to the next component.""" + """Propagate one specific output field from parent to a child component.""" pipe = Pipeline() pipe.add_component(ComponentPassThrough(), "a") pipe.add_component(ComponentPassThrough(), "b") @@ -164,6 +166,43 @@ async def test_orchestrator_get_component_inputs_ignore_user_input_if_input_def_ ) +@pytest.mark.asyncio +@patch( + "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component" +) +async def test_orchestrator_set_component_status(mock_status: Mock) -> None: + pipe = Pipeline() + orchestrator = Orchestrator(pipeline=pipe) + # Normal status update: UNKNOWN -> RUNNING -> DONE + # UNKNOWN -> RUNNING + mock_status.side_effect = [ + RunStatus.UNKNOWN, + ] + assert await orchestrator.set_task_status("task_name", RunStatus.RUNNING) is None + # RUNNING -> DONE + mock_status.side_effect = [ + RunStatus.RUNNING, + ] + assert await orchestrator.set_task_status("task_name", RunStatus.DONE) is None + # Error path, raising PipelineStatusUpdateError + # Same status + # RUNNING -> RUNNING + mock_status.side_effect = [ + RunStatus.RUNNING, + ] + with pytest.raises(PipelineStatusUpdateError) as exc: + await orchestrator.set_task_status("task_name", RunStatus.RUNNING) + assert "Status is already" in str(exc) + # Going back to RUNNING after the task is DONE + # DONE -> RUNNING + mock_status.side_effect = [ + RunStatus.DONE, + ] + with pytest.raises(PipelineStatusUpdateError) as exc: + await orchestrator.set_task_status("task_name", RunStatus.RUNNING) + assert "Can't go from DONE to RUNNING" in str(exc) + + @pytest.fixture(scope="function") def pipeline_branch() -> Pipeline: pipe = Pipeline() @@ -190,21 +229,54 @@ def pipeline_aggregation() -> Pipeline: @patch( "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component" ) -async def test_orchestrator_branch( - mock_status: Mock, pipeline_branch: Pipeline +async def test_orchestrator_check_dependency_complete(mock_status: Mock, pipeline_branch: Pipeline) -> None: + """a -> b, c""" + orchestrator = Orchestrator(pipeline=pipeline_branch) + node_a = pipeline_branch.get_node_by_name("a") + assert await orchestrator.check_dependencies_complete(node_a) is None + node_b = pipeline_branch.get_node_by_name("b") + # dependency is DONE: + mock_status.side_effect = [RunStatus.DONE] + assert await orchestrator.check_dependencies_complete(node_b) is None + # dependency is not DONE: + mock_status.side_effect = [RunStatus.RUNNING] + with pytest.raises(PipelineMissingDependencyError): + await orchestrator.check_dependencies_complete(node_b) + + +@pytest.mark.asyncio +@patch( + "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component" +) +async def test_orchestrator_check_dependency_complete(mock_status: Mock, pipeline_branch: Pipeline) -> None: + """a -> b, c""" + orchestrator = Orchestrator(pipeline=pipeline_branch) + node_a = pipeline_branch.get_node_by_name("a") + assert await orchestrator.check_dependencies_complete(node_a) is None + + +@pytest.mark.asyncio +@patch( + "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component" +) +@patch( + "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.check_dependencies_complete", +) +async def test_orchestrator_next_task_branch_no_missing_dependencies( + mock_dep: Mock, mock_status: Mock, pipeline_branch: Pipeline ) -> None: """a -> b, c""" orchestrator = Orchestrator(pipeline=pipeline_branch) node_a = pipeline_branch.get_node_by_name("a") mock_status.side_effect = [ - # next b + # next "b" RunStatus.UNKNOWN, - # dep of b = a - RunStatus.DONE, - # next c + # next "c" RunStatus.UNKNOWN, - # dep of c = a - RunStatus.DONE, + ] + mock_dep.side_effect = [ + None, # "b" has no missing dependencies + None, # "c" has no missing dependencies ] next_tasks = [n async for n in orchestrator.next(node_a)] next_task_names = [n.name for n in next_tasks] @@ -215,31 +287,48 @@ async def test_orchestrator_branch( @patch( "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component" ) -async def test_orchestrator_aggregation( - mock_status: Mock, pipeline_aggregation: Pipeline +@patch( + "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.check_dependencies_complete", +) +async def test_orchestrator_next_task_branch_missing_dependencies( + mock_dep: Mock, mock_status: Mock, pipeline_branch: Pipeline ) -> None: - """a, b -> c""" - orchestrator = Orchestrator(pipeline=pipeline_aggregation) - node_a = pipeline_aggregation.get_node_by_name("a") + """a -> b, c""" + orchestrator = Orchestrator(pipeline=pipeline_branch) + node_a = pipeline_branch.get_node_by_name("a") mock_status.side_effect = [ - # next c: + # next "b" RunStatus.UNKNOWN, - # dep of c = a - RunStatus.DONE, - # dep of c = b + # next "c" RunStatus.UNKNOWN, ] - next_task_names = [n.name async for n in orchestrator.next(node_a)] - # "c" dependencies not ready yet - assert next_task_names == [] - # set "b" to DONE + mock_dep.side_effect = [ + PipelineMissingDependencyError, # "b" has missing dependencies + None, # "c" has no missing dependencies + ] + next_tasks = [n async for n in orchestrator.next(node_a)] + next_task_names = [n.name for n in next_tasks] + assert next_task_names == ["c"] + + +@pytest.mark.asyncio +@patch( + "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component" +) +@patch( + "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.check_dependencies_complete", +) +async def test_orchestrator_next_task_aggregation_no_missing_dependencies( + mock_dep: Mock, mock_status: Mock, pipeline_aggregation: Pipeline +) -> None: + """a, b -> c""" + orchestrator = Orchestrator(pipeline=pipeline_aggregation) + node_a = pipeline_aggregation.get_node_by_name("a") mock_status.side_effect = [ - # next c: - RunStatus.UNKNOWN, - # dep of c = a - RunStatus.DONE, - # dep of c = b - RunStatus.DONE, + RunStatus.UNKNOWN, # status for "c", not started + ] + mock_dep.side_effect = [ + None, # no missing deps ] # then "c" can start next_tasks = [n async for n in orchestrator.next(node_a)] @@ -248,8 +337,41 @@ async def test_orchestrator_aggregation( @pytest.mark.asyncio -async def test_orchestrator_aggregation_waiting(pipeline_aggregation: Pipeline) -> None: +@patch( + "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component" +) +@patch( + "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.check_dependencies_complete", +) +async def test_orchestrator_next_task_aggregation_missing_dependency( + mock_dep: Mock, mock_status: Mock, pipeline_aggregation: Pipeline +) -> None: + """a, b -> c""" orchestrator = Orchestrator(pipeline=pipeline_aggregation) node_a = pipeline_aggregation.get_node_by_name("a") - next_tasks = [n async for n in orchestrator.next(node_a)] - assert next_tasks == [] + mock_status.side_effect = [ + RunStatus.UNKNOWN, # status for "c" is unknown, it's a possible next + ] + mock_dep.side_effect = [ + PipelineMissingDependencyError, # some dependencies are not done yet + ] + next_task_names = [n.name async for n in orchestrator.next(node_a)] + # "c" dependencies not ready yet + assert next_task_names == [] + + +@pytest.mark.asyncio +@patch( + "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component" +) +async def test_orchestrator_next_task_aggregation_next_already_started( + mock_status: Mock, pipeline_aggregation: Pipeline +) -> None: + """a, b -> c""" + orchestrator = Orchestrator(pipeline=pipeline_aggregation) + node_a = pipeline_aggregation.get_node_by_name("a") + mock_status.side_effect = [ + RunStatus.RUNNING, # status for "c" is already running, do not start it again + ] + next_task_names = [n.name async for n in orchestrator.next(node_a)] + assert next_task_names == [] From 67a24816ae878c89cae4cd3791baf297e6060451 Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 4 Mar 2025 11:43:20 +0100 Subject: [PATCH 2/5] Ruff ruff --- .../experimental/pipeline/test_orchestrator.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/tests/unit/experimental/pipeline/test_orchestrator.py b/tests/unit/experimental/pipeline/test_orchestrator.py index d98fd1516..86dd8525d 100644 --- a/tests/unit/experimental/pipeline/test_orchestrator.py +++ b/tests/unit/experimental/pipeline/test_orchestrator.py @@ -19,8 +19,11 @@ Component, Pipeline, ) -from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError, \ - PipelineMissingDependencyError, PipelineStatusUpdateError +from neo4j_graphrag.experimental.pipeline.exceptions import ( + PipelineDefinitionError, + PipelineMissingDependencyError, + PipelineStatusUpdateError, +) from neo4j_graphrag.experimental.pipeline.orchestrator import Orchestrator from neo4j_graphrag.experimental.pipeline.types import RunStatus @@ -229,7 +232,9 @@ def pipeline_aggregation() -> Pipeline: @patch( "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component" ) -async def test_orchestrator_check_dependency_complete(mock_status: Mock, pipeline_branch: Pipeline) -> None: +async def test_orchestrator_check_dependency_complete( + mock_status: Mock, pipeline_branch: Pipeline +) -> None: """a -> b, c""" orchestrator = Orchestrator(pipeline=pipeline_branch) node_a = pipeline_branch.get_node_by_name("a") @@ -248,7 +253,9 @@ async def test_orchestrator_check_dependency_complete(mock_status: Mock, pipelin @patch( "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component" ) -async def test_orchestrator_check_dependency_complete(mock_status: Mock, pipeline_branch: Pipeline) -> None: +async def test_orchestrator_check_dependency_complete( + mock_status: Mock, pipeline_branch: Pipeline +) -> None: """a -> b, c""" orchestrator = Orchestrator(pipeline=pipeline_branch) node_a = pipeline_branch.get_node_by_name("a") From b8222a13875cea8e608dd5ebcdeb8d2a0d53502a Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 4 Mar 2025 11:45:10 +0100 Subject: [PATCH 3/5] Mypy --- .../pipeline/test_orchestrator.py | 21 ++++--------------- 1 file changed, 4 insertions(+), 17 deletions(-) diff --git a/tests/unit/experimental/pipeline/test_orchestrator.py b/tests/unit/experimental/pipeline/test_orchestrator.py index 86dd8525d..8ba5fe220 100644 --- a/tests/unit/experimental/pipeline/test_orchestrator.py +++ b/tests/unit/experimental/pipeline/test_orchestrator.py @@ -181,12 +181,12 @@ async def test_orchestrator_set_component_status(mock_status: Mock) -> None: mock_status.side_effect = [ RunStatus.UNKNOWN, ] - assert await orchestrator.set_task_status("task_name", RunStatus.RUNNING) is None + await orchestrator.set_task_status("task_name", RunStatus.RUNNING) # RUNNING -> DONE mock_status.side_effect = [ RunStatus.RUNNING, ] - assert await orchestrator.set_task_status("task_name", RunStatus.DONE) is None + await orchestrator.set_task_status("task_name", RunStatus.DONE) # Error path, raising PipelineStatusUpdateError # Same status # RUNNING -> RUNNING @@ -238,30 +238,17 @@ async def test_orchestrator_check_dependency_complete( """a -> b, c""" orchestrator = Orchestrator(pipeline=pipeline_branch) node_a = pipeline_branch.get_node_by_name("a") - assert await orchestrator.check_dependencies_complete(node_a) is None + await orchestrator.check_dependencies_complete(node_a) node_b = pipeline_branch.get_node_by_name("b") # dependency is DONE: mock_status.side_effect = [RunStatus.DONE] - assert await orchestrator.check_dependencies_complete(node_b) is None + await orchestrator.check_dependencies_complete(node_b) # dependency is not DONE: mock_status.side_effect = [RunStatus.RUNNING] with pytest.raises(PipelineMissingDependencyError): await orchestrator.check_dependencies_complete(node_b) -@pytest.mark.asyncio -@patch( - "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component" -) -async def test_orchestrator_check_dependency_complete( - mock_status: Mock, pipeline_branch: Pipeline -) -> None: - """a -> b, c""" - orchestrator = Orchestrator(pipeline=pipeline_branch) - node_a = pipeline_branch.get_node_by_name("a") - assert await orchestrator.check_dependencies_complete(node_a) is None - - @pytest.mark.asyncio @patch( "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component" From 3fdc4004956c26a935c8c5eb8246eb7bdeec7525 Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 4 Mar 2025 14:05:33 +0100 Subject: [PATCH 4/5] Test other status transitions --- .../experimental/pipeline/orchestrator.py | 8 ++- .../experimental/pipeline/types.py | 11 +++ .../pipeline/test_orchestrator.py | 67 +++++++++++-------- 3 files changed, 56 insertions(+), 30 deletions(-) diff --git a/src/neo4j_graphrag/experimental/pipeline/orchestrator.py b/src/neo4j_graphrag/experimental/pipeline/orchestrator.py index d5a434dde..6c218784e 100644 --- a/src/neo4j_graphrag/experimental/pipeline/orchestrator.py +++ b/src/neo4j_graphrag/experimental/pipeline/orchestrator.py @@ -95,9 +95,11 @@ async def set_task_status(self, task_name: str, status: RunStatus) -> None: async with asyncio.Lock(): current_status = await self.get_status_for_component(task_name) if status == current_status: - raise PipelineStatusUpdateError(f"Status is already '{status}'") - if status == RunStatus.RUNNING and current_status == RunStatus.DONE: - raise PipelineStatusUpdateError("Can't go from DONE to RUNNING") + raise PipelineStatusUpdateError(f"Status is already {status}") + if status not in current_status.possible_next_status(): + raise PipelineStatusUpdateError( + f"Can't go from {current_status} to {status}" + ) return await self.pipeline.store.add_status_for_component( self.run_id, task_name, status.value ) diff --git a/src/neo4j_graphrag/experimental/pipeline/types.py b/src/neo4j_graphrag/experimental/pipeline/types.py index d689ac11c..6c65d756e 100644 --- a/src/neo4j_graphrag/experimental/pipeline/types.py +++ b/src/neo4j_graphrag/experimental/pipeline/types.py @@ -54,6 +54,17 @@ class RunStatus(enum.Enum): RUNNING = "RUNNING" DONE = "DONE" + def possible_next_status(self) -> list[RunStatus]: + match self: + case RunStatus.UNKNOWN: + return [RunStatus.RUNNING] + case RunStatus.RUNNING: + return [RunStatus.DONE] + case RunStatus.DONE: + return [] + case _: + return [] + class RunResult(BaseModel): status: RunStatus = RunStatus.DONE diff --git a/tests/unit/experimental/pipeline/test_orchestrator.py b/tests/unit/experimental/pipeline/test_orchestrator.py index 8ba5fe220..a253abcdf 100644 --- a/tests/unit/experimental/pipeline/test_orchestrator.py +++ b/tests/unit/experimental/pipeline/test_orchestrator.py @@ -173,37 +173,50 @@ async def test_orchestrator_get_component_inputs_ignore_user_input_if_input_def_ @patch( "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component" ) -async def test_orchestrator_set_component_status(mock_status: Mock) -> None: +@pytest.mark.parametrize( + "old_status, new_status, result", + [ + # Normal path: from UNKNOWN to RUNNING to DONE + (RunStatus.UNKNOWN, RunStatus.RUNNING, "ok"), + (RunStatus.RUNNING, RunStatus.DONE, "ok"), + # Error: status is already set to this value + (RunStatus.RUNNING, RunStatus.RUNNING, "Status is already RunStatus.RUNNING"), + (RunStatus.DONE, RunStatus.DONE, "Status is already RunStatus.DONE"), + # Error: can't go back in time + ( + RunStatus.DONE, + RunStatus.RUNNING, + "Can't go from RunStatus.DONE to RunStatus.RUNNING", + ), + ( + RunStatus.RUNNING, + RunStatus.UNKNOWN, + "Can't go from RunStatus.RUNNING to RunStatus.UNKNOWN", + ), + ( + RunStatus.DONE, + RunStatus.UNKNOWN, + "Can't go from RunStatus.DONE to RunStatus.UNKNOWN", + ), + ], +) +async def test_orchestrator_set_component_status( + mock_status: Mock, + old_status: RunStatus, + new_status: RunStatus, + result: str, +) -> None: pipe = Pipeline() orchestrator = Orchestrator(pipeline=pipe) - # Normal status update: UNKNOWN -> RUNNING -> DONE - # UNKNOWN -> RUNNING - mock_status.side_effect = [ - RunStatus.UNKNOWN, - ] - await orchestrator.set_task_status("task_name", RunStatus.RUNNING) - # RUNNING -> DONE - mock_status.side_effect = [ - RunStatus.RUNNING, - ] - await orchestrator.set_task_status("task_name", RunStatus.DONE) - # Error path, raising PipelineStatusUpdateError - # Same status - # RUNNING -> RUNNING - mock_status.side_effect = [ - RunStatus.RUNNING, - ] - with pytest.raises(PipelineStatusUpdateError) as exc: - await orchestrator.set_task_status("task_name", RunStatus.RUNNING) - assert "Status is already" in str(exc) - # Going back to RUNNING after the task is DONE - # DONE -> RUNNING mock_status.side_effect = [ - RunStatus.DONE, + old_status, ] - with pytest.raises(PipelineStatusUpdateError) as exc: - await orchestrator.set_task_status("task_name", RunStatus.RUNNING) - assert "Can't go from DONE to RUNNING" in str(exc) + if result == "ok": + await orchestrator.set_task_status("task_name", new_status) + else: + with pytest.raises(PipelineStatusUpdateError) as exc: + await orchestrator.set_task_status("task_name", new_status) + assert result in str(exc) @pytest.fixture(scope="function") From c7c93d8951986d46b1895f230a57cde07e7260dd Mon Sep 17 00:00:00 2001 From: estelle Date: Tue, 4 Mar 2025 14:13:42 +0100 Subject: [PATCH 5/5] Match statement was only introduced in Python 3.10 --- .../experimental/pipeline/types.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/neo4j_graphrag/experimental/pipeline/types.py b/src/neo4j_graphrag/experimental/pipeline/types.py index 6c65d756e..f4f8267c7 100644 --- a/src/neo4j_graphrag/experimental/pipeline/types.py +++ b/src/neo4j_graphrag/experimental/pipeline/types.py @@ -55,15 +55,13 @@ class RunStatus(enum.Enum): DONE = "DONE" def possible_next_status(self) -> list[RunStatus]: - match self: - case RunStatus.UNKNOWN: - return [RunStatus.RUNNING] - case RunStatus.RUNNING: - return [RunStatus.DONE] - case RunStatus.DONE: - return [] - case _: - return [] + if self == RunStatus.UNKNOWN: + return [RunStatus.RUNNING] + if self == RunStatus.RUNNING: + return [RunStatus.DONE] + if self == RunStatus.DONE: + return [] + return [] class RunResult(BaseModel):