Skip to content

Commit d91e15d

Browse files
committed
get_search_query refactoring
1 parent bf9c328 commit d91e15d

File tree

8 files changed

+123
-63
lines changed

8 files changed

+123
-63
lines changed

src/neo4j_graphrag/neo4j_queries.py

Lines changed: 99 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,22 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17+
import warnings
1718
from typing import Any, Optional
1819

1920
from neo4j_graphrag.filters import get_metadata_filter
2021
from neo4j_graphrag.types import IndexType, SearchType
2122

2223
VECTOR_EXACT_QUERY = (
2324
"WITH node, "
24-
"vector.similarity.cosine(node.`{embedding_node_property}`, $query_vector) AS score "
25+
"vector.similarity.cosine(node.`{embedding_property}`, $query_vector) AS score "
2526
"ORDER BY score DESC LIMIT $top_k"
2627
)
2728

2829
BASE_VECTOR_EXACT_QUERY = (
2930
"MATCH (node:`{node_label}`) "
30-
"WHERE node.`{embedding_node_property}` IS NOT NULL "
31-
"AND size(node.`{embedding_node_property}`) = toInteger($embedding_dimension)"
31+
"WHERE node.`{embedding_property}` IS NOT NULL "
32+
"AND size(node.`{embedding_property}`) = toInteger($embedding_dimension)"
3233
)
3334

3435

@@ -151,7 +152,7 @@ def _get_hybrid_query(
151152
def _get_filtered_vector_query(
152153
filters: dict[str, Any],
153154
node_label: str,
154-
embedding_node_property: str,
155+
embedding_property: str,
155156
embedding_dimension: int,
156157
) -> tuple[str, dict[str, Any]]:
157158
"""Build Cypher query for vector search with filters
@@ -160,7 +161,7 @@ def _get_filtered_vector_query(
160161
Args:
161162
filters (dict[str, Any]): filters used to pre-filter the nodes before vector search
162163
node_label (str): node label we want to search for
163-
embedding_node_property (str): the name of the property holding the embeddings
164+
embedding_property (str): the name of the property holding the embeddings
164165
embedding_dimension (int): the dimension of the embeddings
165166
166167
Returns:
@@ -169,77 +170,129 @@ def _get_filtered_vector_query(
169170
where_filters, query_params = get_metadata_filter(filters, node_alias="node")
170171
base_query = BASE_VECTOR_EXACT_QUERY.format(
171172
node_label=node_label,
172-
embedding_node_property=embedding_node_property,
173+
embedding_property=embedding_property,
173174
)
174175
vector_query = VECTOR_EXACT_QUERY.format(
175-
embedding_node_property=embedding_node_property,
176+
embedding_property=embedding_property,
176177
)
177178
query_params["embedding_dimension"] = embedding_dimension
178179
return f"{base_query} AND ({where_filters}) {vector_query}", query_params
179180

180181

181182
def get_search_query(
182183
search_type: SearchType,
184+
index_type: IndexType = IndexType.NODE,
183185
return_properties: Optional[list[str]] = None,
184186
retrieval_query: Optional[str] = None,
185187
node_label: Optional[str] = None,
186188
embedding_node_property: Optional[str] = None,
189+
embedding_property: Optional[str] = None,
187190
embedding_dimension: Optional[int] = None,
188191
filters: Optional[dict[str, Any]] = None,
189192
neo4j_version_is_5_23_or_above: bool = False,
190193
) -> tuple[str, dict[str, Any]]:
191-
"""Build the search query, including pre-filtering if needed, and return clause.
194+
"""
195+
Constructs a search query for vector or hybrid search, including optional pre-filtering
196+
and return clause.
192197
193-
Args
194-
search_type: Search type we want to search for:
195-
return_properties (list[str]): list of property names to return.
196-
It can't be provided together with retrieval_query.
197-
retrieval_query (str): the query to use to retrieve the search results
198-
It can't be provided together with return_properties.
199-
node_label (str): node label we want to search for
200-
embedding_node_property (str): the name of the property holding the embeddings
201-
embedding_dimension (int): the dimension of the embeddings
202-
filters (dict[str, Any]): filters used to pre-filter the nodes before vector search
198+
Args:
199+
search_type (SearchType): Specifies whether to perform a vector or hybrid search.
200+
index_type (Optional[IndexType]): Specifies whether to search over node or
201+
relationship indexes. Defaults to 'node'.
202+
return_properties (Optional[list[str]]): List of property names to return.
203+
Cannot be provided alongside `retrieval_query`.
204+
retrieval_query (Optional[str]): Query used to retrieve search results.
205+
Cannot be provided alongside `return_properties`.
206+
node_label (Optional[str]): Label of the nodes to search.
207+
embedding_property (Optional[str])): Name of the property containing the embeddings.
208+
embedding_dimension (Optional[int]): Dimension of the embeddings.
209+
filters (Optional[dict[str, Any]]): Filters to pre-filter nodes before vector search.
210+
neo4j_version_is_5_23_or_above (Optional[bool]): Whether the Neo4j version is 5.23 or above.
203211
204212
Returns:
205-
tuple[str, dict[str, Any]]: query and parameters
213+
tuple[str, dict[str, Any]]: A tuple containing the constructed query string and
214+
a dictionary of query parameters.
206215
216+
Raises:
217+
Exception: If filters are used with Hybrid Search.
218+
Exception: If Vector Search with filters is missing required parameters.
219+
ValueError: If an unsupported search type is provided.
207220
"""
208-
if search_type == SearchType.HYBRID:
209-
if filters:
210-
raise Exception("Filters are not supported with Hybrid Search")
211-
query = _get_hybrid_query(neo4j_version_is_5_23_or_above)
212-
params: dict[str, Any] = {}
213-
elif search_type == SearchType.VECTOR:
214-
if filters:
215-
if (
216-
node_label is not None
217-
and embedding_node_property is not None
218-
and embedding_dimension is not None
219-
):
220-
query, params = _get_filtered_vector_query(
221-
filters, node_label, embedding_node_property, embedding_dimension
222-
)
221+
warnings.warn(
222+
"embedding_node_property is deprecated, use embedding_property instead",
223+
DeprecationWarning,
224+
stacklevel=2,
225+
)
226+
if embedding_node_property:
227+
if embedding_property:
228+
warnings.warn(
229+
"Both embedding_node_property and embedding_property provided, using embedding_property",
230+
UserWarning,
231+
stacklevel=2,
232+
)
233+
else:
234+
embedding_property = embedding_node_property
235+
236+
if index_type == IndexType.NODE:
237+
if search_type == SearchType.HYBRID:
238+
if filters:
239+
raise Exception("Filters are not supported with Hybrid Search")
240+
query = _get_hybrid_query(neo4j_version_is_5_23_or_above)
241+
params: dict[str, Any] = {}
242+
elif search_type == SearchType.VECTOR:
243+
if filters:
244+
if (
245+
node_label is not None
246+
and embedding_property is not None
247+
and embedding_dimension is not None
248+
):
249+
query, params = _get_filtered_vector_query(
250+
filters,
251+
node_label,
252+
embedding_property,
253+
embedding_dimension,
254+
)
255+
else:
256+
raise Exception(
257+
"Vector Search with filters requires: node_label, embedding_property, embedding_dimension"
258+
)
223259
else:
224-
raise Exception(
225-
"Vector Search with filters requires: node_label, embedding_node_property, embedding_dimension"
226-
)
260+
query, params = _get_vector_search_query(index_type=index_type), {}
227261
else:
228-
query, params = _get_vector_search_query(), {}
262+
raise ValueError(f"Search type is not supported: {search_type}")
263+
fallback_return = (
264+
f"RETURN node {{ .*, `{embedding_property}`: null }} AS node, "
265+
"labels(node) AS nodeLabels, elementId(node) AS elementId, score"
266+
)
267+
elif index_type == IndexType.RELATIONSHIP:
268+
if search_type == SearchType.HYBRID:
269+
raise Exception("Hybrid search is not support for relationship indexes")
270+
elif search_type == SearchType.VECTOR:
271+
query, params = _get_vector_search_query(index_type=index_type), {}
272+
else:
273+
raise ValueError(f"Search type is not supported: {search_type}")
274+
fallback_return = (
275+
f"RETURN relationship {{ .*, `{embedding_property}`: null }} AS relationship, "
276+
"elementId(relationship) AS elementId, score"
277+
)
229278
else:
230-
raise ValueError(f"Search type is not supported: {search_type}")
279+
raise ValueError(f"Index type is not supported: {index_type}")
280+
231281
query_tail = get_query_tail(
232282
retrieval_query,
233283
return_properties,
234-
fallback_return=f"RETURN node {{ .*, `{embedding_node_property}`: null }} AS node, labels(node) AS nodeLabels, elementId(node) AS elementId, score",
284+
fallback_return=fallback_return,
285+
index_type=index_type,
235286
)
287+
236288
return f"{query} {query_tail}", params
237289

238290

239291
def get_query_tail(
240292
retrieval_query: Optional[str] = None,
241293
return_properties: Optional[list[str]] = None,
242294
fallback_return: Optional[str] = None,
295+
index_type: IndexType = IndexType.NODE,
243296
) -> str:
244297
"""Build the RETURN statement after the search is performed
245298
@@ -257,5 +310,10 @@ def get_query_tail(
257310
return retrieval_query
258311
if return_properties:
259312
return_properties_cypher = ", ".join([f".{prop}" for prop in return_properties])
260-
return f"RETURN node {{{return_properties_cypher}}} AS node, labels(node) AS nodeLabels, elementId(node) AS elementId, score"
313+
if index_type == IndexType.NODE:
314+
return f"RETURN node {{{return_properties_cypher}}} AS node, labels(node) AS nodeLabels, elementId(node) AS elementId, score"
315+
elif index_type == IndexType.RELATIONSHIP:
316+
return f"RETURN relationship {{{return_properties_cypher}}} AS relationship, elementId(relationship) AS elementId, score"
317+
else:
318+
raise ValueError(f"Index type is not supported: {index_type}")
261319
return fallback_return if fallback_return else ""

src/neo4j_graphrag/retrievers/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def _fetch_index_infos(self, vector_index_name: str) -> None:
122122
try:
123123
result = query_result.records[0]
124124
self._node_label = result["labels"][0]
125-
self._embedding_node_property = result["properties"][0]
125+
self._embedding_property = result["properties"][0]
126126
self._embedding_dimension = result["dimensions"]
127127
except IndexError as e:
128128
raise Exception(f"No index with name {self.index_name} found") from e

src/neo4j_graphrag/retrievers/hybrid.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def __init__(
118118
else None
119119
)
120120
self.result_formatter = validated_data.result_formatter
121-
self._embedding_node_property = None
121+
self._embedding_property = None
122122
self._embedding_dimension = None
123123
self._fetch_index_infos(self.vector_index_name)
124124

@@ -193,9 +193,9 @@ def get_search_results(
193193
parameters["query_vector"] = query_vector
194194

195195
search_query, _ = get_search_query(
196-
SearchType.HYBRID,
197-
self.return_properties,
198-
embedding_node_property=self._embedding_node_property,
196+
search_type=SearchType.HYBRID,
197+
return_properties=self.return_properties,
198+
embedding_property=self._embedding_property,
199199
neo4j_version_is_5_23_or_above=self.neo4j_version_is_5_23_or_above,
200200
)
201201
sanitized_parameters = copy.deepcopy(parameters)
@@ -358,7 +358,7 @@ def get_search_results(
358358
del parameters["query_params"]
359359

360360
search_query, _ = get_search_query(
361-
SearchType.HYBRID,
361+
search_type=SearchType.HYBRID,
362362
retrieval_query=self.retrieval_query,
363363
neo4j_version_is_5_23_or_above=self.neo4j_version_is_5_23_or_above,
364364
)

src/neo4j_graphrag/retrievers/vector.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def __init__(
121121
)
122122
self.result_formatter = validated_data.result_formatter
123123
self._node_label = None
124-
self._embedding_node_property = None
124+
self._embedding_property = None
125125
self._embedding_dimension = None
126126
self._fetch_index_infos(self.index_name)
127127

@@ -198,10 +198,10 @@ def get_search_results(
198198
del parameters["query_text"]
199199

200200
search_query, search_params = get_search_query(
201-
SearchType.VECTOR,
202-
self.return_properties,
201+
search_type=SearchType.VECTOR,
202+
return_properties=self.return_properties,
203203
node_label=self._node_label,
204-
embedding_node_property=self._embedding_node_property,
204+
embedding_property=self._embedding_property,
205205
embedding_dimension=self._embedding_dimension,
206206
filters=filters,
207207
)
@@ -361,10 +361,10 @@ def get_search_results(
361361
del parameters["query_params"]
362362

363363
search_query, search_params = get_search_query(
364-
SearchType.VECTOR,
364+
search_type=SearchType.VECTOR,
365365
retrieval_query=self.retrieval_query,
366366
node_label=self._node_label,
367-
embedding_node_property=self._node_embedding_property,
367+
embedding_property=self._node_embedding_property,
368368
embedding_dimension=self._embedding_dimension,
369369
filters=filters,
370370
)

tests/e2e/test_vector_e2e.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def test_vector_retriever_search_text(
3737
assert isinstance(results, RetrieverResult)
3838
assert len(results.items) == 5
3939
for result in results.items:
40-
assert f"'{retriever._embedding_node_property}': None" in result.content
40+
assert f"'{retriever._embedding_property}': None" in result.content
4141
assert isinstance(result, RetrieverResultItem)
4242

4343

@@ -82,7 +82,7 @@ def test_vector_retriever_search_vector(driver: Driver) -> None:
8282
assert isinstance(results, RetrieverResult)
8383
assert len(results.items) == 5
8484
for result in results.items:
85-
assert f"'{retriever._embedding_node_property}': None" in result.content
85+
assert f"'{retriever._embedding_property}': None" in result.content
8686
assert isinstance(result, RetrieverResultItem)
8787

8888

tests/unit/retrievers/test_hybrid.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def test_hybrid_search_text_happy_path(
184184
driver, vector_index_name, fulltext_index_name, embedder
185185
)
186186
retriever.neo4j_version_is_5_23_or_above = True
187-
retriever._embedding_node_property = (
187+
retriever._embedding_property = (
188188
"embedding" # variable normally filled by fetch_index_infos
189189
)
190190
retriever.driver.execute_query.return_value = [ # type: ignore
@@ -194,7 +194,7 @@ def test_hybrid_search_text_happy_path(
194194
]
195195
search_query, _ = get_search_query(
196196
SearchType.HYBRID,
197-
embedding_node_property="embedding",
197+
embedding_property="embedding",
198198
neo4j_version_is_5_23_or_above=retriever.neo4j_version_is_5_23_or_above,
199199
)
200200

@@ -349,8 +349,8 @@ def test_hybrid_retriever_return_properties(
349349
None,
350350
]
351351
search_query, _ = get_search_query(
352-
SearchType.HYBRID,
353-
return_properties,
352+
search_type=SearchType.HYBRID,
353+
return_properties=return_properties,
354354
neo4j_version_is_5_23_or_above=retriever.neo4j_version_is_5_23_or_above,
355355
)
356356

@@ -417,7 +417,7 @@ def test_hybrid_cypher_retrieval_query_with_params(
417417
None,
418418
]
419419
search_query, _ = get_search_query(
420-
SearchType.HYBRID,
420+
search_type=SearchType.HYBRID,
421421
retrieval_query=retrieval_query,
422422
neo4j_version_is_5_23_or_above=retriever.neo4j_version_is_5_23_or_above,
423423
)

tests/unit/retrievers/test_vector.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,9 @@ def test_similarity_search_text_return_properties(
241241
None,
242242
None,
243243
]
244-
search_query, _ = get_search_query(SearchType.VECTOR, return_properties)
244+
search_query, _ = get_search_query(
245+
search_type=SearchType.VECTOR, return_properties=return_properties
246+
)
245247

246248
records = retriever.search(
247249
query_text=query_text,

tests/unit/test_neo4j_queries.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def test_vector_search_with_filters(_mock: Any) -> None:
9191
result, params = get_search_query(
9292
SearchType.VECTOR,
9393
node_label="Label",
94-
embedding_node_property="vector",
94+
embedding_property="vector",
9595
embedding_dimension=1,
9696
filters={"field": "value"},
9797
)
@@ -117,7 +117,7 @@ def test_vector_search_with_params_from_filters(_mock: Any) -> None:
117117
result, params = get_search_query(
118118
SearchType.VECTOR,
119119
node_label="Label",
120-
embedding_node_property="vector",
120+
embedding_property="vector",
121121
embedding_dimension=1,
122122
filters={"field": "value"},
123123
)

0 commit comments

Comments
 (0)