From 1ae7b844db743353353c90f4926585c0870228ee Mon Sep 17 00:00:00 2001 From: Conrado Silva Miranda Date: Thu, 21 Nov 2024 14:18:05 -0800 Subject: [PATCH] Provide CAII batch embedding for better performance --- .../app/services/CaiiEmbeddingModel.py | 29 ++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/llm-service/app/services/CaiiEmbeddingModel.py b/llm-service/app/services/CaiiEmbeddingModel.py index 136430106..8a96f9db0 100644 --- a/llm-service/app/services/CaiiEmbeddingModel.py +++ b/llm-service/app/services/CaiiEmbeddingModel.py @@ -38,7 +38,7 @@ import http.client as http_client import json import os -from typing import Any, Dict +from typing import Any, Dict, List from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding from pydantic import Field @@ -86,6 +86,33 @@ def _get_embedding(self, query: str, input_type: str) -> Embedding: return embedding + def _get_text_embeddings(self, texts: List[str]) -> List[Embedding]: + model = self.endpoint["endpointmetadata"]["model_name"] + domain = os.environ["CAII_DOMAIN"] + + connection = http_client.HTTPSConnection(domain, 443) + headers = self.build_auth_headers() + headers["Content-Type"] = "application/json" + body = json.dumps( + { + "input": texts, + "input_type": "passage", + "truncate": "END", + "model": model, + } + ) + connection.request("POST", self.endpoint["url"], body=body, headers=headers) + res = connection.getresponse() + data = res.read() + json_response = data.decode("utf-8") + structured_response = json.loads(json_response) + embeddings = structured_response["data"][0]["embedding"] + assert isinstance(embeddings, list) + assert all(isinstance(x, list) for x in embeddings) + assert all(all(isinstance(y, float) for y in x) for x in embeddings) + + return embeddings + def build_auth_headers(self) -> Dict[str, str]: with open("/tmp/jwt", "r") as file: jwt_contents = json.load(file)