Skip to content

Commit 7ec3304

Browse files
committed
Add tests
1 parent 1ea08f6 commit 7ec3304

File tree

3 files changed

+56
-2
lines changed

3 files changed

+56
-2
lines changed

src/neo4j_graphrag/experimental/pipeline/component.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ class Component(abc.ABC, metaclass=ComponentMeta):
8181
# added here for the type checker
8282
# DO NOT CHANGE
8383
component_inputs: dict[str, dict[str, str | bool]]
84-
component_outputs: dict[str, dict[str, str | bool]]
84+
component_outputs: dict[str, dict[str, str | bool | type]]
8585

8686
@abc.abstractmethod
8787
async def run(self, *args: Any, **kwargs: Any) -> DataModel:

tests/unit/experimental/pipeline/components.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
from neo4j_graphrag.experimental.pipeline import Component, DataModel
16+
from neo4j_graphrag.experimental.pipeline.types.context import RunContext
1617

1718

1819
class StringResultModel(DataModel):
@@ -41,3 +42,13 @@ async def run(self, number1: int, number2: int) -> IntResultModel:
4142
class ComponentMultiply(Component):
4243
async def run(self, number1: int, number2: int = 2) -> IntResultModel:
4344
return IntResultModel(result=number1 * number2)
45+
46+
47+
class ComponentMultiplyWithContext(Component):
48+
async def run_with_context(
49+
self, context_: RunContext, number1: int, number2: int = 2
50+
) -> IntResultModel:
51+
await context_.notify(
52+
message="my message", data={"number1": number1, "number2": number2}
53+
)
54+
return IntResultModel(result=number1 * number2)

tests/unit/experimental/pipeline/test_component.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,12 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
from .components import ComponentMultiply
15+
from unittest.mock import MagicMock, AsyncMock
16+
17+
import pytest
18+
19+
from neo4j_graphrag.experimental.pipeline.types.context import RunContext
20+
from .components import ComponentMultiply, ComponentMultiplyWithContext, IntResultModel
1621

1722

1823
def test_component_inputs() -> None:
@@ -26,3 +31,41 @@ def test_component_inputs() -> None:
2631
def test_component_outputs() -> None:
2732
outputs = ComponentMultiply.component_outputs
2833
assert "result" in outputs
34+
assert outputs["result"]["has_default"] is True
35+
assert outputs["result"]["annotation"] == int
36+
37+
38+
@pytest.mark.asyncio
39+
async def test_component_run() -> None:
40+
c = ComponentMultiply()
41+
result = await c.run(number1=1, number2=2)
42+
assert isinstance(result, IntResultModel)
43+
assert isinstance(
44+
result.result, ComponentMultiply.component_outputs["result"]["annotation"]
45+
)
46+
47+
48+
@pytest.mark.asyncio
49+
async def test_component_run_with_context_default_implementation() -> None:
50+
c = ComponentMultiply()
51+
result = await c.run_with_context(
52+
# context can not be null in the function signature,
53+
# but it's ignored in this case
54+
None, # type: ignore
55+
number1=1,
56+
number2=2,
57+
)
58+
assert result.result == 2
59+
60+
61+
@pytest.mark.asyncio
62+
async def test_component_run_with_context() -> None:
63+
c = ComponentMultiplyWithContext()
64+
notifier_mock = AsyncMock()
65+
result = await c.run_with_context(
66+
RunContext(run_id="run_id", task_name="task_name", notifier=notifier_mock),
67+
number1=1,
68+
number2=2,
69+
)
70+
assert result.result == 2
71+
notifier_mock.assert_awaited_once()

0 commit comments

Comments
 (0)