Skip to content

Commit ab54d7b

Browse files
authored
Move few-shot examples in Text2CypherRetriever to constructor (neo4j#62)
* Moved the few-shot examples in the Text2CypherRetriever to the constructor * Updated CHANGELOG
1 parent 491c1ff commit ab54d7b

File tree

5 files changed

+42
-27
lines changed

5 files changed

+42
-27
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
### Fixed
1313

1414
- Removed Pinecone and Weaviate retrievers from **init**.py to prevent ImportError when optional dependencies are not installed.
15+
- Moved few-shot examples in `Text2CypherRetriever` to the constructor for better initialization and usage. Updated unit tests and example script accordingly.
1516

1617
## 0.2.0a5
1718

examples/text2cypher_search.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,19 @@
2828
(:Person)-[:REVIEWED]->(:Movie)
2929
"""
3030

31-
# Initialize the retriever
32-
retriever = Text2CypherRetriever(driver=driver, llm=llm, neo4j_schema=neo4j_schema) # type: ignore
33-
3431
# (Optional) Provide user input/query pairs for the LLM to use as examples
3532
examples = [
3633
"USER INPUT: 'Which actors starred in the Matrix?' QUERY: MATCH (p:Person)-[:ACTED_IN]->(m:Movie) WHERE m.title = 'The Matrix' RETURN p.name"
3734
]
3835

36+
# Initialize the retriever
37+
retriever = Text2CypherRetriever(
38+
driver=driver,
39+
llm=llm, # type: ignore
40+
neo4j_schema=neo4j_schema,
41+
examples=examples,
42+
)
43+
3944
# Generate a Cypher query using the LLM, send it to the Neo4j database, and return the results
4045
query_text = "Which movies did Hugo Weaving star in?"
41-
print(retriever.search(query_text=query_text, examples=examples))
46+
print(retriever.search(query_text=query_text))

src/neo4j_genai/retrievers/text2cypher.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121

2222
from neo4j_genai.exceptions import (
2323
RetrieverInitializationError,
24+
SchemaFetchError,
2425
SearchValidationError,
2526
Text2CypherRetrievalError,
26-
SchemaFetchError,
2727
)
2828
from neo4j_genai.llm import LLM
2929
from neo4j_genai.prompts import TEXT2CYPHER_PROMPT
@@ -33,9 +33,9 @@
3333
LLMModel,
3434
Neo4jDriverModel,
3535
Neo4jSchemaModel,
36+
RawSearchResult,
3637
Text2CypherRetrieverModel,
3738
Text2CypherSearchModel,
38-
RawSearchResult,
3939
)
4040

4141
logger = logging.getLogger(__name__)
@@ -51,13 +51,18 @@ class Text2CypherRetriever(Retriever):
5151
driver (neo4j.driver): The Neo4j Python driver.
5252
llm (neo4j_genai.llm.LLM): LLM object to generate the Cypher query.
5353
neo4j_schema (Optional[str]): Neo4j schema used to generate the Cypher query.
54+
examples (Optional[list[str], optional): Optional user input/query pairs for the LLM to use as examples.
5455
5556
Raises:
5657
RetrieverInitializationError: If validation of the input arguments fail.
5758
"""
5859

5960
def __init__(
60-
self, driver: neo4j.Driver, llm: LLM, neo4j_schema: Optional[str] = None
61+
self,
62+
driver: neo4j.Driver,
63+
llm: LLM,
64+
neo4j_schema: Optional[str] = None,
65+
examples: Optional[list[str]] = None,
6166
) -> None:
6267
try:
6368
driver_model = Neo4jDriverModel(driver=driver)
@@ -69,12 +74,14 @@ def __init__(
6974
driver_model=driver_model,
7075
llm_model=llm_model,
7176
neo4j_schema_model=neo4j_schema_model,
77+
examples=examples,
7278
)
7379
except ValidationError as e:
7480
raise RetrieverInitializationError(e.errors())
7581

7682
super().__init__(validated_data.driver_model.driver)
7783
self.llm = validated_data.llm_model.llm
84+
self.examples = validated_data.examples
7885
try:
7986
self.neo4j_schema = (
8087
validated_data.neo4j_schema_model.neo4j_schema
@@ -88,14 +95,14 @@ def __init__(
8895
)
8996

9097
def _get_search_results(
91-
self, query_text: str, examples: Optional[list[str]] = None
98+
self,
99+
query_text: str,
92100
) -> RawSearchResult:
93101
"""Converts query_text to a Cypher query using an LLM.
94102
Retrieve records from a Neo4j database using the generated Cypher query.
95103
96104
Args:
97105
query_text (str): The natural language query used to search the Neo4j database.
98-
examples (Optional[list[str], optional): Optional user input/query pairs for the LLM to use as examples.
99106
100107
Raises:
101108
SearchValidationError: If validation of the input arguments fail.
@@ -105,17 +112,13 @@ def _get_search_results(
105112
RawSearchResult: The results of the search query as a list of neo4j.Record and an optional metadata dict
106113
"""
107114
try:
108-
validated_data = Text2CypherSearchModel(
109-
query_text=query_text, examples=examples
110-
)
115+
validated_data = Text2CypherSearchModel(query_text=query_text)
111116
except ValidationError as e:
112117
raise SearchValidationError(e.errors())
113118

114119
prompt = TEXT2CYPHER_PROMPT.format(
115120
schema=self.neo4j_schema,
116-
examples="\n".join(validated_data.examples)
117-
if validated_data.examples
118-
else "",
121+
examples="\n".join(self.examples) if self.examples else "",
119122
input=validated_data.query_text,
120123
)
121124
logger.debug("Text2CypherRetriever prompt: %s", prompt)

src/neo4j_genai/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,6 @@ class HybridCypherSearchModel(HybridSearchModel):
148148

149149
class Text2CypherSearchModel(BaseModel):
150150
query_text: str
151-
examples: Optional[list[str]] = None
152151

153152

154153
class SearchType(str, Enum):
@@ -231,3 +230,4 @@ class Text2CypherRetrieverModel(BaseModel):
231230
driver_model: Neo4jDriverModel
232231
llm_model: LLMModel
233232
neo4j_schema_model: Optional[Neo4jSchemaModel] = None
233+
examples: Optional[list[str]] = None

tests/unit/retrievers/test_text2cypher.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
from unittest.mock import patch, MagicMock
15+
from unittest.mock import MagicMock, patch
1616

1717
import pytest
1818
from neo4j.exceptions import CypherSyntaxError, Neo4jError
1919
from neo4j_genai import Text2CypherRetriever
2020
from neo4j_genai.exceptions import (
21-
SearchValidationError,
2221
RetrieverInitializationError,
22+
SearchValidationError,
2323
Text2CypherRetrievalError,
2424
)
2525
from neo4j_genai.prompts import TEXT2CYPHER_PROMPT
@@ -85,14 +85,16 @@ def test_t2c_retriever_invalid_search_query(
8585
def test_t2c_retriever_invalid_search_examples(
8686
_verify_version_mock: MagicMock, driver: MagicMock, llm: MagicMock
8787
) -> None:
88-
with pytest.raises(SearchValidationError) as exc_info:
89-
retriever = Text2CypherRetriever(
90-
driver=driver, llm=llm, neo4j_schema="dummy-text"
88+
with pytest.raises(RetrieverInitializationError) as exc_info:
89+
Text2CypherRetriever(
90+
driver=driver,
91+
llm=llm,
92+
neo4j_schema="dummy-text",
93+
examples=42, # type: ignore
9194
)
92-
retriever.search(query_text="dummy-text", examples=42)
9395

9496
assert "examples" in str(exc_info.value)
95-
assert "Input should be a valid list" in str(exc_info.value)
97+
assert "Initialization failed" in str(exc_info.value)
9698

9799

98100
@patch("neo4j_genai.Text2CypherRetriever._verify_version")
@@ -106,7 +108,9 @@ def test_t2c_retriever_happy_path(
106108
query_text = "may thy knife chip and shatter"
107109
neo4j_schema = "dummy-schema"
108110
examples = ["example-1", "example-2"]
109-
retriever = Text2CypherRetriever(driver=driver, llm=llm, neo4j_schema=neo4j_schema)
111+
retriever = Text2CypherRetriever(
112+
driver=driver, llm=llm, neo4j_schema=neo4j_schema, examples=examples
113+
)
110114
retriever.llm.invoke.return_value = t2c_query
111115
retriever.driver.execute_query.return_value = ( # type: ignore
112116
[neo4j_record],
@@ -118,7 +122,7 @@ def test_t2c_retriever_happy_path(
118122
examples="\n".join(examples),
119123
input=query_text,
120124
)
121-
retriever.search(query_text=query_text, examples=examples)
125+
retriever.search(query_text=query_text)
122126
retriever.llm.invoke.assert_called_once_with(prompt)
123127
retriever.driver.execute_query.assert_called_once_with(query_=t2c_query) # type: ignore
124128

@@ -130,10 +134,12 @@ def test_t2c_retriever_cypher_error(
130134
t2c_query = "this is not a cypher query"
131135
neo4j_schema = "dummy-schema"
132136
examples = ["example-1", "example-2"]
133-
retriever = Text2CypherRetriever(driver=driver, llm=llm, neo4j_schema=neo4j_schema)
137+
retriever = Text2CypherRetriever(
138+
driver=driver, llm=llm, neo4j_schema=neo4j_schema, examples=examples
139+
)
134140
retriever.llm.invoke.return_value = t2c_query
135141
query_text = "may thy knife chip and shatter"
136142
driver.execute_query.side_effect = CypherSyntaxError
137143
with pytest.raises(Text2CypherRetrievalError) as e:
138-
retriever.search(query_text=query_text, examples=examples)
144+
retriever.search(query_text=query_text)
139145
assert "Failed to get search result" in str(e)

0 commit comments

Comments
 (0)