Skip to content

Commit 439b4b9

Browse files
authored
Added extract_cypher function (#277)
* Added extract_cypher function * Updated CHANGELOG
1 parent 4dae4eb commit 439b4b9

File tree

3 files changed

+129
-3
lines changed

3 files changed

+129
-3
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
- Utility functions to retrieve metadata for vector and full-text indexes.
88
- Support for effective_search_ratio parameter in vector and hybrid searches.
99
- Introduced upsert_vectors utility function for batch upserting embeddings to vector indexes.
10+
- Introduced `extract_cypher` function to enhance Cypher query extraction and formatting in `Text2CypherRetriever`.
1011

1112
### Changed
1213

src/neo4j_graphrag/retrievers/text2cypher.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
import logging
18+
import re
1819
from typing import Any, Callable, Dict, Optional
1920

2021
import neo4j
@@ -44,6 +45,48 @@
4445
logger = logging.getLogger(__name__)
4546

4647

48+
def extract_cypher(text: str) -> str:
49+
"""Extract and format Cypher query from text, handling code blocks and special characters.
50+
51+
This function performs two main operations:
52+
1. Extracts Cypher code from within triple backticks (```), if present
53+
2. Automatically adds backtick quotes around multi-word identifiers:
54+
- Node labels (e.g., ":Data Science" becomes ":`Data Science`")
55+
- Property keys (e.g., "first name:" becomes "`first name`:")
56+
- Relationship types (e.g., "[:WORKS WITH]" becomes "[:`WORKS WITH`]")
57+
58+
Args:
59+
text (str): Raw text that may contain Cypher code, either within triple
60+
backticks or as plain text.
61+
62+
Returns:
63+
str: Properly formatted Cypher query with correct backtick quoting.
64+
"""
65+
# Extract Cypher code enclosed in triple backticks
66+
pattern = r"```(.*?)```"
67+
matches = re.findall(pattern, text, re.DOTALL)
68+
cypher_query = matches[0] if matches else text
69+
# Quote node labels in backticks if they contain spaces and are not already quoted
70+
cypher_query = re.sub(
71+
r":\s*(?!`\s*)(\s*)([a-zA-Z0-9_]+(?:\s+[a-zA-Z0-9_]+)+)(?!\s*`)(\s*)",
72+
r":`\2`",
73+
cypher_query,
74+
)
75+
# Quote property keys in backticks if they contain spaces and are not already quoted
76+
cypher_query = re.sub(
77+
r"([,{]\s*)(?!`)([a-zA-Z0-9_]+(?:\s+[a-zA-Z0-9_]+)+)(?!`)(\s*:)",
78+
r"\1`\2`\3",
79+
cypher_query,
80+
)
81+
# Quote relationship types in backticks if they contain spaces and are not already quoted
82+
cypher_query = re.sub(
83+
r"(\[\s*[a-zA-Z0-9_]*\s*:\s*)(?!`)([a-zA-Z0-9_]+(?:\s+[a-zA-Z0-9_]+)+)(?!`)(\s*(?:\]|-))",
84+
r"\1`\2`\3",
85+
cypher_query,
86+
)
87+
return cypher_query
88+
89+
4790
class Text2CypherRetriever(Retriever):
4891
"""
4992
Allows for the retrieval of records from a Neo4j database using natural language.
@@ -168,7 +211,7 @@ def get_search_results(
168211

169212
try:
170213
llm_result = self.llm.invoke(prompt)
171-
t2c_query = llm_result.content
214+
t2c_query = extract_cypher(llm_result.content)
172215
logger.debug("Text2CypherRetriever Cypher query: %s", t2c_query)
173216
records, _, _ = self.driver.execute_query(
174217
query_=t2c_query,

tests/unit/retrievers/test_text2cypher.py

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
)
2626
from neo4j_graphrag.generation.prompts import Text2CypherTemplate
2727
from neo4j_graphrag.llm import LLMResponse
28-
from neo4j_graphrag.retrievers import Text2CypherRetriever
28+
from neo4j_graphrag.retrievers.text2cypher import Text2CypherRetriever, extract_cypher
2929
from neo4j_graphrag.types import RetrieverResult, RetrieverResultItem
3030

3131

@@ -204,9 +204,11 @@ def test_t2c_retriever_with_result_format_function(
204204
)
205205

206206

207+
@patch("neo4j_graphrag.retrievers.text2cypher.extract_cypher")
207208
@patch("neo4j_graphrag.retrievers.base.get_version")
208209
def test_t2c_retriever_initialization_with_custom_prompt(
209210
mock_get_version: MagicMock,
211+
mock_extract_cypher: MagicMock,
210212
driver: MagicMock,
211213
llm: MagicMock,
212214
neo4j_record: MagicMock,
@@ -224,9 +226,11 @@ def test_t2c_retriever_initialization_with_custom_prompt(
224226
llm.invoke.assert_called_once_with("This is a custom prompt. test")
225227

226228

229+
@patch("neo4j_graphrag.retrievers.text2cypher.extract_cypher")
227230
@patch("neo4j_graphrag.retrievers.base.get_version")
228231
def test_t2c_retriever_initialization_with_custom_prompt_and_schema_and_examples(
229232
mock_get_version: MagicMock,
233+
mock_extract_cypher: MagicMock,
230234
driver: MagicMock,
231235
llm: MagicMock,
232236
neo4j_record: MagicMock,
@@ -254,9 +258,11 @@ def test_t2c_retriever_initialization_with_custom_prompt_and_schema_and_examples
254258
llm.invoke.assert_called_once_with("This is a custom prompt. test")
255259

256260

261+
@patch("neo4j_graphrag.retrievers.text2cypher.extract_cypher")
257262
@patch("neo4j_graphrag.retrievers.base.get_version")
258263
def test_t2c_retriever_initialization_with_custom_prompt_and_schema_and_examples_for_prompt_params(
259264
mock_get_version: MagicMock,
265+
mock_extract_cypher: MagicMock,
260266
driver: MagicMock,
261267
llm: MagicMock,
262268
neo4j_record: MagicMock,
@@ -286,9 +292,11 @@ def test_t2c_retriever_initialization_with_custom_prompt_and_schema_and_examples
286292
)
287293

288294

295+
@patch("neo4j_graphrag.retrievers.text2cypher.extract_cypher")
289296
@patch("neo4j_graphrag.retrievers.base.get_version")
290297
def test_t2c_retriever_initialization_with_custom_prompt_and_unused_schema_and_examples(
291298
mock_get_version: MagicMock,
299+
mock_extract_cypher: MagicMock,
292300
driver: MagicMock,
293301
llm: MagicMock,
294302
neo4j_record: MagicMock,
@@ -321,9 +329,13 @@ def test_t2c_retriever_initialization_with_custom_prompt_and_unused_schema_and_e
321329
)
322330

323331

332+
@patch("neo4j_graphrag.retrievers.text2cypher.extract_cypher")
324333
@patch("neo4j_graphrag.retrievers.base.get_version")
325334
def test_t2c_retriever_invalid_custom_prompt_type(
326-
mock_get_version: MagicMock, driver: MagicMock, llm: MagicMock
335+
mock_get_version: MagicMock,
336+
mock_extract_cypher: MagicMock,
337+
driver: MagicMock,
338+
llm: MagicMock,
327339
) -> None:
328340
mock_get_version.return_value = ((5, 23, 0), False, False)
329341
with pytest.raises(RetrieverInitializationError) as exc_info:
@@ -336,9 +348,11 @@ def test_t2c_retriever_invalid_custom_prompt_type(
336348
assert "Input should be a valid string" in str(exc_info.value)
337349

338350

351+
@patch("neo4j_graphrag.retrievers.text2cypher.extract_cypher")
339352
@patch("neo4j_graphrag.retrievers.base.get_version")
340353
def test_t2c_retriever_with_custom_prompt_prompt_params(
341354
mock_get_version: MagicMock,
355+
mock_extract_cypher: MagicMock,
342356
driver: MagicMock,
343357
llm: MagicMock,
344358
neo4j_record: MagicMock,
@@ -361,9 +375,11 @@ def test_t2c_retriever_with_custom_prompt_prompt_params(
361375
)
362376

363377

378+
@patch("neo4j_graphrag.retrievers.text2cypher.extract_cypher")
364379
@patch("neo4j_graphrag.retrievers.base.get_version")
365380
def test_t2c_retriever_with_custom_prompt_bad_prompt_params(
366381
mock_get_version: MagicMock,
382+
mock_extract_cypher: MagicMock,
367383
driver: MagicMock,
368384
llm: MagicMock,
369385
neo4j_record: MagicMock,
@@ -392,11 +408,13 @@ def test_t2c_retriever_with_custom_prompt_bad_prompt_params(
392408
)
393409

394410

411+
@patch("neo4j_graphrag.retrievers.text2cypher.extract_cypher")
395412
@patch("neo4j_graphrag.retrievers.base.get_version")
396413
@patch("neo4j_graphrag.retrievers.text2cypher.get_schema")
397414
def test_t2c_retriever_with_custom_prompt_and_schema(
398415
get_schema_mock: MagicMock,
399416
mock_get_version: MagicMock,
417+
mock_extract_cypher: MagicMock,
400418
driver: MagicMock,
401419
llm: MagicMock,
402420
neo4j_record: MagicMock,
@@ -419,3 +437,67 @@ def test_t2c_retriever_with_custom_prompt_and_schema(
419437

420438
get_schema_mock.assert_not_called()
421439
llm.invoke.assert_called_once_with("""This is a custom prompt. test """)
440+
441+
442+
@pytest.mark.parametrize(
443+
"description, cypher_query, expected_output",
444+
[
445+
("No changes", "MATCH (n) RETURN n;", "MATCH (n) RETURN n;"),
446+
(
447+
"Surrounded by backticks",
448+
"Cypher query: ```MATCH (n) RETURN n;```",
449+
"MATCH (n) RETURN n;",
450+
),
451+
(
452+
"Spaces in label",
453+
"Cypher query: ```MATCH (n: Label With Spaces ) RETURN n;```",
454+
"MATCH (n:`Label With Spaces`) RETURN n;",
455+
),
456+
(
457+
"No spaces in label",
458+
"Cypher query: ```MATCH (n: LabelWithNoSpaces ) RETURN n;```",
459+
"MATCH (n: LabelWithNoSpaces ) RETURN n;",
460+
),
461+
(
462+
"Backticks in label",
463+
"Cypher query: ```MATCH (n: `LabelWithBackticks` ) RETURN n;```",
464+
"MATCH (n: `LabelWithBackticks` ) RETURN n;",
465+
),
466+
(
467+
"Spaces in property key",
468+
"Cypher query: ```MATCH (n: { prop 1: 1, prop 2: 2 }) RETURN n;```",
469+
"MATCH (n: { `prop 1`: 1, `prop 2`: 2 }) RETURN n;",
470+
),
471+
(
472+
"No spaces in property key",
473+
"Cypher query: ```MATCH (n: { prop1: 1, prop2: 2 }) RETURN n;```",
474+
"MATCH (n: { prop1: 1, prop2: 2 }) RETURN n;",
475+
),
476+
(
477+
"Backticks in property key",
478+
"Cypher query: ```MATCH (n: { `prop 1`: 1, `prop 2`: 2 }) RETURN n;```",
479+
"MATCH (n: { `prop 1`: 1, `prop 2`: 2 }) RETURN n;",
480+
),
481+
(
482+
"Spaces in relationship type",
483+
"Cypher query: ```MATCH (n)-[: Relationship With Spaces ]->(m) RETURN n, m;```",
484+
"MATCH (n)-[:`Relationship With Spaces`]->(m) RETURN n, m;",
485+
),
486+
(
487+
"No spaces in relationship type",
488+
"Cypher query: ```MATCH (n)-[ : RelationshipWithNoSpaces ]->(m) RETURN n, m;```",
489+
"MATCH (n)-[ : RelationshipWithNoSpaces ]->(m) RETURN n, m;",
490+
),
491+
(
492+
"Backticks in relationship type",
493+
"Cypher query: ```MATCH (n)-[ : `RelationshipWithBackticks` ]->(m) RETURN n, m;```",
494+
"MATCH (n)-[ : `RelationshipWithBackticks` ]->(m) RETURN n, m;",
495+
),
496+
],
497+
)
498+
def test_extract_cypher(
499+
description: str, cypher_query: str, expected_output: str
500+
) -> None:
501+
assert (
502+
extract_cypher(cypher_query) == expected_output
503+
), f"Failed test case: {description}"

0 commit comments

Comments
 (0)