|
14 | 14 | GenericStreamingChunk,
|
15 | 15 | ModelResponse,
|
16 | 16 | convert_to_model_response_object,
|
| 17 | + get_model_info, |
17 | 18 | )
|
18 | 19 | from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
19 | 20 | from llama_cpp import ( # type: ignore[attr-defined]
|
|
24 | 25 | LlamaRAMCache,
|
25 | 26 | )
|
26 | 27 |
|
| 28 | +from raglite._config import RAGLiteConfig |
| 29 | + |
27 | 30 | # Reduce the logging level for LiteLLM and flashrank.
|
28 | 31 | logging.getLogger("litellm").setLevel(logging.WARNING)
|
29 | 32 | logging.getLogger("flashrank").setLevel(logging.WARNING)
|
@@ -259,3 +262,54 @@ async def astreaming( # type: ignore[misc,override] # noqa: PLR0913
|
259 | 262 | {"provider": "llama-cpp-python", "custom_handler": LlamaCppPythonLLM()}
|
260 | 263 | )
|
261 | 264 | litellm.suppress_debug_info = True
|
| 265 | + |
| 266 | + |
| 267 | +@cache |
| 268 | +def get_context_size(config: RAGLiteConfig, *, fallback: int = 2048) -> int: |
| 269 | + """Get the context size for the configured LLM.""" |
| 270 | + # If the user has configured a llama-cpp-python model, we ensure that LiteLLM's model info is up |
| 271 | + # to date by loading that LLM. |
| 272 | + if config.llm.startswith("llama-cpp-python"): |
| 273 | + _ = LlamaCppPythonLLM.llm(config.llm) |
| 274 | + # Attempt to read the context size from LiteLLM's model info. |
| 275 | + llm_provider = "llama-cpp-python" if config.llm.startswith("llama-cpp") else None |
| 276 | + model_info = get_model_info(config.llm, custom_llm_provider=llm_provider) |
| 277 | + max_tokens = model_info.get("max_tokens") |
| 278 | + if isinstance(max_tokens, int) and max_tokens > 0: |
| 279 | + return max_tokens |
| 280 | + # Fall back to a default context size if the model info is not available. |
| 281 | + if fallback > 0: |
| 282 | + warnings.warn( |
| 283 | + f"Could not determine the context size of {config.llm} from LiteLLM's model_info, using {fallback}.", |
| 284 | + stacklevel=2, |
| 285 | + ) |
| 286 | + return 2048 |
| 287 | + error_message = f"Could not determine the context size of {config.llm}." |
| 288 | + raise ValueError(error_message) |
| 289 | + |
| 290 | + |
| 291 | +@cache |
| 292 | +def get_embedding_dim(config: RAGLiteConfig, *, fallback: bool = True) -> int: |
| 293 | + """Get the embedding dimension for the configured embedder.""" |
| 294 | + # If the user has configured a llama-cpp-python model, we ensure that LiteLLM's model info is up |
| 295 | + # to date by loading that LLM. |
| 296 | + if config.embedder.startswith("llama-cpp-python"): |
| 297 | + _ = LlamaCppPythonLLM.llm(config.embedder, embedding=True) |
| 298 | + # Attempt to read the embedding dimension from LiteLLM's model info. |
| 299 | + llm_provider = "llama-cpp-python" if config.embedder.startswith("llama-cpp") else None |
| 300 | + model_info = get_model_info(config.embedder, custom_llm_provider=llm_provider) |
| 301 | + embedding_dim = model_info.get("output_vector_size") |
| 302 | + if isinstance(embedding_dim, int) and embedding_dim > 0: |
| 303 | + return embedding_dim |
| 304 | + # If that fails, fall back to embedding a single sentence and reading its embedding dimension. |
| 305 | + if fallback: |
| 306 | + from raglite._embed import embed_sentences |
| 307 | + |
| 308 | + warnings.warn( |
| 309 | + f"Could not determine the embedding dimension of {config.embedder} from LiteLLM's model_info, using fallback.", |
| 310 | + stacklevel=2, |
| 311 | + ) |
| 312 | + fallback_embeddings = embed_sentences(["Hello world"], config=config) |
| 313 | + return fallback_embeddings.shape[1] |
| 314 | + error_message = f"Could not determine the embedding dimension of {config.embedder}." |
| 315 | + raise ValueError(error_message) |
0 commit comments