Skip to content

Commit 09440e0

Browse files
authored
Add linear hybrid search ranker (#284)
* Add linear hybrid search ranker * Update CHANGELOG * Make alpha mandatory for linear ranker * Use query parameters for alpha to avoid Cypher injection * Refactor Cypher query string for linear ranker * Removed isinstance check for float in HybridSearchModel's alpha * Update E2E test for linear ranker * Remove delete of alpha from query parameters
1 parent eed1a04 commit 09440e0

File tree

8 files changed

+362
-8
lines changed

8 files changed

+362
-8
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
- Introduced Neo4jMessageHistory and InMemoryMessageHistory classes for managing LLM message histories.
1414
- Added examples and documentation for using message history with Neo4j and in-memory storage.
1515
- Updated LLM and GraphRAG classes to support new message history classes.
16-
16+
- Introduced a linear hybrid search ranker for HybridRetriever and HybridCypherRetriever, allowing customizable ranking with an `alpha` parameter.
1717
### Changed
1818

1919
- Refactored index-related functions for improved compatibility and functionality.

src/neo4j_graphrag/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,3 +124,7 @@ class PdfLoaderError(Neo4jGraphRagError):
124124

125125
class PromptMissingPlaceholderError(Neo4jGraphRagError):
126126
"""Exception raised when a prompt is missing an expected placeholder."""
127+
128+
129+
class InvalidHybridSearchRankerError(Neo4jGraphRagError):
130+
"""Exception raised when an invalid ranker type for Hybrid Search is provided."""

src/neo4j_graphrag/neo4j_queries.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515
from __future__ import annotations
1616

1717
import warnings
18-
from typing import Any, Optional
18+
from typing import Any, Optional, Union
1919

20+
from neo4j_graphrag.exceptions import InvalidHybridSearchRankerError
2021
from neo4j_graphrag.filters import get_metadata_filter
21-
from neo4j_graphrag.types import EntityType, SearchType
22+
from neo4j_graphrag.types import EntityType, SearchType, HybridSearchRanker
2223

2324
NODE_VECTOR_INDEX_QUERY = (
2425
"CALL db.index.vector.queryNodes"
@@ -171,6 +172,45 @@ def _get_hybrid_query(neo4j_version_is_5_23_or_above: bool) -> str:
171172
return call_prefix + query_body
172173

173174

175+
def _get_hybrid_query_linear(neo4j_version_is_5_23_or_above: bool, alpha: float) -> str:
176+
"""
177+
Construct a Cypher query for hybrid search using a linear combination approach with an alpha parameter.
178+
179+
This query retrieves normalized scores from both the vector index and full-text index. It then
180+
computes the final score as a weighted sum:
181+
182+
```
183+
final_score = alpha * (vector normalized score) + (1 - alpha) * (fulltext normalized score)
184+
```
185+
186+
If a node appears in only one index, the missing score is treated as 0.
187+
188+
Args:
189+
neo4j_version_is_5_23_or_above (bool): Whether the Neo4j version is 5.23 or above; determines the call syntax.
190+
alpha (float): Weight for the vector index normalized score. The full-text score is weighted as (1 - alpha).
191+
192+
Returns:
193+
str: The constructed Cypher query string.
194+
"""
195+
call_prefix = "CALL () { " if neo4j_version_is_5_23_or_above else "CALL { "
196+
197+
query_body = (
198+
f"{NODE_VECTOR_INDEX_QUERY} "
199+
"WITH collect({node: node, score: score}) AS nodes, max(score) AS vector_index_max_score "
200+
"UNWIND nodes AS n "
201+
"WITH n.node AS node, (n.score / vector_index_max_score) AS rawScore "
202+
"RETURN node, rawScore * $alpha AS score "
203+
"UNION "
204+
f"{FULL_TEXT_SEARCH_QUERY} "
205+
"WITH collect({node: node, score: score}) AS nodes, max(score) AS ft_index_max_score "
206+
"UNWIND nodes AS n "
207+
"WITH n.node AS node, (n.score / ft_index_max_score) AS rawScore "
208+
"RETURN node, rawScore * (1 - $alpha) AS score } "
209+
"WITH node, sum(score) AS score ORDER BY score DESC LIMIT $top_k"
210+
)
211+
return call_prefix + query_body
212+
213+
174214
def _get_filtered_vector_query(
175215
filters: dict[str, Any],
176216
node_label: str,
@@ -223,6 +263,8 @@ def get_search_query(
223263
filters: Optional[dict[str, Any]] = None,
224264
neo4j_version_is_5_23_or_above: bool = False,
225265
use_parallel_runtime: bool = False,
266+
ranker: Union[str, HybridSearchRanker] = HybridSearchRanker.NAIVE,
267+
alpha: Optional[float] = None,
226268
) -> tuple[str, dict[str, Any]]:
227269
"""
228270
Constructs a search query for vector or hybrid search, including optional pre-filtering
@@ -243,6 +285,8 @@ def get_search_query(
243285
neo4j_version_is_5_23_or_above (Optional[bool]): Whether the Neo4j version is 5.23 or above.
244286
use_parallel_runtime (bool): Whether or not use the parallel runtime to run the query.
245287
Defaults to False.
288+
ranker (HybridSearchRanker): Type of ranker to order the results from retrieval.
289+
alpha (Optional[float]): Weight for the vector score when using the linear ranker. Only used when ranker is 'linear'. Defaults to 0.5 if not provided.
246290
247291
Returns:
248292
tuple[str, dict[str, Any]]: A tuple containing the constructed query string and
@@ -262,7 +306,14 @@ def get_search_query(
262306
if search_type == SearchType.HYBRID:
263307
if filters:
264308
raise Exception("Filters are not supported with hybrid search")
265-
query = _get_hybrid_query(neo4j_version_is_5_23_or_above)
309+
if ranker == HybridSearchRanker.NAIVE:
310+
query = _get_hybrid_query(neo4j_version_is_5_23_or_above)
311+
elif ranker == HybridSearchRanker.LINEAR and alpha:
312+
query = _get_hybrid_query_linear(
313+
neo4j_version_is_5_23_or_above, alpha=alpha
314+
)
315+
else:
316+
raise InvalidHybridSearchRankerError()
266317
params: dict[str, Any] = {}
267318
elif search_type == SearchType.VECTOR:
268319
if filters:

src/neo4j_graphrag/retrievers/hybrid.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import copy
1818
import logging
19-
from typing import Any, Callable, Optional
19+
from typing import Any, Callable, Optional, Union
2020

2121
import neo4j
2222
from pydantic import ValidationError
@@ -39,6 +39,7 @@
3939
RawSearchResult,
4040
RetrieverResultItem,
4141
SearchType,
42+
HybridSearchRanker,
4243
)
4344

4445
logger = logging.getLogger(__name__)
@@ -142,6 +143,8 @@ def get_search_results(
142143
query_vector: Optional[list[float]] = None,
143144
top_k: int = 5,
144145
effective_search_ratio: int = 1,
146+
ranker: Union[str, HybridSearchRanker] = HybridSearchRanker.NAIVE,
147+
alpha: Optional[float] = None,
145148
) -> RawSearchResult:
146149
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
147150
Both query_vector and query_text can be provided.
@@ -162,6 +165,10 @@ def get_search_results(
162165
top_k (int, optional): The number of neighbors to return. Defaults to 5.
163166
effective_search_ratio (int): Controls the candidate pool size for the vector index by multiplying top_k to balance query
164167
accuracy and performance. Defaults to 1.
168+
ranker (str, HybridSearchRanker): Type of ranker to order the results from retrieval.
169+
alpha (Optional[float]): Weight for the vector score when using the linear ranker.
170+
The fulltext index score is multiplied by (1 - alpha).
171+
**Required** when using the linear ranker; must be between 0 and 1.
165172
166173
Raises:
167174
SearchValidationError: If validation of the input arguments fail.
@@ -176,6 +183,8 @@ def get_search_results(
176183
query_text=query_text,
177184
top_k=top_k,
178185
effective_search_ratio=effective_search_ratio,
186+
ranker=ranker,
187+
alpha=alpha,
179188
)
180189
except ValidationError as e:
181190
raise SearchValidationError(e.errors()) from e
@@ -191,13 +200,18 @@ def get_search_results(
191200
)
192201
query_vector = self.embedder.embed_query(query_text)
193202
parameters["query_vector"] = query_vector
194-
195203
search_query, _ = get_search_query(
196204
search_type=SearchType.HYBRID,
197205
return_properties=self.return_properties,
198206
embedding_node_property=self._embedding_node_property,
199207
neo4j_version_is_5_23_or_above=self.neo4j_version_is_5_23_or_above,
208+
ranker=validated_data.ranker,
209+
alpha=validated_data.alpha,
200210
)
211+
212+
if "ranker" in parameters:
213+
del parameters["ranker"]
214+
201215
sanitized_parameters = copy.deepcopy(parameters)
202216
if "query_vector" in sanitized_parameters:
203217
sanitized_parameters["query_vector"] = "..."
@@ -301,6 +315,8 @@ def get_search_results(
301315
top_k: int = 5,
302316
effective_search_ratio: int = 1,
303317
query_params: Optional[dict[str, Any]] = None,
318+
ranker: Union[str, HybridSearchRanker] = HybridSearchRanker.NAIVE,
319+
alpha: Optional[float] = None,
304320
) -> RawSearchResult:
305321
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
306322
Both query_vector and query_text can be provided.
@@ -320,7 +336,10 @@ def get_search_results(
320336
effective_search_ratio (int): Controls the candidate pool size for the vector index by multiplying top_k to balance query
321337
accuracy and performance. Defaults to 1.
322338
query_params (Optional[dict[str, Any]]): Parameters for the Cypher query. Defaults to None.
323-
339+
ranker (str, HybridSearchRanker): Type of ranker to order the results from retrieval.
340+
alpha (Optional[float]): Weight for the vector score when using the linear ranker.
341+
The fulltext index score is multiplied by (1 - alpha).
342+
**Required** when using the linear ranker; must be between 0 and 1.
324343
Raises:
325344
SearchValidationError: If validation of the input arguments fail.
326345
EmbeddingRequiredError: If no embedder is provided.
@@ -334,6 +353,8 @@ def get_search_results(
334353
query_text=query_text,
335354
top_k=top_k,
336355
effective_search_ratio=effective_search_ratio,
356+
ranker=ranker,
357+
alpha=alpha,
337358
query_params=query_params,
338359
)
339360
except ValidationError as e:
@@ -361,7 +382,13 @@ def get_search_results(
361382
search_type=SearchType.HYBRID,
362383
retrieval_query=self.retrieval_query,
363384
neo4j_version_is_5_23_or_above=self.neo4j_version_is_5_23_or_above,
385+
ranker=validated_data.ranker,
386+
alpha=validated_data.alpha,
364387
)
388+
389+
if "ranker" in parameters:
390+
del parameters["ranker"]
391+
365392
sanitized_parameters = copy.deepcopy(parameters)
366393
if "query_vector" in sanitized_parameters:
367394
sanitized_parameters["query_vector"] = "..."

src/neo4j_graphrag/types.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17+
import warnings
1718
from enum import Enum
1819
from typing import Any, Callable, Literal, Optional, TypedDict, Union
1920

@@ -25,6 +26,7 @@
2526
field_validator,
2627
model_validator,
2728
)
29+
from typing_extensions import Self
2830

2931
from neo4j_graphrag.utils.validation import validate_search_query_input
3032

@@ -137,11 +139,53 @@ class VectorCypherSearchModel(VectorSearchModel):
137139
query_params: Optional[dict[str, Any]] = None
138140

139141

142+
class HybridSearchRanker(Enum):
143+
"""Enumerator of Hybrid search rankers."""
144+
145+
NAIVE = "naive"
146+
LINEAR = "linear"
147+
148+
140149
class HybridSearchModel(BaseModel):
141150
query_text: str
142151
query_vector: Optional[list[float]] = None
143152
top_k: PositiveInt = 5
144153
effective_search_ratio: PositiveInt = 1
154+
ranker: Union[str, HybridSearchRanker] = HybridSearchRanker.NAIVE
155+
alpha: Optional[float] = None
156+
157+
@field_validator("ranker", mode="before")
158+
def validate_ranker(cls, v: Union[str, HybridSearchRanker]) -> HybridSearchRanker:
159+
if isinstance(v, str):
160+
try:
161+
return HybridSearchRanker(v.lower())
162+
except ValueError:
163+
allowed = ", ".join([r.value for r in HybridSearchRanker])
164+
raise ValueError(
165+
f"Invalid ranker value. Allowed values are: {allowed}."
166+
)
167+
elif isinstance(v, HybridSearchRanker):
168+
return v
169+
else:
170+
allowed = ", ".join([r.value for r in HybridSearchRanker])
171+
raise ValueError(f"Invalid ranker type. Allowed values are: {allowed}.")
172+
173+
@model_validator(mode="after")
174+
def validate_alpha(self) -> Self:
175+
ranker, alpha = self.ranker, self.alpha
176+
if ranker == HybridSearchRanker.LINEAR:
177+
if alpha is None:
178+
raise ValueError("alpha must be provided when using the linear ranker")
179+
if not (0.0 <= alpha <= 1.0):
180+
raise ValueError("alpha must be between 0 and 1")
181+
else:
182+
if alpha is not None:
183+
warnings.warn(
184+
"alpha parameter is only used when ranker is 'linear'. Ignoring alpha.",
185+
UserWarning,
186+
)
187+
self.alpha = None
188+
return self
145189

146190

147191
class HybridCypherSearchModel(HybridSearchModel):

tests/e2e/test_hybrid_e2e.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,3 +176,27 @@ def test_hybrid_retriever_return_properties(driver: Driver) -> None:
176176
assert len(results.items) == 5
177177
for result in results.items:
178178
assert isinstance(result, RetrieverResultItem)
179+
180+
181+
@pytest.mark.usefixtures("setup_neo4j_for_retrieval")
182+
def test_hybrid_retriever_search_text_linear_ranker(
183+
driver: Driver, random_embedder: Embedder
184+
) -> None:
185+
retriever = HybridRetriever(
186+
driver, "vector-index-name", "fulltext-index-name", random_embedder
187+
)
188+
189+
top_k = 5
190+
effective_search_ratio = 2
191+
results = retriever.search(
192+
query_text="Find me a book about Fremen",
193+
top_k=top_k,
194+
effective_search_ratio=effective_search_ratio,
195+
ranker="linear",
196+
alpha=0.9,
197+
)
198+
199+
assert isinstance(results, RetrieverResult)
200+
assert len(results.items) == 5
201+
for result in results.items:
202+
assert isinstance(result, RetrieverResultItem)

0 commit comments

Comments
 (0)