Skip to content

feat: Include vectorized text in search queries #953

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
from typing import List
from urllib.parse import urljoin
from azure.identity import DefaultAzureCredential, get_bearer_token_provider

Expand All @@ -15,6 +14,7 @@ class AzureComputerVisionClient:

__TOKEN_SCOPE = "https://cognitiveservices.azure.com/.default"
__VECTORIZE_IMAGE_PATH = "computervision/retrieval:vectorizeImage"
__VECTORIZE_TEXT_PATH = "computervision/retrieval:vectorizeText"
__RESPONSE_VECTOR_KEY = "vector"

def __init__(self, env_helper: EnvHelper) -> None:
Expand All @@ -27,15 +27,29 @@ def __init__(self, env_helper: EnvHelper) -> None:
env_helper.AZURE_COMPUTER_VISION_VECTORIZE_IMAGE_MODEL_VERSION
)

def vectorize_image(self, image_url: str) -> List[float]:
def vectorize_image(self, image_url: str) -> list[float]:
logger.info(f"Making call to computer vision to vectorize image: {image_url}")
response = self.__make_request(image_url)
response = self.__make_request(
self.__VECTORIZE_IMAGE_PATH,
body={"url": image_url},
)
self.__validate_response(response)

response_json = self.__get_json_body(response)
return self.__get_vectors(response_json)

def vectorize_text(self, text: str) -> list[float]:
logger.debug(f"Making call to computer vision to vectorize text: {text}")
response = self.__make_request(
self.__VECTORIZE_TEXT_PATH,
body={"text": text},
)
self.__validate_response(response)

response_json = self.__get_json_body(response)
return self.__get_vectors(response_json)

def __make_request(self, image_url: str) -> Response:
def __make_request(self, path: str, body) -> Response:
try:
headers = {}
if self.use_keys:
Expand All @@ -47,36 +61,36 @@ def __make_request(self, image_url: str) -> Response:
headers["Authorization"] = "Bearer " + token_provider()

return requests.post(
url=urljoin(self.host, self.__VECTORIZE_IMAGE_PATH),
url=urljoin(self.host, path),
params={
"api-version": self.api_version,
"model-version": self.model_version,
},
json={"url": image_url},
json=body,
headers=headers,
timeout=self.timeout,
)
except Exception as e:
raise Exception(f"Call to vectorize image failed: {image_url}") from e
raise Exception("Call to Azure Computer Vision failed") from e

def __validate_response(self, response: Response):
if response.status_code != 200:
raise Exception(
f"Call to vectorize image failed with status: {response.status_code} body: {response.text}"
f"Call to Azure Computer Vision failed with status: {response.status_code}, body: {response.text}"
)

def __get_json_body(self, response: Response) -> dict:
try:
return response.json()
except Exception as e:
raise Exception(
f"Call to vectorize image returned malformed response body: {response.text}",
f"Call to Azure Computer Vision returned malformed response body: {response.text}",
) from e

def __get_vectors(self, response_json: dict) -> List[float]:
def __get_vectors(self, response_json: dict) -> list[float]:
if self.__RESPONSE_VECTOR_KEY in response_json:
return response_json[self.__RESPONSE_VECTOR_KEY]
else:
raise Exception(
f"Call to vectorize image returned no vector: {response_json}"
f"Call to Azure Computer Vision returned no vector: {response_json}"
)
14 changes: 13 additions & 1 deletion code/backend/batch/utilities/helpers/azure_search_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
VectorSearchAlgorithmMetric,
VectorSearchProfile,
)

from ..helpers.azure_computer_vision_client import AzureComputerVisionClient
from .llm_helper import LLMHelper
from .env_helper import EnvHelper

Expand All @@ -32,6 +34,7 @@

class AzureSearchHelper:
_search_dimension: int | None = None
_image_search_dimension: int | None = None

def __init__(self):
self.llm_helper = LLMHelper()
Expand All @@ -40,6 +43,7 @@ def __init__(self):
search_credential = self._search_credential()
self.search_client = self._create_search_client(search_credential)
self.search_index_client = self._create_search_index_client(search_credential)
self.azure_computer_vision_client = AzureComputerVisionClient(self.env_helper)

def _search_credential(self):
if self.env_helper.is_auth_type_keys():
Expand Down Expand Up @@ -75,6 +79,14 @@ def search_dimensions(self) -> int:
)
return AzureSearchHelper._search_dimension

@property
def image_search_dimensions(self) -> int:
if AzureSearchHelper._image_search_dimension is None:
AzureSearchHelper._image_search_dimension = len(
self.azure_computer_vision_client.vectorize_text("Text")
)
return AzureSearchHelper._image_search_dimension

def create_index(self):
fields = [
SimpleField(
Expand Down Expand Up @@ -128,7 +140,7 @@ def create_index(self):
name="image_vector",
type=SearchFieldDataType.Collection(SearchFieldDataType.Single),
searchable=True,
vector_search_dimensions=1024,
vector_search_dimensions=self.image_search_dimensions,
vector_search_profile_name="myHnswProfile",
),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,6 @@ def __embed(
and file_extension
in self.config.get_advanced_image_processing_image_types()
):
logger.warning("Advanced image processing is not supported yet")

caption = self.__generate_image_caption(source_url)
caption_vector = self.llm_helper.generate_embeddings(caption)

Expand Down
60 changes: 52 additions & 8 deletions code/backend/batch/utilities/search/azure_search_handler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List
from .search_handler_base import SearchHandlerBase
from ..helpers.llm_helper import LLMHelper
from ..helpers.azure_computer_vision_client import AzureComputerVisionClient
from ..helpers.azure_search_helper import AzureSearchHelper
from ..common.source_document import SourceDocument
import json
Expand All @@ -9,13 +10,12 @@


class AzureSearchHandler(SearchHandlerBase):

_ENCODER_NAME = "cl100k_base"
_VECTOR_FIELD = "content_vector"

def __init__(self, env_helper):
super().__init__(env_helper)
self.llm_helper = LLMHelper()
self.azure_computer_vision_client = AzureComputerVisionClient(env_helper)

def create_search_client(self):
return AzureSearchHelper().get_search_client()
Expand Down Expand Up @@ -66,22 +66,50 @@ def delete_files(self, files):
def query_search(self, question) -> List[SourceDocument]:
encoding = tiktoken.get_encoding(self._ENCODER_NAME)
tokenised_question = encoding.encode(question)

if self.env_helper.USE_ADVANCED_IMAGE_PROCESSING:
vectorized_question = self.azure_computer_vision_client.vectorize_text(
question
)
else:
vectorized_question = None

if self.env_helper.AZURE_SEARCH_USE_SEMANTIC_SEARCH:
results = self._semantic_search(question, tokenised_question)
results = self._semantic_search(
question, tokenised_question, vectorized_question
)
else:
results = self._hybrid_search(question, tokenised_question)
results = self._hybrid_search(
question, tokenised_question, vectorized_question
)

return self._convert_to_source_documents(results)

def _semantic_search(self, question: str, tokenised_question: list[int]):
def _semantic_search(
self,
question: str,
tokenised_question: list[int],
vectorized_question: list[float] | None,
):
return self.search_client.search(
search_text=question,
vector_queries=[
VectorizedQuery(
vector=self.llm_helper.generate_embeddings(tokenised_question),
k_nearest_neighbors=self.env_helper.AZURE_SEARCH_TOP_K,
fields=self._VECTOR_FIELD,
)
),
*(
[
VectorizedQuery(
vector=vectorized_question,
k_nearest_neighbors=self.env_helper.AZURE_SEARCH_TOP_K,
fields=self._IMAGE_VECTOR_FIELD,
)
]
if vectorized_question is not None
else []
),
],
filter=self.env_helper.AZURE_SEARCH_FILTER,
query_type="semantic",
Expand All @@ -91,7 +119,12 @@ def _semantic_search(self, question: str, tokenised_question: list[int]):
top=self.env_helper.AZURE_SEARCH_TOP_K,
)

def _hybrid_search(self, question: str, tokenised_question: list[int]):
def _hybrid_search(
self,
question: str,
tokenised_question: list[int],
vectorized_question: list[float] | None,
):
return self.search_client.search(
search_text=question,
vector_queries=[
Expand All @@ -100,7 +133,18 @@ def _hybrid_search(self, question: str, tokenised_question: list[int]):
k_nearest_neighbors=self.env_helper.AZURE_SEARCH_TOP_K,
filter=self.env_helper.AZURE_SEARCH_FILTER,
fields=self._VECTOR_FIELD,
)
),
*(
[
VectorizedQuery(
vector=vectorized_question,
k_nearest_neighbors=self.env_helper.AZURE_SEARCH_TOP_K,
fields=self._IMAGE_VECTOR_FIELD,
)
]
if vectorized_question is not None
else []
),
],
query_type="simple", # this is the default value
filter=self.env_helper.AZURE_SEARCH_FILTER,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def _hybrid_search(self, question: str):
vector_query = VectorizableTextQuery(
text=question,
k_nearest_neighbors=self.env_helper.AZURE_SEARCH_TOP_K,
fields="content_vector",
fields=self._VECTOR_FIELD,
exhaustive=True,
)
return self.search_client.search(
Expand All @@ -94,7 +94,7 @@ def _semantic_search(self, question: str):
vector_query = VectorizableTextQuery(
text=question,
k_nearest_neighbors=self.env_helper.AZURE_SEARCH_TOP_K,
fields="content_vector",
fields=self._VECTOR_FIELD,
exhaustive=True,
)
return self.search_client.search(
Expand Down
7 changes: 5 additions & 2 deletions code/backend/batch/utilities/search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,21 @@
from ..search.integrated_vectorization_search_handler import (
IntegratedVectorizationSearchHandler,
)
from ..search.search_handler_base import SearchHandlerBase
from ..common.source_document import SourceDocument
from ..helpers.env_helper import EnvHelper


class Search:
@staticmethod
def get_search_handler(env_helper: EnvHelper):
def get_search_handler(env_helper: EnvHelper) -> SearchHandlerBase:
if env_helper.AZURE_SEARCH_USE_INTEGRATED_VECTORIZATION:
return IntegratedVectorizationSearchHandler(env_helper)
else:
return AzureSearchHandler(env_helper)

@staticmethod
def get_source_documents(search_handler, question) -> list[SourceDocument]:
def get_source_documents(
search_handler: SearchHandlerBase, question: str
) -> list[SourceDocument]:
return search_handler.query_search(question)
7 changes: 5 additions & 2 deletions code/backend/batch/utilities/search/search_handler_base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from abc import ABC, abstractmethod
from ..helpers.env_helper import EnvHelper

from ..common.source_document import SourceDocument
from azure.search.documents import SearchClient


class SearchHandlerBase(ABC):
_VECTOR_FIELD = "content_vector"
_IMAGE_VECTOR_FIELD = "image_vector"

def __init__(self, env_helper: EnvHelper):
self.env_helper = env_helper
self.search_client = self.create_search_client()
Expand All @@ -20,7 +23,7 @@ def get_unique_files(self, results, facet_key: str):
return []

@abstractmethod
def create_search_client(self):
def create_search_client(self) -> SearchClient:
pass

@abstractmethod
Expand Down
2 changes: 2 additions & 0 deletions code/tests/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@

COMPUTER_VISION_VECTORIZE_IMAGE_PATH = "/computervision/retrieval:vectorizeImage"
COMPUTER_VISION_VECTORIZE_IMAGE_REQUEST_METHOD = "POST"
COMPUTER_VISION_VECTORIZE_TEXT_PATH = "/computervision/retrieval:vectorizeText"
COMPUTER_VISION_VECTORIZE_TEXT_REQUEST_METHOD = "POST"
7 changes: 7 additions & 0 deletions code/tests/functional/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
AZURE_STORAGE_CONFIG_FILE_NAME,
COMPUTER_VISION_VECTORIZE_IMAGE_PATH,
COMPUTER_VISION_VECTORIZE_IMAGE_REQUEST_METHOD,
COMPUTER_VISION_VECTORIZE_TEXT_PATH,
COMPUTER_VISION_VECTORIZE_TEXT_REQUEST_METHOD,
)


Expand Down Expand Up @@ -128,6 +130,11 @@ def setup_default_mocking(httpserver: HTTPServer, app_config: AppConfig):
COMPUTER_VISION_VECTORIZE_IMAGE_REQUEST_METHOD,
).respond_with_json({"modelVersion": "2022-04-11", "vector": [1.0, 2.0, 3.0]})

httpserver.expect_request(
COMPUTER_VISION_VECTORIZE_TEXT_PATH,
COMPUTER_VISION_VECTORIZE_TEXT_REQUEST_METHOD,
).respond_with_json({"modelVersion": "2022-04-11", "vector": [1.0, 2.0, 3.0]})

httpserver.expect_request(
f"/indexes('{app_config.get('AZURE_SEARCH_INDEX')}')/docs/search.index",
method="POST",
Expand Down
2 changes: 2 additions & 0 deletions code/tests/functional/tests/backend_api/default/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def app_config(make_httpserver, ca):
"AZURE_CONTENT_SAFETY_ENDPOINT": f"https://localhost:{make_httpserver.port}/",
"AZURE_SPEECH_REGION_ENDPOINT": f"https://localhost:{make_httpserver.port}/",
"AZURE_STORAGE_ACCOUNT_ENDPOINT": f"https://localhost:{make_httpserver.port}/",
"AZURE_COMPUTER_VISION_ENDPOINT": f"https://localhost:{make_httpserver.port}/",
"USE_ADVANCED_IMAGE_PROCESSING": "True",
"SSL_CERT_FILE": ca_temp_path,
"CURL_CA_BUNDLE": ca_temp_path,
}
Expand Down
Loading
Loading