Skip to content

Commit 1921dd5

Browse files
authored
Fix missing inclusion of neo4j_schema in custom prompt generation (#256)
* Fix missing inclusion of neo4j_schema in custom prompt generation * Add documentation for schema and examples in prompt_params
1 parent d7d6674 commit 1921dd5

File tree

2 files changed

+76
-3
lines changed

2 files changed

+76
-3
lines changed

src/neo4j_graphrag/retrievers/text2cypher.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class Text2CypherRetriever(Retriever):
5555
llm (neo4j_graphrag.generation.llm.LLMInterface): LLM object to generate the Cypher query.
5656
neo4j_schema (Optional[str]): Neo4j schema used to generate the Cypher query.
5757
examples (Optional[list[str], optional): Optional user input/query pairs for the LLM to use as examples.
58-
custom_prompt (Optional[str]): Optional custom prompt to use instead of auto generated prompt. Will not include the neo4j_schema or examples args, if provided.
58+
custom_prompt (Optional[str]): Optional custom prompt to use instead of auto generated prompt. Will include the neo4j_schema for schema and examples for examples prompt parameters, if they are provided.
5959
6060
Raises:
6161
RetrieverInitializationError: If validation of the input arguments fail.
@@ -99,7 +99,13 @@ def __init__(
9999
self.result_formatter = validated_data.result_formatter
100100
self.custom_prompt = validated_data.custom_prompt
101101
if validated_data.custom_prompt:
102-
neo4j_schema = ""
102+
if (
103+
validated_data.neo4j_schema_model
104+
and validated_data.neo4j_schema_model.neo4j_schema
105+
):
106+
neo4j_schema = validated_data.neo4j_schema_model.neo4j_schema
107+
else:
108+
neo4j_schema = ""
103109
else:
104110
if (
105111
validated_data.neo4j_schema_model
@@ -124,7 +130,7 @@ def get_search_results(
124130
125131
Args:
126132
query_text (str): The natural language query used to search the Neo4j database.
127-
prompt_params (Dict[str, Any]): additional values to inject into the custom prompt, if it is provided. Example: {'schema': 'this is the graph schema'}
133+
prompt_params (Dict[str, Any]): additional values to inject into the custom prompt, if it is provided. If the schema or examples parameter is specified, it will overwrite the corresponding value passed during initialization. Example: {'schema': 'this is the graph schema'}
128134
129135
Raises:
130136
SearchValidationError: If validation of the input arguments fail.

tests/unit/retrievers/test_text2cypher.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,73 @@ def test_t2c_retriever_initialization_with_custom_prompt_and_schema_and_examples
254254
llm.invoke.assert_called_once_with("This is a custom prompt. test")
255255

256256

257+
@patch("neo4j_graphrag.retrievers.base.get_version")
258+
def test_t2c_retriever_initialization_with_custom_prompt_and_schema_and_examples_for_prompt_params(
259+
mock_get_version: MagicMock,
260+
driver: MagicMock,
261+
llm: MagicMock,
262+
neo4j_record: MagicMock,
263+
) -> None:
264+
mock_get_version.return_value = ((5, 23, 0), False, False)
265+
prompt = "This is a custom prompt. {query_text} {schema} {examples}"
266+
neo4j_schema = "dummy-schema"
267+
examples = ["example-1", "example-2"]
268+
269+
retriever = Text2CypherRetriever(
270+
driver=driver,
271+
llm=llm,
272+
custom_prompt=prompt,
273+
neo4j_schema=neo4j_schema,
274+
examples=examples,
275+
)
276+
277+
driver.execute_query.return_value = (
278+
[neo4j_record],
279+
None,
280+
None,
281+
)
282+
retriever.search(query_text="test")
283+
284+
llm.invoke.assert_called_once_with(
285+
"This is a custom prompt. test dummy-schema example-1\nexample-2"
286+
)
287+
288+
289+
@patch("neo4j_graphrag.retrievers.base.get_version")
290+
def test_t2c_retriever_initialization_with_custom_prompt_and_unused_schema_and_examples(
291+
mock_get_version: MagicMock,
292+
driver: MagicMock,
293+
llm: MagicMock,
294+
neo4j_record: MagicMock,
295+
) -> None:
296+
mock_get_version.return_value = ((5, 23, 0), False, False)
297+
prompt = "This is a custom prompt. {query_text} {schema} {examples}"
298+
neo4j_schema = "dummy-schema"
299+
examples = ["example-1", "example-2"]
300+
301+
retriever = Text2CypherRetriever(
302+
driver=driver,
303+
llm=llm,
304+
custom_prompt=prompt,
305+
neo4j_schema=neo4j_schema,
306+
examples=examples,
307+
)
308+
309+
driver.execute_query.return_value = (
310+
[neo4j_record],
311+
None,
312+
None,
313+
)
314+
retriever.search(
315+
query_text="test",
316+
prompt_params={"schema": "another-dummy-schema", "examples": "another-example"},
317+
)
318+
319+
llm.invoke.assert_called_once_with(
320+
"This is a custom prompt. test another-dummy-schema another-example"
321+
)
322+
323+
257324
@patch("neo4j_graphrag.retrievers.base.get_version")
258325
def test_t2c_retriever_invalid_custom_prompt_type(
259326
mock_get_version: MagicMock, driver: MagicMock, llm: MagicMock

0 commit comments

Comments
 (0)