Skip to content

Add linear hybrid search ranker #284

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
- Introduced Neo4jMessageHistory and InMemoryMessageHistory classes for managing LLM message histories.
- Added examples and documentation for using message history with Neo4j and in-memory storage.
- Updated LLM and GraphRAG classes to support new message history classes.

- Introduced a linear hybrid search ranker for HybridRetriever and HybridCypherRetriever, allowing customizable ranking with an `alpha` parameter.
### Changed

- Refactored index-related functions for improved compatibility and functionality.
Expand Down
4 changes: 4 additions & 0 deletions src/neo4j_graphrag/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,7 @@ class PdfLoaderError(Neo4jGraphRagError):

class PromptMissingPlaceholderError(Neo4jGraphRagError):
"""Exception raised when a prompt is missing an expected placeholder."""


class InvalidHybridSearchRankerError(Neo4jGraphRagError):
"""Exception raised when an invalid ranker type for Hybrid Search is provided."""
57 changes: 54 additions & 3 deletions src/neo4j_graphrag/neo4j_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
from __future__ import annotations

import warnings
from typing import Any, Optional
from typing import Any, Optional, Union

from neo4j_graphrag.exceptions import InvalidHybridSearchRankerError
from neo4j_graphrag.filters import get_metadata_filter
from neo4j_graphrag.types import EntityType, SearchType
from neo4j_graphrag.types import EntityType, SearchType, HybridSearchRanker

NODE_VECTOR_INDEX_QUERY = (
"CALL db.index.vector.queryNodes"
Expand Down Expand Up @@ -171,6 +172,45 @@ def _get_hybrid_query(neo4j_version_is_5_23_or_above: bool) -> str:
return call_prefix + query_body


def _get_hybrid_query_linear(neo4j_version_is_5_23_or_above: bool, alpha: float) -> str:
"""
Construct a Cypher query for hybrid search using a linear combination approach with an alpha parameter.

This query retrieves normalized scores from both the vector index and full-text index. It then
computes the final score as a weighted sum:

```
final_score = alpha * (vector normalized score) + (1 - alpha) * (fulltext normalized score)
```

If a node appears in only one index, the missing score is treated as 0.

Args:
neo4j_version_is_5_23_or_above (bool): Whether the Neo4j version is 5.23 or above; determines the call syntax.
alpha (float): Weight for the vector index normalized score. The full-text score is weighted as (1 - alpha).

Returns:
str: The constructed Cypher query string.
"""
call_prefix = "CALL () { " if neo4j_version_is_5_23_or_above else "CALL { "

query_body = (
f"{NODE_VECTOR_INDEX_QUERY} "
"WITH collect({node: node, score: score}) AS nodes, max(score) AS vector_index_max_score "
"UNWIND nodes AS n "
"WITH n.node AS node, (n.score / vector_index_max_score) AS rawScore "
"RETURN node, rawScore * $alpha AS score "
"UNION "
f"{FULL_TEXT_SEARCH_QUERY} "
"WITH collect({node: node, score: score}) AS nodes, max(score) AS ft_index_max_score "
"UNWIND nodes AS n "
"WITH n.node AS node, (n.score / ft_index_max_score) AS rawScore "
"RETURN node, rawScore * (1 - $alpha) AS score } "
"WITH node, sum(score) AS score ORDER BY score DESC LIMIT $top_k"
)
return call_prefix + query_body


def _get_filtered_vector_query(
filters: dict[str, Any],
node_label: str,
Expand Down Expand Up @@ -223,6 +263,8 @@ def get_search_query(
filters: Optional[dict[str, Any]] = None,
neo4j_version_is_5_23_or_above: bool = False,
use_parallel_runtime: bool = False,
ranker: Union[str, HybridSearchRanker] = HybridSearchRanker.NAIVE,
alpha: Optional[float] = None,
) -> tuple[str, dict[str, Any]]:
"""
Constructs a search query for vector or hybrid search, including optional pre-filtering
Expand All @@ -243,6 +285,8 @@ def get_search_query(
neo4j_version_is_5_23_or_above (Optional[bool]): Whether the Neo4j version is 5.23 or above.
use_parallel_runtime (bool): Whether or not use the parallel runtime to run the query.
Defaults to False.
ranker (HybridSearchRanker): Type of ranker to order the results from retrieval.
alpha (Optional[float]): Weight for the vector score when using the linear ranker. Only used when ranker is 'linear'. Defaults to 0.5 if not provided.

Returns:
tuple[str, dict[str, Any]]: A tuple containing the constructed query string and
Expand All @@ -262,7 +306,14 @@ def get_search_query(
if search_type == SearchType.HYBRID:
if filters:
raise Exception("Filters are not supported with hybrid search")
query = _get_hybrid_query(neo4j_version_is_5_23_or_above)
if ranker == HybridSearchRanker.NAIVE:
query = _get_hybrid_query(neo4j_version_is_5_23_or_above)
elif ranker == HybridSearchRanker.LINEAR and alpha:
query = _get_hybrid_query_linear(
neo4j_version_is_5_23_or_above, alpha=alpha
)
else:
raise InvalidHybridSearchRankerError()
params: dict[str, Any] = {}
elif search_type == SearchType.VECTOR:
if filters:
Expand Down
33 changes: 30 additions & 3 deletions src/neo4j_graphrag/retrievers/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import copy
import logging
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, Union

import neo4j
from pydantic import ValidationError
Expand All @@ -39,6 +39,7 @@
RawSearchResult,
RetrieverResultItem,
SearchType,
HybridSearchRanker,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -142,6 +143,8 @@ def get_search_results(
query_vector: Optional[list[float]] = None,
top_k: int = 5,
effective_search_ratio: int = 1,
ranker: Union[str, HybridSearchRanker] = HybridSearchRanker.NAIVE,
alpha: Optional[float] = None,
) -> RawSearchResult:
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
Both query_vector and query_text can be provided.
Expand All @@ -162,6 +165,10 @@ def get_search_results(
top_k (int, optional): The number of neighbors to return. Defaults to 5.
effective_search_ratio (int): Controls the candidate pool size for the vector index by multiplying top_k to balance query
accuracy and performance. Defaults to 1.
ranker (str, HybridSearchRanker): Type of ranker to order the results from retrieval.
alpha (Optional[float]): Weight for the vector score when using the linear ranker.
The fulltext index score is multiplied by (1 - alpha).
**Required** when using the linear ranker; must be between 0 and 1.

Raises:
SearchValidationError: If validation of the input arguments fail.
Expand All @@ -176,6 +183,8 @@ def get_search_results(
query_text=query_text,
top_k=top_k,
effective_search_ratio=effective_search_ratio,
ranker=ranker,
alpha=alpha,
)
except ValidationError as e:
raise SearchValidationError(e.errors()) from e
Expand All @@ -191,13 +200,18 @@ def get_search_results(
)
query_vector = self.embedder.embed_query(query_text)
parameters["query_vector"] = query_vector

search_query, _ = get_search_query(
search_type=SearchType.HYBRID,
return_properties=self.return_properties,
embedding_node_property=self._embedding_node_property,
neo4j_version_is_5_23_or_above=self.neo4j_version_is_5_23_or_above,
ranker=validated_data.ranker,
alpha=validated_data.alpha,
)

if "ranker" in parameters:
del parameters["ranker"]

sanitized_parameters = copy.deepcopy(parameters)
if "query_vector" in sanitized_parameters:
sanitized_parameters["query_vector"] = "..."
Expand Down Expand Up @@ -301,6 +315,8 @@ def get_search_results(
top_k: int = 5,
effective_search_ratio: int = 1,
query_params: Optional[dict[str, Any]] = None,
ranker: Union[str, HybridSearchRanker] = HybridSearchRanker.NAIVE,
alpha: Optional[float] = None,
) -> RawSearchResult:
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
Both query_vector and query_text can be provided.
Expand All @@ -320,7 +336,10 @@ def get_search_results(
effective_search_ratio (int): Controls the candidate pool size for the vector index by multiplying top_k to balance query
accuracy and performance. Defaults to 1.
query_params (Optional[dict[str, Any]]): Parameters for the Cypher query. Defaults to None.

ranker (str, HybridSearchRanker): Type of ranker to order the results from retrieval.
alpha (Optional[float]): Weight for the vector score when using the linear ranker.
The fulltext index score is multiplied by (1 - alpha).
**Required** when using the linear ranker; must be between 0 and 1.
Raises:
SearchValidationError: If validation of the input arguments fail.
EmbeddingRequiredError: If no embedder is provided.
Expand All @@ -334,6 +353,8 @@ def get_search_results(
query_text=query_text,
top_k=top_k,
effective_search_ratio=effective_search_ratio,
ranker=ranker,
alpha=alpha,
query_params=query_params,
)
except ValidationError as e:
Expand Down Expand Up @@ -361,7 +382,13 @@ def get_search_results(
search_type=SearchType.HYBRID,
retrieval_query=self.retrieval_query,
neo4j_version_is_5_23_or_above=self.neo4j_version_is_5_23_or_above,
ranker=validated_data.ranker,
alpha=validated_data.alpha,
)

if "ranker" in parameters:
del parameters["ranker"]

sanitized_parameters = copy.deepcopy(parameters)
if "query_vector" in sanitized_parameters:
sanitized_parameters["query_vector"] = "..."
Expand Down
44 changes: 44 additions & 0 deletions src/neo4j_graphrag/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
from __future__ import annotations

import warnings
from enum import Enum
from typing import Any, Callable, Literal, Optional, TypedDict, Union

Expand All @@ -25,6 +26,7 @@
field_validator,
model_validator,
)
from typing_extensions import Self

from neo4j_graphrag.utils.validation import validate_search_query_input

Expand Down Expand Up @@ -137,11 +139,53 @@ class VectorCypherSearchModel(VectorSearchModel):
query_params: Optional[dict[str, Any]] = None


class HybridSearchRanker(Enum):
"""Enumerator of Hybrid search rankers."""

NAIVE = "naive"
LINEAR = "linear"


class HybridSearchModel(BaseModel):
query_text: str
query_vector: Optional[list[float]] = None
top_k: PositiveInt = 5
effective_search_ratio: PositiveInt = 1
ranker: Union[str, HybridSearchRanker] = HybridSearchRanker.NAIVE
alpha: Optional[float] = None

@field_validator("ranker", mode="before")
def validate_ranker(cls, v: Union[str, HybridSearchRanker]) -> HybridSearchRanker:
if isinstance(v, str):
try:
return HybridSearchRanker(v.lower())
except ValueError:
allowed = ", ".join([r.value for r in HybridSearchRanker])
raise ValueError(
f"Invalid ranker value. Allowed values are: {allowed}."
)
elif isinstance(v, HybridSearchRanker):
return v
else:
allowed = ", ".join([r.value for r in HybridSearchRanker])
raise ValueError(f"Invalid ranker type. Allowed values are: {allowed}.")

@model_validator(mode="after")
def validate_alpha(self) -> Self:
ranker, alpha = self.ranker, self.alpha
if ranker == HybridSearchRanker.LINEAR:
if alpha is None:
raise ValueError("alpha must be provided when using the linear ranker")
if not (0.0 <= alpha <= 1.0):
raise ValueError("alpha must be between 0 and 1")
else:
if alpha is not None:
warnings.warn(
"alpha parameter is only used when ranker is 'linear'. Ignoring alpha.",
UserWarning,
)
self.alpha = None
return self


class HybridCypherSearchModel(HybridSearchModel):
Expand Down
24 changes: 24 additions & 0 deletions tests/e2e/test_hybrid_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,27 @@ def test_hybrid_retriever_return_properties(driver: Driver) -> None:
assert len(results.items) == 5
for result in results.items:
assert isinstance(result, RetrieverResultItem)


@pytest.mark.usefixtures("setup_neo4j_for_retrieval")
def test_hybrid_retriever_search_text_linear_ranker(
driver: Driver, random_embedder: Embedder
) -> None:
retriever = HybridRetriever(
driver, "vector-index-name", "fulltext-index-name", random_embedder
)

top_k = 5
effective_search_ratio = 2
results = retriever.search(
query_text="Find me a book about Fremen",
top_k=top_k,
effective_search_ratio=effective_search_ratio,
ranker="linear",
alpha=0.9,
)

assert isinstance(results, RetrieverResult)
assert len(results.items) == 5
for result in results.items:
assert isinstance(result, RetrieverResultItem)
Loading