Skip to content

Refactor vector store for dependency injection #37

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions llm-service/app/ai/indexing/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@
from llama_index.core.schema import BaseNode, Document, TextNode
from llama_index.readers.file import DocxReader

from ...ai.vector_stores.vector_store import VectorStore
from ...services.utils import batch_sequence, flatten_sequence
from ...services.vector_store import VectorStore
from .readers.csv import CSVReader
from .readers.json import JSONReader
from .readers.nop import NopReader
Expand Down Expand Up @@ -124,7 +124,7 @@ def index_file(self, file_path: Path, document_id: str) -> None:
# we're capturing "text".
converted_chunks: List[BaseNode] = [chunk for chunk in chunk_batch]

chunks_vector_store = self.chunks_vector_store.access_vector_store()
chunks_vector_store = self.chunks_vector_store.llama_vector_store()
chunks_vector_store.add(converted_chunks)

logger.debug(f"Indexing file: {file_path} completed")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,3 @@
# BUSINESS ADVANTAGE OR UNAVAILABILITY, OR LOSS OR CORRUPTION OF
# DATA.
#

from .rag_qdrant_vector_store import RagQdrantVectorStore
from .vector_store import VectorStore


def create_rag_vector_store(data_source_id: int) -> VectorStore:
return RagQdrantVectorStore(table_name=f"index_{data_source_id}")


def create_summary_vector_store(data_source_id: int) -> VectorStore:
return RagQdrantVectorStore(table_name=f"summary_index_{data_source_id}")
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,43 @@
#

import os
from typing import Optional

import qdrant_client
from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.qdrant import QdrantVectorStore
from llama_index.vector_stores.qdrant import (
QdrantVectorStore as LlamaIndexQdrantVectorStore,
)
from qdrant_client.http.models import CountResult

from .vector_store import VectorStore


class RagQdrantVectorStore(VectorStore):
def new_qdrant_client() -> qdrant_client.QdrantClient:
host = os.environ.get("QDRANT_HOST", "localhost")
port = 6333
return qdrant_client.QdrantClient(host=host, port=port)

def __init__(self, table_name: str, memory_store: bool = False):
self.client = self._create_qdrant_clients(memory_store)

class QdrantVectorStore(VectorStore):
@staticmethod
def for_chunks(
data_source_id: int, client: Optional[qdrant_client.QdrantClient] = None
) -> "QdrantVectorStore":
return QdrantVectorStore(table_name=f"index_{data_source_id}", client=client)

@staticmethod
def for_summaries(
data_source_id: int, client: Optional[qdrant_client.QdrantClient] = None
) -> "QdrantVectorStore":
return QdrantVectorStore(
table_name=f"summary_index_{data_source_id}", client=client
)

def __init__(
self, table_name: str, client: Optional[qdrant_client.QdrantClient] = None
):
self.client = client or new_qdrant_client()
self.table_name = table_name

def size(self) -> int:
Expand All @@ -70,14 +92,6 @@ def delete(self) -> None:
def exists(self) -> bool:
return self.client.collection_exists(self.table_name)

def _create_qdrant_clients(self, memory_store: bool) -> qdrant_client.QdrantClient:
if memory_store:
client = qdrant_client.QdrantClient(":memory:")
else:
client = qdrant_client.QdrantClient(host=self.host, port=self.port)

return client

def access_vector_store(self) -> BasePydanticVectorStore:
vector_store = QdrantVectorStore(self.table_name, self.client)
def llama_vector_store(self) -> BasePydanticVectorStore:
vector_store = LlamaIndexQdrantVectorStore(self.table_name, self.client)
return vector_store
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def delete(self) -> None:
"""Delete the vector store"""

@abstractmethod
def access_vector_store(self) -> BasePydanticVectorStore:
def llama_vector_store(self) -> BasePydanticVectorStore:
"""Access the underlying llama-index vector store implementation"""

@abstractmethod
Expand Down
208 changes: 120 additions & 88 deletions llm-service/app/routers/index/data_source/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,17 @@
import tempfile
from pathlib import Path

from fastapi import APIRouter
from fastapi import APIRouter, Depends
from fastapi_utils.cbv import cbv
from llama_index.core.indices import VectorStoreIndex
from llama_index.core.node_parser import SentenceSplitter
from pydantic import BaseModel

from .... import exceptions
from ....ai.indexing.index import Indexer
from ....services import doc_summaries, models, qdrant, rag_vector_store, s3
from ....ai.vector_stores.qdrant import QdrantVectorStore
from ....ai.vector_stores.vector_store import VectorStore
from ....services import doc_summaries, models, s3

logger = logging.getLogger(__name__)

Expand All @@ -51,58 +55,6 @@ class SummarizeDocumentRequest(BaseModel):
s3_document_key: str


@router.get("/size", summary="Returns the number of chunks in the data source.")
@exceptions.propagates
def size(data_source_id: int) -> int:
data_source_size = qdrant.size_of(data_source_id)
qdrant.check_data_source_exists(data_source_size)
return data_source_size


@router.get("/chunks/{chunk_id}", summary="Returns the content of a chunk.")
@exceptions.propagates
def chunk_contents(data_source_id: int, chunk_id: str) -> str:
return qdrant.chunk_contents(data_source_id, chunk_id)


@router.delete("", summary="Deletes the data source from the index.")
@exceptions.propagates
def delete(data_source_id: int) -> None:
qdrant.delete(data_source_id)
doc_summaries.delete_data_source(data_source_id)


@router.get("/documents/{doc_id}/summary", summary="summarize a single document")
@exceptions.propagates
def get_document_summary(data_source_id: int, doc_id: str) -> str:
summaries = doc_summaries.read_summary(data_source_id, doc_id)
return summaries


@router.get("/summary", summary="summarize all documents for a datasource")
@exceptions.propagates
def get_document_summary_of_summaries(data_source_id: int) -> str:
return doc_summaries.summarize_data_source(data_source_id)


@router.post("/summarize-document", summary="summarize a document")
@exceptions.propagates
def summarize_document(
data_source_id: int,
request: SummarizeDocumentRequest,
) -> str:
return doc_summaries.generate_summary(
data_source_id, request.s3_bucket_name, request.s3_document_key
)


@router.delete("/documents/{doc_id}", summary="delete a single document")
@exceptions.propagates
def delete_document(data_source_id: int, doc_id: str) -> None:
qdrant.delete_document(data_source_id, doc_id)
doc_summaries.delete_document(data_source_id, doc_id)


class RagIndexDocumentConfiguration(BaseModel):
# TODO: Add more params
chunk_size: int = 512 # this is llama-index's default
Expand All @@ -116,38 +68,118 @@ class RagIndexDocumentRequest(BaseModel):
configuration: RagIndexDocumentConfiguration = RagIndexDocumentConfiguration()


@router.post(
"/documents/download-and-index",
summary="Download and index document",
description="Download document from S3 and index in Pinecone",
)
@exceptions.propagates
def download_and_index(
data_source_id: int,
request: RagIndexDocumentRequest,
) -> None:
with tempfile.TemporaryDirectory() as tmpdirname:
logger.debug("created temporary directory %s", tmpdirname)
s3.download(tmpdirname, request.s3_bucket_name, request.s3_document_key)
# Get the single file in the directory
files = os.listdir(tmpdirname)
if len(files) != 1:
raise ValueError("Expected a single file in the temporary directory")
file_path = Path(os.path.join(tmpdirname, files[0]))

indexer = Indexer(
data_source_id,
splitter=SentenceSplitter(
chunk_size=request.configuration.chunk_size,
chunk_overlap=int(
request.configuration.chunk_overlap
* 0.01
* request.configuration.chunk_size
),
),
embedding_model=models.get_embedding_model(),
chunks_vector_store=rag_vector_store.create_rag_vector_store(
data_source_id
),
@cbv(router)
class DataSourceController:
chunks_vector_store: VectorStore = Depends(
lambda data_source_id: QdrantVectorStore.for_chunks(data_source_id)
)

@router.get(
"/size",
summary="Returns the number of chunks in the data source.",
response_model=None,
)
@exceptions.propagates
def size(self) -> int:
return self.chunks_vector_store.size()

@router.get(
"/chunks/{chunk_id}",
summary="Returns the content of a chunk.",
response_model=None,
)
@exceptions.propagates
def chunk_contents(self, chunk_id: str) -> str:
return (
self.chunks_vector_store.llama_vector_store()
.get_nodes([chunk_id])[0]
.get_content()
)

@router.delete(
"/", summary="Deletes the data source from the index.", response_model=None
)
@exceptions.propagates
def delete(self, data_source_id: int) -> None:
self.chunks_vector_store.delete()
doc_summaries.delete_data_source(data_source_id)

@router.get(
"/documents/{doc_id}/summary",
summary="summarize a single document",
response_model=None,
)
@exceptions.propagates
def get_document_summary(self, data_source_id: int, doc_id: str) -> str:
summaries = doc_summaries.read_summary(data_source_id, doc_id)
return summaries

@router.get(
"/summary",
summary="summarize all documents for a datasource",
response_model=None,
)
@exceptions.propagates
def get_document_summary_of_summaries(self, data_source_id: int) -> str:
return doc_summaries.summarize_data_source(data_source_id)

@router.post(
"/summarize-document", summary="summarize a document", response_model=None
)
@exceptions.propagates
def summarize_document(
self,
data_source_id: int,
request: SummarizeDocumentRequest,
) -> str:
return doc_summaries.generate_summary(
data_source_id, request.s3_bucket_name, request.s3_document_key
)

@router.delete(
"/documents/{doc_id}", summary="delete a single document", response_model=None
)
@exceptions.propagates
def delete_document(self, data_source_id: int, doc_id: str) -> None:
index = VectorStoreIndex.from_vector_store(
vector_store=self.chunks_vector_store.llama_vector_store(),
embed_model=models.get_embedding_model(),
)
indexer.index_file(file_path, request.document_id)
index.delete_ref_doc(doc_id)
doc_summaries.delete_document(data_source_id, doc_id)

@router.post(
"/documents/download-and-index",
summary="Download and index document",
description="Download document from S3 and index in Pinecone",
response_model=None,
)
@exceptions.propagates
def download_and_index(
self,
data_source_id: int,
request: RagIndexDocumentRequest,
) -> None:
with tempfile.TemporaryDirectory() as tmpdirname:
logger.debug("created temporary directory %s", tmpdirname)
s3.download(tmpdirname, request.s3_bucket_name, request.s3_document_key)
# Get the single file in the directory
files = os.listdir(tmpdirname)
if len(files) != 1:
raise ValueError("Expected a single file in the temporary directory")
file_path = Path(os.path.join(tmpdirname, files[0]))

indexer = Indexer(
data_source_id,
splitter=SentenceSplitter(
chunk_size=request.configuration.chunk_size,
chunk_overlap=int(
request.configuration.chunk_overlap
* 0.01
* request.configuration.chunk_size
),
),
embedding_model=models.get_embedding_model(),
chunks_vector_store=self.chunks_vector_store,
)
indexer.index_file(file_path, request.document_id)
10 changes: 6 additions & 4 deletions llm-service/app/routers/index/sessions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,13 @@
import time
import uuid

from fastapi import APIRouter
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel

from .... import exceptions
from ....ai.vector_stores.qdrant import QdrantVectorStore
from ....rag_types import RagPredictConfiguration
from ....services import llm_completion, qdrant
from ....services import llm_completion
from ....services.chat import generate_suggested_questions, v2_chat
from ....services.chat_store import RagStudioChatMessage, chat_store

Expand Down Expand Up @@ -130,8 +131,9 @@ def suggest_questions(
session_id: int,
request: SuggestQuestionsRequest,
) -> RagSuggestedQuestionsResponse:
data_source_size = qdrant.size_of(request.data_source_id)
qdrant.check_data_source_exists(data_source_size)
data_source_size = QdrantVectorStore.for_chunks(request.data_source_id).size()
if data_source_size == -1:
raise HTTPException(status_code=404, detail="Knowledge base not found.")
suggested_questions = generate_suggested_questions(
request.configuration, request.data_source_id, data_source_size, session_id
)
Expand Down
3 changes: 2 additions & 1 deletion llm-service/app/services/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from llama_index.core.base.llms.types import MessageRole
from llama_index.core.chat_engine.types import AgentChatResponse

from ..ai.vector_stores.qdrant import QdrantVectorStore
from ..rag_types import RagPredictConfiguration
from . import evaluators, qdrant
from .chat_store import (
Expand All @@ -61,7 +62,7 @@ def v2_chat(
configuration: RagPredictConfiguration,
) -> RagStudioChatMessage:
response_id = str(uuid.uuid4())
if qdrant.size_of(data_source_id) == 0:
if QdrantVectorStore.for_chunks(data_source_id).size() == 0:
return RagStudioChatMessage(
id=response_id,
source_nodes=[],
Expand Down
Loading
Loading