28
28
SchemaBuilder ,
29
29
SchemaEntity ,
30
30
SchemaRelation ,
31
+ SchemaFromTextExtractor ,
31
32
)
32
33
from neo4j_graphrag .experimental .components .text_splitters .fixed_size_splitter import (
33
34
FixedSizeSplitter ,
@@ -116,8 +117,21 @@ def test_simple_kg_pipeline_config_chunk_embedder(
116
117
assert chunk_embedder ._embedder == embedder
117
118
118
119
119
- def test_simple_kg_pipeline_config_schema () -> None :
120
+ @patch (
121
+ "neo4j_graphrag.experimental.pipeline.config.template_pipeline.simple_kg_builder.SimpleKGPipelineConfig.get_default_llm"
122
+ )
123
+ def test_simple_kg_pipeline_config_automatic_schema (
124
+ mock_llm : Mock , llm : LLMInterface
125
+ ) -> None :
126
+ mock_llm .return_value = llm
120
127
config = SimpleKGPipelineConfig ()
128
+ schema = config ._get_schema ()
129
+ assert isinstance (schema , SchemaFromTextExtractor )
130
+ assert schema ._llm == llm
131
+
132
+
133
+ def test_simple_kg_pipeline_config_manual_schema () -> None :
134
+ config = SimpleKGPipelineConfig (entities = ["Person" ])
121
135
assert isinstance (config ._get_schema (), SchemaBuilder )
122
136
123
137
@@ -205,9 +219,10 @@ def test_simple_kg_pipeline_config_connections_from_pdf() -> None:
205
219
perform_entity_resolution = False ,
206
220
)
207
221
connections = config ._get_connections ()
208
- assert len (connections ) == 5
222
+ assert len (connections ) == 6
209
223
expected_connections = [
210
224
("pdf_loader" , "splitter" ),
225
+ ("pdf_loader" , "schema" ),
211
226
("schema" , "extractor" ),
212
227
("splitter" , "chunk_embedder" ),
213
228
("chunk_embedder" , "extractor" ),
@@ -240,9 +255,10 @@ def test_simple_kg_pipeline_config_connections_with_er() -> None:
240
255
perform_entity_resolution = True ,
241
256
)
242
257
connections = config ._get_connections ()
243
- assert len (connections ) == 6
258
+ assert len (connections ) == 7
244
259
expected_connections = [
245
260
("pdf_loader" , "splitter" ),
261
+ ("pdf_loader" , "schema" ),
246
262
("schema" , "extractor" ),
247
263
("splitter" , "chunk_embedder" ),
248
264
("chunk_embedder" , "extractor" ),
@@ -263,7 +279,8 @@ def test_simple_kg_pipeline_config_run_params_from_pdf_file_path() -> None:
263
279
def test_simple_kg_pipeline_config_run_params_from_text_text () -> None :
264
280
config = SimpleKGPipelineConfig (from_pdf = False )
265
281
assert config .get_run_params ({"text" : "my text" }) == {
266
- "splitter" : {"text" : "my text" }
282
+ "splitter" : {"text" : "my text" },
283
+ "schema" : {"text" : "my text" },
267
284
}
268
285
269
286
0 commit comments