Skip to content

Commit c7234d3

Browse files
authored
Enhance tests around Orchestrator (#293)
* Update tests * Ruff ruff * Mypy * Test other status transitions * Match statement was only introduced in Python 3.10
1 parent 75ef7e4 commit c7234d3

File tree

3 files changed

+176
-36
lines changed

3 files changed

+176
-36
lines changed

src/neo4j_graphrag/experimental/pipeline/orchestrator.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,11 @@ async def set_task_status(self, task_name: str, status: RunStatus) -> None:
9595
async with asyncio.Lock():
9696
current_status = await self.get_status_for_component(task_name)
9797
if status == current_status:
98-
raise PipelineStatusUpdateError(f"Status is already '{status}'")
99-
if status == RunStatus.RUNNING and current_status == RunStatus.DONE:
100-
raise PipelineStatusUpdateError("Can't go from DONE to RUNNING")
98+
raise PipelineStatusUpdateError(f"Status is already {status}")
99+
if status not in current_status.possible_next_status():
100+
raise PipelineStatusUpdateError(
101+
f"Can't go from {current_status} to {status}"
102+
)
101103
return await self.pipeline.store.add_status_for_component(
102104
self.run_id, task_name, status.value
103105
)

src/neo4j_graphrag/experimental/pipeline/types.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,15 @@ class RunStatus(enum.Enum):
5454
RUNNING = "RUNNING"
5555
DONE = "DONE"
5656

57+
def possible_next_status(self) -> list[RunStatus]:
58+
if self == RunStatus.UNKNOWN:
59+
return [RunStatus.RUNNING]
60+
if self == RunStatus.RUNNING:
61+
return [RunStatus.DONE]
62+
if self == RunStatus.DONE:
63+
return []
64+
return []
65+
5766

5867
class RunResult(BaseModel):
5968
status: RunStatus = RunStatus.DONE

tests/unit/experimental/pipeline/test_orchestrator.py

Lines changed: 162 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
Component,
2020
Pipeline,
2121
)
22-
from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError
22+
from neo4j_graphrag.experimental.pipeline.exceptions import (
23+
PipelineDefinitionError,
24+
PipelineMissingDependencyError,
25+
PipelineStatusUpdateError,
26+
)
2327
from neo4j_graphrag.experimental.pipeline.orchestrator import Orchestrator
2428
from neo4j_graphrag.experimental.pipeline.types import RunStatus
2529

@@ -34,8 +38,9 @@ def test_orchestrator_get_input_config_for_task_pipeline_not_validated() -> None
3438
pipe.add_component(ComponentPassThrough(), "a")
3539
pipe.add_component(ComponentPassThrough(), "b")
3640
orchestrator = Orchestrator(pipe)
37-
with pytest.raises(PipelineDefinitionError):
41+
with pytest.raises(PipelineDefinitionError) as exc:
3842
orchestrator.get_input_config_for_task(pipe.get_node_by_name("a"))
43+
assert "You must validate the pipeline input config first" in str(exc.value)
3944

4045

4146
@pytest.mark.asyncio
@@ -59,10 +64,10 @@ async def test_orchestrator_get_component_inputs_from_user_only() -> None:
5964
"neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_results_for_component"
6065
)
6166
@pytest.mark.asyncio
62-
async def test_pipeline_get_component_inputs_from_parent_specific(
67+
async def test_orchestrator_get_component_inputs_from_parent_specific(
6368
mock_result: Mock,
6469
) -> None:
65-
"""Propagate one specific output field from 'a' to the next component."""
70+
"""Propagate one specific output field from parent to a child component."""
6671
pipe = Pipeline()
6772
pipe.add_component(ComponentPassThrough(), "a")
6873
pipe.add_component(ComponentPassThrough(), "b")
@@ -164,6 +169,56 @@ async def test_orchestrator_get_component_inputs_ignore_user_input_if_input_def_
164169
)
165170

166171

172+
@pytest.mark.asyncio
173+
@patch(
174+
"neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component"
175+
)
176+
@pytest.mark.parametrize(
177+
"old_status, new_status, result",
178+
[
179+
# Normal path: from UNKNOWN to RUNNING to DONE
180+
(RunStatus.UNKNOWN, RunStatus.RUNNING, "ok"),
181+
(RunStatus.RUNNING, RunStatus.DONE, "ok"),
182+
# Error: status is already set to this value
183+
(RunStatus.RUNNING, RunStatus.RUNNING, "Status is already RunStatus.RUNNING"),
184+
(RunStatus.DONE, RunStatus.DONE, "Status is already RunStatus.DONE"),
185+
# Error: can't go back in time
186+
(
187+
RunStatus.DONE,
188+
RunStatus.RUNNING,
189+
"Can't go from RunStatus.DONE to RunStatus.RUNNING",
190+
),
191+
(
192+
RunStatus.RUNNING,
193+
RunStatus.UNKNOWN,
194+
"Can't go from RunStatus.RUNNING to RunStatus.UNKNOWN",
195+
),
196+
(
197+
RunStatus.DONE,
198+
RunStatus.UNKNOWN,
199+
"Can't go from RunStatus.DONE to RunStatus.UNKNOWN",
200+
),
201+
],
202+
)
203+
async def test_orchestrator_set_component_status(
204+
mock_status: Mock,
205+
old_status: RunStatus,
206+
new_status: RunStatus,
207+
result: str,
208+
) -> None:
209+
pipe = Pipeline()
210+
orchestrator = Orchestrator(pipeline=pipe)
211+
mock_status.side_effect = [
212+
old_status,
213+
]
214+
if result == "ok":
215+
await orchestrator.set_task_status("task_name", new_status)
216+
else:
217+
with pytest.raises(PipelineStatusUpdateError) as exc:
218+
await orchestrator.set_task_status("task_name", new_status)
219+
assert result in str(exc)
220+
221+
167222
@pytest.fixture(scope="function")
168223
def pipeline_branch() -> Pipeline:
169224
pipe = Pipeline()
@@ -190,21 +245,45 @@ def pipeline_aggregation() -> Pipeline:
190245
@patch(
191246
"neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component"
192247
)
193-
async def test_orchestrator_branch(
248+
async def test_orchestrator_check_dependency_complete(
194249
mock_status: Mock, pipeline_branch: Pipeline
250+
) -> None:
251+
"""a -> b, c"""
252+
orchestrator = Orchestrator(pipeline=pipeline_branch)
253+
node_a = pipeline_branch.get_node_by_name("a")
254+
await orchestrator.check_dependencies_complete(node_a)
255+
node_b = pipeline_branch.get_node_by_name("b")
256+
# dependency is DONE:
257+
mock_status.side_effect = [RunStatus.DONE]
258+
await orchestrator.check_dependencies_complete(node_b)
259+
# dependency is not DONE:
260+
mock_status.side_effect = [RunStatus.RUNNING]
261+
with pytest.raises(PipelineMissingDependencyError):
262+
await orchestrator.check_dependencies_complete(node_b)
263+
264+
265+
@pytest.mark.asyncio
266+
@patch(
267+
"neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component"
268+
)
269+
@patch(
270+
"neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.check_dependencies_complete",
271+
)
272+
async def test_orchestrator_next_task_branch_no_missing_dependencies(
273+
mock_dep: Mock, mock_status: Mock, pipeline_branch: Pipeline
195274
) -> None:
196275
"""a -> b, c"""
197276
orchestrator = Orchestrator(pipeline=pipeline_branch)
198277
node_a = pipeline_branch.get_node_by_name("a")
199278
mock_status.side_effect = [
200-
# next b
279+
# next "b"
201280
RunStatus.UNKNOWN,
202-
# dep of b = a
203-
RunStatus.DONE,
204-
# next c
281+
# next "c"
205282
RunStatus.UNKNOWN,
206-
# dep of c = a
207-
RunStatus.DONE,
283+
]
284+
mock_dep.side_effect = [
285+
None, # "b" has no missing dependencies
286+
None, # "c" has no missing dependencies
208287
]
209288
next_tasks = [n async for n in orchestrator.next(node_a)]
210289
next_task_names = [n.name for n in next_tasks]
@@ -215,31 +294,48 @@ async def test_orchestrator_branch(
215294
@patch(
216295
"neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component"
217296
)
218-
async def test_orchestrator_aggregation(
219-
mock_status: Mock, pipeline_aggregation: Pipeline
297+
@patch(
298+
"neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.check_dependencies_complete",
299+
)
300+
async def test_orchestrator_next_task_branch_missing_dependencies(
301+
mock_dep: Mock, mock_status: Mock, pipeline_branch: Pipeline
220302
) -> None:
221-
"""a, b -> c"""
222-
orchestrator = Orchestrator(pipeline=pipeline_aggregation)
223-
node_a = pipeline_aggregation.get_node_by_name("a")
303+
"""a -> b, c"""
304+
orchestrator = Orchestrator(pipeline=pipeline_branch)
305+
node_a = pipeline_branch.get_node_by_name("a")
224306
mock_status.side_effect = [
225-
# next c:
307+
# next "b"
226308
RunStatus.UNKNOWN,
227-
# dep of c = a
228-
RunStatus.DONE,
229-
# dep of c = b
309+
# next "c"
230310
RunStatus.UNKNOWN,
231311
]
232-
next_task_names = [n.name async for n in orchestrator.next(node_a)]
233-
# "c" dependencies not ready yet
234-
assert next_task_names == []
235-
# set "b" to DONE
312+
mock_dep.side_effect = [
313+
PipelineMissingDependencyError, # "b" has missing dependencies
314+
None, # "c" has no missing dependencies
315+
]
316+
next_tasks = [n async for n in orchestrator.next(node_a)]
317+
next_task_names = [n.name for n in next_tasks]
318+
assert next_task_names == ["c"]
319+
320+
321+
@pytest.mark.asyncio
322+
@patch(
323+
"neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component"
324+
)
325+
@patch(
326+
"neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.check_dependencies_complete",
327+
)
328+
async def test_orchestrator_next_task_aggregation_no_missing_dependencies(
329+
mock_dep: Mock, mock_status: Mock, pipeline_aggregation: Pipeline
330+
) -> None:
331+
"""a, b -> c"""
332+
orchestrator = Orchestrator(pipeline=pipeline_aggregation)
333+
node_a = pipeline_aggregation.get_node_by_name("a")
236334
mock_status.side_effect = [
237-
# next c:
238-
RunStatus.UNKNOWN,
239-
# dep of c = a
240-
RunStatus.DONE,
241-
# dep of c = b
242-
RunStatus.DONE,
335+
RunStatus.UNKNOWN, # status for "c", not started
336+
]
337+
mock_dep.side_effect = [
338+
None, # no missing deps
243339
]
244340
# then "c" can start
245341
next_tasks = [n async for n in orchestrator.next(node_a)]
@@ -248,8 +344,41 @@ async def test_orchestrator_aggregation(
248344

249345

250346
@pytest.mark.asyncio
251-
async def test_orchestrator_aggregation_waiting(pipeline_aggregation: Pipeline) -> None:
347+
@patch(
348+
"neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component"
349+
)
350+
@patch(
351+
"neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.check_dependencies_complete",
352+
)
353+
async def test_orchestrator_next_task_aggregation_missing_dependency(
354+
mock_dep: Mock, mock_status: Mock, pipeline_aggregation: Pipeline
355+
) -> None:
356+
"""a, b -> c"""
252357
orchestrator = Orchestrator(pipeline=pipeline_aggregation)
253358
node_a = pipeline_aggregation.get_node_by_name("a")
254-
next_tasks = [n async for n in orchestrator.next(node_a)]
255-
assert next_tasks == []
359+
mock_status.side_effect = [
360+
RunStatus.UNKNOWN, # status for "c" is unknown, it's a possible next
361+
]
362+
mock_dep.side_effect = [
363+
PipelineMissingDependencyError, # some dependencies are not done yet
364+
]
365+
next_task_names = [n.name async for n in orchestrator.next(node_a)]
366+
# "c" dependencies not ready yet
367+
assert next_task_names == []
368+
369+
370+
@pytest.mark.asyncio
371+
@patch(
372+
"neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component"
373+
)
374+
async def test_orchestrator_next_task_aggregation_next_already_started(
375+
mock_status: Mock, pipeline_aggregation: Pipeline
376+
) -> None:
377+
"""a, b -> c"""
378+
orchestrator = Orchestrator(pipeline=pipeline_aggregation)
379+
node_a = pipeline_aggregation.get_node_by_name("a")
380+
mock_status.side_effect = [
381+
RunStatus.RUNNING, # status for "c" is already running, do not start it again
382+
]
383+
next_task_names = [n.name async for n in orchestrator.next(node_a)]
384+
assert next_task_names == []

0 commit comments

Comments
 (0)