Skip to content

Commit 3d3f988

Browse files
committed
Move components test in the example folder - add some tests
1 parent a38f895 commit 3d3f988

File tree

10 files changed

+275
-11
lines changed

10 files changed

+275
-11
lines changed

examples/pipeline/__init__.py

Whitespace-only changes.

src/neo4j_genai/components/kg_builder.py renamed to examples/pipeline/kg_builder.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
# #
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
# #
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
# #
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
115
from __future__ import annotations
216

317
import asyncio

src/neo4j_genai/components/rag.py renamed to examples/pipeline/rag.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
# #
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
# #
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
# #
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
115
from __future__ import annotations
216

317
import asyncio

poetry.lock

Lines changed: 19 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ scipy = [
6161
{version = "^1", python = "<3.12"},
6262
{version = "^1.7.0", python = ">=3.12"}
6363
]
64+
pytest-asyncio = "^0.23.8"
6465

6566
[tool.poetry.extras]
6667
external_clients = ["weaviate-client", "pinecone-client"]

src/neo4j_genai/core/graph.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
# #
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
# #
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
# #
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
115
"""
216
Basic graph structure for Pipeline
317
"""

src/neo4j_genai/core/pipeline.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
# #
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
# #
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
# #
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
115
"""
216
Pipeline implementation.
317
@@ -118,26 +132,26 @@ def save_results(self, node: Node, res: RunResult) -> None:
118132
node.name, res.result or {}, is_final=node.is_leaf()
119133
)
120134

121-
async def run_node(self, node: TaskNode, data: dict[str, Any]) -> None:
122-
await node.run(data, callback=self.on_task_complete)
135+
async def run_task(self, task: TaskNode, data: dict[str, Any]) -> None:
136+
await task.run(data, callback=self.on_task_complete)
123137

124138
async def on_task_complete(
125-
self, node: TaskNode, data: dict[str, Any], res: RunResult
139+
self, task: TaskNode, data: dict[str, Any], res: RunResult
126140
) -> None:
127-
self.save_results(node, res)
128-
await asyncio.gather(*[self.run_node(n, data) for n in self.next(node)])
141+
self.pipeline.on_task_complete(task, res)
142+
await asyncio.gather(*[self.run_task(n, data) for n in self.next(task)])
129143

130-
def check_dependencies_complete(self, node: TaskNode) -> None:
131-
dependencies = self.pipeline.previous_edges(node)
144+
def check_dependencies_complete(self, task: TaskNode) -> None:
145+
dependencies = self.pipeline.previous_edges(task)
132146
for d in dependencies:
133147
if d.start.status != RunStatus.DONE: # type: ignore
134148
logger.warning(
135-
f"Missing dependency {d.start.name} for {node.name} (status: {d.start.status})" # type: ignore
149+
f"Missing dependency {d.start.name} for {task.name} (status: {d.start.status})" # type: ignore
136150
)
137151
raise MissingDependencyError()
138152

139-
def next(self, node: TaskNode) -> Generator[TaskNode, None, None]:
140-
possible_nexts = self.pipeline.next_edges(node)
153+
def next(self, task: TaskNode) -> Generator[TaskNode, None, None]:
154+
possible_nexts = self.pipeline.next_edges(task)
141155
for next_edge in possible_nexts:
142156
next_node = next_edge.end
143157
# check status
@@ -185,6 +199,13 @@ def connect( # type: ignore
185199
if self.is_cyclic():
186200
raise Exception("Cyclic graph")
187201

202+
def on_task_complete(self, node, result) -> None:
203+
self.add_result_for_component(
204+
node.name,
205+
result.result,
206+
is_final=node.is_leaf()
207+
)
208+
188209
def add_result_for_component(
189210
self, name: str, result: dict[str, Any], is_final: bool = False
190211
) -> None:

src/neo4j_genai/core/stores.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,20 @@
1+
# Copyright (c) "Neo4j"
2+
# Neo4j Sweden AB [https://neo4j.com]
3+
# #
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
# #
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
# #
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Result store interface
16+
and in-memory store implementation.
17+
"""
118
from __future__ import annotations
219

320
import abc

tests/unit/test_core_orchestrator.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import pytest
2+
3+
from neo4j_genai.core.graph import Node, Graph
4+
from neo4j_genai.core.pipeline import Orchestrator, Component, Pipeline, RunStatus
5+
6+
7+
@pytest.fixture(scope="function")
8+
def component():
9+
return Component()
10+
11+
12+
@pytest.fixture(scope="function")
13+
def pipeline_branch(component):
14+
pipe = Pipeline()
15+
pipe.add_component("a", component)
16+
pipe.add_component("b", component)
17+
pipe.add_component("c", component)
18+
pipe.connect("a", "b")
19+
pipe.connect("a", "c")
20+
return pipe
21+
22+
23+
@pytest.fixture(scope="function")
24+
def pipeline_aggregation(component):
25+
pipe = Pipeline()
26+
pipe.add_component("a", component)
27+
pipe.add_component("b", component)
28+
pipe.add_component("c", component)
29+
pipe.connect("a", "b")
30+
pipe.connect("a", "c")
31+
return pipe
32+
33+
34+
def test_orchestrator_branch(pipeline_branch):
35+
orchestrator = Orchestrator(pipeline=pipeline_branch)
36+
node_a = pipeline_branch.get_node_by_name("a")
37+
node_a.status = RunStatus.DONE
38+
next_tasks = orchestrator.next(node_a)
39+
next_task_names = [n.name for n in next_tasks]
40+
assert next_task_names == ["b", "c"]
41+
42+
43+
def test_orchestrator_aggregation(pipeline_aggregation):
44+
orchestrator = Orchestrator(pipeline=pipeline_aggregation)
45+
node_a = pipeline_aggregation.get_node_by_name("a")
46+
node_a.status = RunStatus.DONE
47+
node_b = pipeline_aggregation.get_node_by_name("b")
48+
node_b.status = RunStatus.DONE
49+
next_tasks = orchestrator.next(node_a)
50+
next_task_names = [n.name for n in next_tasks]
51+
assert next_task_names == ["c"]
52+
53+
54+
def test_orchestrator_aggregation_waiting(pipeline_aggregation):
55+
orchestrator = Orchestrator(pipeline=pipeline_aggregation)
56+
node_a = pipeline_aggregation.get_node_by_name("a")
57+
node_a.status = RunStatus.DONE
58+
node_b = pipeline_aggregation.get_node_by_name("a")
59+
node_b.status = RunStatus.UNKNOWN
60+
next_tasks = orchestrator.next(node_a)
61+
next_task_names = [n.name for n in next_tasks]
62+
assert next_task_names == []

tests/unit/test_core_pipeline.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
from __future__ import annotations
2+
from typing import Any
3+
from unittest.mock import AsyncMock
4+
5+
import pytest
6+
7+
from neo4j_genai.core.pipeline import Component, Pipeline
8+
9+
10+
@pytest.fixture(scope="function")
11+
def component_multiply():
12+
class ComponentMultiply(Component):
13+
def __init__(self, r: float = 2.0) -> None:
14+
self.r = r
15+
16+
async def run(self, number: float):
17+
return {"product": number * self.r}
18+
19+
return ComponentMultiply()
20+
21+
22+
@pytest.fixture(scope="function")
23+
def component_add():
24+
class ComponentAdd(Component):
25+
async def run(self, number1: float, number2: float):
26+
return {"sum": number1 + number2}
27+
28+
return ComponentAdd()
29+
30+
31+
@pytest.mark.asyncio
32+
async def test_simple_pipeline_two_components():
33+
pipe = Pipeline()
34+
component_a = AsyncMock(spec=Component)
35+
component_a.run = AsyncMock(return_value={})
36+
component_b = AsyncMock(spec=Component)
37+
component_b.run = AsyncMock(return_value={"result": "result"})
38+
pipe.add_component("a", component_a)
39+
pipe.add_component("b", component_b)
40+
pipe.connect("a", "b", {})
41+
res = await pipe.run({})
42+
assert component_a.run.called_one_with({})
43+
assert component_b.run.called_one_with({})
44+
assert "b" in res
45+
assert res["b"] == {"result": "result"}
46+
47+
48+
@pytest.mark.asyncio
49+
async def test_simple_pipeline_two_components_parameter_propagation():
50+
pipe = Pipeline()
51+
component_a = AsyncMock(spec=Component)
52+
component_a.run = AsyncMock(return_value={"product": 20})
53+
component_b = AsyncMock(spec=Component)
54+
component_b.run = AsyncMock(return_value={"sum": 54})
55+
pipe.add_component("a", component_a)
56+
pipe.add_component("b", component_b)
57+
# first component output product goes to second component input number1
58+
pipe.connect("a", "b", {
59+
"number1": "a.product",
60+
})
61+
res = await pipe.run({"a": {}, "b": {"number2": 1}})
62+
assert component_a.run.called_one_with({})
63+
assert component_b.run.called_one_with({"number1": 20, "number2": 1})
64+
assert res == {"b": {"sum": 54}}
65+
66+
67+
@pytest.mark.asyncio
68+
async def test_pipeline_branches():
69+
pipe = Pipeline()
70+
component_a = AsyncMock(spec=Component)
71+
component_a.run = AsyncMock(return_value={})
72+
component_b = AsyncMock(spec=Component)
73+
component_b.run = AsyncMock(return_value={})
74+
component_c = AsyncMock(spec=Component)
75+
component_c.run = AsyncMock(return_value={})
76+
77+
pipe.add_component("a", component_a)
78+
pipe.add_component("b", component_b)
79+
pipe.add_component("c", component_c)
80+
pipe.connect("a", "b")
81+
pipe.connect("a", "c")
82+
res = await pipe.run({})
83+
assert "b" in res
84+
assert "c" in res
85+
86+
87+
@pytest.mark.asyncio
88+
async def test_pipeline_aggregation():
89+
pipe = Pipeline()
90+
component_a = AsyncMock(spec=Component)
91+
component_a.run = AsyncMock(return_value={})
92+
component_b = AsyncMock(spec=Component)
93+
component_b.run = AsyncMock(return_value={})
94+
component_c = AsyncMock(spec=Component)
95+
component_c.run = AsyncMock(return_value={})
96+
97+
pipe.add_component("a", component_a)
98+
pipe.add_component("b", component_b)
99+
pipe.add_component("c", component_c)
100+
pipe.connect("a", "c")
101+
pipe.connect("b", "c")
102+
res = await pipe.run({})
103+
assert "c" in res

0 commit comments

Comments
 (0)