Skip to content

Litellm stable dev #5711

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Sep 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion litellm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,7 +914,7 @@ class LlmProviders(str, Enum):
from .llms.ollama import OllamaConfig
from .llms.ollama_chat import OllamaChatConfig
from .llms.maritalk import MaritTalkConfig
from .llms.bedrock.chat import (
from .llms.bedrock.chat.invoke_handler import (
AmazonCohereChatConfig,
AmazonConverseConfig,
BEDROCK_CONVERSE_MODELS,
Expand Down
99 changes: 95 additions & 4 deletions litellm/llms/base_aws_llm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import hashlib
import json
from typing import List, Optional
import os
from typing import Dict, List, Optional, Tuple

import httpx

Expand Down Expand Up @@ -28,6 +30,14 @@ def __init__(self) -> None:
self.iam_cache = DualCache()
super().__init__()

def get_cache_key(self, credential_args: Dict[str, Optional[str]]) -> str:
"""
Generate a unique cache key based on the credential arguments.
"""
# Convert credential arguments to a JSON string and hash it to create a unique key
credential_str = json.dumps(credential_args, sort_keys=True)
return hashlib.sha256(credential_str.encode()).hexdigest()

def get_credentials(
self,
aws_access_key_id: Optional[str] = None,
Expand All @@ -43,9 +53,22 @@ def get_credentials(
"""
Return a boto3.Credentials object
"""

import boto3
from botocore.credentials import Credentials

## CHECK IS 'os.environ/' passed in
param_names = [
"aws_access_key_id",
"aws_secret_access_key",
"aws_session_token",
"aws_region_name",
"aws_session_name",
"aws_profile_name",
"aws_role_name",
"aws_web_identity_token",
"aws_sts_endpoint",
]
params_to_check: List[Optional[str]] = [
aws_access_key_id,
aws_secret_access_key,
Expand All @@ -64,6 +87,11 @@ def get_credentials(
_v = get_secret(param)
if _v is not None and isinstance(_v, str):
params_to_check[i] = _v
elif param is None: # check if uppercase value in env
key = param_names[i]
if key.upper() in os.environ:
params_to_check[i] = os.getenv(key)

# Assign updated values back to parameters
(
aws_access_key_id,
Expand All @@ -77,6 +105,10 @@ def get_credentials(
aws_sts_endpoint,
) = params_to_check

# create cache key for non-expiring auth flows
args = {k: v for k, v in locals().items() if k.startswith("aws_")}
cache_key = self.get_cache_key(args)

verbose_logger.debug(
"in get credentials\n"
"aws_access_key_id=%s\n"
Expand Down Expand Up @@ -186,7 +218,6 @@ def get_credentials(

# Extract the credentials from the response and convert to Session Credentials
sts_credentials = sts_response["Credentials"]
from botocore.credentials import Credentials

credentials = Credentials(
access_key=sts_credentials["AccessKeyId"],
Expand All @@ -211,12 +242,72 @@ def get_credentials(
secret_key=aws_secret_access_key,
token=aws_session_token,
)

return credentials
else:
elif (
aws_access_key_id is not None
and aws_secret_access_key is not None
and aws_region_name is not None
):
# Check if credentials are already in cache. These credentials have no expiry time.
cached_credentials: Optional[Credentials] = self.iam_cache.get_cache(
cache_key
)
if cached_credentials:
return cached_credentials

session = boto3.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
region_name=aws_region_name,
)

return session.get_credentials()
credentials = session.get_credentials()

if (
credentials.token is None
): # don't cache if session token exists. The expiry time for that is not known.
self.iam_cache.set_cache(cache_key, credentials, ttl=3600 - 60)

return credentials
else:
# check env var. Do not cache the response from this.
session = boto3.Session()

credentials = session.get_credentials()

return credentials

def get_runtime_endpoint(
self,
api_base: Optional[str],
aws_bedrock_runtime_endpoint: Optional[str],
aws_region_name: str,
) -> Tuple[str, str]:
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
if api_base is not None:
endpoint_url = api_base
elif aws_bedrock_runtime_endpoint is not None and isinstance(
aws_bedrock_runtime_endpoint, str
):
endpoint_url = aws_bedrock_runtime_endpoint
elif env_aws_bedrock_runtime_endpoint and isinstance(
env_aws_bedrock_runtime_endpoint, str
):
endpoint_url = env_aws_bedrock_runtime_endpoint
else:
endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"

# Determine proxy_endpoint_url
if env_aws_bedrock_runtime_endpoint and isinstance(
env_aws_bedrock_runtime_endpoint, str
):
proxy_endpoint_url = env_aws_bedrock_runtime_endpoint
elif aws_bedrock_runtime_endpoint is not None and isinstance(
aws_bedrock_runtime_endpoint, str
):
proxy_endpoint_url = aws_bedrock_runtime_endpoint
else:
proxy_endpoint_url = endpoint_url

return endpoint_url, proxy_endpoint_url
2 changes: 2 additions & 0 deletions litellm/llms/bedrock/chat/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .converse_handler import BedrockConverseLLM
from .invoke_handler import BedrockLLM
Loading