19
19
Component ,
20
20
Pipeline ,
21
21
)
22
- from neo4j_graphrag .experimental .pipeline .exceptions import PipelineDefinitionError
22
+ from neo4j_graphrag .experimental .pipeline .exceptions import (
23
+ PipelineDefinitionError ,
24
+ PipelineMissingDependencyError ,
25
+ PipelineStatusUpdateError ,
26
+ )
23
27
from neo4j_graphrag .experimental .pipeline .orchestrator import Orchestrator
24
28
from neo4j_graphrag .experimental .pipeline .types import RunStatus
25
29
@@ -34,8 +38,9 @@ def test_orchestrator_get_input_config_for_task_pipeline_not_validated() -> None
34
38
pipe .add_component (ComponentPassThrough (), "a" )
35
39
pipe .add_component (ComponentPassThrough (), "b" )
36
40
orchestrator = Orchestrator (pipe )
37
- with pytest .raises (PipelineDefinitionError ):
41
+ with pytest .raises (PipelineDefinitionError ) as exc :
38
42
orchestrator .get_input_config_for_task (pipe .get_node_by_name ("a" ))
43
+ assert "You must validate the pipeline input config first" in str (exc .value )
39
44
40
45
41
46
@pytest .mark .asyncio
@@ -59,10 +64,10 @@ async def test_orchestrator_get_component_inputs_from_user_only() -> None:
59
64
"neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_results_for_component"
60
65
)
61
66
@pytest .mark .asyncio
62
- async def test_pipeline_get_component_inputs_from_parent_specific (
67
+ async def test_orchestrator_get_component_inputs_from_parent_specific (
63
68
mock_result : Mock ,
64
69
) -> None :
65
- """Propagate one specific output field from 'a' to the next component."""
70
+ """Propagate one specific output field from parent to a child component."""
66
71
pipe = Pipeline ()
67
72
pipe .add_component (ComponentPassThrough (), "a" )
68
73
pipe .add_component (ComponentPassThrough (), "b" )
@@ -164,6 +169,56 @@ async def test_orchestrator_get_component_inputs_ignore_user_input_if_input_def_
164
169
)
165
170
166
171
172
+ @pytest .mark .asyncio
173
+ @patch (
174
+ "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component"
175
+ )
176
+ @pytest .mark .parametrize (
177
+ "old_status, new_status, result" ,
178
+ [
179
+ # Normal path: from UNKNOWN to RUNNING to DONE
180
+ (RunStatus .UNKNOWN , RunStatus .RUNNING , "ok" ),
181
+ (RunStatus .RUNNING , RunStatus .DONE , "ok" ),
182
+ # Error: status is already set to this value
183
+ (RunStatus .RUNNING , RunStatus .RUNNING , "Status is already RunStatus.RUNNING" ),
184
+ (RunStatus .DONE , RunStatus .DONE , "Status is already RunStatus.DONE" ),
185
+ # Error: can't go back in time
186
+ (
187
+ RunStatus .DONE ,
188
+ RunStatus .RUNNING ,
189
+ "Can't go from RunStatus.DONE to RunStatus.RUNNING" ,
190
+ ),
191
+ (
192
+ RunStatus .RUNNING ,
193
+ RunStatus .UNKNOWN ,
194
+ "Can't go from RunStatus.RUNNING to RunStatus.UNKNOWN" ,
195
+ ),
196
+ (
197
+ RunStatus .DONE ,
198
+ RunStatus .UNKNOWN ,
199
+ "Can't go from RunStatus.DONE to RunStatus.UNKNOWN" ,
200
+ ),
201
+ ],
202
+ )
203
+ async def test_orchestrator_set_component_status (
204
+ mock_status : Mock ,
205
+ old_status : RunStatus ,
206
+ new_status : RunStatus ,
207
+ result : str ,
208
+ ) -> None :
209
+ pipe = Pipeline ()
210
+ orchestrator = Orchestrator (pipeline = pipe )
211
+ mock_status .side_effect = [
212
+ old_status ,
213
+ ]
214
+ if result == "ok" :
215
+ await orchestrator .set_task_status ("task_name" , new_status )
216
+ else :
217
+ with pytest .raises (PipelineStatusUpdateError ) as exc :
218
+ await orchestrator .set_task_status ("task_name" , new_status )
219
+ assert result in str (exc )
220
+
221
+
167
222
@pytest .fixture (scope = "function" )
168
223
def pipeline_branch () -> Pipeline :
169
224
pipe = Pipeline ()
@@ -190,21 +245,45 @@ def pipeline_aggregation() -> Pipeline:
190
245
@patch (
191
246
"neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component"
192
247
)
193
- async def test_orchestrator_branch (
248
+ async def test_orchestrator_check_dependency_complete (
194
249
mock_status : Mock , pipeline_branch : Pipeline
250
+ ) -> None :
251
+ """a -> b, c"""
252
+ orchestrator = Orchestrator (pipeline = pipeline_branch )
253
+ node_a = pipeline_branch .get_node_by_name ("a" )
254
+ await orchestrator .check_dependencies_complete (node_a )
255
+ node_b = pipeline_branch .get_node_by_name ("b" )
256
+ # dependency is DONE:
257
+ mock_status .side_effect = [RunStatus .DONE ]
258
+ await orchestrator .check_dependencies_complete (node_b )
259
+ # dependency is not DONE:
260
+ mock_status .side_effect = [RunStatus .RUNNING ]
261
+ with pytest .raises (PipelineMissingDependencyError ):
262
+ await orchestrator .check_dependencies_complete (node_b )
263
+
264
+
265
+ @pytest .mark .asyncio
266
+ @patch (
267
+ "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component"
268
+ )
269
+ @patch (
270
+ "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.check_dependencies_complete" ,
271
+ )
272
+ async def test_orchestrator_next_task_branch_no_missing_dependencies (
273
+ mock_dep : Mock , mock_status : Mock , pipeline_branch : Pipeline
195
274
) -> None :
196
275
"""a -> b, c"""
197
276
orchestrator = Orchestrator (pipeline = pipeline_branch )
198
277
node_a = pipeline_branch .get_node_by_name ("a" )
199
278
mock_status .side_effect = [
200
- # next b
279
+ # next "b"
201
280
RunStatus .UNKNOWN ,
202
- # dep of b = a
203
- RunStatus .DONE ,
204
- # next c
281
+ # next "c"
205
282
RunStatus .UNKNOWN ,
206
- # dep of c = a
207
- RunStatus .DONE ,
283
+ ]
284
+ mock_dep .side_effect = [
285
+ None , # "b" has no missing dependencies
286
+ None , # "c" has no missing dependencies
208
287
]
209
288
next_tasks = [n async for n in orchestrator .next (node_a )]
210
289
next_task_names = [n .name for n in next_tasks ]
@@ -215,31 +294,48 @@ async def test_orchestrator_branch(
215
294
@patch (
216
295
"neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component"
217
296
)
218
- async def test_orchestrator_aggregation (
219
- mock_status : Mock , pipeline_aggregation : Pipeline
297
+ @patch (
298
+ "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.check_dependencies_complete" ,
299
+ )
300
+ async def test_orchestrator_next_task_branch_missing_dependencies (
301
+ mock_dep : Mock , mock_status : Mock , pipeline_branch : Pipeline
220
302
) -> None :
221
- """a, b -> c"""
222
- orchestrator = Orchestrator (pipeline = pipeline_aggregation )
223
- node_a = pipeline_aggregation .get_node_by_name ("a" )
303
+ """a -> b, c"""
304
+ orchestrator = Orchestrator (pipeline = pipeline_branch )
305
+ node_a = pipeline_branch .get_node_by_name ("a" )
224
306
mock_status .side_effect = [
225
- # next c:
307
+ # next "b"
226
308
RunStatus .UNKNOWN ,
227
- # dep of c = a
228
- RunStatus .DONE ,
229
- # dep of c = b
309
+ # next "c"
230
310
RunStatus .UNKNOWN ,
231
311
]
232
- next_task_names = [n .name async for n in orchestrator .next (node_a )]
233
- # "c" dependencies not ready yet
234
- assert next_task_names == []
235
- # set "b" to DONE
312
+ mock_dep .side_effect = [
313
+ PipelineMissingDependencyError , # "b" has missing dependencies
314
+ None , # "c" has no missing dependencies
315
+ ]
316
+ next_tasks = [n async for n in orchestrator .next (node_a )]
317
+ next_task_names = [n .name for n in next_tasks ]
318
+ assert next_task_names == ["c" ]
319
+
320
+
321
+ @pytest .mark .asyncio
322
+ @patch (
323
+ "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component"
324
+ )
325
+ @patch (
326
+ "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.check_dependencies_complete" ,
327
+ )
328
+ async def test_orchestrator_next_task_aggregation_no_missing_dependencies (
329
+ mock_dep : Mock , mock_status : Mock , pipeline_aggregation : Pipeline
330
+ ) -> None :
331
+ """a, b -> c"""
332
+ orchestrator = Orchestrator (pipeline = pipeline_aggregation )
333
+ node_a = pipeline_aggregation .get_node_by_name ("a" )
236
334
mock_status .side_effect = [
237
- # next c:
238
- RunStatus .UNKNOWN ,
239
- # dep of c = a
240
- RunStatus .DONE ,
241
- # dep of c = b
242
- RunStatus .DONE ,
335
+ RunStatus .UNKNOWN , # status for "c", not started
336
+ ]
337
+ mock_dep .side_effect = [
338
+ None , # no missing deps
243
339
]
244
340
# then "c" can start
245
341
next_tasks = [n async for n in orchestrator .next (node_a )]
@@ -248,8 +344,41 @@ async def test_orchestrator_aggregation(
248
344
249
345
250
346
@pytest .mark .asyncio
251
- async def test_orchestrator_aggregation_waiting (pipeline_aggregation : Pipeline ) -> None :
347
+ @patch (
348
+ "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component"
349
+ )
350
+ @patch (
351
+ "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.check_dependencies_complete" ,
352
+ )
353
+ async def test_orchestrator_next_task_aggregation_missing_dependency (
354
+ mock_dep : Mock , mock_status : Mock , pipeline_aggregation : Pipeline
355
+ ) -> None :
356
+ """a, b -> c"""
252
357
orchestrator = Orchestrator (pipeline = pipeline_aggregation )
253
358
node_a = pipeline_aggregation .get_node_by_name ("a" )
254
- next_tasks = [n async for n in orchestrator .next (node_a )]
255
- assert next_tasks == []
359
+ mock_status .side_effect = [
360
+ RunStatus .UNKNOWN , # status for "c" is unknown, it's a possible next
361
+ ]
362
+ mock_dep .side_effect = [
363
+ PipelineMissingDependencyError , # some dependencies are not done yet
364
+ ]
365
+ next_task_names = [n .name async for n in orchestrator .next (node_a )]
366
+ # "c" dependencies not ready yet
367
+ assert next_task_names == []
368
+
369
+
370
+ @pytest .mark .asyncio
371
+ @patch (
372
+ "neo4j_graphrag.experimental.pipeline.pipeline.Orchestrator.get_status_for_component"
373
+ )
374
+ async def test_orchestrator_next_task_aggregation_next_already_started (
375
+ mock_status : Mock , pipeline_aggregation : Pipeline
376
+ ) -> None :
377
+ """a, b -> c"""
378
+ orchestrator = Orchestrator (pipeline = pipeline_aggregation )
379
+ node_a = pipeline_aggregation .get_node_by_name ("a" )
380
+ mock_status .side_effect = [
381
+ RunStatus .RUNNING , # status for "c" is already running, do not start it again
382
+ ]
383
+ next_task_names = [n .name async for n in orchestrator .next (node_a )]
384
+ assert next_task_names == []
0 commit comments