Skip to content

Commit 5183439

Browse files
Fix unit tests
1 parent b412a05 commit 5183439

File tree

1 file changed

+21
-4
lines changed

1 file changed

+21
-4
lines changed

tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
SchemaBuilder,
2929
SchemaEntity,
3030
SchemaRelation,
31+
SchemaFromTextExtractor,
3132
)
3233
from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import (
3334
FixedSizeSplitter,
@@ -116,8 +117,21 @@ def test_simple_kg_pipeline_config_chunk_embedder(
116117
assert chunk_embedder._embedder == embedder
117118

118119

119-
def test_simple_kg_pipeline_config_schema() -> None:
120+
@patch(
121+
"neo4j_graphrag.experimental.pipeline.config.template_pipeline.simple_kg_builder.SimpleKGPipelineConfig.get_default_llm"
122+
)
123+
def test_simple_kg_pipeline_config_automatic_schema(
124+
mock_llm: Mock, llm: LLMInterface
125+
) -> None:
126+
mock_llm.return_value = llm
120127
config = SimpleKGPipelineConfig()
128+
schema = config._get_schema()
129+
assert isinstance(schema, SchemaFromTextExtractor)
130+
assert schema._llm == llm
131+
132+
133+
def test_simple_kg_pipeline_config_manual_schema() -> None:
134+
config = SimpleKGPipelineConfig(entities=["Person"])
121135
assert isinstance(config._get_schema(), SchemaBuilder)
122136

123137

@@ -205,9 +219,10 @@ def test_simple_kg_pipeline_config_connections_from_pdf() -> None:
205219
perform_entity_resolution=False,
206220
)
207221
connections = config._get_connections()
208-
assert len(connections) == 5
222+
assert len(connections) == 6
209223
expected_connections = [
210224
("pdf_loader", "splitter"),
225+
("pdf_loader", "schema"),
211226
("schema", "extractor"),
212227
("splitter", "chunk_embedder"),
213228
("chunk_embedder", "extractor"),
@@ -240,9 +255,10 @@ def test_simple_kg_pipeline_config_connections_with_er() -> None:
240255
perform_entity_resolution=True,
241256
)
242257
connections = config._get_connections()
243-
assert len(connections) == 6
258+
assert len(connections) == 7
244259
expected_connections = [
245260
("pdf_loader", "splitter"),
261+
("pdf_loader", "schema"),
246262
("schema", "extractor"),
247263
("splitter", "chunk_embedder"),
248264
("chunk_embedder", "extractor"),
@@ -263,7 +279,8 @@ def test_simple_kg_pipeline_config_run_params_from_pdf_file_path() -> None:
263279
def test_simple_kg_pipeline_config_run_params_from_text_text() -> None:
264280
config = SimpleKGPipelineConfig(from_pdf=False)
265281
assert config.get_run_params({"text": "my text"}) == {
266-
"splitter": {"text": "my text"}
282+
"splitter": {"text": "my text"},
283+
"schema": {"text": "my text"},
267284
}
268285

269286

0 commit comments

Comments
 (0)