24
24
)
25
25
from neo4j_graphrag .experimental .components .kg_writer import Neo4jWriter
26
26
from neo4j_graphrag .experimental .components .pdf_loader import PdfLoader
27
+ from neo4j_graphrag .experimental .components .resolver import (
28
+ SinglePropertyExactMatchResolver ,
29
+ )
27
30
from neo4j_graphrag .experimental .components .schema import (
28
31
SchemaBuilder ,
29
32
SchemaEntity ,
@@ -152,6 +155,28 @@ def test_simple_kg_pipeline_config_extractor(mock_llm: Mock, llm: LLMInterface)
152
155
assert extractor .prompt_template .template == "my template {text}"
153
156
154
157
158
+ @patch (
159
+ "neo4j_graphrag.experimental.pipeline.config.template_pipeline.simple_kg_builder.SimpleKGPipelineConfig.get_default_llm"
160
+ )
161
+ @patch ("neo4j_graphrag.experimental.pipeline.config.object_config.ComponentType.parse" )
162
+ def test_simple_kg_pipeline_config_extractor_overwrite (
163
+ mock_component_parse : Mock , mock_llm : Mock
164
+ ) -> None :
165
+ my_extractor = LLMEntityRelationExtractor (llm = mock_llm )
166
+ mock_component_parse .return_value = my_extractor
167
+ config = SimpleKGPipelineConfig (
168
+ on_error = "IGNORE" , # type: ignore
169
+ prompt_template = ERExtractionTemplate (template = "my template {text}" ),
170
+ extractor = {}, # type: ignore
171
+ )
172
+ extractor = config ._get_extractor ()
173
+ assert isinstance (extractor , LLMEntityRelationExtractor )
174
+ assert extractor .llm == mock_llm
175
+ # default values are not overwritten by the parameters:
176
+ assert extractor .on_error == OnError .RAISE
177
+ assert extractor .prompt_template .template == ERExtractionTemplate .DEFAULT_TEMPLATE
178
+
179
+
155
180
@patch (
156
181
"neo4j_graphrag.experimental.components.kg_writer.get_version" ,
157
182
return_value = ((5 , 23 , 0 ), False , False ),
@@ -184,13 +209,10 @@ def test_simple_kg_pipeline_config_writer_overwrite(
184
209
_ : Mock ,
185
210
driver : neo4j .Driver ,
186
211
) -> None :
187
- my_writer_config : ComponentConfig [Neo4jWriter ] = ComponentConfig (
188
- class_ = "" ,
189
- )
190
212
my_writer = Neo4jWriter (driver , neo4j_database = "my_db" )
191
213
mock_component_parse .return_value = my_writer
192
214
config = SimpleKGPipelineConfig (
193
- kg_writer = my_writer_config , # type: ignore
215
+ kg_writer = {} , # type: ignore
194
216
neo4j_database = "my_other_db" ,
195
217
)
196
218
writer : Neo4jWriter = config ._get_writer () # type: ignore
@@ -199,6 +221,42 @@ def test_simple_kg_pipeline_config_writer_overwrite(
199
221
assert writer .neo4j_database == "my_db"
200
222
201
223
224
+ @patch (
225
+ "neo4j_graphrag.experimental.pipeline.config.template_pipeline.simple_kg_builder.SimpleKGPipelineConfig.get_default_llm"
226
+ )
227
+ @patch ("neo4j_graphrag.experimental.pipeline.config.object_config.ComponentType.parse" )
228
+ def test_simple_kg_pipeline_config_resolver_overwrite (
229
+ mock_component_parse : Mock , driver : neo4j .Driver
230
+ ) -> None :
231
+ my_resolver = SinglePropertyExactMatchResolver (driver , resolve_property = "full_name" )
232
+ mock_component_parse .return_value = my_resolver
233
+ config = SimpleKGPipelineConfig (
234
+ perform_entity_resolution = True ,
235
+ resolver = {}, # type: ignore
236
+ )
237
+ resolver = config ._get_resolver ()
238
+ assert isinstance (resolver , SinglePropertyExactMatchResolver )
239
+ assert resolver .driver == driver
240
+ assert resolver .resolve_property == "full_name"
241
+
242
+
243
+ @patch (
244
+ "neo4j_graphrag.experimental.pipeline.config.template_pipeline.simple_kg_builder.SimpleKGPipelineConfig.get_default_llm"
245
+ )
246
+ @patch ("neo4j_graphrag.experimental.pipeline.config.object_config.ComponentType.parse" )
247
+ def test_simple_kg_pipeline_config_resolver_overwrite_but_disabled (
248
+ mock_component_parse : Mock , driver : neo4j .Driver
249
+ ) -> None :
250
+ my_resolver = SinglePropertyExactMatchResolver (driver , resolve_property = "full_name" )
251
+ mock_component_parse .return_value = my_resolver
252
+ config = SimpleKGPipelineConfig (
253
+ perform_entity_resolution = False ,
254
+ resolver = {}, # type: ignore
255
+ )
256
+ resolver = config ._get_resolver ()
257
+ assert resolver is None
258
+
259
+
202
260
def test_simple_kg_pipeline_config_connections_from_pdf () -> None :
203
261
config = SimpleKGPipelineConfig (
204
262
from_pdf = True ,
@@ -234,7 +292,7 @@ def test_simple_kg_pipeline_config_connections_from_text() -> None:
234
292
assert (actual .start , actual .end ) == expected
235
293
236
294
237
- def test_simple_kg_pipeline_config_connections_with_er () -> None :
295
+ def test_simple_kg_pipeline_config_connections_with_entity_resolution () -> None :
238
296
config = SimpleKGPipelineConfig (
239
297
from_pdf = True ,
240
298
perform_entity_resolution = True ,
0 commit comments