Skip to content

First commit for google rerank as postprocessor #18441

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
poetry_requirements(name="poetry", module_mapping={"google-genai": ["google"]})
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
python_sources()

resource(
name="py_typed",
source="py.typed",
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from llama_index.postprocessor.google_rerank.base import GoogleReranker

__all__ = ["GoogleGenAIReranker"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import os
import json
from enum import Enum
from typing import Any, List, Optional, TypedDict

from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.callbacks import CBEventType, EventPayload
from llama_index.core.instrumentation import get_dispatcher
from llama_index.core.instrumentation.events.rerank import (
ReRankEndEvent,
ReRankStartEvent,
)
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.schema import MetadataMode, NodeWithScore, QueryBundle

import google.auth
from google.cloud import discoveryengine_v1 as discoveryengine


dispatcher = get_dispatcher(__name__)

class VertexAIConfig(TypedDict):
credentials: Optional[google.auth.credentials.Credentials] = None
project: Optional[str] = None
location: Optional[str] = None

class Models(str, Enum):
SEMANTIC_RERANK_512_003 = "semantic-ranker-512-003"


class GoogleRerank(BaseNodePostprocessor):
top_n: int = Field(default=2, description="Top N nodes to return.")
rerank_model_name: str = Field(
default=Models.SEMANTIC_RERANK_512_003.value,
description="The modelId of the VertexAI model to use.",
)
_client: discoveryengine.RankServiceClient = PrivateAttr()
_ranking_config: str = PrivateAttr()

def __init__(
self,
top_n: int = 2,
rerank_model_name: str = Models.SEMANTIC_RERANK_512_003.value,
client: Optional[discoveryengine.RankServiceClient] = None,
vertexai_config: Optional[VertexAIConfig] = None,
ranking_config: Optional[Any] = "default_ranking_config",
**kwargs: Any,
):
super().__init__(**kwargs)
self.top_n = top_n
self.rerank_model_name = rerank_model_name

project = (vertexai_config or {}).get("project") or os.getenv(
"GOOGLE_CLOUD_PROJECT", None
)
credentials = json.loads((vertexai_config or {}).get("credentials") or os.getenv(
"GOOGLE_CLOUD_CREDENTIALS", None
))
location = (vertexai_config or {}).get("location") or os.getenv(
"GOOGLE_CLOUD_LOCATION", None
)

if client is not None:
self._client = client
self._ranking_config = client.ranking_config_path(
project=project,
location=location,
ranking_config=ranking_config,
)

elif vertexai_config is not None:
self._client = discoveryengine.RankServiceClient(credentials=credentials)
self._ranking_config = self._client.ranking_config_path(
project=project,
location=location,
ranking_config=ranking_config,
)

@classmethod
def class_name(cls) -> str:
return "GoogleRerank"

def _postprocess_nodes(
self,
nodes: List[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> List[NodeWithScore]:
if dispatcher:
dispatcher.event(
ReRankStartEvent(
query=query_bundle,
nodes=nodes,
top_n=self.top_n,
model_name=self.rerank_model_name,
)
)

if query_bundle is None:
raise ValueError("Missing query bundle in extra info.")
if len(nodes) == 0:
return []

with self.callback_manager.event(
CBEventType.RERANKING,
payload={
EventPayload.NODES: nodes,
EventPayload.MODEL_NAME: self.rerank_model_name,
EventPayload.QUERY_STR: query_bundle.query_str,
EventPayload.TOP_K: self.top_n,
},
) as event:

# Prepare the text sources for Google Reranker
text_sources = []
for index, node in enumerate(nodes):
text_sources.append(
discoveryengine.RankingRecord(
id=str(index),
content=node.node.get_content(metadata_mode=MetadataMode.EMBED),
),
)
# change top_n if the number of nodes is less than top_n
if len(nodes) < self.top_n:
self.top_n = len(nodes)

try:
request = discoveryengine.RankRequest(
ranking_config=self._ranking_config,
model=self.rerank_model_name,
top_n=self.top_n,
query=query_bundle.query_str,
records=text_sources
)
response = self._client.rank(request=request)

results = response["records"]
except Exception as e:
raise RuntimeError(f"Failed to invoke VertexAI model: {e}")

new_nodes = []
for result in results:
index = int(result["id"])
relevance_score = result.get("score", 0.0)
new_node_with_score = NodeWithScore(
node=nodes[index].node,
score=relevance_score,
)
new_nodes.append(new_node_with_score)

event.on_end(payload={EventPayload.NODES: new_nodes})

dispatcher.event(ReRankEndEvent(nodes=new_nodes))
return new_nodes
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
python_tests(
interpreter_constraints=[">=3.10"],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import os
import json
from unittest import TestCase, mock

from google.oauth2 import service_account
from google.cloud import discoveryengine_v1 as discoveryengine

from llama_index.core.postprocessor.types import (
BaseNodePostprocessor,
NodeWithScore,
QueryBundle,
)
from llama_index.core.schema import TextNode
from llama_index.postprocessor.google_rerank import GoogleRerank


class TestAWSBedrockRerank(TestCase):
def test_class(self):
names_of_base_classes = [b.__name__ for b in GoogleRerank.__mro__]
self.assertIn(BaseNodePostprocessor.__name__, names_of_base_classes)

def test_bedrock_rerank(self):
exp_rerank_response = {
"records": [
{
"id": "2",
"score": 0.9,
},
{
"id": "3",
"score": 0.8,
},
]
}

input_nodes = [
NodeWithScore(node=TextNode(id_="1", text="first 1")),
NodeWithScore(node=TextNode(id_="2", text="first 2")),
NodeWithScore(node=TextNode(id_="3", text="last 1")),
NodeWithScore(node=TextNode(id_="4", text="last 2")),
]

expected_nodes = [
NodeWithScore(node=TextNode(id_="3", text="last 1"), score=0.9),
NodeWithScore(node=TextNode(id_="4", text="last 2"), score=0.8),
]

gcp_param = json.loads(os.getenv("GOOGLE_CLOUD_CREDENTIALS", None))
google_credentials = service_account.Credentials.from_service_account_info(gcp_param)
reranker_client = discoveryengine.RankServiceClient(credentials=google_credentials)
reranker = GoogleRerank(client=reranker_client, num_results=2)

with mock.patch.object(
reranker_client, "rerank", return_value=exp_rerank_response
):
query_bundle = QueryBundle(query_str="last")

actual_nodes = reranker.postprocess_nodes(
input_nodes, query_bundle=query_bundle
)

self.assertEqual(len(actual_nodes), len(expected_nodes))
for actual_node_with_score, expected_node_with_score in zip(
actual_nodes, expected_nodes
):
self.assertEqual(
actual_node_with_score.node.get_content(),
expected_node_with_score.node.get_content(),
)
self.assertAlmostEqual(
actual_node_with_score.score, expected_node_with_score.score
)
Loading