Skip to content

Commit b307fa0

Browse files
committed
mypy on tests
1 parent d550888 commit b307fa0

File tree

2 files changed

+23
-49
lines changed

2 files changed

+23
-49
lines changed

tests/unit/test_core_orchestrator.py

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,58 +3,53 @@
33

44

55
@pytest.fixture(scope="function")
6-
def component():
7-
return Component()
8-
9-
10-
@pytest.fixture(scope="function")
11-
def pipeline_branch(component):
6+
def pipeline_branch() -> Pipeline:
127
pipe = Pipeline()
13-
pipe.add_component("a", component)
14-
pipe.add_component("b", component)
15-
pipe.add_component("c", component)
8+
pipe.add_component("a", Component())
9+
pipe.add_component("b", Component())
10+
pipe.add_component("c", Component())
1611
pipe.connect("a", "b")
1712
pipe.connect("a", "c")
1813
return pipe
1914

2015

2116
@pytest.fixture(scope="function")
22-
def pipeline_aggregation(component):
17+
def pipeline_aggregation() -> Pipeline:
2318
pipe = Pipeline()
24-
pipe.add_component("a", component)
25-
pipe.add_component("b", component)
26-
pipe.add_component("c", component)
19+
pipe.add_component("a", Component())
20+
pipe.add_component("b", Component())
21+
pipe.add_component("c", Component())
2722
pipe.connect("a", "b")
2823
pipe.connect("a", "c")
2924
return pipe
3025

3126

32-
def test_orchestrator_branch(pipeline_branch):
27+
async def test_orchestrator_branch(pipeline_branch: Pipeline) -> None:
3328
orchestrator = Orchestrator(pipeline=pipeline_branch)
3429
node_a = pipeline_branch.get_node_by_name("a")
35-
node_a.status = RunStatus.DONE
36-
next_tasks = orchestrator.next(node_a)
30+
node_a.status = RunStatus.DONE # type: ignore
31+
next_tasks = [n async for n in orchestrator.next(node_a)] # type: ignore
3732
next_task_names = [n.name for n in next_tasks]
3833
assert next_task_names == ["b", "c"]
3934

4035

41-
def test_orchestrator_aggregation(pipeline_aggregation):
36+
async def test_orchestrator_aggregation(pipeline_aggregation: Pipeline) -> None:
4237
orchestrator = Orchestrator(pipeline=pipeline_aggregation)
4338
node_a = pipeline_aggregation.get_node_by_name("a")
44-
node_a.status = RunStatus.DONE
39+
node_a.status = RunStatus.DONE # type: ignore
4540
node_b = pipeline_aggregation.get_node_by_name("b")
46-
node_b.status = RunStatus.DONE
47-
next_tasks = orchestrator.next(node_a)
41+
node_b.status = RunStatus.DONE # type: ignore
42+
next_tasks = [n async for n in orchestrator.next(node_a)] # type: ignore
4843
next_task_names = [n.name for n in next_tasks]
4944
assert next_task_names == ["c"]
5045

5146

52-
def test_orchestrator_aggregation_waiting(pipeline_aggregation):
47+
async def test_orchestrator_aggregation_waiting(pipeline_aggregation: Pipeline) -> None:
5348
orchestrator = Orchestrator(pipeline=pipeline_aggregation)
5449
node_a = pipeline_aggregation.get_node_by_name("a")
55-
node_a.status = RunStatus.DONE
50+
node_a.status = RunStatus.DONE # type: ignore
5651
node_b = pipeline_aggregation.get_node_by_name("a")
57-
node_b.status = RunStatus.UNKNOWN
58-
next_tasks = orchestrator.next(node_a)
52+
node_b.status = RunStatus.UNKNOWN # type: ignore
53+
next_tasks = [n async for n in orchestrator.next(node_a)] # type: ignore
5954
next_task_names = [n.name for n in next_tasks]
6055
assert next_task_names == []

tests/unit/test_core_pipeline.py

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,29 +6,8 @@
66
from neo4j_genai.core.pipeline import Component, Pipeline
77

88

9-
@pytest.fixture(scope="function")
10-
def component_multiply():
11-
class ComponentMultiply(Component):
12-
def __init__(self, r: float = 2.0) -> None:
13-
self.r = r
14-
15-
async def run(self, number: float):
16-
return {"product": number * self.r}
17-
18-
return ComponentMultiply()
19-
20-
21-
@pytest.fixture(scope="function")
22-
def component_add():
23-
class ComponentAdd(Component):
24-
async def run(self, number1: float, number2: float):
25-
return {"sum": number1 + number2}
26-
27-
return ComponentAdd()
28-
29-
309
@pytest.mark.asyncio
31-
async def test_simple_pipeline_two_components():
10+
async def test_simple_pipeline_two_components() -> None:
3211
pipe = Pipeline()
3312
component_a = AsyncMock(spec=Component)
3413
component_a.run = AsyncMock(return_value={})
@@ -45,7 +24,7 @@ async def test_simple_pipeline_two_components():
4524

4625

4726
@pytest.mark.asyncio
48-
async def test_simple_pipeline_two_components_parameter_propagation():
27+
async def test_simple_pipeline_two_components_parameter_propagation() -> None:
4928
pipe = Pipeline()
5029
component_a = AsyncMock(spec=Component)
5130
component_a.run = AsyncMock(return_value={"product": 20})
@@ -68,7 +47,7 @@ async def test_simple_pipeline_two_components_parameter_propagation():
6847

6948

7049
@pytest.mark.asyncio
71-
async def test_pipeline_branches():
50+
async def test_pipeline_branches() -> None:
7251
pipe = Pipeline()
7352
component_a = AsyncMock(spec=Component)
7453
component_a.run = AsyncMock(return_value={})
@@ -88,7 +67,7 @@ async def test_pipeline_branches():
8867

8968

9069
@pytest.mark.asyncio
91-
async def test_pipeline_aggregation():
70+
async def test_pipeline_aggregation() -> None:
9271
pipe = Pipeline()
9372
component_a = AsyncMock(spec=Component)
9473
component_a.run = AsyncMock(return_value={})

0 commit comments

Comments
 (0)