|
3 | 3 |
|
4 | 4 |
|
5 | 5 | @pytest.fixture(scope="function")
|
6 |
| -def component(): |
7 |
| - return Component() |
8 |
| - |
9 |
| - |
10 |
| -@pytest.fixture(scope="function") |
11 |
| -def pipeline_branch(component): |
| 6 | +def pipeline_branch() -> Pipeline: |
12 | 7 | pipe = Pipeline()
|
13 |
| - pipe.add_component("a", component) |
14 |
| - pipe.add_component("b", component) |
15 |
| - pipe.add_component("c", component) |
| 8 | + pipe.add_component("a", Component()) |
| 9 | + pipe.add_component("b", Component()) |
| 10 | + pipe.add_component("c", Component()) |
16 | 11 | pipe.connect("a", "b")
|
17 | 12 | pipe.connect("a", "c")
|
18 | 13 | return pipe
|
19 | 14 |
|
20 | 15 |
|
21 | 16 | @pytest.fixture(scope="function")
|
22 |
| -def pipeline_aggregation(component): |
| 17 | +def pipeline_aggregation() -> Pipeline: |
23 | 18 | pipe = Pipeline()
|
24 |
| - pipe.add_component("a", component) |
25 |
| - pipe.add_component("b", component) |
26 |
| - pipe.add_component("c", component) |
| 19 | + pipe.add_component("a", Component()) |
| 20 | + pipe.add_component("b", Component()) |
| 21 | + pipe.add_component("c", Component()) |
27 | 22 | pipe.connect("a", "b")
|
28 | 23 | pipe.connect("a", "c")
|
29 | 24 | return pipe
|
30 | 25 |
|
31 | 26 |
|
32 |
| -def test_orchestrator_branch(pipeline_branch): |
| 27 | +async def test_orchestrator_branch(pipeline_branch: Pipeline) -> None: |
33 | 28 | orchestrator = Orchestrator(pipeline=pipeline_branch)
|
34 | 29 | node_a = pipeline_branch.get_node_by_name("a")
|
35 |
| - node_a.status = RunStatus.DONE |
36 |
| - next_tasks = orchestrator.next(node_a) |
| 30 | + node_a.status = RunStatus.DONE # type: ignore |
| 31 | + next_tasks = [n async for n in orchestrator.next(node_a)] # type: ignore |
37 | 32 | next_task_names = [n.name for n in next_tasks]
|
38 | 33 | assert next_task_names == ["b", "c"]
|
39 | 34 |
|
40 | 35 |
|
41 |
| -def test_orchestrator_aggregation(pipeline_aggregation): |
| 36 | +async def test_orchestrator_aggregation(pipeline_aggregation: Pipeline) -> None: |
42 | 37 | orchestrator = Orchestrator(pipeline=pipeline_aggregation)
|
43 | 38 | node_a = pipeline_aggregation.get_node_by_name("a")
|
44 |
| - node_a.status = RunStatus.DONE |
| 39 | + node_a.status = RunStatus.DONE # type: ignore |
45 | 40 | node_b = pipeline_aggregation.get_node_by_name("b")
|
46 |
| - node_b.status = RunStatus.DONE |
47 |
| - next_tasks = orchestrator.next(node_a) |
| 41 | + node_b.status = RunStatus.DONE # type: ignore |
| 42 | + next_tasks = [n async for n in orchestrator.next(node_a)] # type: ignore |
48 | 43 | next_task_names = [n.name for n in next_tasks]
|
49 | 44 | assert next_task_names == ["c"]
|
50 | 45 |
|
51 | 46 |
|
52 |
| -def test_orchestrator_aggregation_waiting(pipeline_aggregation): |
| 47 | +async def test_orchestrator_aggregation_waiting(pipeline_aggregation: Pipeline) -> None: |
53 | 48 | orchestrator = Orchestrator(pipeline=pipeline_aggregation)
|
54 | 49 | node_a = pipeline_aggregation.get_node_by_name("a")
|
55 |
| - node_a.status = RunStatus.DONE |
| 50 | + node_a.status = RunStatus.DONE # type: ignore |
56 | 51 | node_b = pipeline_aggregation.get_node_by_name("a")
|
57 |
| - node_b.status = RunStatus.UNKNOWN |
58 |
| - next_tasks = orchestrator.next(node_a) |
| 52 | + node_b.status = RunStatus.UNKNOWN # type: ignore |
| 53 | + next_tasks = [n async for n in orchestrator.next(node_a)] # type: ignore |
59 | 54 | next_task_names = [n.name for n in next_tasks]
|
60 | 55 | assert next_task_names == []
|
0 commit comments