Skip to content

Commit b4cde39

Browse files
add custom prompt option to text2cypher, tested (#84)
* Update CHANGELOG.md * Update CHANGELOG.md * add custom prompt option to text2cypher, tested * add test for RetrieverInitializationError * Update tests/unit/retrievers/test_text2cypher.py Co-authored-by: Alex Thomas <alexthomas93@users.noreply.github.com> * Update tests/unit/retrievers/test_text2cypher.py Co-authored-by: Alex Thomas <alexthomas93@users.noreply.github.com> * Update tests/unit/retrievers/test_text2cypher.py Co-authored-by: Alex Thomas <alexthomas93@users.noreply.github.com> * Update tests/unit/retrievers/test_text2cypher.py Co-authored-by: Alex Thomas <alexthomas93@users.noreply.github.com> * Update tests/unit/retrievers/test_text2cypher.py Co-authored-by: Alex Thomas <alexthomas93@users.noreply.github.com> * Pre-commit fix --------- Co-authored-by: Alex Thomas <alexthomas93@users.noreply.github.com>
1 parent 4eec1fa commit b4cde39

File tree

4 files changed

+105
-17
lines changed

4 files changed

+105
-17
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
## Next
44

5+
### Added
6+
- Add optional custom_prompt arg to the Text2CypherRetriever class.
7+
58
## 0.3.1
69

710
### Fixed

src/neo4j_genai/retrievers/text2cypher.py

+23-11
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class Text2CypherRetriever(Retriever):
5555
llm (neo4j_genai.generation.llm.LLMInterface): LLM object to generate the Cypher query.
5656
neo4j_schema (Optional[str]): Neo4j schema used to generate the Cypher query.
5757
examples (Optional[list[str], optional): Optional user input/query pairs for the LLM to use as examples.
58+
custom_prompt (Optional[str]): Optional custom prompt to use instead of auto generated prompt. Will not include the neo4j_schema or examples args, if provided.
5859
5960
Raises:
6061
RetrieverInitializationError: If validation of the input arguments fail.
@@ -69,6 +70,7 @@ def __init__(
6970
result_formatter: Optional[
7071
Callable[[neo4j.Record], RetrieverResultItem]
7172
] = None,
73+
custom_prompt: Optional[str] = None,
7274
) -> None:
7375
try:
7476
driver_model = Neo4jDriverModel(driver=driver)
@@ -82,6 +84,7 @@ def __init__(
8284
neo4j_schema_model=neo4j_schema_model,
8385
examples=examples,
8486
result_formatter=result_formatter,
87+
custom_prompt=custom_prompt,
8588
)
8689
except ValidationError as e:
8790
raise RetrieverInitializationError(e.errors()) from e
@@ -90,12 +93,17 @@ def __init__(
9093
self.llm = validated_data.llm_model.llm
9194
self.examples = validated_data.examples
9295
self.result_formatter = validated_data.result_formatter
96+
self.custom_prompt = validated_data.custom_prompt
9397
try:
94-
self.neo4j_schema = (
95-
validated_data.neo4j_schema_model.neo4j_schema
96-
if validated_data.neo4j_schema_model
97-
else get_schema(validated_data.driver_model.driver)
98-
)
98+
if (
99+
not validated_data.custom_prompt
100+
): # don't need schema for a custom prompt
101+
self.neo4j_schema = (
102+
validated_data.neo4j_schema_model.neo4j_schema
103+
if validated_data.neo4j_schema_model
104+
else get_schema(validated_data.driver_model.driver)
105+
)
106+
99107
except (Neo4jError, DriverError) as e:
100108
error_message = getattr(e, "message", str(e))
101109
raise SchemaFetchError(
@@ -124,12 +132,16 @@ def get_search_results(
124132
except ValidationError as e:
125133
raise SearchValidationError(e.errors()) from e
126134

127-
prompt_template = Text2CypherTemplate()
128-
prompt = prompt_template.format(
129-
schema=self.neo4j_schema,
130-
examples="\n".join(self.examples) if self.examples else "",
131-
query=validated_data.query_text,
132-
)
135+
if not self.custom_prompt:
136+
prompt_template = Text2CypherTemplate()
137+
prompt = prompt_template.format(
138+
schema=self.neo4j_schema,
139+
examples="\n".join(self.examples) if self.examples else "",
140+
query=validated_data.query_text,
141+
)
142+
else:
143+
prompt = self.custom_prompt
144+
133145
logger.debug("Text2CypherRetriever prompt: %s", prompt)
134146

135147
try:

src/neo4j_genai/types.py

+1
Original file line numberDiff line numberDiff line change
@@ -240,3 +240,4 @@ class Text2CypherRetrieverModel(BaseModel):
240240
neo4j_schema_model: Optional[Neo4jSchemaModel] = None
241241
examples: Optional[list[str]] = None
242242
result_formatter: Optional[Callable[[neo4j.Record], RetrieverResultItem]] = None
243+
custom_prompt: Optional[str] = None

tests/unit/retrievers/test_text2cypher.py

+78-6
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
import logging
1516
from unittest.mock import MagicMock, patch
1617

1718
import pytest
@@ -63,7 +64,11 @@ def test_t2c_retriever_invalid_neo4j_schema(
6364
_verify_version_mock: MagicMock, driver: MagicMock, llm: MagicMock
6465
) -> None:
6566
with pytest.raises(RetrieverInitializationError) as exc_info:
66-
Text2CypherRetriever(driver=driver, llm=llm, neo4j_schema=42) # type: ignore
67+
Text2CypherRetriever(
68+
driver=driver,
69+
llm=llm,
70+
neo4j_schema=42, # type: ignore[arg-type, unused-ignore]
71+
)
6772

6873
assert "neo4j_schema" in str(exc_info.value)
6974
assert "Input should be a valid string" in str(exc_info.value)
@@ -92,7 +97,7 @@ def test_t2c_retriever_invalid_search_examples(
9297
driver=driver,
9398
llm=llm,
9499
neo4j_schema="dummy-text",
95-
examples=42, # type: ignore
100+
examples=42, # type: ignore[arg-type, unused-ignore]
96101
)
97102

98103
assert "examples" in str(exc_info.value)
@@ -113,8 +118,8 @@ def test_t2c_retriever_happy_path(
113118
retriever = Text2CypherRetriever(
114119
driver=driver, llm=llm, neo4j_schema=neo4j_schema, examples=examples
115120
)
116-
retriever.llm.invoke.return_value = LLMResponse(content=t2c_query)
117-
retriever.driver.execute_query.return_value = ( # type: ignore
121+
llm.invoke.return_value = LLMResponse(content=t2c_query)
122+
driver.execute_query.return_value = (
118123
[neo4j_record],
119124
None,
120125
None,
@@ -126,8 +131,8 @@ def test_t2c_retriever_happy_path(
126131
query=query_text,
127132
)
128133
retriever.search(query_text=query_text)
129-
retriever.llm.invoke.assert_called_once_with(prompt)
130-
retriever.driver.execute_query.assert_called_once_with(query_=t2c_query) # type: ignore
134+
llm.invoke.assert_called_once_with(prompt)
135+
driver.execute_query.assert_called_once_with(query_=t2c_query)
131136

132137

133138
@patch("neo4j_genai.retrievers.Text2CypherRetriever._verify_version")
@@ -178,3 +183,70 @@ def test_t2c_retriever_with_result_format_function(
178183
],
179184
metadata={"cypher": t2c_query, "__retriever": "Text2CypherRetriever"},
180185
)
186+
187+
188+
@pytest.mark.usefixtures("caplog")
189+
@patch("neo4j_genai.retrievers.base.Retriever._verify_version")
190+
def test_t2c_retriever_initialization_with_custom_prompt(
191+
_verify_version_mock: MagicMock,
192+
driver: MagicMock,
193+
llm: MagicMock,
194+
neo4j_record: MagicMock,
195+
caplog: pytest.LogCaptureFixture,
196+
) -> None:
197+
prompt = "This is a custom prompt."
198+
with caplog.at_level(logging.DEBUG):
199+
retriever = Text2CypherRetriever(driver=driver, llm=llm, custom_prompt=prompt)
200+
driver.execute_query.return_value = (
201+
[neo4j_record],
202+
None,
203+
None,
204+
)
205+
retriever.search(query_text="test")
206+
207+
assert f"Text2CypherRetriever prompt: {prompt}" in caplog.text
208+
209+
210+
@pytest.mark.usefixtures("caplog")
211+
@patch("neo4j_genai.retrievers.base.Retriever._verify_version")
212+
def test_t2c_retriever_initialization_with_custom_prompt_and_schema_and_examples(
213+
_verify_version_mock: MagicMock,
214+
driver: MagicMock,
215+
llm: MagicMock,
216+
neo4j_record: MagicMock,
217+
caplog: pytest.LogCaptureFixture,
218+
) -> None:
219+
prompt = "This is another custom prompt."
220+
neo4j_schema = "dummy-schema"
221+
examples = ["example-1", "example-2"]
222+
with caplog.at_level(logging.DEBUG):
223+
retriever = Text2CypherRetriever(
224+
driver=driver,
225+
llm=llm,
226+
custom_prompt=prompt,
227+
neo4j_schema=neo4j_schema,
228+
examples=examples,
229+
)
230+
231+
driver.execute_query.return_value = (
232+
[neo4j_record],
233+
None,
234+
None,
235+
)
236+
retriever.search(query_text="test")
237+
238+
assert f"Text2CypherRetriever prompt: {prompt}" in caplog.text
239+
240+
241+
@patch("neo4j_genai.retrievers.Text2CypherRetriever._verify_version")
242+
def test_t2c_retriever_invalid_custom_prompt_type(
243+
_verify_version_mock: MagicMock, driver: MagicMock, llm: MagicMock
244+
) -> None:
245+
with pytest.raises(RetrieverInitializationError) as exc_info:
246+
Text2CypherRetriever(
247+
driver=driver,
248+
llm=llm,
249+
custom_prompt=42, # type: ignore[arg-type, unused-ignore]
250+
)
251+
252+
assert "Input should be a valid string" in str(exc_info.value)

0 commit comments

Comments
 (0)