Skip to content

Commit fd6ec19

Browse files
committed
Added effective search ratio to vector and hybrid searches
1 parent 8602761 commit fd6ec19

File tree

9 files changed

+137
-23
lines changed

9 files changed

+137
-23
lines changed

src/neo4j_graphrag/neo4j_queries.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@
2121
from neo4j_graphrag.types import SearchType
2222

2323
VECTOR_INDEX_QUERY = (
24-
"CALL db.index.vector.queryNodes($vector_index_name, $top_k, $query_vector) "
25-
"YIELD node, score"
24+
"CALL db.index.vector.queryNodes($vector_index_name, $top_k * $effective_search_ratio, $query_vector) "
25+
"YIELD node, score "
26+
"WITH node, score LIMIT $top_k"
2627
)
2728

2829
VECTOR_EXACT_QUERY = (
@@ -84,7 +85,6 @@
8485
"RETURN elementId(rel)"
8586
)
8687

87-
8888
UPSERT_RELATIONSHIP_QUERY_VARIABLE_SCOPE_CLAUSE = (
8989
"UNWIND $rows as row "
9090
"MATCH (start:__KGBuilder__ {id: row.start_node_id}) "

src/neo4j_graphrag/retrievers/hybrid.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def get_search_results(
141141
query_text: str,
142142
query_vector: Optional[list[float]] = None,
143143
top_k: int = 5,
144+
effective_search_ratio: int = 1,
144145
) -> RawSearchResult:
145146
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
146147
Both query_vector and query_text can be provided.
@@ -159,6 +160,8 @@ def get_search_results(
159160
query_text (str): The text to get the closest neighbors of.
160161
query_vector (Optional[list[float]], optional): The vector embeddings to get the closest neighbors of. Defaults to None.
161162
top_k (int, optional): The number of neighbors to return. Defaults to 5.
163+
effective_search_ratio (int): Controls the candidate pool size for the vector index by multiplying top_k to balance query
164+
accuracy and performance. Defaults to 1.
162165
163166
Raises:
164167
SearchValidationError: If validation of the input arguments fail.
@@ -172,6 +175,7 @@ def get_search_results(
172175
query_vector=query_vector,
173176
query_text=query_text,
174177
top_k=top_k,
178+
effective_search_ratio=effective_search_ratio,
175179
)
176180
except ValidationError as e:
177181
raise SearchValidationError(e.errors()) from e
@@ -295,6 +299,7 @@ def get_search_results(
295299
query_text: str,
296300
query_vector: Optional[list[float]] = None,
297301
top_k: int = 5,
302+
effective_search_ratio: int = 1,
298303
query_params: Optional[dict[str, Any]] = None,
299304
) -> RawSearchResult:
300305
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
@@ -312,6 +317,8 @@ def get_search_results(
312317
query_text (str): The text to get the closest neighbors of.
313318
query_vector (Optional[list[float]]): The vector embeddings to get the closest neighbors of. Defaults to None.
314319
top_k (int): The number of neighbors to return. Defaults to 5.
320+
effective_search_ratio (int): Controls the candidate pool size for the vector index by multiplying top_k to balance query
321+
accuracy and performance. Defaults to 1.
315322
query_params (Optional[dict[str, Any]]): Parameters for the Cypher query. Defaults to None.
316323
317324
Raises:
@@ -326,6 +333,7 @@ def get_search_results(
326333
query_vector=query_vector,
327334
query_text=query_text,
328335
top_k=top_k,
336+
effective_search_ratio=effective_search_ratio,
329337
query_params=query_params,
330338
)
331339
except ValidationError as e:

src/neo4j_graphrag/retrievers/vector.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def get_search_results(
146146
query_vector: Optional[list[float]] = None,
147147
query_text: Optional[str] = None,
148148
top_k: int = 5,
149+
effective_search_ratio: int = 1,
149150
filters: Optional[dict[str, Any]] = None,
150151
) -> RawSearchResult:
151152
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
@@ -160,6 +161,8 @@ def get_search_results(
160161
query_vector (Optional[list[float]]): The vector embeddings to get the closest neighbors of. Defaults to None.
161162
query_text (Optional[str]): The text to get the closest neighbors of. Defaults to None.
162163
top_k (int): The number of neighbors to return. Defaults to 5.
164+
effective_search_ratio (int): Controls the candidate pool size by multiplying top_k to balance query accuracy and performance.
165+
Defaults to 1.
163166
filters (Optional[dict[str, Any]]): Filters for metadata pre-filtering. Defaults to None.
164167
165168
Raises:
@@ -174,6 +177,7 @@ def get_search_results(
174177
query_vector=query_vector,
175178
query_text=query_text,
176179
top_k=top_k,
180+
effective_search_ratio=effective_search_ratio,
177181
filters=filters,
178182
)
179183
except ValidationError as e:
@@ -297,6 +301,7 @@ def get_search_results(
297301
query_vector: Optional[list[float]] = None,
298302
query_text: Optional[str] = None,
299303
top_k: int = 5,
304+
effective_search_ratio: int = 1,
300305
query_params: Optional[dict[str, Any]] = None,
301306
filters: Optional[dict[str, Any]] = None,
302307
) -> RawSearchResult:
@@ -312,6 +317,8 @@ def get_search_results(
312317
query_vector (Optional[list[float]]): The vector embeddings to get the closest neighbors of. Defaults to None.
313318
query_text (Optional[str]): The text to get the closest neighbors of. Defaults to None.
314319
top_k (int): The number of neighbors to return. Defaults to 5.
320+
effective_search_ratio (int): Controls the candidate pool size by multiplying top_k to balance query accuracy and performance.
321+
Defaults to 1.
315322
query_params (Optional[dict[str, Any]]): Parameters for the Cypher query. Defaults to None.
316323
filters (Optional[dict[str, Any]]): Filters for metadata pre-filtering. Defaults to None.
317324
@@ -327,6 +334,7 @@ def get_search_results(
327334
query_vector=query_vector,
328335
query_text=query_text,
329336
top_k=top_k,
337+
effective_search_ratio=effective_search_ratio,
330338
query_params=query_params,
331339
filters=filters,
332340
)

src/neo4j_graphrag/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ class VectorSearchModel(BaseModel):
120120
query_vector: Optional[list[float]] = None
121121
query_text: Optional[str] = None
122122
top_k: PositiveInt = 5
123+
effective_search_ratio: PositiveInt = 1
123124
filters: Optional[dict[str, Any]] = None
124125

125126
@model_validator(mode="before")
@@ -140,6 +141,7 @@ class HybridSearchModel(BaseModel):
140141
query_text: str
141142
query_vector: Optional[list[float]] = None
142143
top_k: PositiveInt = 5
144+
effective_search_ratio: PositiveInt = 1
143145

144146

145147
class HybridCypherSearchModel(HybridSearchModel):

tests/e2e/test_hybrid_e2e.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,12 @@ def test_hybrid_retriever_search_text(
3333
)
3434

3535
top_k = 5
36-
results = retriever.search(query_text="Find me a book about Fremen", top_k=top_k)
36+
effective_search_ratio = 2
37+
results = retriever.search(
38+
query_text="Find me a book about Fremen",
39+
top_k=top_k,
40+
effective_search_ratio=effective_search_ratio,
41+
)
3742

3843
assert isinstance(results, RetrieverResult)
3944
assert len(results.items) == 5
@@ -51,8 +56,13 @@ def test_hybrid_retriever_no_neo4j_deprecation_warning(
5156
)
5257

5358
top_k = 5
59+
effective_search_ratio = 2
5460
with caplog.at_level(logging.WARNING):
55-
retriever.search(query_text="Find me a book about Fremen", top_k=top_k)
61+
retriever.search(
62+
query_text="Find me a book about Fremen",
63+
top_k=top_k,
64+
effective_search_ratio=effective_search_ratio,
65+
)
5666

5767
for record in caplog.records:
5868
if (
@@ -78,7 +88,12 @@ def test_hybrid_cypher_retriever_search_text(
7888
)
7989

8090
top_k = 5
81-
results = retriever.search(query_text="Find me a book about Fremen", top_k=top_k)
91+
effective_search_ratio = 2
92+
results = retriever.search(
93+
query_text="Find me a book about Fremen",
94+
top_k=top_k,
95+
effective_search_ratio=effective_search_ratio,
96+
)
8297

8398
assert isinstance(results, RetrieverResult)
8499
assert len(results.items) == 5
@@ -96,10 +111,12 @@ def test_hybrid_retriever_search_vector(driver: Driver) -> None:
96111
)
97112

98113
top_k = 5
114+
effective_search_ratio = 2
99115
results = retriever.search(
100116
query_text="Find me a book about Fremen",
101117
query_vector=[1.0 for _ in range(1536)],
102118
top_k=top_k,
119+
effective_search_ratio=effective_search_ratio,
103120
)
104121

105122
assert isinstance(results, RetrieverResult)
@@ -121,10 +138,12 @@ def test_hybrid_cypher_retriever_search_vector(driver: Driver) -> None:
121138
)
122139

123140
top_k = 5
141+
effective_search_ratio = 2
124142
results = retriever.search(
125143
query_text="Find me a book about Fremen",
126144
query_vector=[1.0 for _ in range(1536)],
127145
top_k=top_k,
146+
effective_search_ratio=effective_search_ratio,
128147
)
129148

130149
assert isinstance(results, RetrieverResult)
@@ -145,10 +164,12 @@ def test_hybrid_retriever_return_properties(driver: Driver) -> None:
145164
)
146165

147166
top_k = 5
167+
effective_search_ratio = 2
148168
results = retriever.search(
149169
query_text="Find me a book about Fremen",
150170
query_vector=[1.0 for _ in range(1536)],
151171
top_k=top_k,
172+
effective_search_ratio=effective_search_ratio,
152173
)
153174

154175
assert isinstance(results, RetrieverResult)

tests/e2e/test_vector_e2e.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,12 @@ def test_vector_retriever_search_text(
2727
retriever = VectorRetriever(driver, "vector-index-name", random_embedder)
2828

2929
top_k = 5
30-
results = retriever.search(query_text="Find me a book about Fremen", top_k=top_k)
30+
effective_search_ratio = 2
31+
results = retriever.search(
32+
query_text="Find me a book about Fremen",
33+
top_k=top_k,
34+
effective_search_ratio=effective_search_ratio,
35+
)
3136

3237
assert isinstance(results, RetrieverResult)
3338
assert len(results.items) == 5
@@ -48,7 +53,12 @@ def test_vector_cypher_retriever_search_text(
4853
)
4954

5055
top_k = 5
51-
results = retriever.search(query_text="Find me a book about Fremen", top_k=top_k)
56+
effective_search_ratio = 2
57+
results = retriever.search(
58+
query_text="Find me a book about Fremen",
59+
top_k=top_k,
60+
effective_search_ratio=effective_search_ratio,
61+
)
5262

5363
assert isinstance(results, RetrieverResult)
5464
assert len(results.items) == 5
@@ -62,7 +72,12 @@ def test_vector_retriever_search_vector(driver: Driver) -> None:
6272
retriever = VectorRetriever(driver, "vector-index-name")
6373

6474
top_k = 5
65-
results = retriever.search(query_vector=[1.0 for _ in range(1536)], top_k=top_k)
75+
effective_search_ratio = 2
76+
results = retriever.search(
77+
query_vector=[1.0 for _ in range(1536)],
78+
top_k=top_k,
79+
effective_search_ratio=effective_search_ratio,
80+
)
6681

6782
assert isinstance(results, RetrieverResult)
6883
assert len(results.items) == 5
@@ -79,7 +94,12 @@ def test_vector_cypher_retriever_search_vector(driver: Driver) -> None:
7994
retriever = VectorCypherRetriever(driver, "vector-index-name", retrieval_query)
8095

8196
top_k = 5
82-
results = retriever.search(query_vector=[1.0 for _ in range(1536)], top_k=top_k)
97+
effective_search_ratio = 2
98+
results = retriever.search(
99+
query_vector=[1.0 for _ in range(1536)],
100+
top_k=top_k,
101+
effective_search_ratio=effective_search_ratio,
102+
)
83103

84104
assert isinstance(results, RetrieverResult)
85105
assert len(results.items) == 5

tests/unit/retrievers/test_hybrid.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ def test_hybrid_search_text_happy_path(
178178
fulltext_index_name = "fulltext-index"
179179
query_text = "may thy knife chip and shatter"
180180
top_k = 5
181+
effective_search_ratio = 2
181182

182183
retriever = HybridRetriever(
183184
driver, vector_index_name, fulltext_index_name, embedder
@@ -197,13 +198,18 @@ def test_hybrid_search_text_happy_path(
197198
neo4j_version_is_5_23_or_above=retriever.neo4j_version_is_5_23_or_above,
198199
)
199200

200-
records = retriever.search(query_text=query_text, top_k=top_k)
201+
records = retriever.search(
202+
query_text=query_text,
203+
top_k=top_k,
204+
effective_search_ratio=effective_search_ratio,
205+
)
201206

202207
retriever.driver.execute_query.assert_called_once_with( # type: ignore
203208
search_query,
204209
{
205210
"vector_index_name": vector_index_name,
206211
"top_k": top_k,
212+
"effective_search_ratio": effective_search_ratio,
207213
"query_text": query_text,
208214
"fulltext_index_name": fulltext_index_name,
209215
"query_vector": embed_query_vector,
@@ -238,6 +244,7 @@ def test_hybrid_search_favors_query_vector_over_embedding_vector(
238244
fulltext_index_name = "fulltext-index"
239245
query_text = "may thy knife chip and shatter"
240246
top_k = 5
247+
effective_search_ratio = 2
241248
database = "neo4j"
242249
retriever = HybridRetriever(
243250
driver,
@@ -257,13 +264,19 @@ def test_hybrid_search_favors_query_vector_over_embedding_vector(
257264
neo4j_version_is_5_23_or_above=retriever.neo4j_version_is_5_23_or_above,
258265
)
259266

260-
retriever.search(query_text=query_text, query_vector=query_vector, top_k=top_k)
267+
retriever.search(
268+
query_text=query_text,
269+
query_vector=query_vector,
270+
top_k=top_k,
271+
effective_search_ratio=effective_search_ratio,
272+
)
261273

262274
retriever.driver.execute_query.assert_called_once_with( # type: ignore
263275
search_query,
264276
{
265277
"vector_index_name": vector_index_name,
266278
"top_k": top_k,
279+
"effective_search_ratio": effective_search_ratio,
267280
"query_text": query_text,
268281
"fulltext_index_name": fulltext_index_name,
269282
"query_vector": query_vector,
@@ -320,6 +333,7 @@ def test_hybrid_retriever_return_properties(
320333
fulltext_index_name = "fulltext-index"
321334
query_text = "may thy knife chip and shatter"
322335
top_k = 5
336+
effective_search_ratio = 2
323337
return_properties = ["node-property-1", "node-property-2"]
324338
retriever = HybridRetriever(
325339
driver,
@@ -340,14 +354,19 @@ def test_hybrid_retriever_return_properties(
340354
neo4j_version_is_5_23_or_above=retriever.neo4j_version_is_5_23_or_above,
341355
)
342356

343-
records = retriever.search(query_text=query_text, top_k=top_k)
357+
records = retriever.search(
358+
query_text=query_text,
359+
top_k=top_k,
360+
effective_search_ratio=effective_search_ratio,
361+
)
344362

345363
embedder.embed_query.assert_called_once_with(query_text)
346364
driver.execute_query.assert_called_once_with(
347365
search_query,
348366
{
349367
"vector_index_name": vector_index_name,
350368
"top_k": top_k,
369+
"effective_search_ratio": effective_search_ratio,
351370
"query_text": query_text,
352371
"fulltext_index_name": fulltext_index_name,
353372
"query_vector": embed_query_vector,
@@ -377,6 +396,7 @@ def test_hybrid_cypher_retrieval_query_with_params(
377396
fulltext_index_name = "fulltext-index"
378397
query_text = "may thy knife chip and shatter"
379398
top_k = 5
399+
effective_search_ratio = 2
380400
retrieval_query = """
381401
RETURN node.id AS node_id, node.text AS text, score, {test: $param} AS metadata
382402
"""
@@ -405,6 +425,7 @@ def test_hybrid_cypher_retrieval_query_with_params(
405425
records = retriever.search(
406426
query_text=query_text,
407427
top_k=top_k,
428+
effective_search_ratio=effective_search_ratio,
408429
query_params=query_params,
409430
)
410431

@@ -415,6 +436,7 @@ def test_hybrid_cypher_retrieval_query_with_params(
415436
{
416437
"vector_index_name": vector_index_name,
417438
"top_k": top_k,
439+
"effective_search_ratio": effective_search_ratio,
418440
"query_text": query_text,
419441
"fulltext_index_name": fulltext_index_name,
420442
"query_vector": embed_query_vector,

0 commit comments

Comments
 (0)