From b5eec740ec5c987218dc5e730dbd930b193780fc Mon Sep 17 00:00:00 2001 From: Conrado Silva Miranda Date: Wed, 20 Nov 2024 16:06:05 -0800 Subject: [PATCH 1/3] Add ruff and mypy checks to CI --- .github/workflows/pr_build.yml | 10 +++ llm-service/__init__.py | 38 --------- llm-service/app/exceptions.py | 1 - llm-service/app/main.py | 2 +- llm-service/app/routers/index/__init__.py | 4 - .../app/routers/index/amp_update/__init__.py | 22 ++++- .../app/routers/index/data_source/__init__.py | 11 +-- .../app/routers/index/models/__init__.py | 8 +- .../app/routers/index/sessions/__init__.py | 49 +++++++---- .../app/services/CaiiEmbeddingModel.py | 25 +++--- llm-service/app/services/CaiiModel.py | 42 +++++----- llm-service/app/services/amp_update.py | 49 ++++++++--- llm-service/app/services/caii.py | 83 +++++++++++-------- llm-service/app/services/chat.py | 5 +- llm-service/app/services/chat_store.py | 45 ++++++---- llm-service/app/services/doc_summaries.py | 35 ++++---- llm-service/app/services/evaluators.py | 10 +-- llm-service/app/services/llama_utils.py | 32 ++++--- llm-service/app/services/llm_completion.py | 15 ++-- llm-service/app/services/models.py | 10 +-- llm-service/app/services/qdrant.py | 43 ++++++---- .../app/services/rag_qdrant_vector_store.py | 12 +-- llm-service/app/services/rag_vector_store.py | 2 +- llm-service/app/services/utils.py | 4 +- llm-service/app/services/vector_store.py | 8 +- llm-service/app/tests/conftest.py | 72 ++++++++-------- .../tests/routers/index/test_data_source.py | 63 ++++++++------ .../tests/routers/index/test_doc_summaries.py | 46 +++++++--- llm-service/pdm.lock | 70 +++++++++++++--- llm-service/pyproject.toml | 36 +++++--- 30 files changed, 517 insertions(+), 335 deletions(-) delete mode 100644 llm-service/__init__.py diff --git a/.github/workflows/pr_build.yml b/.github/workflows/pr_build.yml index 5496b180e..00f9b2671 100644 --- a/.github/workflows/pr_build.yml +++ b/.github/workflows/pr_build.yml @@ -70,6 +70,16 @@ jobs: pdm install working-directory: llm-service + - name: Run ruff + run: | + pdm run ruff check app + working-directory: llm-service + + - name: Run mypy + run: | + pdm run mypy app + working-directory: llm-service + - name: Test with pytest run: | pdm run pytest -sxvvra diff --git a/llm-service/__init__.py b/llm-service/__init__.py deleted file mode 100644 index 48cd003a9..000000000 --- a/llm-service/__init__.py +++ /dev/null @@ -1,38 +0,0 @@ -# ############################################################################## -# CLOUDERA APPLIED MACHINE LEARNING PROTOTYPE (AMP) -# (C) Cloudera, Inc. 2024 -# All rights reserved. -# -# Applicable Open Source License: Apache 2.0 -# -# NOTE: Cloudera open source products are modular software products -# made up of hundreds of individual components, each of which was -# individually copyrighted. Each Cloudera open source product is a -# collective work under U.S. Copyright Law. Your license to use the -# collective work is as provided in your written agreement with -# Cloudera. Used apart from the collective work, this file is -# licensed for your use pursuant to the open source license -# identified above. -# -# This code is provided to you pursuant a written agreement with -# (i) Cloudera, Inc. or (ii) a third-party authorized to distribute -# this code. If you do not have a written agreement with Cloudera nor -# with an authorized and properly licensed third party, you do not -# have any rights to access nor to use this code. -# -# Absent a written agreement with Cloudera, Inc. (“Cloudera”) to the -# contrary, A) CLOUDERA PROVIDES THIS CODE TO YOU WITHOUT WARRANTIES OF ANY -# KIND; (B) CLOUDERA DISCLAIMS ANY AND ALL EXPRESS AND IMPLIED -# WARRANTIES WITH RESPECT TO THIS CODE, INCLUDING BUT NOT LIMITED TO -# IMPLIED WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY AND -# FITNESS FOR A PARTICULAR PURPOSE; (C) CLOUDERA IS NOT LIABLE TO YOU, -# AND WILL NOT DEFEND, INDEMNIFY, NOR HOLD YOU HARMLESS FOR ANY CLAIMS -# ARISING FROM OR RELATED TO THE CODE; AND (D)WITH RESPECT TO YOUR EXERCISE -# OF ANY RIGHTS GRANTED TO YOU FOR THE CODE, CLOUDERA IS NOT LIABLE FOR ANY -# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, PUNITIVE OR -# CONSEQUENTIAL DAMAGES INCLUDING, BUT NOT LIMITED TO, DAMAGES -# RELATED TO LOST REVENUE, LOST PROFITS, LOSS OF INCOME, LOSS OF -# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF -# DATA. -# ############################################################################## - diff --git a/llm-service/app/exceptions.py b/llm-service/app/exceptions.py index 9d14ed72f..9e17c8daf 100644 --- a/llm-service/app/exceptions.py +++ b/llm-service/app/exceptions.py @@ -119,7 +119,6 @@ def banana(): """ if inspect.iscoroutinefunction(f): - # for coroutines, the wrapper must be declared async, # and the wrapped function's result must be awaited @functools.wraps(f) diff --git a/llm-service/app/main.py b/llm-service/app/main.py index c586d71b1..995ae75d6 100644 --- a/llm-service/app/main.py +++ b/llm-service/app/main.py @@ -115,7 +115,7 @@ def _configure_app_logger(app_logger: logging.Logger) -> None: app_logger.setLevel(settings.rag_log_level) -def initialize_logging(): +def initialize_logging() -> None: logger.info("Initializing logging.") _configure_app_logger(logging.getLogger("uvicorn.access")) diff --git a/llm-service/app/routers/index/__init__.py b/llm-service/app/routers/index/__init__.py index 35607c7e0..4a46ab4ad 100644 --- a/llm-service/app/routers/index/__init__.py +++ b/llm-service/app/routers/index/__init__.py @@ -55,7 +55,3 @@ # include this for legacy UI calls router.include_router(amp_update.router, prefix="/index", deprecated=True) router.include_router(models.router) - - - - diff --git a/llm-service/app/routers/index/amp_update/__init__.py b/llm-service/app/routers/index/amp_update/__init__.py index bcec0f05e..1efffa91d 100644 --- a/llm-service/app/routers/index/amp_update/__init__.py +++ b/llm-service/app/routers/index/amp_update/__init__.py @@ -43,22 +43,36 @@ from .... import exceptions from ....services.amp_update import check_amp_update_status -router = APIRouter(prefix="/amp-update" , tags=["AMP Update"]) +router = APIRouter(prefix="/amp-update", tags=["AMP Update"]) + @router.get("", summary="Returns a boolean for whether AMP needs updating.") @exceptions.propagates def amp_up_to_date_status() -> bool: return check_amp_update_status() + @router.post("", summary="Updates AMP.") @exceptions.propagates def update_amp() -> str: - print(subprocess.run(["python /home/cdsw/llm-service/scripts/run_refresh_job.py"], shell=True, check=True)) + print( + subprocess.run( + ["python /home/cdsw/llm-service/scripts/run_refresh_job.py"], + shell=True, + check=True, + ) + ) return "OK" + @router.get("/job-status", summary="Get AMP Status.") @exceptions.propagates def get_amp_status() -> str: - process: CompletedProcess[bytes] = subprocess.run(["python /home/cdsw/llm-service/scripts/get_job_run_status.py"], shell=True, check=True, capture_output=True) - stdout = process.stdout.decode('utf-8') + process: CompletedProcess[bytes] = subprocess.run( + ["python /home/cdsw/llm-service/scripts/get_job_run_status.py"], + shell=True, + check=True, + capture_output=True, + ) + stdout = process.stdout.decode("utf-8") return stdout.strip() diff --git a/llm-service/app/routers/index/data_source/__init__.py b/llm-service/app/routers/index/data_source/__init__.py index aa7e133eb..a3d544bf3 100644 --- a/llm-service/app/routers/index/data_source/__init__.py +++ b/llm-service/app/routers/index/data_source/__init__.py @@ -100,7 +100,6 @@ def delete_document(data_source_id: int, doc_id: str) -> None: doc_summaries.delete_document(data_source_id, doc_id) - class RagIndexDocumentRequest(BaseModel): s3_bucket_name: str s3_document_key: str @@ -116,17 +115,13 @@ class RagIndexDocumentRequest(BaseModel): ) @exceptions.propagates def download_and_index( - data_source_id: int, - request: RagIndexDocumentRequest, + data_source_id: int, + request: RagIndexDocumentRequest, ) -> str: with tempfile.TemporaryDirectory() as tmpdirname: logger.debug("created temporary directory %s", tmpdirname) s3.download(tmpdirname, request.s3_bucket_name, request.s3_document_key) qdrant.download_and_index( - tmpdirname, - data_source_id, - request.configuration, - request.s3_document_key + tmpdirname, data_source_id, request.configuration, request.s3_document_key ) return http.HTTPStatus.OK.phrase - diff --git a/llm-service/app/routers/index/models/__init__.py b/llm-service/app/routers/index/models/__init__.py index 83ff3efa2..7235c9c38 100644 --- a/llm-service/app/routers/index/models/__init__.py +++ b/llm-service/app/routers/index/models/__init__.py @@ -35,7 +35,7 @@ # BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF # DATA. # -from typing import Literal +from typing import Any, Dict, List, Literal from fastapi import APIRouter @@ -54,13 +54,13 @@ @router.get("/llm", summary="Get LLM Inference models.") @exceptions.propagates -def get_llm_models() -> list: +def get_llm_models() -> List[Dict[str, Any]]: return get_available_llm_models() @router.get("/embeddings", summary="Get LLM Embedding models.") @exceptions.propagates -def get_llm_embedding_models() -> list: +def get_llm_embedding_models() -> List[Dict[str, Any]]: return get_available_embedding_models() @@ -79,4 +79,4 @@ def llm_model_test(model_name: str) -> Literal["ok"]: @router.get("/embedding/{model_name}/test", summary="Test Embedding model.") @exceptions.propagates def embedding_model_test(model_name: str) -> str: - return test_embedding_model(model_name) \ No newline at end of file + return test_embedding_model(model_name) diff --git a/llm-service/app/routers/index/sessions/__init__.py b/llm-service/app/routers/index/sessions/__init__.py index ca514f8a8..284d94814 100644 --- a/llm-service/app/routers/index/sessions/__init__.py +++ b/llm-service/app/routers/index/sessions/__init__.py @@ -39,27 +39,35 @@ import uuid from fastapi import APIRouter - from pydantic import BaseModel from .... import exceptions +from ....services import llm_completion, qdrant +from ....services.chat import generate_suggested_questions, v2_chat from ....services.chat_store import RagStudioChatMessage, chat_store -from ....services import qdrant, llm_completion -from ....services.chat import (v2_chat, generate_suggested_questions) +from ....services.qdrant import RagPredictConfiguration router = APIRouter(prefix="/sessions/{session_id}", tags=["Sessions"]) -@router.get("/chat-history", summary="Returns an array of chat messages for the provided session.") + +@router.get( + "/chat-history", + summary="Returns an array of chat messages for the provided session.", +) @exceptions.propagates def chat_history(session_id: int) -> list[RagStudioChatMessage]: return chat_store.retrieve_chat_history(session_id=session_id) -@router.delete("/chat-history", summary="Deletes the chat history for the provided session.") + +@router.delete( + "/chat-history", summary="Deletes the chat history for the provided session." +) @exceptions.propagates def clear_chat_history(session_id: int) -> str: chat_store.clear_chat_history(session_id=session_id) return "Chat history cleared." + @router.delete("", summary="Deletes the requested session.") @exceptions.propagates def delete_chat_history(session_id: int) -> str: @@ -70,24 +78,29 @@ def delete_chat_history(session_id: int) -> str: class RagStudioChatRequest(BaseModel): data_source_id: int query: str - configuration: qdrant.RagPredictConfiguration + configuration: RagPredictConfiguration @router.post("/chat", summary="Chat with your documents in the requested datasource") @exceptions.propagates def chat( - session_id: int, - request: RagStudioChatRequest, + session_id: int, + request: RagStudioChatRequest, ) -> RagStudioChatMessage: if request.configuration.exclude_knowledge_base: return llm_talk(session_id, request) - return v2_chat(session_id, request.data_source_id, request.query, request.configuration) + return v2_chat( + session_id, request.data_source_id, request.query, request.configuration + ) + def llm_talk( - session_id: int, - request: RagStudioChatRequest, + session_id: int, + request: RagStudioChatRequest, ) -> RagStudioChatMessage: - chat_response = llm_completion.completion(session_id, request.query, request.configuration) + chat_response = llm_completion.completion( + session_id, request.query, request.configuration + ) new_chat_message = RagStudioChatMessage( id=str(uuid.uuid4()), source_nodes=[], @@ -96,7 +109,7 @@ def llm_talk( "user": request.query, "assistant": chat_response.message.content, }, - timestamp=time.time() + timestamp=time.time(), ) chat_store.append_to_history(session_id, [new_chat_message]) return new_chat_message @@ -104,20 +117,22 @@ def llm_talk( class SuggestQuestionsRequest(BaseModel): data_source_id: int - configuration: qdrant.RagPredictConfiguration = qdrant.RagPredictConfiguration() + configuration: RagPredictConfiguration = RagPredictConfiguration() + class RagSuggestedQuestionsResponse(BaseModel): suggested_questions: list[str] + @router.post("/suggest-questions", summary="Suggest questions with context") @exceptions.propagates def suggest_questions( - session_id: int, - request: SuggestQuestionsRequest, + session_id: int, + request: SuggestQuestionsRequest, ) -> RagSuggestedQuestionsResponse: data_source_size = qdrant.size_of(request.data_source_id) qdrant.check_data_source_exists(data_source_size) suggested_questions = generate_suggested_questions( request.configuration, request.data_source_id, data_source_size, session_id ) - return RagSuggestedQuestionsResponse(suggested_questions=suggested_questions) \ No newline at end of file + return RagSuggestedQuestionsResponse(suggested_questions=suggested_questions) diff --git a/llm-service/app/services/CaiiEmbeddingModel.py b/llm-service/app/services/CaiiEmbeddingModel.py index d0ed2ba8f..ca8019e15 100644 --- a/llm-service/app/services/CaiiEmbeddingModel.py +++ b/llm-service/app/services/CaiiEmbeddingModel.py @@ -63,17 +63,19 @@ def _get_query_embedding(self, query: str) -> Embedding: def _get_embedding(self, query: str, input_type: str) -> Embedding: model = self.endpoint["endpointmetadata"]["model_name"] - domain = os.environ['CAII_DOMAIN'] + domain = os.environ["CAII_DOMAIN"] connection = http_client.HTTPSConnection(domain, 443) headers = self.build_auth_headers() headers["Content-Type"] = "application/json" - body = json.dumps({ - "input": query, - "input_type": input_type, - "truncate": "END", - "model": model - }) + body = json.dumps( + { + "input": query, + "input_type": input_type, + "truncate": "END", + "model": model, + } + ) connection.request("POST", self.endpoint["url"], body=body, headers=headers) res = connection.getresponse() data = res.read() @@ -83,12 +85,9 @@ def _get_embedding(self, query: str, input_type: str) -> Embedding: return embedding - def build_auth_headers(self) -> dict: - with open('/tmp/jwt', 'r') as file: + with open("/tmp/jwt", "r") as file: jwt_contents = json.load(file) - access_token = jwt_contents['access_token'] - headers = { - "Authorization": f"Bearer {access_token}" - } + access_token = jwt_contents["access_token"] + headers = {"Authorization": f"Bearer {access_token}"} return headers diff --git a/llm-service/app/services/CaiiModel.py b/llm-service/app/services/CaiiModel.py index b457621fa..6ded40539 100644 --- a/llm-service/app/services/CaiiModel.py +++ b/llm-service/app/services/CaiiModel.py @@ -40,6 +40,7 @@ from llama_index.llms.mistralai.base import MistralAI from llama_index.llms.openai import OpenAI + class CaiiModel(OpenAI): context: int = Field( description="The context size", @@ -47,20 +48,22 @@ class CaiiModel(OpenAI): ) def __init__( - self, - model: str, - context: int, - api_base: str, - messages_to_prompt, - completion_to_prompt, - default_headers): + self, + model: str, + context: int, + api_base: str, + messages_to_prompt, + completion_to_prompt, + default_headers, + ): super().__init__( model=model, api_base=api_base, messages_to_prompt=messages_to_prompt, completion_to_prompt=completion_to_prompt, default_headers=default_headers, - context=context) + context=context, + ) self.context = context @property @@ -76,24 +79,25 @@ def metadata(self) -> LLMMetadata: class CaiiModelMistral(MistralAI): - def __init__( - self, - model: str, - context: int, - api_base: str, - messages_to_prompt, - completion_to_prompt, - default_headers): + self, + model: str, + context: int, + api_base: str, + messages_to_prompt, + completion_to_prompt, + default_headers, + ): super().__init__( api_key=default_headers.get("Authorization"), model=model, - endpoint=api_base.removesuffix("/v1"), # mistral expects the base url without the /v1 + endpoint=api_base.removesuffix( + "/v1" + ), # mistral expects the base url without the /v1 messages_to_prompt=messages_to_prompt, - completion_to_prompt=completion_to_prompt + completion_to_prompt=completion_to_prompt, ) - @property def metadata(self) -> LLMMetadata: ## todo: pull this info from somewhere diff --git a/llm-service/app/services/amp_update.py b/llm-service/app/services/amp_update.py index 3bd14e969..74ed5250c 100644 --- a/llm-service/app/services/amp_update.py +++ b/llm-service/app/services/amp_update.py @@ -38,47 +38,75 @@ import subprocess + def get_current_git_hash() -> str: """Retrieve the current git hash of the deployed AMP.""" try: - current_hash = subprocess.check_output(["git", "rev-parse", "HEAD"]).strip().decode("utf-8") + current_hash = ( + subprocess.check_output(["git", "rev-parse", "HEAD"]) + .strip() + .decode("utf-8") + ) return current_hash except subprocess.CalledProcessError: raise ValueError("Failed to retrieve current git hash.") -def get_latest_git_hash(current_branch) -> str: +def get_latest_git_hash(current_branch: str) -> str: """Retrieve the latest git hash from the remote repository for the current branch.""" try: # Fetch the latest updates from the remote subprocess.check_call(["git", "fetch", "origin", current_branch]) # Get the latest hash for the current branch - latest_hash = subprocess.check_output(["git", "rev-parse", f"origin/{current_branch}"]).strip().decode("utf-8") + latest_hash = ( + subprocess.check_output(["git", "rev-parse", f"origin/{current_branch}"]) + .strip() + .decode("utf-8") + ) return latest_hash except subprocess.CalledProcessError: - raise ValueError(f"Failed to retrieve latest git hash from remote for the branch: {current_branch}.") + raise ValueError( + f"Failed to retrieve latest git hash from remote for the branch: {current_branch}." + ) -def check_if_ahead_or_behind(current_hash, current_branch) -> tuple[int, int]: +def check_if_ahead_or_behind(current_hash: str, current_branch: str) -> tuple[int, int]: """Check if the current commit is ahead or behind the remote branch.""" try: # Get the number of commits ahead or behind - ahead_behind = subprocess.check_output( - ["git", "rev-list", "--left-right", "--count", f"{current_hash}...origin/{current_branch}"] - ).strip().decode("utf-8") + ahead_behind = ( + subprocess.check_output( + [ + "git", + "rev-list", + "--left-right", + "--count", + f"{current_hash}...origin/{current_branch}", + ] + ) + .strip() + .decode("utf-8") + ) ahead, behind = map(int, ahead_behind.split()) return ahead, behind except subprocess.CalledProcessError: - raise ValueError(f"Failed to determine if the branch {current_branch} is ahead or behind.") + raise ValueError( + f"Failed to determine if the branch {current_branch} is ahead or behind." + ) + def check_amp_update_status() -> bool: """Check if the AMP is up-to-date.""" # Retrieve the current branch only once - current_branch = subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"]).strip().decode("utf-8") + current_branch = ( + subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"]) + .strip() + .decode("utf-8") + ) # Retrieve the current and latest git hashes current_hash = get_current_git_hash() @@ -90,4 +118,3 @@ def check_amp_update_status() -> bool: return True return False - diff --git a/llm-service/app/services/caii.py b/llm-service/app/services/caii.py index 8ae247f05..fa2855445 100644 --- a/llm-service/app/services/caii.py +++ b/llm-service/app/services/caii.py @@ -35,45 +35,45 @@ # BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF # DATA. # -import requests import json import os +from typing import Any, Dict, List +import requests from fastapi import HTTPException from llama_index.core.base.embeddings.base import BaseEmbedding from llama_index.core.llms import LLM -from .CaiiModel import CaiiModel, CaiiModelMistral from .CaiiEmbeddingModel import CaiiEmbeddingModel +from .CaiiModel import CaiiModel, CaiiModelMistral -def describe_endpoint(domain: str, endpoint_name: str): - with open('/tmp/jwt', 'r') as file: + +def describe_endpoint(domain: str, endpoint_name: str) -> Any: + with open("/tmp/jwt", "r") as file: jwt_contents = json.load(file) - access_token = jwt_contents['access_token'] + access_token = jwt_contents["access_token"] - headers = { - "Authorization": f"Bearer {access_token}" - } - describe_url=f"https://{domain}/api/v1alpha1/describeEndpoint" - desc_json = { - "name": endpoint_name, - "namespace": "serving-default" - } + headers = {"Authorization": f"Bearer {access_token}"} + describe_url = f"https://{domain}/api/v1alpha1/describeEndpoint" + desc_json = {"name": endpoint_name, "namespace": "serving-default"} desc = requests.post(describe_url, headers=headers, json=desc_json) if desc.status_code == 404: - raise HTTPException(status_code=404, detail = f"Endpoint '{endpoint_name}' not found") + raise HTTPException( + status_code=404, detail=f"Endpoint '{endpoint_name}' not found" + ) return json.loads(desc.content) -def get_llm(domain: str, endpoint_name: str, messages_to_prompt, completion_to_prompt) -> LLM: + +def get_llm( + domain: str, endpoint_name: str, messages_to_prompt, completion_to_prompt +) -> LLM: endpoint = describe_endpoint(domain=domain, endpoint_name=endpoint_name) api_base = endpoint["url"].removesuffix("/chat/completions") - with open('/tmp/jwt', 'r') as file: + with open("/tmp/jwt", "r") as file: jwt_contents = json.load(file) - access_token = jwt_contents['access_token'] - headers = { - "Authorization": f"Bearer {access_token}" - } + access_token = jwt_contents["access_token"] + headers = {"Authorization": f"Bearer {access_token}"} model = endpoint["endpointmetadata"]["model_name"] if "mistral" in endpoint_name.lower(): @@ -98,22 +98,28 @@ def get_llm(domain: str, endpoint_name: str, messages_to_prompt, completion_to_p return llm + def get_embedding_model() -> BaseEmbedding: - domain = os.environ['CAII_DOMAIN'] - endpoint_name = os.environ['CAII_EMBEDDING_ENDPOINT_NAME'] + domain = os.environ["CAII_DOMAIN"] + endpoint_name = os.environ["CAII_EMBEDDING_ENDPOINT_NAME"] endpoint = describe_endpoint(domain=domain, endpoint_name=endpoint_name) return CaiiEmbeddingModel(endpoint=endpoint) + ### metadata methods below here -def get_caii_llm_models(): - domain = os.environ['CAII_DOMAIN'] - endpoint_name = os.environ['CAII_INFERENCE_ENDPOINT_NAME'] + +def get_caii_llm_models() -> List[Dict[str, Any]]: + domain = os.environ["CAII_DOMAIN"] + endpoint_name = os.environ["CAII_INFERENCE_ENDPOINT_NAME"] try: models = describe_endpoint(domain=domain, endpoint_name=endpoint_name) except requests.exceptions.ConnectionError as e: print(e) - raise HTTPException(status_code=421, detail = f"Unable to connect to host {domain}. Please check your CAII_DOMAIN env variable.") + raise HTTPException( + status_code=421, + detail=f"Unable to connect to host {domain}. Please check your CAII_DOMAIN env variable.", + ) except HTTPException as e: if e.status_code == 404: return [{"model_id": endpoint_name}] @@ -121,17 +127,21 @@ def get_caii_llm_models(): raise e return build_model_response(models) -def get_caii_embedding_models(): + +def get_caii_embedding_models() -> List[Dict[str, Any]]: # notes: # NameResolutionError is we can't contact the CAII_DOMAIN - domain = os.environ['CAII_DOMAIN'] - endpoint_name = os.environ['CAII_EMBEDDING_ENDPOINT_NAME'] + domain = os.environ["CAII_DOMAIN"] + endpoint_name = os.environ["CAII_EMBEDDING_ENDPOINT_NAME"] try: models = describe_endpoint(domain=domain, endpoint_name=endpoint_name) except requests.exceptions.ConnectionError as e: print(e) - raise HTTPException(status_code=421, detail = f"Unable to connect to host {domain}. Please check your CAII_DOMAIN env variable.") + raise HTTPException( + status_code=421, + detail=f"Unable to connect to host {domain}. Please check your CAII_DOMAIN env variable.", + ) except HTTPException as e: if e.status_code == 404: return [{"model_id": endpoint_name}] @@ -139,6 +149,13 @@ def get_caii_embedding_models(): raise e return build_model_response(models) -def build_model_response(models): - return [{"model_id": models["name"], "name": models["name"], "available": models['replica_count'] > 0, - "replica_count": models["replica_count"]}] + +def build_model_response(models: Dict[str, Any]) -> List[Dict[str, Any]]: + return [ + { + "model_id": models["name"], + "name": models["name"], + "available": models["replica_count"] > 0, + "replica_count": models["replica_count"], + } + ] diff --git a/llm-service/app/services/chat.py b/llm-service/app/services/chat.py index 6f6a1d373..5d80f105a 100644 --- a/llm-service/app/services/chat.py +++ b/llm-service/app/services/chat.py @@ -38,6 +38,7 @@ import time import uuid +from typing import List from llama_index.core.base.llms.types import MessageRole @@ -96,9 +97,9 @@ def v2_chat( return new_chat_message -def retrieve_chat_history(session_id) -> list[RagContext]: +def retrieve_chat_history(session_id: int) -> List[RagContext]: chat_history = chat_store.retrieve_chat_history(session_id)[:10] - history: [RagContext] = list() + history: List[RagContext] = [] for message in chat_history: history.append( RagContext(role=MessageRole.USER, content=message.rag_message["user"]) diff --git a/llm-service/app/services/chat_store.py b/llm-service/app/services/chat_store.py index 53378a3bb..0a353df2b 100644 --- a/llm-service/app/services/chat_store.py +++ b/llm-service/app/services/chat_store.py @@ -37,11 +37,12 @@ # ############################################################################## import os -from typing import Literal +from typing import List, Literal from llama_index.core.base.llms.types import ChatMessage, MessageRole from llama_index.core.storage.chat_store import SimpleChatStore from pydantic import BaseModel + from ..config import settings @@ -75,7 +76,7 @@ def __init__(self, store_path: str): self.store_path = store_path # note: needs pagination in the future - def retrieve_chat_history(self, session_id: int) -> list[RagStudioChatMessage]: + def retrieve_chat_history(self, session_id: int) -> List[RagStudioChatMessage]: store = self.store_for_session(session_id) messages: list[ChatMessage] = store.get_messages( @@ -94,38 +95,46 @@ def retrieve_chat_history(self, session_id: int) -> list[RagStudioChatMessage]: assistant_message.role = MessageRole.ASSISTANT assistant_message.content = "" i = i - 1 - results.append(RagStudioChatMessage( - id=user_message.additional_kwargs["id"], - source_nodes=assistant_message.additional_kwargs.get("source_nodes", []), - rag_message={MessageRole.USER.value: user_message.content, MessageRole.ASSISTANT.value: assistant_message.content}, - evaluations=assistant_message.additional_kwargs.get("evaluations", []), - timestamp=assistant_message.additional_kwargs.get("timestamp", 0.0), - )) + results.append( + RagStudioChatMessage( + id=user_message.additional_kwargs["id"], + source_nodes=assistant_message.additional_kwargs.get( + "source_nodes", [] + ), + rag_message={ + MessageRole.USER.value: user_message.content, + MessageRole.ASSISTANT.value: assistant_message.content, + }, + evaluations=assistant_message.additional_kwargs.get( + "evaluations", [] + ), + timestamp=assistant_message.additional_kwargs.get("timestamp", 0.0), + ) + ) i += 2 return results - def store_for_session(self, session_id): + def store_for_session(self, session_id: int) -> SimpleChatStore: store = SimpleChatStore.from_persist_path( - persist_path=self.store_file(session_id)) + persist_path=self.store_file(session_id) + ) return store - def clear_chat_history(self, session_id): + def clear_chat_history(self, session_id: int) -> None: store = self.store_for_session(session_id) store.delete_messages(self.build_chat_key(session_id)) store.persist(self.store_file(session_id)) - def delete_chat_history(self, session_id): + def delete_chat_history(self, session_id: int) -> None: session_storage = self.store_file(session_id) if os.path.exists(session_storage): os.remove(session_storage) - def store_file(self, session_id): + def store_file(self, session_id: int) -> str: return os.path.join(self.store_path, f"chat_store-{session_id}.json") - def append_to_history( - self, session_id: int, messages: list[RagStudioChatMessage] - ): + def append_to_history(self, session_id: int, messages: list[RagStudioChatMessage]): store = self.store_for_session(session_id) for message in messages: @@ -155,7 +164,7 @@ def append_to_history( store.persist(self.store_file(session_id)) @staticmethod - def build_chat_key(session_id: int): + def build_chat_key(session_id: int) -> str: return "session_" + str(session_id) diff --git a/llm-service/app/services/doc_summaries.py b/llm-service/app/services/doc_summaries.py index c476f555d..1ad1303ad 100644 --- a/llm-service/app/services/doc_summaries.py +++ b/llm-service/app/services/doc_summaries.py @@ -41,9 +41,6 @@ from typing import cast from fastapi import HTTPException - -from . import models - from llama_index.core import ( DocumentSummaryIndex, Settings, @@ -53,18 +50,19 @@ from llama_index.core.node_parser import SentenceSplitter from llama_index.core.readers import SimpleDirectoryReader -from . import rag_vector_store +from ..config import settings +from . import models, rag_vector_store from .s3 import download from .utils import get_last_segment -from ..config import settings - SUMMARY_PROMPT = 'Summarize the document into a single sentence. If an adequate summary is not possible, please return "No summary available.".' def index_dir(data_source_id: int) -> str: """Return the directory name to be used for a data source's summary index.""" - return os.path.join(settings.rag_databases_dir, f"doc_summary_index_{data_source_id}") + return os.path.join( + settings.rag_databases_dir, f"doc_summary_index_{data_source_id}" + ) def read_summary(data_source_id: int, document_id: str) -> str: @@ -82,7 +80,6 @@ def read_summary(data_source_id: int, document_id: str) -> str: return doc_summary_index.get_document_summary(doc_id=document_id) - def generate_summary( data_source_id: int, s3_bucket_name: str, @@ -116,13 +113,13 @@ def generate_summary( ## todo: move to somewhere better; these are defaults to use when none are explicitly provided -def set_settings_globals(): +def set_settings_globals() -> None: Settings.llm = models.get_llm("meta.llama3-8b-instruct-v1:0") Settings.embed_model = models.get_embedding_model() Settings.splitter = SentenceSplitter(chunk_size=1024) -def initialize_summary_index_storage(data_source_id): +def initialize_summary_index_storage(data_source_id: int) -> None: set_settings_globals() doc_summary_index = DocumentSummaryIndex.from_documents( [], @@ -131,7 +128,9 @@ def initialize_summary_index_storage(data_source_id): doc_summary_index.storage_context.persist(persist_dir=index_dir(data_source_id)) -def load_document_summary_index(storage_context) -> DocumentSummaryIndex: +def load_document_summary_index( + storage_context: StorageContext, +) -> DocumentSummaryIndex: set_settings_globals() doc_summary_index: DocumentSummaryIndex = cast( DocumentSummaryIndex, @@ -156,19 +155,21 @@ def summarize_data_source(data_source_id: int) -> str: return response.text -def make_storage_context(data_source_id): +def make_storage_context(data_source_id: int) -> StorageContext: storage_context = StorageContext.from_defaults( persist_dir=index_dir(data_source_id), - vector_store=rag_vector_store.create_summary_vector_store(data_source_id).access_vector_store(), + vector_store=rag_vector_store.create_summary_vector_store( + data_source_id + ).access_vector_store(), ) return storage_context -def doc_summary_vector_table_name_from(data_source_id: int): +def doc_summary_vector_table_name_from(data_source_id: int) -> str: return f"summary_index_{data_source_id}" -def delete_data_source(data_source_id): +def delete_data_source(data_source_id: int) -> None: """Delete the summary index for `data_source_id`.""" index = index_dir(data_source_id) if os.path.exists(index): @@ -176,7 +177,7 @@ def delete_data_source(data_source_id): rag_vector_store.create_summary_vector_store(data_source_id).delete() -def delete_document(data_source_id, doc_id): +def delete_document(data_source_id: int, doc_id: str) -> None: index = index_dir(data_source_id) if not os.path.exists(index): return @@ -185,4 +186,4 @@ def delete_document(data_source_id, doc_id): if doc_id not in doc_summary_index.index_struct.doc_id_to_summary_id: return doc_summary_index.delete(doc_id) - doc_summary_index.storage_context.persist(persist_dir=index_dir(data_source_id)) \ No newline at end of file + doc_summary_index.storage_context.persist(persist_dir=index_dir(data_source_id)) diff --git a/llm-service/app/services/evaluators.py b/llm-service/app/services/evaluators.py index a032564f3..2603d052d 100644 --- a/llm-service/app/services/evaluators.py +++ b/llm-service/app/services/evaluators.py @@ -39,13 +39,14 @@ from llama_index.core.chat_engine.types import AgentChatResponse from llama_index.core.evaluation import FaithfulnessEvaluator, RelevancyEvaluator from llama_index.llms.bedrock import Bedrock + from .llama_utils import completion_to_prompt, messages_to_prompt def evaluate_response( - query: str, - chat_response: AgentChatResponse, -) -> tuple[float | None, float | None]: + query: str, + chat_response: AgentChatResponse, +) -> tuple[float, float]: evaluator_llm = Bedrock( model="meta.llama3-8b-instruct-v1:0", context_size=128000, @@ -61,5 +62,4 @@ def evaluate_response( faithfulness = faithfulness_evaluator.evaluate_response( query=query, response=chat_response ) - return relevance.score, faithfulness.score - + return relevance.score or 0, faithfulness.score or 0 diff --git a/llm-service/app/services/llama_utils.py b/llm-service/app/services/llama_utils.py index aba80a33e..d2a0bb493 100644 --- a/llm-service/app/services/llama_utils.py +++ b/llm-service/app/services/llama_utils.py @@ -66,7 +66,7 @@ def messages_to_prompt( # first message should always be a user user_message = messages[i] assert user_message.role == MessageRole.USER - str_message = '' + str_message = "" if i == 0: # make sure system prompt is included at the start str_message = f"{BOS}{SH}system{EH}\n\n{system_message_str.strip()}{EOT}\n" @@ -90,6 +90,7 @@ def messages_to_prompt( result = "".join(string_messages) return result + def messages_to_prompt_mistral( messages: Sequence[ChatMessage], system_prompt: Optional[str] = None ) -> str: @@ -101,7 +102,9 @@ def messages_to_prompt_mistral( string_messages[-1] += f"{EOT}\n" # include user message content - str_message = f"{SH}user{EH}\n\n{user_message.content}{EOT}\n{SH}assistant{EH}\n\n" + str_message = ( + f"{SH}user{EH}\n\n{user_message.content}{EOT}\n{SH}assistant{EH}\n\n" + ) if len(messages) > (i + 1): # if assistant message exists, add to str_message @@ -115,13 +118,14 @@ def messages_to_prompt_mistral( return result - def completion_to_prompt(completion: str, system_prompt: Optional[str] = None) -> str: system_prompt_str = system_prompt or DEFAULT_SYSTEM_PROMPT - result = (f"{BOS}{SH}system{EH}\n\n{system_prompt_str.strip()}{EOT}\n" \ - f"{SH}user{EH}\n\n{completion.strip()}{EOT}\n" \ - f"{SH}assistant{EH}\n\n") + result = ( + f"{BOS}{SH}system{EH}\n\n{system_prompt_str.strip()}{EOT}\n" + f"{SH}user{EH}\n\n{completion.strip()}{EOT}\n" + f"{SH}assistant{EH}\n\n" + ) return result @@ -129,8 +133,8 @@ def mistralv2_messages_to_prompt(messages): print(f"mistralv2_messages_to_prompt: {messages}") conversation = "" bos_token = "" - eos_token= "" - if messages[0].role == MessageRole.SYSTEM: + eos_token = "" + if messages[0].role == MessageRole.SYSTEM: loop_messages = messages[1:] system_message = messages[0].content else: @@ -139,14 +143,16 @@ def mistralv2_messages_to_prompt(messages): for index, message in enumerate(loop_messages): if (message.role == MessageRole.USER) != (index % 2 == 0): - raise Exception('HFI Conversation roles must alternate user/assistant/user/assistant/...') + raise Exception( + "HFI Conversation roles must alternate user/assistant/user/assistant/..." + ) if index == 0 and system_message != False: - content = '<>\n' + system_message + '\n<>\n\n' + message.content + content = "<>\n" + system_message + "\n<>\n\n" + message.content else: content = message.content if message.role == MessageRole.USER: - conversation += bos_token + '[INST] ' + content.strip() + ' [/INST]' + conversation += bos_token + "[INST] " + content.strip() + " [/INST]" elif message.role == MessageRole.ASSISTANT: - conversation += ' ' + content.strip() + ' ' + eos_token + conversation += " " + content.strip() + " " + eos_token - return (conversation) \ No newline at end of file + return conversation diff --git a/llm-service/app/services/llm_completion.py b/llm-service/app/services/llm_completion.py index 16f990400..c5a49fc13 100644 --- a/llm-service/app/services/llm_completion.py +++ b/llm-service/app/services/llm_completion.py @@ -45,15 +45,20 @@ def make_chat_messages(x: RagStudioChatMessage) -> list[ChatMessage]: - user = ChatMessage.from_str(x.rag_message['user'], role="user") - assistant = ChatMessage.from_str(x.rag_message['assistant'], role="assistant") + user = ChatMessage.from_str(x.rag_message["user"], role="user") + assistant = ChatMessage.from_str(x.rag_message["assistant"], role="assistant") return [user, assistant] -def completion(session_id: int, question: str, configuration: RagPredictConfiguration) -> ChatResponse: +def completion( + session_id: int, question: str, configuration: RagPredictConfiguration +) -> ChatResponse: model = get_llm(configuration.model_name) chat_history = chat_store.retrieve_chat_history(session_id)[:10] - messages = list(itertools.chain.from_iterable(map(lambda x: make_chat_messages(x), chat_history))) + messages = list( + itertools.chain.from_iterable( + map(lambda x: make_chat_messages(x), chat_history) + ) + ) messages.append(ChatMessage.from_str(question, role="user")) return model.chat(messages) - diff --git a/llm-service/app/services/models.py b/llm-service/app/services/models.py index bae197aa1..236e4d708 100644 --- a/llm-service/app/services/models.py +++ b/llm-service/app/services/models.py @@ -37,7 +37,7 @@ # import os from enum import Enum -from typing import Literal +from typing import Any, Dict, List, Literal from fastapi import HTTPException from llama_index.core.base.embeddings.base import BaseEmbedding @@ -73,13 +73,13 @@ def get_llm(model_name: str = None) -> LLM: ) -def get_available_embedding_models(): +def get_available_embedding_models() -> List[Dict[str, Any]]: if is_caii_enabled(): return get_caii_embedding_models() return _get_bedrock_embedding_models() -def get_available_llm_models(): +def get_available_llm_models() -> List[Dict[str, Any]]: if is_caii_enabled(): return get_caii_llm_models() return _get_bedrock_llm_models() @@ -90,7 +90,7 @@ def is_caii_enabled() -> bool: return len(domain) > 0 -def _get_bedrock_llm_models(): +def _get_bedrock_llm_models() -> List[Dict[str, Any]]: return [ { "model_id": "meta.llama3-1-8b-instruct-v1:0", @@ -107,7 +107,7 @@ def _get_bedrock_llm_models(): ] -def _get_bedrock_embedding_models(): +def _get_bedrock_embedding_models() -> List[Dict[str, Any]]: return [ { "model_id": "cohere.embed-english-v3", diff --git a/llm-service/app/services/qdrant.py b/llm-service/app/services/qdrant.py index ffc1cca3d..94fd14bbd 100644 --- a/llm-service/app/services/qdrant.py +++ b/llm-service/app/services/qdrant.py @@ -51,10 +51,9 @@ from llama_index.core.storage import StorageContext from pydantic import BaseModel -from . import rag_vector_store from ..rag_types import RagPredictConfiguration +from . import models, rag_vector_store from .chat_store import RagContext -from . import models from .utils import get_last_segment logger = logging.getLogger(__name__) @@ -67,11 +66,11 @@ class RagIndexDocumentConfiguration(BaseModel): def download_and_index( - tmpdirname: str, - data_source_id: int, - configuration: RagIndexDocumentConfiguration, - s3_document_key: str, -): + tmpdirname: str, + data_source_id: int, + configuration: RagIndexDocumentConfiguration, + s3_document_key: str, +) -> None: try: documents = SimpleDirectoryReader(tmpdirname).load_data() document_id = get_last_segment(s3_document_key) @@ -89,7 +88,9 @@ def download_and_index( ) from e logger.info("instantiating vector store") - vector_store = rag_vector_store.create_rag_vector_store(data_source_id).access_vector_store() + vector_store = rag_vector_store.create_rag_vector_store( + data_source_id + ).access_vector_store() logger.info("instantiated vector store") storage_context = StorageContext.from_defaults(vector_store=vector_store) @@ -125,7 +126,9 @@ def size_of(data_source_id: int) -> int: def chunk_contents(data_source_id: int, chunk_id: str) -> str: - vector_store = rag_vector_store.create_rag_vector_store(data_source_id).access_vector_store() + vector_store = rag_vector_store.create_rag_vector_store( + data_source_id + ).access_vector_store() node = vector_store.get_nodes([chunk_id])[0] return node.get_content() @@ -136,7 +139,9 @@ def delete(data_source_id: int) -> None: def delete_document(data_source_id: int, document_id: str) -> None: - vector_store = rag_vector_store.create_rag_vector_store(data_source_id).access_vector_store() + vector_store = rag_vector_store.create_rag_vector_store( + data_source_id + ).access_vector_store() index = VectorStoreIndex.from_vector_store( vector_store=vector_store, embed_model=models.get_embedding_model(), @@ -145,12 +150,14 @@ def delete_document(data_source_id: int, document_id: str) -> None: def query( - data_source_id: int, - query_str: str, - configuration: RagPredictConfiguration, - chat_history: list[RagContext], + data_source_id: int, + query_str: str, + configuration: RagPredictConfiguration, + chat_history: list[RagContext], ) -> AgentChatResponse: - vector_store = rag_vector_store.create_rag_vector_store(data_source_id).access_vector_store() + vector_store = rag_vector_store.create_rag_vector_store( + data_source_id + ).access_vector_store() embedding_model = models.get_embedding_model() index = VectorStoreIndex.from_vector_store( vector_store=vector_store, @@ -161,14 +168,15 @@ def query( retriever = VectorIndexRetriever( index=index, similarity_top_k=configuration.top_k, - embed_model=embedding_model, # is this needed, really, if it's in the index? + embed_model=embedding_model, # is this needed, really, if it's in the index? ) # TODO: factor out LLM and chat engine into a separate function llm = models.get_llm(model_name=configuration.model_name) response_synthesizer = get_response_synthesizer(llm=llm) query_engine = RetrieverQueryEngine( - retriever=retriever, response_synthesizer=response_synthesizer) + retriever=retriever, response_synthesizer=response_synthesizer + ) chat_engine = CondenseQuestionChatEngine.from_defaults( query_engine=query_engine, llm=llm, @@ -193,4 +201,3 @@ def query( status_code=json_error["ResponseMetadata"]["HTTPStatusCode"], detail=json_error["message"], ) from error - diff --git a/llm-service/app/services/rag_qdrant_vector_store.py b/llm-service/app/services/rag_qdrant_vector_store.py index 8e58b5239..55a8cd7f8 100644 --- a/llm-service/app/services/rag_qdrant_vector_store.py +++ b/llm-service/app/services/rag_qdrant_vector_store.py @@ -71,17 +71,19 @@ def delete(self): def exists(self) -> bool: return self.client.collection_exists(self.table_name) - def _create_qdrant_clients(self, memory_store) -> tuple[qdrant_client.QdrantClient, qdrant_client.AsyncQdrantClient]: + def _create_qdrant_clients( + self, memory_store + ) -> tuple[qdrant_client.QdrantClient, qdrant_client.AsyncQdrantClient]: if memory_store: client = qdrant_client.QdrantClient(":memory:") aclient = qdrant_client.AsyncQdrantClient(":memory:") else: client = qdrant_client.QdrantClient(host=self.host, port=self.port) - aclient = qdrant_client.AsyncQdrantClient(host=self.host, port=self.async_port) + aclient = qdrant_client.AsyncQdrantClient( + host=self.host, port=self.async_port + ) return client, aclient def access_vector_store(self) -> BasePydanticVectorStore: - vector_store = QdrantVectorStore( - self.table_name, self.client, self.aclient - ) + vector_store = QdrantVectorStore(self.table_name, self.client, self.aclient) return vector_store diff --git a/llm-service/app/services/rag_vector_store.py b/llm-service/app/services/rag_vector_store.py index f9793d819..332134289 100644 --- a/llm-service/app/services/rag_vector_store.py +++ b/llm-service/app/services/rag_vector_store.py @@ -43,6 +43,6 @@ def create_rag_vector_store(data_source_id: int) -> VectorStore: return RagQdrantVectorStore(table_name=f"index_{data_source_id}") + def create_summary_vector_store(data_source_id: int) -> VectorStore: return RagQdrantVectorStore(table_name=f"summary_index_{data_source_id}") - diff --git a/llm-service/app/services/utils.py b/llm-service/app/services/utils.py index 8d17bd8bf..6ac9893cf 100644 --- a/llm-service/app/services/utils.py +++ b/llm-service/app/services/utils.py @@ -42,6 +42,7 @@ # TODO delete this if it's not being used + def parse_choice_select_answer_fn( answer: str, num_choices: int, raise_error: bool = False ) -> Tuple[List[int], List[float]]: @@ -78,5 +79,6 @@ def parse_choice_select_answer_fn( answer_relevances.append(float(_answer_relevance)) return answer_nums, answer_relevances + def get_last_segment(path: str) -> str: - return path.split('/')[-1] \ No newline at end of file + return path.split("/")[-1] diff --git a/llm-service/app/services/vector_store.py b/llm-service/app/services/vector_store.py index 3d06b5798..77fd31f29 100644 --- a/llm-service/app/services/vector_store.py +++ b/llm-service/app/services/vector_store.py @@ -42,7 +42,7 @@ class VectorStore: - """ RAG Studio Vector Store functionality. Implementations of this should house the vectors for a single document collection.""" + """RAG Studio Vector Store functionality. Implementations of this should house the vectors for a single document collection.""" @abstractmethod def size(self) -> int: @@ -52,12 +52,12 @@ def size(self) -> int: @abstractmethod def delete(self) -> None: - """ Delete the vector store """ + """Delete the vector store""" @abstractmethod def access_vector_store(self) -> BasePydanticVectorStore: - """ Access the underlying llama-index vector store implementation """ + """Access the underlying llama-index vector store implementation""" @abstractmethod def exists(self) -> bool: - """ Does the vector store exist? """ + """Does the vector store exist?""" diff --git a/llm-service/app/tests/conftest.py b/llm-service/app/tests/conftest.py index 9ff9b7098..ede45b49d 100644 --- a/llm-service/app/tests/conftest.py +++ b/llm-service/app/tests/conftest.py @@ -46,8 +46,16 @@ import pytest from fastapi.testclient import TestClient from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding -from llama_index.core.base.llms.types import CompletionResponseAsyncGen, ChatMessage, ChatResponseAsyncGen, \ - CompletionResponse, ChatResponse, CompletionResponseGen, ChatResponseGen, LLMMetadata +from llama_index.core.base.llms.types import ( + ChatMessage, + ChatResponse, + ChatResponseAsyncGen, + ChatResponseGen, + CompletionResponse, + CompletionResponseAsyncGen, + CompletionResponseGen, + LLMMetadata, +) from llama_index.core.llms import LLM from moto import mock_aws from pydantic import Field @@ -71,7 +79,7 @@ def databases_dir(monkeypatch: pytest.MonkeyPatch, tmp_path: pathlib.Path) -> st @pytest.fixture def s3( - monkeypatch: pytest.MonkeyPatch, + monkeypatch: pytest.MonkeyPatch, ) -> Iterator["s3.ServiceResource"]: """Mock all S3 interactions.""" @@ -106,7 +114,11 @@ class DummyLlm(LLM): completion_response = Field("this is a completion response") chat_response = Field("this is a chat response") - def __init__(self, completion_response: str = "this is a completion response", chat_response: str = "hello"): + def __init__( + self, + completion_response: str = "this is a completion response", + chat_response: str = "hello", + ): super().__init__() self.completion_response = completion_response self.chat_response = chat_response @@ -118,27 +130,11 @@ def metadata(self) -> LLMMetadata: def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: return ChatResponse(message=ChatMessage.from_str(self.chat_response)) - def complete(self, prompt: str, formatted: bool = False, **kwargs: Any) -> CompletionResponse: + def complete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponse: return CompletionResponse(text=self.completion_response) - def stream_chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponseGen: - pass - - def stream_complete(self, prompt: str, formatted: bool = False, **kwargs: Any) -> CompletionResponseGen: - pass - - async def achat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: - pass - - async def acomplete(self, prompt: str, formatted: bool = False, **kwargs: Any) -> CompletionResponse: - pass - - async def astream_chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponseAsyncGen: - pass - - async def astream_complete(self, prompt: str, formatted: bool = False, **kwargs: Any) -> CompletionResponseAsyncGen: - pass - class DummyEmbeddingModel(BaseEmbedding): def _get_query_embedding(self, query: str) -> Embedding: @@ -156,24 +152,34 @@ def _get_text_embedding(self, text: str) -> Embedding: table_name_to_vector_store = {} -def _get_vector_store_instance(data_source_id: int, table_prefix: str) -> RagQdrantVectorStore: +def _get_vector_store_instance( + data_source_id: int, table_prefix: str +) -> RagQdrantVectorStore: if data_source_id in table_name_to_vector_store: return table_name_to_vector_store[data_source_id] - res = RagQdrantVectorStore(table_name=f"{table_prefix}{data_source_id}", memory_store=True) + res = RagQdrantVectorStore( + table_name=f"{table_prefix}{data_source_id}", memory_store=True + ) table_name_to_vector_store[data_source_id] = res return res @pytest.fixture(autouse=True) def vector_store(monkeypatch: pytest.MonkeyPatch): - monkeypatch.setattr(rag_vector_store, 'create_rag_vector_store', - lambda ds_id: _get_vector_store_instance(ds_id, "index_")) + monkeypatch.setattr( + rag_vector_store, + "create_rag_vector_store", + lambda ds_id: _get_vector_store_instance(ds_id, "index_"), + ) @pytest.fixture(autouse=True) def summary_vector_store(monkeypatch: pytest.MonkeyPatch): - monkeypatch.setattr(rag_vector_store, 'create_summary_vector_store', - lambda ds_id: _get_vector_store_instance(ds_id, "summary_index_")) + monkeypatch.setattr( + rag_vector_store, + "create_summary_vector_store", + lambda ds_id: _get_vector_store_instance(ds_id, "summary_index_"), + ) @pytest.fixture(autouse=True) @@ -181,7 +187,7 @@ def embedding_model(monkeypatch: pytest.MonkeyPatch) -> BaseEmbedding: model = DummyEmbeddingModel() # Requires that the app usages import the file and not the function directly as python creates a copy when importing the function - monkeypatch.setattr(models, 'get_embedding_model', lambda: model) + monkeypatch.setattr(models, "get_embedding_model", lambda: model) return model @@ -190,13 +196,13 @@ def llm(monkeypatch: pytest.MonkeyPatch) -> LLM: model = DummyLlm() # Requires that the app usages import the file and not the function directly as python creates a copy when importing the function - monkeypatch.setattr(models, 'get_llm', lambda model_name: model) + monkeypatch.setattr(models, "get_llm", lambda model_name: model) return model @pytest.fixture def s3_object( - s3: "s3.ServiceResource", aws_region: str, document_id: str + s3: "s3.ServiceResource", aws_region: str, document_id: str ) -> "s3.Object": """Put and return a mocked S3 object""" bucket_name = "test_bucket" @@ -214,7 +220,7 @@ def s3_object( @pytest.fixture def client( - s3: "s3.ServiceResource", + s3: "s3.ServiceResource", ) -> Iterator[TestClient]: """Return a test client for making calls to the service. diff --git a/llm-service/app/tests/routers/index/test_data_source.py b/llm-service/app/tests/routers/index/test_data_source.py index 51cdbefaf..fece5b718 100644 --- a/llm-service/app/tests/routers/index/test_data_source.py +++ b/llm-service/app/tests/routers/index/test_data_source.py @@ -48,19 +48,22 @@ def get_vector_store_index(data_source_id) -> VectorStoreIndex: - vector_store = rag_vector_store.create_rag_vector_store(data_source_id).access_vector_store() - index = VectorStoreIndex.from_vector_store(vector_store, embed_model=models.get_embedding_model()) + vector_store = rag_vector_store.create_rag_vector_store( + data_source_id + ).access_vector_store() + index = VectorStoreIndex.from_vector_store( + vector_store, embed_model=models.get_embedding_model() + ) return index class TestDocumentIndexing: - @staticmethod def test_create_document( - client, - index_document_request_body: dict[str, Any], - document_id: str, - data_source_id: int, + client, + index_document_request_body: dict[str, Any], + document_id: str, + data_source_id: int, ) -> None: """Test POST /download-and-index.""" response = client.post( @@ -71,15 +74,17 @@ def test_create_document( assert response.status_code == 200 assert document_id is not None index = get_vector_store_index(data_source_id) - vectors = index.vector_store.query(VectorStoreQuery(query_embedding=[0.66] * 1024, doc_ids=[document_id])) + vectors = index.vector_store.query( + VectorStoreQuery(query_embedding=[0.66] * 1024, doc_ids=[document_id]) + ) assert len(vectors.nodes) == 1 @staticmethod def test_delete_data_source( - client, - data_source_id: int, - document_id: str, - index_document_request_body: dict[str, Any], + client, + data_source_id: int, + document_id: str, + index_document_request_body: dict[str, Any], ) -> None: """Test DELETE /data_sources/{data_source_id}.""" client.post( @@ -88,7 +93,9 @@ def test_delete_data_source( ) index = get_vector_store_index(data_source_id) - vectors = index.vector_store.query(VectorStoreQuery(query_embedding=[0.66] * 1024, doc_ids=[document_id])) + vectors = index.vector_store.query( + VectorStoreQuery(query_embedding=[0.66] * 1024, doc_ids=[document_id]) + ) assert len(vectors.nodes) == 1 response = client.delete(f"/data_sources/{data_source_id}") @@ -96,15 +103,17 @@ def test_delete_data_source( vector_store = rag_vector_store.create_rag_vector_store(data_source_id) assert vector_store.exists() is False - get_summary_response = client.get(f'/data_sources/{data_source_id}/documents/{document_id}/summary') + get_summary_response = client.get( + f"/data_sources/{data_source_id}/documents/{document_id}/summary" + ) assert get_summary_response.status_code == 404 @staticmethod def test_delete_document( - client, - data_source_id: int, - document_id: str, - index_document_request_body: dict[str, Any], + client, + data_source_id: int, + document_id: str, + index_document_request_body: dict[str, Any], ) -> None: """Test DELETE /data_sources/{data_source_id}/documents/{document_id}.""" client.post( @@ -113,21 +122,27 @@ def test_delete_document( ) index = get_vector_store_index(data_source_id) - vectors = index.vector_store.query(VectorStoreQuery(query_embedding=[.2] * 1024, doc_ids=[document_id])) + vectors = index.vector_store.query( + VectorStoreQuery(query_embedding=[0.2] * 1024, doc_ids=[document_id]) + ) assert len(vectors.nodes) == 1 - response = client.delete(f"/data_sources/{data_source_id}/documents/{document_id}") + response = client.delete( + f"/data_sources/{data_source_id}/documents/{document_id}" + ) assert response.status_code == 200 index = get_vector_store_index(data_source_id) - vectors = index.vector_store.query(VectorStoreQuery(query_embedding=[.2] * 1024, doc_ids=[document_id])) + vectors = index.vector_store.query( + VectorStoreQuery(query_embedding=[0.2] * 1024, doc_ids=[document_id]) + ) assert len(vectors.nodes) == 0 @staticmethod def test_get_size( - client, - data_source_id: int, - index_document_request_body: dict[str, Any], + client, + data_source_id: int, + index_document_request_body: dict[str, Any], ) -> None: """Test GET /data_sources/{data_source_id}/size.""" client.post( diff --git a/llm-service/app/tests/routers/index/test_doc_summaries.py b/llm-service/app/tests/routers/index/test_doc_summaries.py index 80951213e..a8e460b9a 100644 --- a/llm-service/app/tests/routers/index/test_doc_summaries.py +++ b/llm-service/app/tests/routers/index/test_doc_summaries.py @@ -41,7 +41,13 @@ class TestDocumentSummaries: @staticmethod - def test_generate_summary(client, index_document_request_body: dict[str, Any], data_source_id, document_id, s3_object) -> None: + def test_generate_summary( + client, + index_document_request_body: dict[str, Any], + data_source_id, + document_id, + s3_object, + ) -> None: response = client.post( f"/data_sources/{data_source_id}/documents/download-and-index", json=index_document_request_body, @@ -50,25 +56,37 @@ def test_generate_summary(client, index_document_request_body: dict[str, Any], d assert response.status_code == 200 post_summarization_response = client.post( - f'/data_sources/{data_source_id}/summarize-document', - json={ "s3_bucket_name": s3_object.bucket_name, "s3_document_key": s3_object.key }) + f"/data_sources/{data_source_id}/summarize-document", + json={ + "s3_bucket_name": s3_object.bucket_name, + "s3_document_key": s3_object.key, + }, + ) assert post_summarization_response.status_code == 200 assert post_summarization_response.text == '"this is a completion response"' - get_summary_response = client.get(f'/data_sources/{data_source_id}/documents/{document_id}/summary') + get_summary_response = client.get( + f"/data_sources/{data_source_id}/documents/{document_id}/summary" + ) assert get_summary_response.status_code == 200 assert get_summary_response.text == '"this is a completion response"' - get_data_source_response = client.get(f'/data_sources/{data_source_id}/summary') + get_data_source_response = client.get(f"/data_sources/{data_source_id}/summary") assert get_data_source_response.status_code == 200 # our monkeypatched model always returns this. # todo: Figure out how to parameterize the monkey patch assert get_data_source_response.text == '"this is a completion response"' @staticmethod - def test_delete_document(client, index_document_request_body: dict[str, Any], data_source_id, document_id, s3_object) -> None: + def test_delete_document( + client, + index_document_request_body: dict[str, Any], + data_source_id, + document_id, + s3_object, + ) -> None: response = client.post( f"/data_sources/{data_source_id}/documents/download-and-index", json=index_document_request_body, @@ -77,16 +95,24 @@ def test_delete_document(client, index_document_request_body: dict[str, Any], da assert response.status_code == 200 post_summarization_response = client.post( - f'/data_sources/{data_source_id}/summarize-document', - json={ "s3_bucket_name": s3_object.bucket_name, "s3_document_key": s3_object.key }) + f"/data_sources/{data_source_id}/summarize-document", + json={ + "s3_bucket_name": s3_object.bucket_name, + "s3_document_key": s3_object.key, + }, + ) assert post_summarization_response.status_code == 200 - delete_document_response = client.delete(f'/data_sources/{data_source_id}/documents/{document_id}') + delete_document_response = client.delete( + f"/data_sources/{data_source_id}/documents/{document_id}" + ) assert delete_document_response.status_code == 200 - get_summary_response = client.get(f'/data_sources/{data_source_id}/documents/{document_id}/summary') + get_summary_response = client.get( + f"/data_sources/{data_source_id}/documents/{document_id}/summary" + ) assert get_summary_response.text == '"No summary found for this document."' assert get_summary_response.status_code == 200 diff --git a/llm-service/pdm.lock b/llm-service/pdm.lock index 167a254f6..3a3ec66ed 100644 --- a/llm-service/pdm.lock +++ b/llm-service/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:1b447e8bd4e21937a21505e71d488e708db69dc5fec8cdeed5608fae5dfbfeb5" +content_hash = "sha256:cf38cbf44250032e4b248c90dbc34037bb94662cdfe512d50fcb3e271309fd84" [[metadata.targets]] requires_python = "==3.10.*" @@ -165,23 +165,23 @@ files = [ [[package]] name = "boto3" -version = "1.34.26" -requires_python = ">= 3.8" +version = "1.35.66" +requires_python = ">=3.8" summary = "The AWS SDK for Python" groups = ["default", "dev"] dependencies = [ - "botocore<1.35.0,>=1.34.26", + "botocore<1.36.0,>=1.35.66", "jmespath<2.0.0,>=0.7.1", "s3transfer<0.11.0,>=0.10.0", ] files = [ - {file = "boto3-1.34.26-py3-none-any.whl", hash = "sha256:881b07d0d55e5d85b62e6c965efcb2820bdfbd8f23a73a7bc9dac3a4997a1343"}, - {file = "boto3-1.34.26.tar.gz", hash = "sha256:0491a65e55de999d07f42bb28ff6a38bad493934154b6304fcdfb4699a612d6c"}, + {file = "boto3-1.35.66-py3-none-any.whl", hash = "sha256:09a610f8cf4d3c22d4ca69c1f89079e3a1c82805ce94fa0eb4ecdd4d2ba6c4bc"}, + {file = "boto3-1.35.66.tar.gz", hash = "sha256:c392b9168b65e9c23483eaccb5b68d1f960232d7f967a1e00a045ba065ce050d"}, ] [[package]] name = "botocore" -version = "1.34.162" +version = "1.35.66" requires_python = ">=3.8" summary = "Low-level, data-driven core of boto 3." groups = ["default", "dev"] @@ -192,8 +192,8 @@ dependencies = [ "urllib3<1.27,>=1.25.4; python_version < \"3.10\"", ] files = [ - {file = "botocore-1.34.162-py3-none-any.whl", hash = "sha256:2d918b02db88d27a75b48275e6fb2506e9adaaddbec1ffa6a8a0898b34e769be"}, - {file = "botocore-1.34.162.tar.gz", hash = "sha256:adc23be4fb99ad31961236342b7cbf3c0bfc62532cd02852196032e8c0d682f3"}, + {file = "botocore-1.35.66-py3-none-any.whl", hash = "sha256:d0683e9c18bb6852f768da268086c3749d925332a664db0dd1459cfa7e96e475"}, + {file = "botocore-1.35.66.tar.gz", hash = "sha256:51f43220315f384959f02ea3266740db4d421592dd87576c18824e424b349fdb"}, ] [[package]] @@ -1088,12 +1088,33 @@ files = [ {file = "multidict-6.1.0.tar.gz", hash = "sha256:22ae2ebf9b0c69d206c003e2f6a914ea33f0a932d4aa16f236afc049d9958f4a"}, ] +[[package]] +name = "mypy" +version = "1.13.0" +requires_python = ">=3.8" +summary = "Optional static typing for Python" +groups = ["dev"] +dependencies = [ + "mypy-extensions>=1.0.0", + "tomli>=1.1.0; python_version < \"3.11\"", + "typing-extensions>=4.6.0", +] +files = [ + {file = "mypy-1.13.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6607e0f1dd1fb7f0aca14d936d13fd19eba5e17e1cd2a14f808fa5f8f6d8f60a"}, + {file = "mypy-1.13.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8a21be69bd26fa81b1f80a61ee7ab05b076c674d9b18fb56239d72e21d9f4c80"}, + {file = "mypy-1.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7b2353a44d2179846a096e25691d54d59904559f4232519d420d64da6828a3a7"}, + {file = "mypy-1.13.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:0730d1c6a2739d4511dc4253f8274cdd140c55c32dfb0a4cf8b7a43f40abfa6f"}, + {file = "mypy-1.13.0-cp310-cp310-win_amd64.whl", hash = "sha256:c5fc54dbb712ff5e5a0fca797e6e0aa25726c7e72c6a5850cfd2adbc1eb0a372"}, + {file = "mypy-1.13.0-py3-none-any.whl", hash = "sha256:9c250883f9fd81d212e0952c92dbfcc96fc237f4b7c92f56ac81fd48460b3e5a"}, + {file = "mypy-1.13.0.tar.gz", hash = "sha256:0291a61b6fbf3e6673e3405cfcc0e7650bebc7939659fdca2702958038bd835e"}, +] + [[package]] name = "mypy-extensions" version = "1.0.0" requires_python = ">=3.5" summary = "Type system extensions for programs checked with the mypy type checker." -groups = ["default"] +groups = ["default", "dev"] files = [ {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, @@ -1636,6 +1657,33 @@ files = [ {file = "rich-13.9.4.tar.gz", hash = "sha256:439594978a49a09530cff7ebc4b5c7103ef57baf48d5ea3184f21d9a2befa098"}, ] +[[package]] +name = "ruff" +version = "0.7.4" +requires_python = ">=3.7" +summary = "An extremely fast Python linter and code formatter, written in Rust." +groups = ["dev"] +files = [ + {file = "ruff-0.7.4-py3-none-linux_armv6l.whl", hash = "sha256:a4919925e7684a3f18e18243cd6bea7cfb8e968a6eaa8437971f681b7ec51478"}, + {file = "ruff-0.7.4-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:cfb365c135b830778dda8c04fb7d4280ed0b984e1aec27f574445231e20d6c63"}, + {file = "ruff-0.7.4-py3-none-macosx_11_0_arm64.whl", hash = "sha256:63a569b36bc66fbadec5beaa539dd81e0527cb258b94e29e0531ce41bacc1f20"}, + {file = "ruff-0.7.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d06218747d361d06fd2fdac734e7fa92df36df93035db3dc2ad7aa9852cb109"}, + {file = "ruff-0.7.4-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e0cea28d0944f74ebc33e9f934238f15c758841f9f5edd180b5315c203293452"}, + {file = "ruff-0.7.4-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:80094ecd4793c68b2571b128f91754d60f692d64bc0d7272ec9197fdd09bf9ea"}, + {file = "ruff-0.7.4-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:997512325c6620d1c4c2b15db49ef59543ef9cd0f4aa8065ec2ae5103cedc7e7"}, + {file = "ruff-0.7.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:00b4cf3a6b5fad6d1a66e7574d78956bbd09abfd6c8a997798f01f5da3d46a05"}, + {file = "ruff-0.7.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7dbdc7d8274e1422722933d1edddfdc65b4336abf0b16dfcb9dedd6e6a517d06"}, + {file = "ruff-0.7.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0e92dfb5f00eaedb1501b2f906ccabfd67b2355bdf117fea9719fc99ac2145bc"}, + {file = "ruff-0.7.4-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:3bd726099f277d735dc38900b6a8d6cf070f80828877941983a57bca1cd92172"}, + {file = "ruff-0.7.4-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:2e32829c429dd081ee5ba39aef436603e5b22335c3d3fff013cd585806a6486a"}, + {file = "ruff-0.7.4-py3-none-musllinux_1_2_i686.whl", hash = "sha256:662a63b4971807623f6f90c1fb664613f67cc182dc4d991471c23c541fee62dd"}, + {file = "ruff-0.7.4-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:876f5e09eaae3eb76814c1d3b68879891d6fde4824c015d48e7a7da4cf066a3a"}, + {file = "ruff-0.7.4-py3-none-win32.whl", hash = "sha256:75c53f54904be42dd52a548728a5b572344b50d9b2873d13a3f8c5e3b91f5cac"}, + {file = "ruff-0.7.4-py3-none-win_amd64.whl", hash = "sha256:745775c7b39f914238ed1f1b0bebed0b9155a17cd8bc0b08d3c87e4703b990d6"}, + {file = "ruff-0.7.4-py3-none-win_arm64.whl", hash = "sha256:11bff065102c3ae9d3ea4dc9ecdfe5a5171349cdd0787c1fc64761212fc9cf1f"}, + {file = "ruff-0.7.4.tar.gz", hash = "sha256:cd12e35031f5af6b9b93715d8c4f40360070b2041f81273d0527683d5708fce2"}, +] + [[package]] name = "s3transfer" version = "0.10.3" @@ -1889,7 +1937,7 @@ name = "typing-extensions" version = "4.12.2" requires_python = ">=3.8" summary = "Backported and Experimental Type Hints for Python 3.8+" -groups = ["default"] +groups = ["default", "dev"] files = [ {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"}, {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, diff --git a/llm-service/pyproject.toml b/llm-service/pyproject.toml index b6c2e69f7..830db6091 100644 --- a/llm-service/pyproject.toml +++ b/llm-service/pyproject.toml @@ -1,11 +1,3 @@ -[tool.pytest.ini_options] -addopts = [ - "--import-mode=importlib", -] - -[tool.pdm] -distribution = false - [project] name = "llm-service" version = "0.1.0" @@ -13,10 +5,34 @@ description = "Default template for PDM package" authors = [ {name = "Conrado Silva Miranda", email = "csilvamiranda@cloudera.com"}, ] -dependencies = ["llama-index-core==0.10.68", "llama-index-readers-file==0.1.33", "fastapi==0.111.0", "pydantic==2.8.2", "pydantic-settings==2.3.4", "boto3==1.34.26", "llama-index-embeddings-bedrock==0.2.1", "llama-index-llms-bedrock==0.1.13", "llama-index-llms-openai==0.1.31", "llama-index-llms-mistralai==0.1.20", "llama-index-embeddings-openai==0.1.11", "llama-index-vector-stores-qdrant==0.2.17"] +dependencies = ["llama-index-core==0.10.68", "llama-index-readers-file==0.1.33", "fastapi==0.111.0", "pydantic==2.8.2", "pydantic-settings==2.3.4", "boto3>=1.35.66", "llama-index-embeddings-bedrock==0.2.1", "llama-index-llms-bedrock==0.1.13", "llama-index-llms-openai==0.1.31", "llama-index-llms-mistralai==0.1.20", "llama-index-embeddings-openai==0.1.11", "llama-index-vector-stores-qdrant==0.2.17"] requires-python = "==3.10.*" readme = "README.md" license = {text = "APACHE"} [dependency-groups] -dev = ["moto[s3]>=5.0.21", "pytest>=8.3.3"] +dev = [ + "moto[s3]>=5.0.21", + "pytest>=8.3.3", + "ruff>=0.7.4", + "mypy>=1.13.0", +] + +[tool.pytest.ini_options] +addopts = [ + "--import-mode=importlib", +] + +[tool.pdm] +distribution = false + +[tool.mypy] +files = ["./app/"] +ignore_missing_imports = true +strict = true +disallow_untyped_calls = true +disallow_untyped_defs = true +warn_unused_ignores = true +warn_return_any = true +show_error_codes = true +disable_error_code = ["import-untyped", "assignment"] From 035f07a6f2d76501af6dee87178939f361d4b3ca Mon Sep 17 00:00:00 2001 From: Conrado Silva Miranda Date: Thu, 21 Nov 2024 09:22:49 -0800 Subject: [PATCH 2/3] fixes --- llm-service/app/exceptions.py | 8 +-- llm-service/app/main.py | 5 +- .../app/routers/index/sessions/__init__.py | 4 +- .../app/services/CaiiEmbeddingModel.py | 10 +-- llm-service/app/services/CaiiModel.py | 16 +++-- llm-service/app/services/caii.py | 8 ++- llm-service/app/services/chat.py | 12 ++-- llm-service/app/services/chat_store.py | 8 ++- llm-service/app/services/doc_summaries.py | 2 +- llm-service/app/services/evaluators.py | 5 +- llm-service/app/services/llama_utils.py | 6 +- llm-service/app/services/llm_completion.py | 4 +- llm-service/app/services/qdrant.py | 4 +- .../app/services/rag_qdrant_vector_store.py | 6 +- llm-service/app/services/s3.py | 2 +- llm-service/app/tests/conftest.py | 65 +++++++++++++++---- .../tests/routers/index/test_data_source.py | 23 ++++--- .../tests/routers/index/test_doc_summaries.py | 20 +++--- 18 files changed, 134 insertions(+), 74 deletions(-) diff --git a/llm-service/app/exceptions.py b/llm-service/app/exceptions.py index 9e17c8daf..5db3b5ff4 100644 --- a/llm-service/app/exceptions.py +++ b/llm-service/app/exceptions.py @@ -43,7 +43,7 @@ import inspect import logging from collections.abc import Callable, Iterator -from typing import ParamSpec, TypeVar +from typing import Awaitable, ParamSpec, Type, TypeVar, Union import requests from fastapi import HTTPException @@ -88,7 +88,7 @@ def _exception_propagation() -> Iterator[None]: ) from e -def propagates(f: Callable[P, T]) -> Callable[P, T]: +def propagates(f: Callable[P, T]) -> Union[Callable[P, T], Callable[P, Awaitable[T]]]: """ Function decorator for catching and propagating exceptions back to a client. @@ -124,8 +124,8 @@ def banana(): @functools.wraps(f) async def exception_propagation_wrapper(*args: P.args, **kwargs: P.kwargs) -> T: with _exception_propagation(): - return await f(*args, **kwargs) - + ret: T = await f(*args, **kwargs) + return ret else: @functools.wraps(f) diff --git a/llm-service/app/main.py b/llm-service/app/main.py index 995ae75d6..6123be675 100644 --- a/llm-service/app/main.py +++ b/llm-service/app/main.py @@ -42,6 +42,7 @@ import time from collections.abc import Awaitable, Callable from contextlib import asynccontextmanager +from typing import AsyncGenerator from fastapi import FastAPI, Request, Response from fastapi.middleware.cors import CORSMiddleware @@ -79,7 +80,7 @@ def _configure_logger() -> None: @functools.cache -def _get_app_log_handler(): +def _get_app_log_handler() -> logging.Handler: """Format and return a reusable handler for application logging.""" # match Java backend's formatting formatter = logging.Formatter( @@ -125,7 +126,7 @@ def initialize_logging() -> None: @asynccontextmanager -async def lifespan(app: FastAPI): +async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: initialize_logging() yield diff --git a/llm-service/app/routers/index/sessions/__init__.py b/llm-service/app/routers/index/sessions/__init__.py index 284d94814..422b988d1 100644 --- a/llm-service/app/routers/index/sessions/__init__.py +++ b/llm-service/app/routers/index/sessions/__init__.py @@ -42,10 +42,10 @@ from pydantic import BaseModel from .... import exceptions +from ....rag_types import RagPredictConfiguration from ....services import llm_completion, qdrant from ....services.chat import generate_suggested_questions, v2_chat from ....services.chat_store import RagStudioChatMessage, chat_store -from ....services.qdrant import RagPredictConfiguration router = APIRouter(prefix="/sessions/{session_id}", tags=["Sessions"]) @@ -107,7 +107,7 @@ def llm_talk( evaluations=[], rag_message={ "user": request.query, - "assistant": chat_response.message.content, + "assistant": str(chat_response.message.content), }, timestamp=time.time(), ) diff --git a/llm-service/app/services/CaiiEmbeddingModel.py b/llm-service/app/services/CaiiEmbeddingModel.py index ca8019e15..c550713c5 100644 --- a/llm-service/app/services/CaiiEmbeddingModel.py +++ b/llm-service/app/services/CaiiEmbeddingModel.py @@ -38,7 +38,7 @@ import http.client as http_client import json import os -from typing import List +from typing import Any, Dict, List from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding from openai import OpenAI @@ -48,7 +48,7 @@ class CaiiEmbeddingModel(BaseEmbedding): endpoint = Field(any, description="The endpoint to use for embeddings") - def __init__(self, endpoint): + def __init__(self, endpoint: Dict[str, Any]): super().__init__() self.endpoint = endpoint @@ -56,7 +56,7 @@ def _get_text_embedding(self, text: str) -> Embedding: return self._get_embedding(text, "passage") async def _aget_query_embedding(self, query: str) -> Embedding: - pass + raise NotImplementedError("Not implemented") def _get_query_embedding(self, query: str) -> Embedding: return self._get_embedding(query, "query") @@ -82,10 +82,12 @@ def _get_embedding(self, query: str, input_type: str) -> Embedding: json_response = data.decode("utf-8") structured_response = json.loads(json_response) embedding = structured_response["data"][0]["embedding"] + assert isinstance(embedding, list) + assert all(isinstance(x, float) for x in embedding) return embedding - def build_auth_headers(self) -> dict: + def build_auth_headers(self) -> Dict[str, str]: with open("/tmp/jwt", "r") as file: jwt_contents = json.load(file) access_token = jwt_contents["access_token"] diff --git a/llm-service/app/services/CaiiModel.py b/llm-service/app/services/CaiiModel.py index 6ded40539..d8e02022f 100644 --- a/llm-service/app/services/CaiiModel.py +++ b/llm-service/app/services/CaiiModel.py @@ -35,7 +35,9 @@ # BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF # DATA. # -from llama_index.core.base.llms.types import LLMMetadata +from typing import Callable, Dict, Sequence + +from llama_index.core.base.llms.types import ChatMessage, LLMMetadata from llama_index.core.bridge.pydantic import Field from llama_index.llms.mistralai.base import MistralAI from llama_index.llms.openai import OpenAI @@ -52,9 +54,9 @@ def __init__( model: str, context: int, api_base: str, - messages_to_prompt, - completion_to_prompt, - default_headers, + messages_to_prompt: Callable[[Sequence[ChatMessage]], str], + completion_to_prompt: Callable[[str], str], + default_headers: Dict[str, str], ): super().__init__( model=model, @@ -84,9 +86,9 @@ def __init__( model: str, context: int, api_base: str, - messages_to_prompt, - completion_to_prompt, - default_headers, + messages_to_prompt: Callable[[Sequence[ChatMessage]], str], + completion_to_prompt: Callable[[str], str], + default_headers: Dict[str, str], ): super().__init__( api_key=default_headers.get("Authorization"), diff --git a/llm-service/app/services/caii.py b/llm-service/app/services/caii.py index fa2855445..34455fa4c 100644 --- a/llm-service/app/services/caii.py +++ b/llm-service/app/services/caii.py @@ -37,11 +37,12 @@ # import json import os -from typing import Any, Dict, List +from typing import Any, Callable, Dict, List, Sequence import requests from fastapi import HTTPException from llama_index.core.base.embeddings.base import BaseEmbedding +from llama_index.core.base.llms.types import ChatMessage from llama_index.core.llms import LLM from .CaiiEmbeddingModel import CaiiEmbeddingModel @@ -66,7 +67,10 @@ def describe_endpoint(domain: str, endpoint_name: str) -> Any: def get_llm( - domain: str, endpoint_name: str, messages_to_prompt, completion_to_prompt + domain: str, + endpoint_name: str, + messages_to_prompt: Callable[[Sequence[ChatMessage]], str], + completion_to_prompt: Callable[[str], str], ) -> LLM: endpoint = describe_endpoint(domain=domain, endpoint_name=endpoint_name) api_base = endpoint["url"].removesuffix("/chat/completions") diff --git a/llm-service/app/services/chat.py b/llm-service/app/services/chat.py index 5d80f105a..9df23f39f 100644 --- a/llm-service/app/services/chat.py +++ b/llm-service/app/services/chat.py @@ -41,6 +41,7 @@ from typing import List from llama_index.core.base.llms.types import MessageRole +from llama_index.core.chat_engine.types import AgentChatResponse from ..rag_types import RagPredictConfiguration from . import evaluators, qdrant @@ -112,7 +113,7 @@ def retrieve_chat_history(session_id: int) -> List[RagContext]: return history -def format_source_nodes(response): +def format_source_nodes(response: AgentChatResponse) -> List[RagPredictSourceNode]: response_source_nodes = [] for source_node in response.source_nodes: doc_id = source_node.node.metadata.get("document_id", source_node.node.node_id) @@ -121,7 +122,7 @@ def format_source_nodes(response): node_id=source_node.node.node_id, doc_id=doc_id, source_file_name=source_node.node.metadata["file_name"], - score=source_node.score, + score=source_node.score or 0.0, ) ) response_source_nodes = sorted( @@ -131,8 +132,11 @@ def format_source_nodes(response): def generate_suggested_questions( - configuration, data_source_id, data_source_size, session_id -): + configuration: RagPredictConfiguration, + data_source_id: int, + data_source_size: int, + session_id: int, +) -> List[str]: chat_history = retrieve_chat_history(session_id) if data_source_size == 0: suggested_questions = [] diff --git a/llm-service/app/services/chat_store.py b/llm-service/app/services/chat_store.py index 0a353df2b..3804cc524 100644 --- a/llm-service/app/services/chat_store.py +++ b/llm-service/app/services/chat_store.py @@ -102,8 +102,8 @@ def retrieve_chat_history(self, session_id: int) -> List[RagStudioChatMessage]: "source_nodes", [] ), rag_message={ - MessageRole.USER.value: user_message.content, - MessageRole.ASSISTANT.value: assistant_message.content, + MessageRole.USER.value: str(user_message.content), + MessageRole.ASSISTANT.value: str(assistant_message.content), }, evaluations=assistant_message.additional_kwargs.get( "evaluations", [] @@ -134,7 +134,9 @@ def delete_chat_history(self, session_id: int) -> None: def store_file(self, session_id: int) -> str: return os.path.join(self.store_path, f"chat_store-{session_id}.json") - def append_to_history(self, session_id: int, messages: list[RagStudioChatMessage]): + def append_to_history( + self, session_id: int, messages: List[RagStudioChatMessage] + ) -> None: store = self.store_for_session(session_id) for message in messages: diff --git a/llm-service/app/services/doc_summaries.py b/llm-service/app/services/doc_summaries.py index 1ad1303ad..7572d93ca 100644 --- a/llm-service/app/services/doc_summaries.py +++ b/llm-service/app/services/doc_summaries.py @@ -116,7 +116,7 @@ def generate_summary( def set_settings_globals() -> None: Settings.llm = models.get_llm("meta.llama3-8b-instruct-v1:0") Settings.embed_model = models.get_embedding_model() - Settings.splitter = SentenceSplitter(chunk_size=1024) + Settings.text_splitter = SentenceSplitter(chunk_size=1024) def initialize_summary_index_storage(data_source_id: int) -> None: diff --git a/llm-service/app/services/evaluators.py b/llm-service/app/services/evaluators.py index 2603d052d..125169497 100644 --- a/llm-service/app/services/evaluators.py +++ b/llm-service/app/services/evaluators.py @@ -36,6 +36,7 @@ # DATA. # ############################################################################## +from llama_index.core.base.response.schema import Response from llama_index.core.chat_engine.types import AgentChatResponse from llama_index.core.evaluation import FaithfulnessEvaluator, RelevancyEvaluator from llama_index.llms.bedrock import Bedrock @@ -56,10 +57,10 @@ def evaluate_response( relevancy_evaluator = RelevancyEvaluator(llm=evaluator_llm) relevance = relevancy_evaluator.evaluate_response( - query=query, response=chat_response + query=query, response=Response(response=chat_response.response) ) faithfulness_evaluator = FaithfulnessEvaluator(llm=evaluator_llm) faithfulness = faithfulness_evaluator.evaluate_response( - query=query, response=chat_response + query=query, response=Response(response=chat_response.response) ) return relevance.score or 0, faithfulness.score or 0 diff --git a/llm-service/app/services/llama_utils.py b/llm-service/app/services/llama_utils.py index d2a0bb493..734b2e9e6 100644 --- a/llm-service/app/services/llama_utils.py +++ b/llm-service/app/services/llama_utils.py @@ -129,7 +129,7 @@ def completion_to_prompt(completion: str, system_prompt: Optional[str] = None) - return result -def mistralv2_messages_to_prompt(messages): +def mistralv2_messages_to_prompt(messages: Sequence[ChatMessage]) -> str: print(f"mistralv2_messages_to_prompt: {messages}") conversation = "" bos_token = "" @@ -139,14 +139,14 @@ def mistralv2_messages_to_prompt(messages): system_message = messages[0].content else: loop_messages = messages - system_message = False + system_message = None for index, message in enumerate(loop_messages): if (message.role == MessageRole.USER) != (index % 2 == 0): raise Exception( "HFI Conversation roles must alternate user/assistant/user/assistant/..." ) - if index == 0 and system_message != False: + if index == 0 and system_message is not None: content = "<>\n" + system_message + "\n<>\n\n" + message.content else: content = message.content diff --git a/llm-service/app/services/llm_completion.py b/llm-service/app/services/llm_completion.py index c5a49fc13..c6ae9e2c8 100644 --- a/llm-service/app/services/llm_completion.py +++ b/llm-service/app/services/llm_completion.py @@ -39,8 +39,8 @@ from llama_index.core.base.llms.types import ChatMessage, ChatResponse -from .chat_store import chat_store, RagStudioChatMessage -from .qdrant import RagPredictConfiguration +from ..rag_types import RagPredictConfiguration +from .chat_store import RagStudioChatMessage, chat_store from .models import get_llm diff --git a/llm-service/app/services/qdrant.py b/llm-service/app/services/qdrant.py index 94fd14bbd..efeeff25d 100644 --- a/llm-service/app/services/qdrant.py +++ b/llm-service/app/services/qdrant.py @@ -183,7 +183,7 @@ def query( ) logger.info("querying chat engine") - chat_history = list( + chat_messages = list( map( lambda message: ChatMessage(role=message.role, content=message.content), chat_history, @@ -191,7 +191,7 @@ def query( ) try: - chat_response = chat_engine.chat(query_str, chat_history) + chat_response: AgentChatResponse = chat_engine.chat(query_str, chat_messages) logger.info("query response received from chat engine") return chat_response except botocore.exceptions.ClientError as error: diff --git a/llm-service/app/services/rag_qdrant_vector_store.py b/llm-service/app/services/rag_qdrant_vector_store.py index 55a8cd7f8..18fd1cf55 100644 --- a/llm-service/app/services/rag_qdrant_vector_store.py +++ b/llm-service/app/services/rag_qdrant_vector_store.py @@ -51,7 +51,7 @@ class RagQdrantVectorStore(VectorStore): port = 6333 async_port = 6334 - def __init__(self, table_name: str, memory_store=False): + def __init__(self, table_name: str, memory_store: bool = False): self.client, self.aclient = self._create_qdrant_clients(memory_store) self.table_name = table_name @@ -64,7 +64,7 @@ def size(self) -> int: document_count: CountResult = self.client.count(self.table_name) return document_count.count - def delete(self): + def delete(self) -> None: if self.exists(): self.client.delete_collection(self.table_name) @@ -72,7 +72,7 @@ def exists(self) -> bool: return self.client.collection_exists(self.table_name) def _create_qdrant_clients( - self, memory_store + self, memory_store: bool ) -> tuple[qdrant_client.QdrantClient, qdrant_client.AsyncQdrantClient]: if memory_store: client = qdrant_client.QdrantClient(":memory:") diff --git a/llm-service/app/services/s3.py b/llm-service/app/services/s3.py index 5e691b239..bd06833b4 100644 --- a/llm-service/app/services/s3.py +++ b/llm-service/app/services/s3.py @@ -45,7 +45,7 @@ logger = logging.getLogger(__name__) -def download(tmpdirname: str, bucket_name: str, document_key: str): +def download(tmpdirname: str, bucket_name: str, document_key: str) -> None: """ Download document from S3 """ diff --git a/llm-service/app/tests/conftest.py b/llm-service/app/tests/conftest.py index ede45b49d..5c6201fa6 100644 --- a/llm-service/app/tests/conftest.py +++ b/llm-service/app/tests/conftest.py @@ -40,10 +40,12 @@ import pathlib import uuid from collections.abc import Iterator -from typing import Any, Sequence +from dataclasses import dataclass +from typing import Any, Dict, Sequence import boto3 import pytest +from boto3.resources.base import ServiceResource from fastapi.testclient import TestClient from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding from llama_index.core.base.llms.types import ( @@ -65,6 +67,12 @@ from app.services.rag_qdrant_vector_store import RagQdrantVectorStore +@dataclass +class BotoObject: + bucket_name: str + key: str + + @pytest.fixture def aws_region() -> str: return os.environ.get("AWS_DEFAULT_REGION", "us-west-2") @@ -78,9 +86,9 @@ def databases_dir(monkeypatch: pytest.MonkeyPatch, tmp_path: pathlib.Path) -> st @pytest.fixture -def s3( +def s3_client( monkeypatch: pytest.MonkeyPatch, -) -> Iterator["s3.ServiceResource"]: +) -> Iterator[ServiceResource]: """Mock all S3 interactions.""" with mock_aws(): @@ -98,7 +106,9 @@ def data_source_id() -> int: @pytest.fixture -def index_document_request_body(data_source_id, s3_object) -> dict[str, Any]: +def index_document_request_body( + data_source_id: int, s3_object: BotoObject +) -> Dict[str, Any]: return { "data_source_id": data_source_id, "s3_bucket_name": s3_object.bucket_name, @@ -135,6 +145,36 @@ def complete( ) -> CompletionResponse: return CompletionResponse(text=self.completion_response) + def stream_chat( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> ChatResponseGen: + raise NotImplementedError("Not implemented") + + def stream_complete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponseGen: + raise NotImplementedError("Not implemented") + + async def achat( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> ChatResponse: + raise NotImplementedError("Not implemented") + + async def acomplete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponse: + raise NotImplementedError("Not implemented") + + async def astream_chat( + self, messages: Sequence[ChatMessage], **kwargs: Any + ) -> ChatResponseAsyncGen: + raise NotImplementedError("Not implemented") + + async def astream_complete( + self, prompt: str, formatted: bool = False, **kwargs: Any + ) -> CompletionResponseAsyncGen: + raise NotImplementedError("Not implemented") + class DummyEmbeddingModel(BaseEmbedding): def _get_query_embedding(self, query: str) -> Embedding: @@ -149,7 +189,7 @@ def _get_text_embedding(self, text: str) -> Embedding: # We're hacking our vector stores to run in-memory. Since they are in memory, we need # to be sure to return the same instance for the same data source id -table_name_to_vector_store = {} +table_name_to_vector_store: Dict[int, RagQdrantVectorStore] = {} def _get_vector_store_instance( @@ -165,7 +205,7 @@ def _get_vector_store_instance( @pytest.fixture(autouse=True) -def vector_store(monkeypatch: pytest.MonkeyPatch): +def vector_store(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( rag_vector_store, "create_rag_vector_store", @@ -174,7 +214,7 @@ def vector_store(monkeypatch: pytest.MonkeyPatch): @pytest.fixture(autouse=True) -def summary_vector_store(monkeypatch: pytest.MonkeyPatch): +def summary_vector_store(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr( rag_vector_store, "create_summary_vector_store", @@ -202,25 +242,26 @@ def llm(monkeypatch: pytest.MonkeyPatch) -> LLM: @pytest.fixture def s3_object( - s3: "s3.ServiceResource", aws_region: str, document_id: str -) -> "s3.Object": + s3_client: ServiceResource, aws_region: str, document_id: str +) -> BotoObject: """Put and return a mocked S3 object""" bucket_name = "test_bucket" key = "test/" + document_id - bucket = s3.Bucket(bucket_name) + bucket = s3_client.Bucket(bucket_name) bucket.create(CreateBucketConfiguration={"LocationConstraint": aws_region}) - return bucket.put_object( + bucket.put_object( Key=key, # TODO: fixturize file Body=b"Some text to be summarized and indexed", Metadata={"originalfilename": "test.txt"}, ) + return BotoObject(bucket_name=bucket_name, key=key) @pytest.fixture def client( - s3: "s3.ServiceResource", + s3_client: ServiceResource, ) -> Iterator[TestClient]: """Return a test client for making calls to the service. diff --git a/llm-service/app/tests/routers/index/test_data_source.py b/llm-service/app/tests/routers/index/test_data_source.py index fece5b718..b6450f46f 100644 --- a/llm-service/app/tests/routers/index/test_data_source.py +++ b/llm-service/app/tests/routers/index/test_data_source.py @@ -40,14 +40,13 @@ from typing import Any +from app.services import models, rag_vector_store +from fastapi.testclient import TestClient from llama_index.core import VectorStoreIndex from llama_index.core.vector_stores import VectorStoreQuery -from app.services import models -from app.services import rag_vector_store - -def get_vector_store_index(data_source_id) -> VectorStoreIndex: +def get_vector_store_index(data_source_id: int) -> VectorStoreIndex: vector_store = rag_vector_store.create_rag_vector_store( data_source_id ).access_vector_store() @@ -60,7 +59,7 @@ def get_vector_store_index(data_source_id) -> VectorStoreIndex: class TestDocumentIndexing: @staticmethod def test_create_document( - client, + client: TestClient, index_document_request_body: dict[str, Any], document_id: str, data_source_id: int, @@ -77,11 +76,11 @@ def test_create_document( vectors = index.vector_store.query( VectorStoreQuery(query_embedding=[0.66] * 1024, doc_ids=[document_id]) ) - assert len(vectors.nodes) == 1 + assert len(vectors.nodes or []) == 1 @staticmethod def test_delete_data_source( - client, + client: TestClient, data_source_id: int, document_id: str, index_document_request_body: dict[str, Any], @@ -96,7 +95,7 @@ def test_delete_data_source( vectors = index.vector_store.query( VectorStoreQuery(query_embedding=[0.66] * 1024, doc_ids=[document_id]) ) - assert len(vectors.nodes) == 1 + assert len(vectors.nodes or []) == 1 response = client.delete(f"/data_sources/{data_source_id}") assert response.status_code == 200 @@ -110,7 +109,7 @@ def test_delete_data_source( @staticmethod def test_delete_document( - client, + client: TestClient, data_source_id: int, document_id: str, index_document_request_body: dict[str, Any], @@ -125,7 +124,7 @@ def test_delete_document( vectors = index.vector_store.query( VectorStoreQuery(query_embedding=[0.2] * 1024, doc_ids=[document_id]) ) - assert len(vectors.nodes) == 1 + assert len(vectors.nodes or []) == 1 response = client.delete( f"/data_sources/{data_source_id}/documents/{document_id}" @@ -136,11 +135,11 @@ def test_delete_document( vectors = index.vector_store.query( VectorStoreQuery(query_embedding=[0.2] * 1024, doc_ids=[document_id]) ) - assert len(vectors.nodes) == 0 + assert len(vectors.nodes or []) == 0 @staticmethod def test_get_size( - client, + client: TestClient, data_source_id: int, index_document_request_body: dict[str, Any], ) -> None: diff --git a/llm-service/app/tests/routers/index/test_doc_summaries.py b/llm-service/app/tests/routers/index/test_doc_summaries.py index a8e460b9a..2f850c847 100644 --- a/llm-service/app/tests/routers/index/test_doc_summaries.py +++ b/llm-service/app/tests/routers/index/test_doc_summaries.py @@ -38,15 +38,19 @@ from typing import Any +from fastapi.testclient import TestClient + +from ...conftest import BotoObject + class TestDocumentSummaries: @staticmethod def test_generate_summary( - client, + client: TestClient, index_document_request_body: dict[str, Any], - data_source_id, - document_id, - s3_object, + data_source_id: int, + document_id: str, + s3_object: BotoObject, ) -> None: response = client.post( f"/data_sources/{data_source_id}/documents/download-and-index", @@ -81,11 +85,11 @@ def test_generate_summary( @staticmethod def test_delete_document( - client, + client: TestClient, index_document_request_body: dict[str, Any], - data_source_id, - document_id, - s3_object, + data_source_id: int, + document_id: str, + s3_object: BotoObject, ) -> None: response = client.post( f"/data_sources/{data_source_id}/documents/download-and-index", From d64e2cbec406fc0eccd1921aeaa9b81503f10814 Mon Sep 17 00:00:00 2001 From: Conrado Silva Miranda Date: Thu, 21 Nov 2024 09:27:08 -0800 Subject: [PATCH 3/3] fix lint --- llm-service/app/exceptions.py | 2 +- llm-service/app/services/CaiiEmbeddingModel.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/llm-service/app/exceptions.py b/llm-service/app/exceptions.py index 5db3b5ff4..d94e18e5f 100644 --- a/llm-service/app/exceptions.py +++ b/llm-service/app/exceptions.py @@ -43,7 +43,7 @@ import inspect import logging from collections.abc import Callable, Iterator -from typing import Awaitable, ParamSpec, Type, TypeVar, Union +from typing import Awaitable, ParamSpec, TypeVar, Union import requests from fastapi import HTTPException diff --git a/llm-service/app/services/CaiiEmbeddingModel.py b/llm-service/app/services/CaiiEmbeddingModel.py index c550713c5..136430106 100644 --- a/llm-service/app/services/CaiiEmbeddingModel.py +++ b/llm-service/app/services/CaiiEmbeddingModel.py @@ -38,10 +38,9 @@ import http.client as http_client import json import os -from typing import Any, Dict, List +from typing import Any, Dict from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding -from openai import OpenAI from pydantic import Field