Skip to content

Commit 776e0c5

Browse files
authored
fix: add fallbacks for model info (#44)
1 parent fdf803b commit 776e0c5

File tree

3 files changed

+61
-20
lines changed

3 files changed

+61
-20
lines changed

src/raglite/_database.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from typing import Any
99

1010
import numpy as np
11-
from litellm import get_model_info # type: ignore[attr-defined]
1211
from markdown_it import MarkdownIt
1312
from pydantic import ConfigDict
1413
from sqlalchemy.engine import Engine, make_url
@@ -24,7 +23,7 @@
2423
)
2524

2625
from raglite._config import RAGLiteConfig
27-
from raglite._litellm import LlamaCppPythonLLM
26+
from raglite._litellm import get_embedding_dim
2827
from raglite._typing import Embedding, FloatMatrix, FloatVector, PickledObject
2928

3029

@@ -274,14 +273,8 @@ def create_database_engine(config: RAGLiteConfig | None = None) -> Engine:
274273
with Session(engine) as session:
275274
session.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
276275
session.commit()
277-
# If the user has configured a llama-cpp-python model, we ensure that LiteLLM's model info is up
278-
# to date by loading that LLM.
279-
if config.embedder.startswith("llama-cpp-python"):
280-
_ = LlamaCppPythonLLM.llm(config.embedder, embedding=True)
281-
llm_provider = "llama-cpp-python" if config.embedder.startswith("llama-cpp") else None
282-
model_info = get_model_info(config.embedder, custom_llm_provider=llm_provider)
283-
embedding_dim = model_info.get("output_vector_size") or -1
284-
assert embedding_dim > 0
276+
# Get the embedding dimension.
277+
embedding_dim = get_embedding_dim(config)
285278
# Create all SQLModel tables.
286279
ChunkEmbedding.set_embedding_dim(embedding_dim)
287280
SQLModel.metadata.create_all(engine)

src/raglite/_litellm.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
GenericStreamingChunk,
1515
ModelResponse,
1616
convert_to_model_response_object,
17+
get_model_info,
1718
)
1819
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
1920
from llama_cpp import ( # type: ignore[attr-defined]
@@ -24,6 +25,8 @@
2425
LlamaRAMCache,
2526
)
2627

28+
from raglite._config import RAGLiteConfig
29+
2730
# Reduce the logging level for LiteLLM and flashrank.
2831
logging.getLogger("litellm").setLevel(logging.WARNING)
2932
logging.getLogger("flashrank").setLevel(logging.WARNING)
@@ -259,3 +262,54 @@ async def astreaming( # type: ignore[misc,override] # noqa: PLR0913
259262
{"provider": "llama-cpp-python", "custom_handler": LlamaCppPythonLLM()}
260263
)
261264
litellm.suppress_debug_info = True
265+
266+
267+
@cache
268+
def get_context_size(config: RAGLiteConfig, *, fallback: int = 2048) -> int:
269+
"""Get the context size for the configured LLM."""
270+
# If the user has configured a llama-cpp-python model, we ensure that LiteLLM's model info is up
271+
# to date by loading that LLM.
272+
if config.llm.startswith("llama-cpp-python"):
273+
_ = LlamaCppPythonLLM.llm(config.llm)
274+
# Attempt to read the context size from LiteLLM's model info.
275+
llm_provider = "llama-cpp-python" if config.llm.startswith("llama-cpp") else None
276+
model_info = get_model_info(config.llm, custom_llm_provider=llm_provider)
277+
max_tokens = model_info.get("max_tokens")
278+
if isinstance(max_tokens, int) and max_tokens > 0:
279+
return max_tokens
280+
# Fall back to a default context size if the model info is not available.
281+
if fallback > 0:
282+
warnings.warn(
283+
f"Could not determine the context size of {config.llm} from LiteLLM's model_info, using {fallback}.",
284+
stacklevel=2,
285+
)
286+
return 2048
287+
error_message = f"Could not determine the context size of {config.llm}."
288+
raise ValueError(error_message)
289+
290+
291+
@cache
292+
def get_embedding_dim(config: RAGLiteConfig, *, fallback: bool = True) -> int:
293+
"""Get the embedding dimension for the configured embedder."""
294+
# If the user has configured a llama-cpp-python model, we ensure that LiteLLM's model info is up
295+
# to date by loading that LLM.
296+
if config.embedder.startswith("llama-cpp-python"):
297+
_ = LlamaCppPythonLLM.llm(config.embedder, embedding=True)
298+
# Attempt to read the embedding dimension from LiteLLM's model info.
299+
llm_provider = "llama-cpp-python" if config.embedder.startswith("llama-cpp") else None
300+
model_info = get_model_info(config.embedder, custom_llm_provider=llm_provider)
301+
embedding_dim = model_info.get("output_vector_size")
302+
if isinstance(embedding_dim, int) and embedding_dim > 0:
303+
return embedding_dim
304+
# If that fails, fall back to embedding a single sentence and reading its embedding dimension.
305+
if fallback:
306+
from raglite._embed import embed_sentences
307+
308+
warnings.warn(
309+
f"Could not determine the embedding dimension of {config.embedder} from LiteLLM's model_info, using fallback.",
310+
stacklevel=2,
311+
)
312+
fallback_embeddings = embed_sentences(["Hello world"], config=config)
313+
return fallback_embeddings.shape[1]
314+
error_message = f"Could not determine the embedding dimension of {config.embedder}."
315+
raise ValueError(error_message)

src/raglite/_rag.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
from collections.abc import AsyncIterator, Iterator
44

5-
from litellm import acompletion, completion, get_model_info # type: ignore[attr-defined]
5+
from litellm import acompletion, completion
66

77
from raglite._config import RAGLiteConfig
88
from raglite._database import Chunk
9-
from raglite._litellm import LlamaCppPythonLLM
9+
from raglite._litellm import get_context_size
1010
from raglite._search import hybrid_search, rerank_chunks, retrieve_segments
1111
from raglite._typing import SearchMethod
1212

@@ -27,15 +27,9 @@ def _max_contexts(
2727
config: RAGLiteConfig | None = None,
2828
) -> int:
2929
"""Determine the maximum number of contexts for RAG."""
30-
# If the user has configured a llama-cpp-python model, we ensure that LiteLLM's model info is up
31-
# to date by loading that LLM.
30+
# Get the model's context size.
3231
config = config or RAGLiteConfig()
33-
if config.llm.startswith("llama-cpp-python"):
34-
_ = LlamaCppPythonLLM.llm(config.llm)
35-
# Get the model's maximum context size.
36-
llm_provider = "llama-cpp-python" if config.llm.startswith("llama-cpp") else None
37-
model_info = get_model_info(config.llm, custom_llm_provider=llm_provider)
38-
max_tokens = model_info.get("max_tokens") or 2048
32+
max_tokens = get_context_size(config)
3933
# Reduce the maximum number of contexts to take into account the LLM's context size.
4034
max_context_tokens = (
4135
max_tokens

0 commit comments

Comments
 (0)