25
25
)
26
26
from neo4j_graphrag .generation .prompts import Text2CypherTemplate
27
27
from neo4j_graphrag .llm import LLMResponse
28
- from neo4j_graphrag .retrievers import Text2CypherRetriever
28
+ from neo4j_graphrag .retrievers . text2cypher import Text2CypherRetriever , extract_cypher
29
29
from neo4j_graphrag .types import RetrieverResult , RetrieverResultItem
30
30
31
31
@@ -204,9 +204,11 @@ def test_t2c_retriever_with_result_format_function(
204
204
)
205
205
206
206
207
+ @patch ("neo4j_graphrag.retrievers.text2cypher.extract_cypher" )
207
208
@patch ("neo4j_graphrag.retrievers.base.get_version" )
208
209
def test_t2c_retriever_initialization_with_custom_prompt (
209
210
mock_get_version : MagicMock ,
211
+ mock_extract_cypher : MagicMock ,
210
212
driver : MagicMock ,
211
213
llm : MagicMock ,
212
214
neo4j_record : MagicMock ,
@@ -224,9 +226,11 @@ def test_t2c_retriever_initialization_with_custom_prompt(
224
226
llm .invoke .assert_called_once_with ("This is a custom prompt. test" )
225
227
226
228
229
+ @patch ("neo4j_graphrag.retrievers.text2cypher.extract_cypher" )
227
230
@patch ("neo4j_graphrag.retrievers.base.get_version" )
228
231
def test_t2c_retriever_initialization_with_custom_prompt_and_schema_and_examples (
229
232
mock_get_version : MagicMock ,
233
+ mock_extract_cypher : MagicMock ,
230
234
driver : MagicMock ,
231
235
llm : MagicMock ,
232
236
neo4j_record : MagicMock ,
@@ -254,9 +258,11 @@ def test_t2c_retriever_initialization_with_custom_prompt_and_schema_and_examples
254
258
llm .invoke .assert_called_once_with ("This is a custom prompt. test" )
255
259
256
260
261
+ @patch ("neo4j_graphrag.retrievers.text2cypher.extract_cypher" )
257
262
@patch ("neo4j_graphrag.retrievers.base.get_version" )
258
263
def test_t2c_retriever_initialization_with_custom_prompt_and_schema_and_examples_for_prompt_params (
259
264
mock_get_version : MagicMock ,
265
+ mock_extract_cypher : MagicMock ,
260
266
driver : MagicMock ,
261
267
llm : MagicMock ,
262
268
neo4j_record : MagicMock ,
@@ -286,9 +292,11 @@ def test_t2c_retriever_initialization_with_custom_prompt_and_schema_and_examples
286
292
)
287
293
288
294
295
+ @patch ("neo4j_graphrag.retrievers.text2cypher.extract_cypher" )
289
296
@patch ("neo4j_graphrag.retrievers.base.get_version" )
290
297
def test_t2c_retriever_initialization_with_custom_prompt_and_unused_schema_and_examples (
291
298
mock_get_version : MagicMock ,
299
+ mock_extract_cypher : MagicMock ,
292
300
driver : MagicMock ,
293
301
llm : MagicMock ,
294
302
neo4j_record : MagicMock ,
@@ -321,9 +329,13 @@ def test_t2c_retriever_initialization_with_custom_prompt_and_unused_schema_and_e
321
329
)
322
330
323
331
332
+ @patch ("neo4j_graphrag.retrievers.text2cypher.extract_cypher" )
324
333
@patch ("neo4j_graphrag.retrievers.base.get_version" )
325
334
def test_t2c_retriever_invalid_custom_prompt_type (
326
- mock_get_version : MagicMock , driver : MagicMock , llm : MagicMock
335
+ mock_get_version : MagicMock ,
336
+ mock_extract_cypher : MagicMock ,
337
+ driver : MagicMock ,
338
+ llm : MagicMock ,
327
339
) -> None :
328
340
mock_get_version .return_value = ((5 , 23 , 0 ), False , False )
329
341
with pytest .raises (RetrieverInitializationError ) as exc_info :
@@ -336,9 +348,11 @@ def test_t2c_retriever_invalid_custom_prompt_type(
336
348
assert "Input should be a valid string" in str (exc_info .value )
337
349
338
350
351
+ @patch ("neo4j_graphrag.retrievers.text2cypher.extract_cypher" )
339
352
@patch ("neo4j_graphrag.retrievers.base.get_version" )
340
353
def test_t2c_retriever_with_custom_prompt_prompt_params (
341
354
mock_get_version : MagicMock ,
355
+ mock_extract_cypher : MagicMock ,
342
356
driver : MagicMock ,
343
357
llm : MagicMock ,
344
358
neo4j_record : MagicMock ,
@@ -361,9 +375,11 @@ def test_t2c_retriever_with_custom_prompt_prompt_params(
361
375
)
362
376
363
377
378
+ @patch ("neo4j_graphrag.retrievers.text2cypher.extract_cypher" )
364
379
@patch ("neo4j_graphrag.retrievers.base.get_version" )
365
380
def test_t2c_retriever_with_custom_prompt_bad_prompt_params (
366
381
mock_get_version : MagicMock ,
382
+ mock_extract_cypher : MagicMock ,
367
383
driver : MagicMock ,
368
384
llm : MagicMock ,
369
385
neo4j_record : MagicMock ,
@@ -392,11 +408,13 @@ def test_t2c_retriever_with_custom_prompt_bad_prompt_params(
392
408
)
393
409
394
410
411
+ @patch ("neo4j_graphrag.retrievers.text2cypher.extract_cypher" )
395
412
@patch ("neo4j_graphrag.retrievers.base.get_version" )
396
413
@patch ("neo4j_graphrag.retrievers.text2cypher.get_schema" )
397
414
def test_t2c_retriever_with_custom_prompt_and_schema (
398
415
get_schema_mock : MagicMock ,
399
416
mock_get_version : MagicMock ,
417
+ mock_extract_cypher : MagicMock ,
400
418
driver : MagicMock ,
401
419
llm : MagicMock ,
402
420
neo4j_record : MagicMock ,
@@ -419,3 +437,67 @@ def test_t2c_retriever_with_custom_prompt_and_schema(
419
437
420
438
get_schema_mock .assert_not_called ()
421
439
llm .invoke .assert_called_once_with ("""This is a custom prompt. test """ )
440
+
441
+
442
+ @pytest .mark .parametrize (
443
+ "description, cypher_query, expected_output" ,
444
+ [
445
+ ("No changes" , "MATCH (n) RETURN n;" , "MATCH (n) RETURN n;" ),
446
+ (
447
+ "Surrounded by backticks" ,
448
+ "Cypher query: ```MATCH (n) RETURN n;```" ,
449
+ "MATCH (n) RETURN n;" ,
450
+ ),
451
+ (
452
+ "Spaces in label" ,
453
+ "Cypher query: ```MATCH (n: Label With Spaces ) RETURN n;```" ,
454
+ "MATCH (n:`Label With Spaces`) RETURN n;" ,
455
+ ),
456
+ (
457
+ "No spaces in label" ,
458
+ "Cypher query: ```MATCH (n: LabelWithNoSpaces ) RETURN n;```" ,
459
+ "MATCH (n: LabelWithNoSpaces ) RETURN n;" ,
460
+ ),
461
+ (
462
+ "Backticks in label" ,
463
+ "Cypher query: ```MATCH (n: `LabelWithBackticks` ) RETURN n;```" ,
464
+ "MATCH (n: `LabelWithBackticks` ) RETURN n;" ,
465
+ ),
466
+ (
467
+ "Spaces in property key" ,
468
+ "Cypher query: ```MATCH (n: { prop 1: 1, prop 2: 2 }) RETURN n;```" ,
469
+ "MATCH (n: { `prop 1`: 1, `prop 2`: 2 }) RETURN n;" ,
470
+ ),
471
+ (
472
+ "No spaces in property key" ,
473
+ "Cypher query: ```MATCH (n: { prop1: 1, prop2: 2 }) RETURN n;```" ,
474
+ "MATCH (n: { prop1: 1, prop2: 2 }) RETURN n;" ,
475
+ ),
476
+ (
477
+ "Backticks in property key" ,
478
+ "Cypher query: ```MATCH (n: { `prop 1`: 1, `prop 2`: 2 }) RETURN n;```" ,
479
+ "MATCH (n: { `prop 1`: 1, `prop 2`: 2 }) RETURN n;" ,
480
+ ),
481
+ (
482
+ "Spaces in relationship type" ,
483
+ "Cypher query: ```MATCH (n)-[: Relationship With Spaces ]->(m) RETURN n, m;```" ,
484
+ "MATCH (n)-[:`Relationship With Spaces`]->(m) RETURN n, m;" ,
485
+ ),
486
+ (
487
+ "No spaces in relationship type" ,
488
+ "Cypher query: ```MATCH (n)-[ : RelationshipWithNoSpaces ]->(m) RETURN n, m;```" ,
489
+ "MATCH (n)-[ : RelationshipWithNoSpaces ]->(m) RETURN n, m;" ,
490
+ ),
491
+ (
492
+ "Backticks in relationship type" ,
493
+ "Cypher query: ```MATCH (n)-[ : `RelationshipWithBackticks` ]->(m) RETURN n, m;```" ,
494
+ "MATCH (n)-[ : `RelationshipWithBackticks` ]->(m) RETURN n, m;" ,
495
+ ),
496
+ ],
497
+ )
498
+ def test_extract_cypher (
499
+ description : str , cypher_query : str , expected_output : str
500
+ ) -> None :
501
+ assert (
502
+ extract_cypher (cypher_query ) == expected_output
503
+ ), f"Failed test case: { description } "
0 commit comments