Skip to content

Commit c5b3767

Browse files
committed
Add unit tests
1 parent a72c314 commit c5b3767

File tree

2 files changed

+67
-9
lines changed

2 files changed

+67
-9
lines changed

tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py renamed to tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_pipeline_config.py

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
)
2525
from neo4j_graphrag.experimental.components.kg_writer import Neo4jWriter
2626
from neo4j_graphrag.experimental.components.pdf_loader import PdfLoader
27+
from neo4j_graphrag.experimental.components.resolver import (
28+
SinglePropertyExactMatchResolver,
29+
)
2730
from neo4j_graphrag.experimental.components.schema import (
2831
SchemaBuilder,
2932
SchemaEntity,
@@ -152,6 +155,28 @@ def test_simple_kg_pipeline_config_extractor(mock_llm: Mock, llm: LLMInterface)
152155
assert extractor.prompt_template.template == "my template {text}"
153156

154157

158+
@patch(
159+
"neo4j_graphrag.experimental.pipeline.config.template_pipeline.simple_kg_builder.SimpleKGPipelineConfig.get_default_llm"
160+
)
161+
@patch("neo4j_graphrag.experimental.pipeline.config.object_config.ComponentType.parse")
162+
def test_simple_kg_pipeline_config_extractor_overwrite(
163+
mock_component_parse: Mock, mock_llm: Mock
164+
) -> None:
165+
my_extractor = LLMEntityRelationExtractor(llm=mock_llm)
166+
mock_component_parse.return_value = my_extractor
167+
config = SimpleKGPipelineConfig(
168+
on_error="IGNORE", # type: ignore
169+
prompt_template=ERExtractionTemplate(template="my template {text}"),
170+
extractor={}, # type: ignore
171+
)
172+
extractor = config._get_extractor()
173+
assert isinstance(extractor, LLMEntityRelationExtractor)
174+
assert extractor.llm == mock_llm
175+
# default values are not overwritten by the parameters:
176+
assert extractor.on_error == OnError.RAISE
177+
assert extractor.prompt_template.template == ERExtractionTemplate.DEFAULT_TEMPLATE
178+
179+
155180
@patch(
156181
"neo4j_graphrag.experimental.components.kg_writer.get_version",
157182
return_value=((5, 23, 0), False, False),
@@ -184,13 +209,10 @@ def test_simple_kg_pipeline_config_writer_overwrite(
184209
_: Mock,
185210
driver: neo4j.Driver,
186211
) -> None:
187-
my_writer_config: ComponentConfig[Neo4jWriter] = ComponentConfig(
188-
class_="",
189-
)
190212
my_writer = Neo4jWriter(driver, neo4j_database="my_db")
191213
mock_component_parse.return_value = my_writer
192214
config = SimpleKGPipelineConfig(
193-
kg_writer=my_writer_config, # type: ignore
215+
kg_writer={}, # type: ignore
194216
neo4j_database="my_other_db",
195217
)
196218
writer: Neo4jWriter = config._get_writer() # type: ignore
@@ -199,6 +221,42 @@ def test_simple_kg_pipeline_config_writer_overwrite(
199221
assert writer.neo4j_database == "my_db"
200222

201223

224+
@patch(
225+
"neo4j_graphrag.experimental.pipeline.config.template_pipeline.simple_kg_builder.SimpleKGPipelineConfig.get_default_llm"
226+
)
227+
@patch("neo4j_graphrag.experimental.pipeline.config.object_config.ComponentType.parse")
228+
def test_simple_kg_pipeline_config_resolver_overwrite(
229+
mock_component_parse: Mock, driver: neo4j.Driver
230+
) -> None:
231+
my_resolver = SinglePropertyExactMatchResolver(driver, resolve_property="full_name")
232+
mock_component_parse.return_value = my_resolver
233+
config = SimpleKGPipelineConfig(
234+
perform_entity_resolution=True,
235+
resolver={}, # type: ignore
236+
)
237+
resolver = config._get_resolver()
238+
assert isinstance(resolver, SinglePropertyExactMatchResolver)
239+
assert resolver.driver == driver
240+
assert resolver.resolve_property == "full_name"
241+
242+
243+
@patch(
244+
"neo4j_graphrag.experimental.pipeline.config.template_pipeline.simple_kg_builder.SimpleKGPipelineConfig.get_default_llm"
245+
)
246+
@patch("neo4j_graphrag.experimental.pipeline.config.object_config.ComponentType.parse")
247+
def test_simple_kg_pipeline_config_resolver_overwrite_but_disabled(
248+
mock_component_parse: Mock, driver: neo4j.Driver
249+
) -> None:
250+
my_resolver = SinglePropertyExactMatchResolver(driver, resolve_property="full_name")
251+
mock_component_parse.return_value = my_resolver
252+
config = SimpleKGPipelineConfig(
253+
perform_entity_resolution=False,
254+
resolver={}, # type: ignore
255+
)
256+
resolver = config._get_resolver()
257+
assert resolver is None
258+
259+
202260
def test_simple_kg_pipeline_config_connections_from_pdf() -> None:
203261
config = SimpleKGPipelineConfig(
204262
from_pdf=True,
@@ -234,7 +292,7 @@ def test_simple_kg_pipeline_config_connections_from_text() -> None:
234292
assert (actual.start, actual.end) == expected
235293

236294

237-
def test_simple_kg_pipeline_config_connections_with_er() -> None:
295+
def test_simple_kg_pipeline_config_connections_with_entity_resolution() -> None:
238296
config = SimpleKGPipelineConfig(
239297
from_pdf=True,
240298
perform_entity_resolution=True,

tests/unit/experimental/pipeline/test_kg_builder.py renamed to tests/unit/experimental/pipeline/test_simple_kg_pipeline.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
return_value=((5, 23, 0), False, False),
3535
)
3636
@pytest.mark.asyncio
37-
async def test_knowledge_graph_builder_document_info_with_file(_: Mock) -> None:
37+
async def test_simple_kg_pipeline_document_info_with_file(_: Mock) -> None:
3838
llm = MagicMock(spec=LLMInterface)
3939
driver = MagicMock(spec=neo4j.Driver)
4040
embedder = MagicMock(spec=Embedder)
@@ -66,7 +66,7 @@ async def test_knowledge_graph_builder_document_info_with_file(_: Mock) -> None:
6666
return_value=((5, 23, 0), False, False),
6767
)
6868
@pytest.mark.asyncio
69-
async def test_knowledge_graph_builder_document_info_with_text(_: Mock) -> None:
69+
async def test_simple_kg_pipeline_document_info_with_text(_: Mock) -> None:
7070
llm = MagicMock(spec=LLMInterface)
7171
driver = MagicMock(spec=neo4j.Driver)
7272
embedder = MagicMock(spec=Embedder)
@@ -97,7 +97,7 @@ async def test_knowledge_graph_builder_document_info_with_text(_: Mock) -> None:
9797
return_value=((5, 23, 0), False, False),
9898
)
9999
@pytest.mark.asyncio
100-
async def test_knowledge_graph_builder_with_entities_and_file(_: Mock) -> None:
100+
async def test_simple_kg_pipeline_with_entities_and_file(_: Mock) -> None:
101101
llm = MagicMock(spec=LLMInterface)
102102
driver = MagicMock(spec=neo4j.Driver)
103103
embedder = MagicMock(spec=Embedder)
@@ -156,7 +156,7 @@ def test_simple_kg_pipeline_on_error_invalid_value() -> None:
156156
return_value=((5, 23, 0), False, False),
157157
)
158158
@pytest.mark.asyncio
159-
async def test_knowledge_graph_builder_with_lexical_graph_config(_: Mock) -> None:
159+
async def test_simple_kg_pipeline_with_lexical_graph_config(_: Mock) -> None:
160160
llm = MagicMock(spec=LLMInterface)
161161
driver = MagicMock(spec=neo4j.Driver)
162162
embedder = MagicMock(spec=Embedder)

0 commit comments

Comments
 (0)