Skip to content

Support discovering gemini, anthropic, xai models by calling their /v1/model endpoint #9530

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 28, 2025
Merged
14 changes: 14 additions & 0 deletions docs/my-website/docs/set_keys.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,14 +188,28 @@ Currently implemented for:
- OpenAI (if OPENAI_API_KEY is set)
- Fireworks AI (if FIREWORKS_AI_API_KEY is set)
- LiteLLM Proxy (if LITELLM_PROXY_API_KEY is set)
- Gemini (if GEMINI_API_KEY is set)
- XAI (if XAI_API_KEY is set)
- Anthropic (if ANTHROPIC_API_KEY is set)

You can also specify a custom provider to check:

**All providers**:
```python
from litellm import get_valid_models

valid_models = get_valid_models(check_provider_endpoint=True)
print(valid_models)
```

**Specific provider**:
```python
from litellm import get_valid_models

valid_models = get_valid_models(check_provider_endpoint=True, custom_llm_provider="openai")
print(valid_models)
```

### `validate_environment(model: str)`

This helper tells you if you have all the required environment variables for a model, and if not - what's missing.
Expand Down
3 changes: 3 additions & 0 deletions litellm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,7 @@ def add_known_models():
from .llms.maritalk import MaritalkConfig
from .llms.openrouter.chat.transformation import OpenrouterConfig
from .llms.anthropic.chat.transformation import AnthropicConfig
from .llms.anthropic.common_utils import AnthropicModelInfo
from .llms.groq.stt.transformation import GroqSTTConfig
from .llms.anthropic.completion.transformation import AnthropicTextConfig
from .llms.triton.completion.transformation import TritonConfig
Expand Down Expand Up @@ -848,6 +849,7 @@ def add_known_models():
VertexGeminiConfig,
VertexGeminiConfig as VertexAIConfig,
)
from .llms.gemini.common_utils import GeminiModelInfo
from .llms.gemini.chat.transformation import (
GoogleAIStudioGeminiConfig,
GoogleAIStudioGeminiConfig as GeminiConfig, # aliased to maintain backwards compatibility
Expand Down Expand Up @@ -984,6 +986,7 @@ def add_known_models():
from .llms.friendliai.chat.transformation import FriendliaiChatConfig
from .llms.jina_ai.embedding.transformation import JinaAIEmbeddingConfig
from .llms.xai.chat.transformation import XAIChatConfig
from .llms.xai.common_utils import XAIModelInfo
from .llms.volcengine import VolcEngineConfig
from .llms.codestral.completion.transformation import CodestralTextCompletionConfig
from .llms.azure.azure import (
Expand Down
51 changes: 51 additions & 0 deletions litellm/llms/anthropic/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@

import httpx

import litellm
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.secret_managers.main import get_secret_str


class AnthropicError(BaseLLMException):
Expand All @@ -19,6 +22,54 @@ def __init__(
super().__init__(status_code=status_code, message=message, headers=headers)


class AnthropicModelInfo(BaseLLMModelInfo):
@staticmethod
def get_api_base(api_base: Optional[str] = None) -> str | None:
return (
api_base
or get_secret_str("ANTHROPIC_API_BASE")
or "https://api.anthropic.com"
)

@staticmethod
def get_api_key(api_key: str | None = None) -> str | None:
return api_key or get_secret_str("ANTHROPIC_API_KEY")

@staticmethod
def get_base_model(model: str) -> str | None:
return model.replace("anthropic/", "")

def get_models(
self, api_key: Optional[str] = None, api_base: Optional[str] = None
) -> list[str]:
api_base = AnthropicModelInfo.get_api_base(api_base)
api_key = AnthropicModelInfo.get_api_key(api_key)
if api_base is None or api_key is None:
raise ValueError(
"ANTHROPIC_API_BASE or ANTHROPIC_API_KEY is not set. Please set the environment variable, to query Anthropic's `/models` endpoint."
)
response = litellm.module_level_client.get(
url=f"{api_base}/v1/models",
headers={"x-api-key": api_key, "anthropic-version": "2023-06-01"},
)

try:
response.raise_for_status()
except httpx.HTTPStatusError:
raise Exception(
f"Failed to fetch models from Anthropic. Status code: {response.status_code}, Response: {response.text}"
)

models = response.json()["data"]

litellm_model_names = []
for model in models:
stripped_model_name = model["id"]
litellm_model_name = "anthropic/" + stripped_model_name
litellm_model_names.append(litellm_model_name)
return litellm_model_names


def process_anthropic_headers(headers: Union[httpx.Headers, dict]) -> dict:
openai_headers = {}
if "anthropic-ratelimit-requests-limit" in headers:
Expand Down
12 changes: 10 additions & 2 deletions litellm/llms/base_llm/base_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,19 @@ def get_provider_info(
self,
model: str,
) -> Optional[ProviderSpecificModelInfo]:
"""
Default values all models of this provider support.
"""
return None

@abstractmethod
def get_models(self) -> List[str]:
pass
def get_models(
self, api_key: Optional[str] = None, api_base: Optional[str] = None
) -> List[str]:
"""
Returns a list of models supported by this provider.
"""
return []

@staticmethod
@abstractmethod
Expand Down
52 changes: 52 additions & 0 deletions litellm/llms/gemini/common_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from typing import List, Optional

import litellm
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
from litellm.secret_managers.main import get_secret_str


class GeminiModelInfo(BaseLLMModelInfo):
@staticmethod
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
return (
api_base
or get_secret_str("GEMINI_API_BASE")
or "https://generativelanguage.googleapis.com/v1beta"
)

@staticmethod
def get_api_key(api_key: Optional[str] = None) -> Optional[str]:
return api_key or (get_secret_str("GEMINI_API_KEY"))

@staticmethod
def get_base_model(model: str) -> Optional[str]:
return model.replace("gemini/", "")

def get_models(
self, api_key: Optional[str] = None, api_base: Optional[str] = None
) -> List[str]:

api_base = GeminiModelInfo.get_api_base(api_base)
api_key = GeminiModelInfo.get_api_key(api_key)
if api_base is None or api_key is None:
raise ValueError(
"GEMINI_API_BASE or GEMINI_API_KEY is not set. Please set the environment variable, to query Gemini's `/models` endpoint."
)

response = litellm.module_level_client.get(
url=f"{api_base}/models?key={api_key}",
)

if response.status_code != 200:
raise ValueError(
f"Failed to fetch models from Gemini. Status code: {response.status_code}, Response: {response.json()}"
)

models = response.json()["models"]

litellm_model_names = []
for model in models:
stripped_model_name = model["name"].strip("models/")
litellm_model_name = "gemini/" + stripped_model_name
litellm_model_names.append(litellm_model_name)
return litellm_model_names
4 changes: 3 additions & 1 deletion litellm/llms/topaz/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ class TopazException(BaseLLMException):


class TopazModelInfo(BaseLLMModelInfo):
def get_models(self) -> List[str]:
def get_models(
self, api_key: Optional[str] = None, api_base: Optional[str] = None
) -> List[str]:
return [
"topaz/Standard V2",
"topaz/Low Resolution V2",
Expand Down
51 changes: 51 additions & 0 deletions litellm/llms/xai/common_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Optional

import httpx

import litellm
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
from litellm.secret_managers.main import get_secret_str


class XAIModelInfo(BaseLLMModelInfo):
@staticmethod
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
return api_base or get_secret_str("XAI_API_BASE") or "https://api.x.ai"

@staticmethod
def get_api_key(api_key: Optional[str] = None) -> Optional[str]:
return api_key or get_secret_str("XAI_API_KEY")

@staticmethod
def get_base_model(model: str) -> Optional[str]:
return model.replace("xai/", "")

def get_models(
self, api_key: Optional[str] = None, api_base: Optional[str] = None
) -> list[str]:
api_base = self.get_api_base(api_base)
api_key = self.get_api_key(api_key)
if api_base is None or api_key is None:
raise ValueError(
"XAI_API_BASE or XAI_API_KEY is not set. Please set the environment variable, to query XAI's `/models` endpoint."
)
response = litellm.module_level_client.get(
url=f"{api_base}/v1/models",
headers={"Authorization": f"Bearer {api_key}"},
)

try:
response.raise_for_status()
except httpx.HTTPStatusError:
raise Exception(
f"Failed to fetch models from XAI. Status code: {response.status_code}, Response: {response.text}"
)

models = response.json()["data"]

litellm_model_names = []
for model in models:
stripped_model_name = model["id"]
litellm_model_name = "xai/" + stripped_model_name
litellm_model_names.append(litellm_model_name)
return litellm_model_names
24 changes: 21 additions & 3 deletions litellm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5744,13 +5744,15 @@ def trim_messages(
return messages


def get_valid_models(check_provider_endpoint: bool = False) -> List[str]:
def get_valid_models(
check_provider_endpoint: bool = False, custom_llm_provider: Optional[str] = None
) -> List[str]:
"""
Returns a list of valid LLMs based on the set environment variables

Args:
check_provider_endpoint: If True, will check the provider's endpoint for valid models.

custom_llm_provider: If provided, will only check the provider's endpoint for valid models.
Returns:
A list of valid LLMs
"""
Expand All @@ -5762,6 +5764,9 @@ def get_valid_models(check_provider_endpoint: bool = False) -> List[str]:
valid_models = []

for provider in litellm.provider_list:
if custom_llm_provider and provider != custom_llm_provider:
continue

# edge case litellm has together_ai as a provider, it should be togetherai
env_provider_1 = provider.replace("_", "")
env_provider_2 = provider
Expand All @@ -5783,10 +5788,17 @@ def get_valid_models(check_provider_endpoint: bool = False) -> List[str]:
provider=LlmProviders(provider),
)

if custom_llm_provider and provider != custom_llm_provider:
continue

if provider == "azure":
valid_models.append("Azure-LLM")
elif provider_config is not None and check_provider_endpoint:
valid_models.extend(provider_config.get_models())
try:
models = provider_config.get_models()
valid_models.extend(models)
except Exception as e:
verbose_logger.debug(f"Error getting valid models: {e}")
else:
models_for_provider = litellm.models_by_provider.get(provider, [])
valid_models.extend(models_for_provider)
Expand Down Expand Up @@ -6400,10 +6412,16 @@ def get_provider_model_info(
return litellm.FireworksAIConfig()
elif LlmProviders.OPENAI == provider:
return litellm.OpenAIGPTConfig()
elif LlmProviders.GEMINI == provider:
return litellm.GeminiModelInfo()
elif LlmProviders.LITELLM_PROXY == provider:
return litellm.LiteLLMProxyChatConfig()
elif LlmProviders.TOPAZ == provider:
return litellm.TopazModelInfo()
elif LlmProviders.ANTHROPIC == provider:
return litellm.AnthropicModelInfo()
elif LlmProviders.XAI == provider:
return litellm.XAIModelInfo()

return None

Expand Down
18 changes: 18 additions & 0 deletions tests/litellm_utils_tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,24 @@ def test_aget_valid_models():
os.environ = old_environ


@pytest.mark.parametrize("custom_llm_provider", ["gemini", "anthropic", "xai"])
def test_get_valid_models_with_custom_llm_provider(custom_llm_provider):
from litellm.utils import ProviderConfigManager
from litellm.types.utils import LlmProviders

provider_config = ProviderConfigManager.get_provider_model_info(
model=None,
provider=LlmProviders(custom_llm_provider),
)
assert provider_config is not None
valid_models = get_valid_models(
check_provider_endpoint=True, custom_llm_provider=custom_llm_provider
)
print(valid_models)
assert len(valid_models) > 0
assert provider_config.get_models() == valid_models


# test_get_valid_models()


Expand Down
Loading