diff --git a/litellm/__init__.py b/litellm/__init__.py index 6439af29ea81..8c11ff1c4ec5 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -914,7 +914,7 @@ class LlmProviders(str, Enum): from .llms.ollama import OllamaConfig from .llms.ollama_chat import OllamaChatConfig from .llms.maritalk import MaritTalkConfig -from .llms.bedrock.chat import ( +from .llms.bedrock.chat.invoke_handler import ( AmazonCohereChatConfig, AmazonConverseConfig, BEDROCK_CONVERSE_MODELS, diff --git a/litellm/llms/base_aws_llm.py b/litellm/llms/base_aws_llm.py index 70f333eb6a0e..348a84180059 100644 --- a/litellm/llms/base_aws_llm.py +++ b/litellm/llms/base_aws_llm.py @@ -1,5 +1,7 @@ +import hashlib import json -from typing import List, Optional +import os +from typing import Dict, List, Optional, Tuple import httpx @@ -28,6 +30,14 @@ def __init__(self) -> None: self.iam_cache = DualCache() super().__init__() + def get_cache_key(self, credential_args: Dict[str, Optional[str]]) -> str: + """ + Generate a unique cache key based on the credential arguments. + """ + # Convert credential arguments to a JSON string and hash it to create a unique key + credential_str = json.dumps(credential_args, sort_keys=True) + return hashlib.sha256(credential_str.encode()).hexdigest() + def get_credentials( self, aws_access_key_id: Optional[str] = None, @@ -43,9 +53,22 @@ def get_credentials( """ Return a boto3.Credentials object """ + import boto3 + from botocore.credentials import Credentials ## CHECK IS 'os.environ/' passed in + param_names = [ + "aws_access_key_id", + "aws_secret_access_key", + "aws_session_token", + "aws_region_name", + "aws_session_name", + "aws_profile_name", + "aws_role_name", + "aws_web_identity_token", + "aws_sts_endpoint", + ] params_to_check: List[Optional[str]] = [ aws_access_key_id, aws_secret_access_key, @@ -64,6 +87,11 @@ def get_credentials( _v = get_secret(param) if _v is not None and isinstance(_v, str): params_to_check[i] = _v + elif param is None: # check if uppercase value in env + key = param_names[i] + if key.upper() in os.environ: + params_to_check[i] = os.getenv(key) + # Assign updated values back to parameters ( aws_access_key_id, @@ -77,6 +105,10 @@ def get_credentials( aws_sts_endpoint, ) = params_to_check + # create cache key for non-expiring auth flows + args = {k: v for k, v in locals().items() if k.startswith("aws_")} + cache_key = self.get_cache_key(args) + verbose_logger.debug( "in get credentials\n" "aws_access_key_id=%s\n" @@ -186,7 +218,6 @@ def get_credentials( # Extract the credentials from the response and convert to Session Credentials sts_credentials = sts_response["Credentials"] - from botocore.credentials import Credentials credentials = Credentials( access_key=sts_credentials["AccessKeyId"], @@ -211,12 +242,72 @@ def get_credentials( secret_key=aws_secret_access_key, token=aws_session_token, ) + return credentials - else: + elif ( + aws_access_key_id is not None + and aws_secret_access_key is not None + and aws_region_name is not None + ): + # Check if credentials are already in cache. These credentials have no expiry time. + cached_credentials: Optional[Credentials] = self.iam_cache.get_cache( + cache_key + ) + if cached_credentials: + return cached_credentials + session = boto3.Session( aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, region_name=aws_region_name, ) - return session.get_credentials() + credentials = session.get_credentials() + + if ( + credentials.token is None + ): # don't cache if session token exists. The expiry time for that is not known. + self.iam_cache.set_cache(cache_key, credentials, ttl=3600 - 60) + + return credentials + else: + # check env var. Do not cache the response from this. + session = boto3.Session() + + credentials = session.get_credentials() + + return credentials + + def get_runtime_endpoint( + self, + api_base: Optional[str], + aws_bedrock_runtime_endpoint: Optional[str], + aws_region_name: str, + ) -> Tuple[str, str]: + env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT") + if api_base is not None: + endpoint_url = api_base + elif aws_bedrock_runtime_endpoint is not None and isinstance( + aws_bedrock_runtime_endpoint, str + ): + endpoint_url = aws_bedrock_runtime_endpoint + elif env_aws_bedrock_runtime_endpoint and isinstance( + env_aws_bedrock_runtime_endpoint, str + ): + endpoint_url = env_aws_bedrock_runtime_endpoint + else: + endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com" + + # Determine proxy_endpoint_url + if env_aws_bedrock_runtime_endpoint and isinstance( + env_aws_bedrock_runtime_endpoint, str + ): + proxy_endpoint_url = env_aws_bedrock_runtime_endpoint + elif aws_bedrock_runtime_endpoint is not None and isinstance( + aws_bedrock_runtime_endpoint, str + ): + proxy_endpoint_url = aws_bedrock_runtime_endpoint + else: + proxy_endpoint_url = endpoint_url + + return endpoint_url, proxy_endpoint_url diff --git a/litellm/llms/bedrock/chat/__init__.py b/litellm/llms/bedrock/chat/__init__.py new file mode 100644 index 000000000000..c3f6aef6d238 --- /dev/null +++ b/litellm/llms/bedrock/chat/__init__.py @@ -0,0 +1,2 @@ +from .converse_handler import BedrockConverseLLM +from .invoke_handler import BedrockLLM diff --git a/litellm/llms/bedrock/chat/converse_handler.py b/litellm/llms/bedrock/chat/converse_handler.py new file mode 100644 index 000000000000..caf6113ac4a1 --- /dev/null +++ b/litellm/llms/bedrock/chat/converse_handler.py @@ -0,0 +1,408 @@ +import json +import urllib +from typing import Any, Callable, Optional, Union + +import httpx + +import litellm +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + HTTPHandler, + _get_httpx_client, + get_async_httpx_client, +) +from litellm.types.utils import ModelResponse +from litellm.utils import CustomStreamWrapper, get_secret + +from ...base_aws_llm import BaseAWSLLM +from ..common_utils import BedrockError +from .invoke_handler import AWSEventStreamDecoder, MockResponseIterator, make_call + + +def make_sync_call( + client: Optional[HTTPHandler], + api_base: str, + headers: dict, + data: str, + model: str, + messages: list, + logging_obj, +): + if client is None: + client = _get_httpx_client() # Create a new client if none provided + + response = client.post( + api_base, + headers=headers, + data=data, + stream=True if "ai21" not in api_base else False, + ) + + if response.status_code != 200: + raise BedrockError(status_code=response.status_code, message=response.read()) + + if "ai21" in api_base: + model_response: ( + ModelResponse + ) = litellm.AmazonConverseConfig()._transform_response( + model=model, + response=response, + model_response=litellm.ModelResponse(), + stream=True, + logging_obj=logging_obj, + optional_params={}, + api_key="", + data=data, + messages=messages, + print_verbose=litellm.print_verbose, + encoding=litellm.encoding, + ) # type: ignore + completion_stream: Any = MockResponseIterator(model_response=model_response) + else: + decoder = AWSEventStreamDecoder(model=model) + completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024)) + + # LOGGING + logging_obj.post_call( + input=messages, + api_key="", + original_response="first stream response received", + additional_args={"complete_input_dict": data}, + ) + + return completion_stream + + +class BedrockConverseLLM(BaseAWSLLM): + def __init__(self) -> None: + super().__init__() + + def encode_model_id(self, model_id: str) -> str: + """ + Double encode the model ID to ensure it matches the expected double-encoded format. + Args: + model_id (str): The model ID to encode. + Returns: + str: The double-encoded model ID. + """ + return urllib.parse.quote(model_id, safe="") # type: ignore + + async def async_streaming( + self, + model: str, + messages: list, + api_base: str, + model_response: ModelResponse, + print_verbose: Callable, + data: str, + timeout: Optional[Union[float, httpx.Timeout]], + encoding, + logging_obj, + stream, + optional_params: dict, + litellm_params=None, + logger_fn=None, + headers={}, + client: Optional[AsyncHTTPHandler] = None, + ) -> CustomStreamWrapper: + + completion_stream = await make_call( + client=client, + api_base=api_base, + headers=headers, + data=data, + model=model, + messages=messages, + logging_obj=logging_obj, + ) + streaming_response = CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider="bedrock", + logging_obj=logging_obj, + ) + return streaming_response + + async def async_completion( + self, + model: str, + messages: list, + api_base: str, + model_response: ModelResponse, + print_verbose: Callable, + data: str, + timeout: Optional[Union[float, httpx.Timeout]], + encoding, + logging_obj, + stream, + optional_params: dict, + litellm_params=None, + logger_fn=None, + headers={}, + client: Optional[AsyncHTTPHandler] = None, + ) -> Union[ModelResponse, CustomStreamWrapper]: + if client is None or not isinstance(client, AsyncHTTPHandler): + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + timeout = httpx.Timeout(timeout) + _params["timeout"] = timeout + client = get_async_httpx_client( + params=_params, llm_provider=litellm.LlmProviders.BEDROCK + ) + else: + client = client # type: ignore + + try: + response = await client.post(url=api_base, headers=headers, data=data) # type: ignore + response.raise_for_status() + except httpx.HTTPStatusError as err: + error_code = err.response.status_code + raise BedrockError(status_code=error_code, message=err.response.text) + except httpx.TimeoutException as e: + raise BedrockError(status_code=408, message="Timeout error occurred.") + + return litellm.AmazonConverseConfig()._transform_response( + model=model, + response=response, + model_response=model_response, + stream=stream if isinstance(stream, bool) else False, + logging_obj=logging_obj, + api_key="", + data=data, + messages=messages, + print_verbose=print_verbose, + optional_params=optional_params, + encoding=encoding, + ) + + def completion( + self, + model: str, + messages: list, + api_base: Optional[str], + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + logging_obj, + optional_params: dict, + acompletion: bool, + timeout: Optional[Union[float, httpx.Timeout]], + litellm_params: dict, + logger_fn=None, + extra_headers: Optional[dict] = None, + client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None, + ): + try: + import boto3 + from botocore.auth import SigV4Auth + from botocore.awsrequest import AWSRequest + from botocore.credentials import Credentials + except ImportError: + raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") + + ## SETUP ## + stream = optional_params.pop("stream", None) + modelId = optional_params.pop("model_id", None) + if modelId is not None: + modelId = self.encode_model_id(model_id=modelId) + else: + modelId = model + + provider = model.split(".")[0] + + ## CREDENTIALS ## + # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them + aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) + aws_access_key_id = optional_params.pop("aws_access_key_id", None) + aws_session_token = optional_params.pop("aws_session_token", None) + aws_region_name = optional_params.pop("aws_region_name", None) + aws_role_name = optional_params.pop("aws_role_name", None) + aws_session_name = optional_params.pop("aws_session_name", None) + aws_profile_name = optional_params.pop("aws_profile_name", None) + aws_bedrock_runtime_endpoint = optional_params.pop( + "aws_bedrock_runtime_endpoint", None + ) # https://bedrock-runtime.{region_name}.amazonaws.com + aws_web_identity_token = optional_params.pop("aws_web_identity_token", None) + aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None) + + ### SET REGION NAME ### + if aws_region_name is None: + # check env # + litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) + + if litellm_aws_region_name is not None and isinstance( + litellm_aws_region_name, str + ): + aws_region_name = litellm_aws_region_name + + standard_aws_region_name = get_secret("AWS_REGION", None) + if standard_aws_region_name is not None and isinstance( + standard_aws_region_name, str + ): + aws_region_name = standard_aws_region_name + + if aws_region_name is None: + aws_region_name = "us-west-2" + + credentials: Credentials = self.get_credentials( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + aws_region_name=aws_region_name, + aws_session_name=aws_session_name, + aws_profile_name=aws_profile_name, + aws_role_name=aws_role_name, + aws_web_identity_token=aws_web_identity_token, + aws_sts_endpoint=aws_sts_endpoint, + ) + + ### SET RUNTIME ENDPOINT ### + endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint( + api_base=api_base, + aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, + aws_region_name=aws_region_name, + ) + if (stream is not None and stream is True) and provider != "ai21": + endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream" + proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse-stream" + else: + endpoint_url = f"{endpoint_url}/model/{modelId}/converse" + proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse" + + sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name) + + ## TRANSFORMATION ## + + _data = litellm.AmazonConverseConfig()._transform_request( + model=model, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + ) + data = json.dumps(_data) + ## COMPLETION CALL + + headers = {"Content-Type": "application/json"} + if extra_headers is not None: + headers = {"Content-Type": "application/json", **extra_headers} + request = AWSRequest( + method="POST", url=endpoint_url, data=data, headers=headers + ) + sigv4.add_auth(request) + if ( + extra_headers is not None and "Authorization" in extra_headers + ): # prevent sigv4 from overwriting the auth header + request.headers["Authorization"] = extra_headers["Authorization"] + prepped = request.prepare() + + ## LOGGING + logging_obj.pre_call( + input=messages, + api_key="", + additional_args={ + "complete_input_dict": data, + "api_base": proxy_endpoint_url, + "headers": prepped.headers, + }, + ) + + ### ROUTING (ASYNC, STREAMING, SYNC) + if acompletion: + if isinstance(client, HTTPHandler): + client = None + if stream is True: + return self.async_streaming( + model=model, + messages=messages, + data=data, + api_base=proxy_endpoint_url, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + logging_obj=logging_obj, + optional_params=optional_params, + stream=True, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=prepped.headers, + timeout=timeout, + client=client, + ) # type: ignore + ### ASYNC COMPLETION + return self.async_completion( + model=model, + messages=messages, + data=data, + api_base=proxy_endpoint_url, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + logging_obj=logging_obj, + optional_params=optional_params, + stream=stream, # type: ignore + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=prepped.headers, + timeout=timeout, + client=client, + ) # type: ignore + + if client is None or isinstance(client, AsyncHTTPHandler): + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + timeout = httpx.Timeout(timeout) + _params["timeout"] = timeout + client = _get_httpx_client(_params) # type: ignore + else: + client = client + + if stream is not None and stream is True: + completion_stream = make_sync_call( + client=( + client + if client is not None and isinstance(client, HTTPHandler) + else None + ), + api_base=proxy_endpoint_url, + headers=prepped.headers, # type: ignore + data=data, + model=model, + messages=messages, + logging_obj=logging_obj, + ) + streaming_response = CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider="bedrock", + logging_obj=logging_obj, + ) + + return streaming_response + + ### COMPLETION + + try: + response = client.post(url=proxy_endpoint_url, headers=prepped.headers, data=data) # type: ignore + response.raise_for_status() + except httpx.HTTPStatusError as err: + error_code = err.response.status_code + raise BedrockError(status_code=error_code, message=err.response.text) + except httpx.TimeoutException: + raise BedrockError(status_code=408, message="Timeout error occurred.") + + return litellm.AmazonConverseConfig()._transform_response( + model=model, + response=response, + model_response=model_response, + stream=stream if isinstance(stream, bool) else False, + logging_obj=logging_obj, + api_key="", + data=data, + messages=messages, + print_verbose=print_verbose, + optional_params=optional_params, + encoding=encoding, + ) diff --git a/litellm/llms/bedrock/chat/converse_transformation.py b/litellm/llms/bedrock/chat/converse_transformation.py new file mode 100644 index 000000000000..5f6141cb4e62 --- /dev/null +++ b/litellm/llms/bedrock/chat/converse_transformation.py @@ -0,0 +1,431 @@ +""" +Translating between OpenAI's `/chat/completion` format and Amazon's `/converse` format +""" + +import copy +import time +import types +from typing import List, Optional, Union + +import httpx + +import litellm +from litellm.litellm_core_utils.core_helpers import map_finish_reason +from litellm.litellm_core_utils.litellm_logging import Logging +from litellm.types.llms.bedrock import * +from litellm.types.llms.openai import ( + AllMessageValues, + ChatCompletionResponseMessage, + ChatCompletionToolCallChunk, + ChatCompletionToolCallFunctionChunk, + ChatCompletionToolParam, + ChatCompletionToolParamFunctionChunk, +) +from litellm.types.utils import ModelResponse, Usage +from litellm.utils import CustomStreamWrapper + +from ...prompt_templates.factory import _bedrock_converse_messages_pt, _bedrock_tools_pt +from ..common_utils import BedrockError, get_bedrock_tool_name + + +class AmazonConverseConfig: + """ + Reference - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html + #2 - https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#conversation-inference-supported-models-features + """ + + maxTokens: Optional[int] + stopSequences: Optional[List[str]] + temperature: Optional[int] + topP: Optional[int] + + def __init__( + self, + maxTokens: Optional[int] = None, + stopSequences: Optional[List[str]] = None, + temperature: Optional[int] = None, + topP: Optional[int] = None, + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def get_supported_openai_params(self, model: str) -> List[str]: + supported_params = [ + "max_tokens", + "stream", + "stream_options", + "stop", + "temperature", + "top_p", + "extra_headers", + "response_format", + ] + + if ( + model.startswith("anthropic") + or model.startswith("mistral") + or model.startswith("cohere") + or model.startswith("meta.llama3-1") + ): + supported_params.append("tools") + + if model.startswith("anthropic") or model.startswith("mistral"): + # only anthropic and mistral support tool choice config. otherwise (E.g. cohere) will fail the call - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html + supported_params.append("tool_choice") + + return supported_params + + def map_tool_choice_values( + self, model: str, tool_choice: Union[str, dict], drop_params: bool + ) -> Optional[ToolChoiceValuesBlock]: + if tool_choice == "none": + if litellm.drop_params is True or drop_params is True: + return None + else: + raise litellm.utils.UnsupportedParamsError( + message="Bedrock doesn't support tool_choice={}. To drop it from the call, set `litellm.drop_params = True.".format( + tool_choice + ), + status_code=400, + ) + elif tool_choice == "required": + return ToolChoiceValuesBlock(any={}) + elif tool_choice == "auto": + return ToolChoiceValuesBlock(auto={}) + elif isinstance(tool_choice, dict): + # only supported for anthropic + mistral models - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html + specific_tool = SpecificToolChoiceBlock( + name=tool_choice.get("function", {}).get("name", "") + ) + return ToolChoiceValuesBlock(tool=specific_tool) + else: + raise litellm.utils.UnsupportedParamsError( + message="Bedrock doesn't support tool_choice={}. Supported tool_choice values=['auto', 'required', json object]. To drop it from the call, set `litellm.drop_params = True.".format( + tool_choice + ), + status_code=400, + ) + + def get_supported_image_types(self) -> List[str]: + return ["png", "jpeg", "gif", "webp"] + + def map_openai_params( + self, + model: str, + non_default_params: dict, + optional_params: dict, + drop_params: bool, + ) -> dict: + for param, value in non_default_params.items(): + if param == "response_format": + json_schema: Optional[dict] = None + schema_name: str = "" + if "response_schema" in value: + json_schema = value["response_schema"] + schema_name = "json_tool_call" + elif "json_schema" in value: + json_schema = value["json_schema"]["schema"] + schema_name = value["json_schema"]["name"] + """ + Follow similar approach to anthropic - translate to a single tool call. + + When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode + - You usually want to provide a single tool + - You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool + - Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective. + """ + if json_schema is not None: + _tool_choice = self.map_tool_choice_values( + model=model, tool_choice="required", drop_params=drop_params # type: ignore + ) + + _tool = ChatCompletionToolParam( + type="function", + function=ChatCompletionToolParamFunctionChunk( + name=schema_name, parameters=json_schema + ), + ) + + optional_params["tools"] = [_tool] + optional_params["tool_choice"] = _tool_choice + optional_params["json_mode"] = True + else: + if litellm.drop_params is True or drop_params is True: + pass + else: + raise litellm.utils.UnsupportedParamsError( + message="Bedrock doesn't support response_format={}. To drop it from the call, set `litellm.drop_params = True.".format( + value + ), + status_code=400, + ) + if param == "max_tokens": + optional_params["maxTokens"] = value + if param == "stream": + optional_params["stream"] = value + if param == "stop": + if isinstance(value, str): + if len(value) == 0: # converse raises error for empty strings + continue + value = [value] + optional_params["stopSequences"] = value + if param == "temperature": + optional_params["temperature"] = value + if param == "top_p": + optional_params["topP"] = value + if param == "tools": + optional_params["tools"] = value + if param == "tool_choice": + _tool_choice_value = self.map_tool_choice_values( + model=model, tool_choice=value, drop_params=drop_params # type: ignore + ) + if _tool_choice_value is not None: + optional_params["tool_choice"] = _tool_choice_value + return optional_params + + def _transform_request( + self, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + ) -> RequestObject: + system_prompt_indices = [] + system_content_blocks: List[SystemContentBlock] = [] + for idx, message in enumerate(messages): + if message["role"] == "system": + _system_content_block: Optional[SystemContentBlock] = None + if isinstance(message["content"], str) and len(message["content"]) > 0: + _system_content_block = SystemContentBlock(text=message["content"]) + elif isinstance(message["content"], list): + for m in message["content"]: + if m.get("type", "") == "text" and len(m["text"]) > 0: + _system_content_block = SystemContentBlock(text=m["text"]) + if _system_content_block is not None: + system_content_blocks.append(_system_content_block) + system_prompt_indices.append(idx) + if len(system_prompt_indices) > 0: + for idx in reversed(system_prompt_indices): + messages.pop(idx) + + inference_params = copy.deepcopy(optional_params) + additional_request_keys = [] + additional_request_params = {} + supported_converse_params = AmazonConverseConfig.__annotations__.keys() + supported_tool_call_params = ["tools", "tool_choice"] + supported_guardrail_params = ["guardrailConfig"] + json_mode: Optional[bool] = inference_params.pop( + "json_mode", None + ) # used for handling json_schema + ## TRANSFORMATION ## + + bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt( + messages=messages, + model=model, + llm_provider="bedrock_converse", + user_continue_message=litellm_params.pop("user_continue_message", None), + ) + + # send all model-specific params in 'additional_request_params' + for k, v in inference_params.items(): + if ( + k not in supported_converse_params + and k not in supported_tool_call_params + and k not in supported_guardrail_params + ): + additional_request_params[k] = v + additional_request_keys.append(k) + for key in additional_request_keys: + inference_params.pop(key, None) + + bedrock_tools: List[ToolBlock] = _bedrock_tools_pt( + inference_params.pop("tools", []) + ) + bedrock_tool_config: Optional[ToolConfigBlock] = None + if len(bedrock_tools) > 0: + tool_choice_values: ToolChoiceValuesBlock = inference_params.pop( + "tool_choice", None + ) + bedrock_tool_config = ToolConfigBlock( + tools=bedrock_tools, + ) + if tool_choice_values is not None: + bedrock_tool_config["toolChoice"] = tool_choice_values + + _data: RequestObject = { + "messages": bedrock_messages, + "additionalModelRequestFields": additional_request_params, + "system": system_content_blocks, + "inferenceConfig": InferenceConfig(**inference_params), + } + + # Guardrail Config + guardrail_config: Optional[GuardrailConfigBlock] = None + request_guardrails_config = inference_params.pop("guardrailConfig", None) + if request_guardrails_config is not None: + guardrail_config = GuardrailConfigBlock(**request_guardrails_config) + _data["guardrailConfig"] = guardrail_config + + # Tool Config + if bedrock_tool_config is not None: + _data["toolConfig"] = bedrock_tool_config + + return _data + + def _transform_response( + self, + model: str, + response: httpx.Response, + model_response: ModelResponse, + stream: bool, + logging_obj: Optional[Logging], + optional_params: dict, + api_key: str, + data: Union[dict, str], + messages: List, + print_verbose, + encoding, + ) -> Union[ModelResponse, CustomStreamWrapper]: + + ## LOGGING + if logging_obj is not None: + logging_obj.post_call( + input=messages, + api_key=api_key, + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) + print_verbose(f"raw model_response: {response.text}") + json_mode: Optional[bool] = optional_params.pop("json_mode", None) + ## RESPONSE OBJECT + try: + completion_response = ConverseResponseBlock(**response.json()) # type: ignore + except Exception as e: + raise BedrockError( + message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format( + response.text, str(e) + ), + status_code=422, + ) + + """ + Bedrock Response Object has optional message block + + completion_response["output"].get("message", None) + + A message block looks like this (Example 1): + "output": { + "message": { + "role": "assistant", + "content": [ + { + "text": "Is there anything else you'd like to talk about? Perhaps I can help with some economic questions or provide some information about economic concepts?" + } + ] + } + }, + (Example 2): + "output": { + "message": { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "tooluse_hbTgdi0CSLq_hM4P8csZJA", + "name": "top_song", + "input": { + "sign": "WZPZ" + } + } + } + ] + } + } + + """ + message: Optional[MessageBlock] = completion_response["output"]["message"] + chat_completion_message: ChatCompletionResponseMessage = {"role": "assistant"} + content_str = "" + tools: List[ChatCompletionToolCallChunk] = [] + if message is not None: + for idx, content in enumerate(message["content"]): + """ + - Content is either a tool response or text + """ + if "text" in content: + content_str += content["text"] + if "toolUse" in content: + + ## check tool name was formatted by litellm + _response_tool_name = content["toolUse"]["name"] + response_tool_name = get_bedrock_tool_name( + response_tool_name=_response_tool_name + ) + _function_chunk = ChatCompletionToolCallFunctionChunk( + name=response_tool_name, + arguments=json.dumps(content["toolUse"]["input"]), + ) + + _tool_response_chunk = ChatCompletionToolCallChunk( + id=content["toolUse"]["toolUseId"], + type="function", + function=_function_chunk, + index=idx, + ) + tools.append(_tool_response_chunk) + chat_completion_message["content"] = content_str + + if json_mode is True and tools is not None and len(tools) == 1: + # to support 'json_schema' logic on bedrock models + json_mode_content_str: Optional[str] = tools[0]["function"].get("arguments") + if json_mode_content_str is not None: + chat_completion_message["content"] = json_mode_content_str + else: + chat_completion_message["tool_calls"] = tools + + ## CALCULATING USAGE - bedrock returns usage in the headers + input_tokens = completion_response["usage"]["inputTokens"] + output_tokens = completion_response["usage"]["outputTokens"] + total_tokens = completion_response["usage"]["totalTokens"] + + model_response.choices = [ + litellm.Choices( + finish_reason=map_finish_reason(completion_response["stopReason"]), + index=0, + message=litellm.Message(**chat_completion_message), + ) + ] + model_response.created = int(time.time()) + model_response.model = model + usage = Usage( + prompt_tokens=input_tokens, + completion_tokens=output_tokens, + total_tokens=total_tokens, + ) + setattr(model_response, "usage", usage) + + # Add "trace" from Bedrock guardrails - if user has opted in to returning it + if "trace" in completion_response: + setattr(model_response, "trace", completion_response["trace"]) + + return model_response diff --git a/litellm/llms/bedrock/chat.py b/litellm/llms/bedrock/chat/invoke_handler.py similarity index 62% rename from litellm/llms/bedrock/chat.py rename to litellm/llms/bedrock/chat/invoke_handler.py index 35f0c794a35f..e40a40372fb0 100644 --- a/litellm/llms/bedrock/chat.py +++ b/litellm/llms/bedrock/chat/invoke_handler.py @@ -52,8 +52,8 @@ from litellm.types.utils import GenericStreamingChunk as GChunk from litellm.utils import CustomStreamWrapper, ModelResponse, Usage, get_secret -from ..base_aws_llm import BaseAWSLLM -from ..prompt_templates.factory import ( +from ...base_aws_llm import BaseAWSLLM +from ...prompt_templates.factory import ( _bedrock_converse_messages_pt, _bedrock_tools_pt, cohere_message_pt, @@ -64,7 +64,8 @@ parse_xml_params, prompt_factory, ) -from .common_utils import BedrockError, ModelResponseIterator, get_runtime_endpoint +from ..common_utils import BedrockError, ModelResponseIterator, get_bedrock_tool_name +from .converse_transformation import AmazonConverseConfig BEDROCK_CONVERSE_MODELS = [ "anthropic.claude-3-5-sonnet-20240620-v1:0", @@ -225,10 +226,9 @@ async def make_call( raise BedrockError(status_code=response.status_code, message=response.text) if "ai21" in api_base: - aws_bedrock_process_response = BedrockConverseLLM() model_response: ( ModelResponse - ) = aws_bedrock_process_response.process_response( + ) = litellm.AmazonConverseConfig()._transform_response( model=model, response=response, model_response=litellm.ModelResponse(), @@ -266,59 +266,6 @@ async def make_call( raise BedrockError(status_code=500, message=str(e)) -def make_sync_call( - client: Optional[HTTPHandler], - api_base: str, - headers: dict, - data: str, - model: str, - messages: list, - logging_obj, -): - if client is None: - client = _get_httpx_client() # Create a new client if none provided - - response = client.post( - api_base, - headers=headers, - data=data, - stream=True if "ai21" not in api_base else False, - ) - - if response.status_code != 200: - raise BedrockError(status_code=response.status_code, message=response.read()) - - if "ai21" in api_base: - aws_bedrock_process_response = BedrockConverseLLM() - model_response: ModelResponse = aws_bedrock_process_response.process_response( - model=model, - response=response, - model_response=litellm.ModelResponse(), - stream=True, - logging_obj=logging_obj, - optional_params={}, - api_key="", - data=data, - messages=messages, - print_verbose=litellm.print_verbose, - encoding=litellm.encoding, - ) # type: ignore - completion_stream: Any = MockResponseIterator(model_response=model_response) - else: - decoder = AWSEventStreamDecoder(model=model) - completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024)) - - # LOGGING - logging_obj.post_call( - input=messages, - api_key="", - original_response="first stream response received", - additional_args={"complete_input_dict": data}, - ) - - return completion_stream - - class BedrockLLM(BaseAWSLLM): """ Example call @@ -417,6 +364,7 @@ def process_response( except: raise BedrockError(message=response.text, status_code=422) + outputText: Optional[str] = None try: if provider == "cohere": if "text" in completion_response: @@ -566,23 +514,27 @@ def process_response( try: if ( - len(outputText) > 0 + outputText is not None + and len(outputText) > 0 and hasattr(model_response.choices[0], "message") - and getattr(model_response.choices[0].message, "tool_calls", None) + and getattr(model_response.choices[0].message, "tool_calls", None) # type: ignore is None ): - model_response.choices[0].message.content = outputText + model_response.choices[0].message.content = outputText # type: ignore elif ( hasattr(model_response.choices[0], "message") - and getattr(model_response.choices[0].message, "tool_calls", None) + and getattr(model_response.choices[0].message, "tool_calls", None) # type: ignore is not None ): pass else: raise Exception() - except: + except Exception as e: raise BedrockError( - message=json.dumps(outputText), status_code=response.status_code + message="Error parsing received text={}.\nError-{}".format( + outputText, str(e) + ), + status_code=response.status_code, ) if stream and provider == "ai21": @@ -594,8 +546,8 @@ def process_response( streaming_choice = litellm.utils.StreamingChoices() streaming_choice.index = model_response.choices[0].index delta_obj = litellm.utils.Delta( - content=getattr(model_response.choices[0].message, "content", None), - role=model_response.choices[0].message.role, + content=getattr(model_response.choices[0].message, "content", None), # type: ignore + role=model_response.choices[0].message.role, # type: ignore ) streaming_choice.delta = delta_obj streaming_model_response.choices = [streaming_choice] @@ -731,7 +683,7 @@ def completion( ) ### SET RUNTIME ENDPOINT ### - endpoint_url, proxy_endpoint_url = get_runtime_endpoint( + endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint( api_base=api_base, aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, aws_region_name=aws_region_name, @@ -1002,7 +954,7 @@ def completion( response.raise_for_status() except httpx.HTTPStatusError as err: error_code = err.response.status_code - raise BedrockError(status_code=error_code, message=response.text) + raise BedrockError(status_code=error_code, message=err.response.text) except httpx.TimeoutException as e: raise BedrockError(status_code=408, message="Timeout error occurred.") @@ -1113,725 +1065,6 @@ def embedding(self, *args, **kwargs): return super().embedding(*args, **kwargs) -class AmazonConverseConfig: - """ - Reference - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html - #2 - https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#conversation-inference-supported-models-features - """ - - maxTokens: Optional[int] - stopSequences: Optional[List[str]] - temperature: Optional[int] - topP: Optional[int] - - def __init__( - self, - maxTokens: Optional[int] = None, - stopSequences: Optional[List[str]] = None, - temperature: Optional[int] = None, - topP: Optional[int] = None, - ) -> None: - locals_ = locals() - for key, value in locals_.items(): - if key != "self" and value is not None: - setattr(self.__class__, key, value) - - @classmethod - def get_config(cls): - return { - k: v - for k, v in cls.__dict__.items() - if not k.startswith("__") - and not isinstance( - v, - ( - types.FunctionType, - types.BuiltinFunctionType, - classmethod, - staticmethod, - ), - ) - and v is not None - } - - def get_supported_openai_params(self, model: str) -> List[str]: - supported_params = [ - "max_tokens", - "max_completion_tokens", - "stream", - "stream_options", - "stop", - "temperature", - "top_p", - "extra_headers", - "response_format", - ] - - if ( - model.startswith("anthropic") - or model.startswith("mistral") - or model.startswith("cohere") - or model.startswith("meta.llama3-1") - ): - supported_params.append("tools") - - if model.startswith("anthropic") or model.startswith("mistral"): - # only anthropic and mistral support tool choice config. otherwise (E.g. cohere) will fail the call - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html - supported_params.append("tool_choice") - - return supported_params - - def map_tool_choice_values( - self, model: str, tool_choice: Union[str, dict], drop_params: bool - ) -> Optional[ToolChoiceValuesBlock]: - if tool_choice == "none": - if litellm.drop_params is True or drop_params is True: - return None - else: - raise litellm.utils.UnsupportedParamsError( - message="Bedrock doesn't support tool_choice={}. To drop it from the call, set `litellm.drop_params = True.".format( - tool_choice - ), - status_code=400, - ) - elif tool_choice == "required": - return ToolChoiceValuesBlock(any={}) - elif tool_choice == "auto": - return ToolChoiceValuesBlock(auto={}) - elif isinstance(tool_choice, dict): - # only supported for anthropic + mistral models - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html - specific_tool = SpecificToolChoiceBlock( - name=tool_choice.get("function", {}).get("name", "") - ) - return ToolChoiceValuesBlock(tool=specific_tool) - else: - raise litellm.utils.UnsupportedParamsError( - message="Bedrock doesn't support tool_choice={}. Supported tool_choice values=['auto', 'required', json object]. To drop it from the call, set `litellm.drop_params = True.".format( - tool_choice - ), - status_code=400, - ) - - def get_supported_image_types(self) -> List[str]: - return ["png", "jpeg", "gif", "webp"] - - def map_openai_params( - self, - model: str, - non_default_params: dict, - optional_params: dict, - drop_params: bool, - ) -> dict: - for param, value in non_default_params.items(): - if param == "response_format": - json_schema: Optional[dict] = None - schema_name: str = "" - if "response_schema" in value: - json_schema = value["response_schema"] - schema_name = "json_tool_call" - elif "json_schema" in value: - json_schema = value["json_schema"]["schema"] - schema_name = value["json_schema"]["name"] - """ - Follow similar approach to anthropic - translate to a single tool call. - - When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode - - You usually want to provide a single tool - - You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool - - Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective. - """ - if json_schema is not None: - _tool_choice = self.map_tool_choice_values( - model=model, tool_choice="required", drop_params=drop_params # type: ignore - ) - - _tool = ChatCompletionToolParam( - type="function", - function=ChatCompletionToolParamFunctionChunk( - name=schema_name, parameters=json_schema - ), - ) - - optional_params["tools"] = [_tool] - optional_params["tool_choice"] = _tool_choice - optional_params["json_mode"] = True - else: - if litellm.drop_params is True or drop_params is True: - pass - else: - raise litellm.utils.UnsupportedParamsError( - message="Bedrock doesn't support response_format={}. To drop it from the call, set `litellm.drop_params = True.".format( - value - ), - status_code=400, - ) - if param == "max_tokens" or param == "max_completion_tokens": - optional_params["maxTokens"] = value - if param == "stream": - optional_params["stream"] = value - if param == "stop": - if isinstance(value, str): - if len(value) == 0: # converse raises error for empty strings - continue - value = [value] - optional_params["stopSequences"] = value - if param == "temperature": - optional_params["temperature"] = value - if param == "top_p": - optional_params["topP"] = value - if param == "tools": - optional_params["tools"] = value - if param == "tool_choice": - _tool_choice_value = self.map_tool_choice_values( - model=model, tool_choice=value, drop_params=drop_params # type: ignore - ) - if _tool_choice_value is not None: - optional_params["tool_choice"] = _tool_choice_value - return optional_params - - -class BedrockConverseLLM(BaseAWSLLM): - def __init__(self) -> None: - super().__init__() - - def process_response( - self, - model: str, - response: Union[requests.Response, httpx.Response], - model_response: ModelResponse, - stream: bool, - logging_obj: Optional[Logging], - optional_params: dict, - api_key: str, - data: Union[dict, str], - messages: List, - print_verbose, - encoding, - ) -> Union[ModelResponse, CustomStreamWrapper]: - - ## LOGGING - if logging_obj is not None: - logging_obj.post_call( - input=messages, - api_key=api_key, - original_response=response.text, - additional_args={"complete_input_dict": data}, - ) - print_verbose(f"raw model_response: {response.text}") - json_mode: Optional[bool] = optional_params.pop("json_mode", None) - ## RESPONSE OBJECT - try: - completion_response = ConverseResponseBlock(**response.json()) # type: ignore - except Exception as e: - raise BedrockError( - message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format( - response.text, str(e) - ), - status_code=422, - ) - - """ - Bedrock Response Object has optional message block - - completion_response["output"].get("message", None) - - A message block looks like this (Example 1): - "output": { - "message": { - "role": "assistant", - "content": [ - { - "text": "Is there anything else you'd like to talk about? Perhaps I can help with some economic questions or provide some information about economic concepts?" - } - ] - } - }, - (Example 2): - "output": { - "message": { - "role": "assistant", - "content": [ - { - "toolUse": { - "toolUseId": "tooluse_hbTgdi0CSLq_hM4P8csZJA", - "name": "top_song", - "input": { - "sign": "WZPZ" - } - } - } - ] - } - } - - """ - message: Optional[MessageBlock] = completion_response["output"]["message"] - chat_completion_message: ChatCompletionResponseMessage = {"role": "assistant"} - content_str = "" - tools: List[ChatCompletionToolCallChunk] = [] - if message is not None: - for idx, content in enumerate(message["content"]): - """ - - Content is either a tool response or text - """ - if "text" in content: - content_str += content["text"] - if "toolUse" in content: - - ## check tool name was formatted by litellm - _response_tool_name = content["toolUse"]["name"] - response_tool_name = get_bedrock_tool_name( - response_tool_name=_response_tool_name - ) - _function_chunk = ChatCompletionToolCallFunctionChunk( - name=response_tool_name, - arguments=json.dumps(content["toolUse"]["input"]), - ) - - _tool_response_chunk = ChatCompletionToolCallChunk( - id=content["toolUse"]["toolUseId"], - type="function", - function=_function_chunk, - index=idx, - ) - tools.append(_tool_response_chunk) - chat_completion_message["content"] = content_str - - if json_mode is True and tools is not None and len(tools) == 1: - # to support 'json_schema' logic on bedrock models - json_mode_content_str: Optional[str] = tools[0]["function"].get("arguments") - if json_mode_content_str is not None: - chat_completion_message["content"] = json_mode_content_str - else: - chat_completion_message["tool_calls"] = tools - - ## CALCULATING USAGE - bedrock returns usage in the headers - input_tokens = completion_response["usage"]["inputTokens"] - output_tokens = completion_response["usage"]["outputTokens"] - total_tokens = completion_response["usage"]["totalTokens"] - - model_response.choices = [ - litellm.Choices( - finish_reason=map_finish_reason(completion_response["stopReason"]), - index=0, - message=litellm.Message(**chat_completion_message), - ) - ] - model_response.created = int(time.time()) - model_response.model = model - usage = Usage( - prompt_tokens=input_tokens, - completion_tokens=output_tokens, - total_tokens=total_tokens, - ) - setattr(model_response, "usage", usage) - - # Add "trace" from Bedrock guardrails - if user has opted in to returning it - if "trace" in completion_response: - setattr(model_response, "trace", completion_response["trace"]) - - return model_response - - def encode_model_id(self, model_id: str) -> str: - """ - Double encode the model ID to ensure it matches the expected double-encoded format. - Args: - model_id (str): The model ID to encode. - Returns: - str: The double-encoded model ID. - """ - return urllib.parse.quote(model_id, safe="") - - async def async_streaming( - self, - model: str, - messages: list, - api_base: str, - model_response: ModelResponse, - print_verbose: Callable, - data: str, - timeout: Optional[Union[float, httpx.Timeout]], - encoding, - logging_obj, - stream, - optional_params: dict, - litellm_params=None, - logger_fn=None, - headers={}, - client: Optional[AsyncHTTPHandler] = None, - ) -> CustomStreamWrapper: - streaming_response = CustomStreamWrapper( - completion_stream=None, - make_call=partial( - make_call, - client=client, - api_base=api_base, - headers=headers, - data=data, - model=model, - messages=messages, - logging_obj=logging_obj, - ), - model=model, - custom_llm_provider="bedrock", - logging_obj=logging_obj, - ) - return streaming_response - - async def async_completion( - self, - model: str, - messages: list, - api_base: str, - model_response: ModelResponse, - print_verbose: Callable, - data: str, - timeout: Optional[Union[float, httpx.Timeout]], - encoding, - logging_obj, - stream, - optional_params: dict, - litellm_params=None, - logger_fn=None, - headers={}, - client: Optional[AsyncHTTPHandler] = None, - ) -> Union[ModelResponse, CustomStreamWrapper]: - if client is None or not isinstance(client, AsyncHTTPHandler): - _params = {} - if timeout is not None: - if isinstance(timeout, float) or isinstance(timeout, int): - timeout = httpx.Timeout(timeout) - _params["timeout"] = timeout - client = get_async_httpx_client( - params=_params, llm_provider=litellm.LlmProviders.BEDROCK - ) - else: - client = client # type: ignore - - try: - response = await client.post(url=api_base, headers=headers, data=data) # type: ignore - response.raise_for_status() - except httpx.HTTPStatusError as err: - error_code = err.response.status_code - raise BedrockError(status_code=error_code, message=err.response.text) - except httpx.TimeoutException as e: - raise BedrockError(status_code=408, message="Timeout error occurred.") - - return self.process_response( - model=model, - response=response, - model_response=model_response, - stream=stream if isinstance(stream, bool) else False, - logging_obj=logging_obj, - api_key="", - data=data, - messages=messages, - print_verbose=print_verbose, - optional_params=optional_params, - encoding=encoding, - ) - - def completion( - self, - model: str, - messages: list, - api_base: Optional[str], - custom_prompt_dict: dict, - model_response: ModelResponse, - print_verbose: Callable, - encoding, - logging_obj, - optional_params: dict, - acompletion: bool, - timeout: Optional[Union[float, httpx.Timeout]], - litellm_params: dict, - logger_fn=None, - extra_headers: Optional[dict] = None, - client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None, - ): - try: - import boto3 - from botocore.auth import SigV4Auth - from botocore.awsrequest import AWSRequest - from botocore.credentials import Credentials - except ImportError: - raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") - - ## SETUP ## - stream = optional_params.pop("stream", None) - modelId = optional_params.pop("model_id", None) - if modelId is not None: - modelId = self.encode_model_id(model_id=modelId) - else: - modelId = model - - provider = model.split(".")[0] - - ## CREDENTIALS ## - # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them - aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) - aws_access_key_id = optional_params.pop("aws_access_key_id", None) - aws_session_token = optional_params.pop("aws_session_token", None) - aws_region_name = optional_params.pop("aws_region_name", None) - aws_role_name = optional_params.pop("aws_role_name", None) - aws_session_name = optional_params.pop("aws_session_name", None) - aws_profile_name = optional_params.pop("aws_profile_name", None) - aws_bedrock_runtime_endpoint = optional_params.pop( - "aws_bedrock_runtime_endpoint", None - ) # https://bedrock-runtime.{region_name}.amazonaws.com - aws_web_identity_token = optional_params.pop("aws_web_identity_token", None) - aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None) - - ### SET REGION NAME ### - if aws_region_name is None: - # check env # - litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) - - if litellm_aws_region_name is not None and isinstance( - litellm_aws_region_name, str - ): - aws_region_name = litellm_aws_region_name - - standard_aws_region_name = get_secret("AWS_REGION", None) - if standard_aws_region_name is not None and isinstance( - standard_aws_region_name, str - ): - aws_region_name = standard_aws_region_name - - if aws_region_name is None: - aws_region_name = "us-west-2" - - credentials: Credentials = self.get_credentials( - aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key, - aws_session_token=aws_session_token, - aws_region_name=aws_region_name, - aws_session_name=aws_session_name, - aws_profile_name=aws_profile_name, - aws_role_name=aws_role_name, - aws_web_identity_token=aws_web_identity_token, - aws_sts_endpoint=aws_sts_endpoint, - ) - - ### SET RUNTIME ENDPOINT ### - endpoint_url, proxy_endpoint_url = get_runtime_endpoint( - api_base=api_base, - aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, - aws_region_name=aws_region_name, - ) - if (stream is not None and stream is True) and provider != "ai21": - endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream" - proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse-stream" - else: - endpoint_url = f"{endpoint_url}/model/{modelId}/converse" - proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse" - - sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name) - - # Separate system prompt from rest of message - system_prompt_indices = [] - system_content_blocks: List[SystemContentBlock] = [] - for idx, message in enumerate(messages): - if message["role"] == "system": - _system_content_block: Optional[SystemContentBlock] = None - if isinstance(message["content"], str) and len(message["content"]) > 0: - _system_content_block = SystemContentBlock(text=message["content"]) - elif isinstance(message["content"], list): - for m in message["content"]: - if m.get("type", "") == "text" and len(m["text"]) > 0: - _system_content_block = SystemContentBlock(text=m["text"]) - if _system_content_block is not None: - system_content_blocks.append(_system_content_block) - system_prompt_indices.append(idx) - if len(system_prompt_indices) > 0: - for idx in reversed(system_prompt_indices): - messages.pop(idx) - - inference_params = copy.deepcopy(optional_params) - additional_request_keys = [] - additional_request_params = {} - supported_converse_params = AmazonConverseConfig.__annotations__.keys() - supported_tool_call_params = ["tools", "tool_choice"] - supported_guardrail_params = ["guardrailConfig"] - json_mode: Optional[bool] = inference_params.pop( - "json_mode", None - ) # used for handling json_schema - ## TRANSFORMATION ## - - bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt( - messages=messages, - model=model, - llm_provider="bedrock_converse", - user_continue_message=litellm_params.pop("user_continue_message", None), - ) - - # send all model-specific params in 'additional_request_params' - for k, v in inference_params.items(): - if ( - k not in supported_converse_params - and k not in supported_tool_call_params - and k not in supported_guardrail_params - ): - additional_request_params[k] = v - additional_request_keys.append(k) - for key in additional_request_keys: - inference_params.pop(key, None) - - bedrock_tools: List[ToolBlock] = _bedrock_tools_pt( - inference_params.pop("tools", []) - ) - bedrock_tool_config: Optional[ToolConfigBlock] = None - if len(bedrock_tools) > 0: - tool_choice_values: ToolChoiceValuesBlock = inference_params.pop( - "tool_choice", None - ) - bedrock_tool_config = ToolConfigBlock( - tools=bedrock_tools, - ) - if tool_choice_values is not None: - bedrock_tool_config["toolChoice"] = tool_choice_values - - _data: RequestObject = { - "messages": bedrock_messages, - "additionalModelRequestFields": additional_request_params, - "system": system_content_blocks, - "inferenceConfig": InferenceConfig(**inference_params), - } - - # Guardrail Config - guardrail_config: Optional[GuardrailConfigBlock] = None - request_guardrails_config = inference_params.pop("guardrailConfig", None) - if request_guardrails_config is not None: - guardrail_config = GuardrailConfigBlock(**request_guardrails_config) - _data["guardrailConfig"] = guardrail_config - - # Tool Config - if bedrock_tool_config is not None: - _data["toolConfig"] = bedrock_tool_config - - data = json.dumps(_data) - ## COMPLETION CALL - - headers = {"Content-Type": "application/json"} - if extra_headers is not None: - headers = {"Content-Type": "application/json", **extra_headers} - request = AWSRequest( - method="POST", url=endpoint_url, data=data, headers=headers - ) - sigv4.add_auth(request) - if ( - extra_headers is not None and "Authorization" in extra_headers - ): # prevent sigv4 from overwriting the auth header - request.headers["Authorization"] = extra_headers["Authorization"] - prepped = request.prepare() - - ## LOGGING - logging_obj.pre_call( - input=messages, - api_key="", - additional_args={ - "complete_input_dict": data, - "api_base": proxy_endpoint_url, - "headers": prepped.headers, - }, - ) - - ### ROUTING (ASYNC, STREAMING, SYNC) - if acompletion: - if isinstance(client, HTTPHandler): - client = None - if stream is True: - return self.async_streaming( - model=model, - messages=messages, - data=data, - api_base=proxy_endpoint_url, - model_response=model_response, - print_verbose=print_verbose, - encoding=encoding, - logging_obj=logging_obj, - optional_params=optional_params, - stream=True, - litellm_params=litellm_params, - logger_fn=logger_fn, - headers=prepped.headers, - timeout=timeout, - client=client, - ) # type: ignore - ### ASYNC COMPLETION - return self.async_completion( - model=model, - messages=messages, - data=data, - api_base=proxy_endpoint_url, - model_response=model_response, - print_verbose=print_verbose, - encoding=encoding, - logging_obj=logging_obj, - optional_params=optional_params, - stream=stream, # type: ignore - litellm_params=litellm_params, - logger_fn=logger_fn, - headers=prepped.headers, - timeout=timeout, - client=client, - ) # type: ignore - - if stream is not None and stream is True: - - streaming_response = CustomStreamWrapper( - completion_stream=None, - make_call=partial( - make_sync_call, - client=None, - api_base=proxy_endpoint_url, - headers=prepped.headers, # type: ignore - data=data, - model=model, - messages=messages, - logging_obj=logging_obj, - ), - model=model, - custom_llm_provider="bedrock", - logging_obj=logging_obj, - ) - - return streaming_response - ### COMPLETION - - if client is None or isinstance(client, AsyncHTTPHandler): - _params = {} - if timeout is not None: - if isinstance(timeout, float) or isinstance(timeout, int): - timeout = httpx.Timeout(timeout) - _params["timeout"] = timeout - client = _get_httpx_client(_params) # type: ignore - else: - client = client - try: - response = client.post(url=proxy_endpoint_url, headers=prepped.headers, data=data) # type: ignore - response.raise_for_status() - except httpx.HTTPStatusError as err: - error_code = err.response.status_code - raise BedrockError(status_code=error_code, message=response.text) - except httpx.TimeoutException: - raise BedrockError(status_code=408, message="Timeout error occurred.") - - return self.process_response( - model=model, - response=response, - model_response=model_response, - stream=stream if isinstance(stream, bool) else False, - logging_obj=logging_obj, - optional_params=optional_params, - api_key="", - data=data, - messages=messages, - print_verbose=print_verbose, - encoding=encoding, - ) - - def get_response_stream_shape(): global _response_stream_shape_cache if _response_stream_shape_cache is None: @@ -1847,24 +1080,6 @@ def get_response_stream_shape(): return _response_stream_shape_cache -def get_bedrock_tool_name(response_tool_name: str) -> str: - """ - If litellm formatted the input tool name, we need to convert it back to the original name. - - Args: - response_tool_name (str): The name of the tool as received from the response. - - Returns: - str: The original name of the tool. - """ - - if response_tool_name in litellm.bedrock_tool_name_mappings.cache_dict: - response_tool_name = litellm.bedrock_tool_name_mappings.cache_dict[ - response_tool_name - ] - return response_tool_name - - class AWSEventStreamDecoder: def __init__(self, model: str) -> None: from botocore.parsers import EventStreamJSONParser diff --git a/litellm/llms/bedrock/common_utils.py b/litellm/llms/bedrock/common_utils.py index 86fced96ad20..cc402f6f3a03 100644 --- a/litellm/llms/bedrock/common_utils.py +++ b/litellm/llms/bedrock/common_utils.py @@ -583,7 +583,7 @@ def init_bedrock_client( # Iterate over parameters and update if needed for i, param in enumerate(params_to_check): if param and param.startswith("os.environ/"): - params_to_check[i] = get_secret(param) + params_to_check[i] = get_secret(param) # type: ignore # Assign updated values back to parameters ( aws_access_key_id, @@ -626,13 +626,13 @@ def init_bedrock_client( import boto3 if isinstance(timeout, float): - config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout) + config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout) # type: ignore elif isinstance(timeout, httpx.Timeout): - config = boto3.session.Config( + config = boto3.session.Config( # type: ignore connect_timeout=timeout.connect, read_timeout=timeout.read ) else: - config = boto3.session.Config() + config = boto3.session.Config() # type: ignore ### CHECK STS ### if ( @@ -733,40 +733,6 @@ def init_bedrock_client( return client -def get_runtime_endpoint( - api_base: Optional[str], - aws_bedrock_runtime_endpoint: Optional[str], - aws_region_name: str, -) -> Tuple[str, str]: - env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT") - if api_base is not None: - endpoint_url = api_base - elif aws_bedrock_runtime_endpoint is not None and isinstance( - aws_bedrock_runtime_endpoint, str - ): - endpoint_url = aws_bedrock_runtime_endpoint - elif env_aws_bedrock_runtime_endpoint and isinstance( - env_aws_bedrock_runtime_endpoint, str - ): - endpoint_url = env_aws_bedrock_runtime_endpoint - else: - endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com" - - # Determine proxy_endpoint_url - if env_aws_bedrock_runtime_endpoint and isinstance( - env_aws_bedrock_runtime_endpoint, str - ): - proxy_endpoint_url = env_aws_bedrock_runtime_endpoint - elif aws_bedrock_runtime_endpoint is not None and isinstance( - aws_bedrock_runtime_endpoint, str - ): - proxy_endpoint_url = aws_bedrock_runtime_endpoint - else: - proxy_endpoint_url = endpoint_url - - return endpoint_url, proxy_endpoint_url - - class ModelResponseIterator: def __init__(self, model_response): self.model_response = model_response @@ -791,3 +757,21 @@ async def __anext__(self): raise StopAsyncIteration self.is_done = True return self.model_response + + +def get_bedrock_tool_name(response_tool_name: str) -> str: + """ + If litellm formatted the input tool name, we need to convert it back to the original name. + + Args: + response_tool_name (str): The name of the tool as received from the response. + + Returns: + str: The original name of the tool. + """ + + if response_tool_name in litellm.bedrock_tool_name_mappings.cache_dict: + response_tool_name = litellm.bedrock_tool_name_mappings.cache_dict[ + response_tool_name + ] + return response_tool_name diff --git a/litellm/llms/bedrock/embed/embedding.py b/litellm/llms/bedrock/embed/embedding.py index 7d2e441da146..8203eb6e6789 100644 --- a/litellm/llms/bedrock/embed/embedding.py +++ b/litellm/llms/bedrock/embed/embedding.py @@ -23,7 +23,7 @@ from litellm.types.utils import Embedding, EmbeddingResponse, Usage from ...base_aws_llm import BaseAWSLLM -from ..common_utils import BedrockError, get_runtime_endpoint +from ..common_utils import BedrockError from .amazon_titan_g1_transformation import AmazonTitanG1Config from .amazon_titan_multimodal_transformation import ( AmazonTitanMultimodalEmbeddingG1Config, @@ -141,7 +141,7 @@ async def _make_async_call( response.raise_for_status() except httpx.HTTPStatusError as err: error_code = err.response.status_code - raise BedrockError(status_code=error_code, message=response.text) + raise BedrockError(status_code=error_code, message=err.response.text) except httpx.TimeoutException: raise BedrockError(status_code=408, message="Timeout error occurred.") @@ -197,7 +197,7 @@ def _single_func_embeddings( client=client, timeout=timeout, api_base=prepped.url, - headers=prepped.headers, + headers=prepped.headers, # type: ignore data=data, ) @@ -288,7 +288,7 @@ async def _async_single_func_embeddings( client=client, timeout=timeout, api_base=prepped.url, - headers=prepped.headers, + headers=prepped.headers, # type: ignore data=data, ) @@ -342,8 +342,8 @@ def embeddings( timeout: Optional[Union[float, httpx.Timeout]], aembedding: Optional[bool], extra_headers: Optional[dict], - optional_params=None, - litellm_params=None, + optional_params: dict, + litellm_params: dict, ) -> EmbeddingResponse: try: import boto3 @@ -392,10 +392,21 @@ def embeddings( transformed_request = AmazonTitanV2Config()._transform_request( input=i, inference_params=inference_params ) + else: + raise Exception( + "Unmapped model. Received={}. Expected={}".format( + model, + [ + "amazon.titan-embed-image-v1", + "amazon.titan-embed-text-v1", + "amazon.titan-embed-text-v2:0", + ], + ) + ) batch_data.append(transformed_request) ### SET RUNTIME ENDPOINT ### - endpoint_url, proxy_endpoint_url = get_runtime_endpoint( + endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint( api_base=api_base, aws_bedrock_runtime_endpoint=optional_params.pop( "aws_bedrock_runtime_endpoint", None @@ -443,6 +454,7 @@ def embeddings( headers = {"Content-Type": "application/json"} if extra_headers is not None: headers = {"Content-Type": "application/json", **extra_headers} + request = AWSRequest( method="POST", url=endpoint_url, data=json.dumps(data), headers=headers ) @@ -467,170 +479,5 @@ def embeddings( aembedding=aembedding, timeout=timeout, client=client, - headers=prepped.headers, + headers=prepped.headers, # type: ignore ) - - # def _embedding_func_single( - # model: str, - # input: str, - # client: Any, - # optional_params=None, - # encoding=None, - # logging_obj=None, - # ): - # if isinstance(input, str) is False: - # raise BedrockError( - # message="Bedrock Embedding API input must be type str | List[str]", - # status_code=400, - # ) - # # logic for parsing in - calling - parsing out model embedding calls - # ## FORMAT EMBEDDING INPUT ## - # provider = model.split(".")[0] - # inference_params = copy.deepcopy(optional_params) - # inference_params.pop( - # "user", None - # ) # make sure user is not passed in for bedrock call - # modelId = ( - # optional_params.pop("model_id", None) or model - # ) # default to model if not passed - # if provider == "amazon": - # input = input.replace(os.linesep, " ") - # data = {"inputText": input, **inference_params} - # # data = json.dumps(data) - # elif provider == "cohere": - # inference_params["input_type"] = inference_params.get( - # "input_type", "search_document" - # ) # aws bedrock example default - https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers?model=cohere.embed-english-v3 - # data = {"texts": [input], **inference_params} # type: ignore - # body = json.dumps(data).encode("utf-8") # type: ignore - # ## LOGGING - # request_str = f""" - # response = client.invoke_model( - # body={body}, - # modelId={modelId}, - # accept="*/*", - # contentType="application/json", - # )""" # type: ignore - # logging_obj.pre_call( - # input=input, - # api_key="", # boto3 is used for init. - # additional_args={ - # "complete_input_dict": {"model": modelId, "texts": input}, - # "request_str": request_str, - # }, - # ) - # try: - # response = client.invoke_model( - # body=body, - # modelId=modelId, - # accept="*/*", - # contentType="application/json", - # ) - # response_body = json.loads(response.get("body").read()) - # ## LOGGING - # logging_obj.post_call( - # input=input, - # api_key="", - # additional_args={"complete_input_dict": data}, - # original_response=json.dumps(response_body), - # ) - # if provider == "cohere": - # response = response_body.get("embeddings") - # # flatten list - # response = [item for sublist in response for item in sublist] - # return response - # elif provider == "amazon": - # return response_body.get("embedding") - # except Exception as e: - # raise BedrockError( - # message=f"Embedding Error with model {model}: {e}", status_code=500 - # ) - - # def embedding( - # model: str, - # input: Union[list, str], - # model_response: litellm.EmbeddingResponse, - # api_key: Optional[str] = None, - # logging_obj=None, - # optional_params=None, - # encoding=None, - # ): - # ### BOTO3 INIT ### - # # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them - # aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) - # aws_access_key_id = optional_params.pop("aws_access_key_id", None) - # aws_region_name = optional_params.pop("aws_region_name", None) - # aws_role_name = optional_params.pop("aws_role_name", None) - # aws_session_name = optional_params.pop("aws_session_name", None) - # aws_bedrock_runtime_endpoint = optional_params.pop( - # "aws_bedrock_runtime_endpoint", None - # ) - # aws_web_identity_token = optional_params.pop("aws_web_identity_token", None) - - # # use passed in BedrockRuntime.Client if provided, otherwise create a new one - # client = init_bedrock_client( - # aws_access_key_id=aws_access_key_id, - # aws_secret_access_key=aws_secret_access_key, - # aws_region_name=aws_region_name, - # aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, - # aws_web_identity_token=aws_web_identity_token, - # aws_role_name=aws_role_name, - # aws_session_name=aws_session_name, - # ) - # if isinstance(input, str): - # ## Embedding Call - # embeddings = [ - # _embedding_func_single( - # model, - # input, - # optional_params=optional_params, - # client=client, - # logging_obj=logging_obj, - # ) - # ] - # elif isinstance(input, list): - # ## Embedding Call - assuming this is a List[str] - # embeddings = [ - # _embedding_func_single( - # model, - # i, - # optional_params=optional_params, - # client=client, - # logging_obj=logging_obj, - # ) - # for i in input - # ] # [TODO]: make these parallel calls - # else: - # # enters this branch if input = int, ex. input=2 - # raise BedrockError( - # message="Bedrock Embedding API input must be type str | List[str]", - # status_code=400, - # ) - - # ## Populate OpenAI compliant dictionary - # embedding_response = [] - # for idx, embedding in enumerate(embeddings): - # embedding_response.append( - # { - # "object": "embedding", - # "index": idx, - # "embedding": embedding, - # } - # ) - # model_response.object = "list" - # model_response.data = embedding_response - # model_response.model = model - # input_tokens = 0 - - # input_str = "".join(input) - - # input_tokens += len(encoding.encode(input_str)) - - # usage = Usage( - # prompt_tokens=input_tokens, - # completion_tokens=0, - # total_tokens=input_tokens + 0, - # ) - # model_response.usage = usage - - # return model_response diff --git a/litellm/main.py b/litellm/main.py index ec3d22f1690a..82d19a976ea8 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2385,6 +2385,7 @@ def completion( ) if model in litellm.BEDROCK_CONVERSE_MODELS: + response = bedrock_converse_chat_completion.completion( model=model, messages=messages, @@ -3570,7 +3571,7 @@ def embedding( client=client, timeout=timeout, aembedding=aembedding, - litellm_params=litellm_params, + litellm_params={}, api_base=api_base, print_verbose=print_verbose, extra_headers=extra_headers, diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index c940e744fdaa..4533bf9114c0 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,7 +1,20 @@ model_list: - - model_name: gpt-3.5-turbo + - model_name: fake-claude-endpoint + litellm_params: + model: anthropic.claude-3-sonnet-20240229-v1:0 + api_base: https://exampleopenaiendpoint-production.up.railway.app + # aws_session_token: "IQoJb3JpZ2luX2VjELj//////////wEaCXVzLXdlc3QtMiJHMEUCIQDatCRVkIZERLcrR6P7Qd1vNfZ8r8xB/LUeaVaTW/lBTwIgAgmHSBe41d65GVRKSkpgVonjsCmOmAS7s/yklM9NsZcq3AEI4P//////////ARABGgw4ODg2MDIyMjM0MjgiDJrio0/CHYEfyt5EqyqwAfyWO4t3bFVWAOIwTyZ1N6lszeJKfMNus2hzVc+r73hia2Anv88uwPxNg2uqnXQNJumEo0DcBt30ZwOw03Isboy0d5l05h8gjb4nl9feyeKmKAnRdcqElrEWtCC1Qcefv78jQv53AbUipH1ssa5NPvptqZZpZYDPMlBEnV3YdvJJiuE23u2yOkCt+EoUJLaOYjOryoRyrSfbWB+JaUsB68R3rNTHzReeN3Nob/9Ic4HrMMmzmLcGOpgBZxclO4w8Z7i6TcVqbCwDOskxuR6bZaiFxKFG+9tDrWS7jaQKpq/YP9HUT0YwYpZplaBEEZR5sbIndg5yb4dRZrSHplblqKz8XLaUf5tuuyRJmwr96PTpw/dyEVk9gicFX6JfLBEv0v5rN2Z0JMFLdfIP4kC1U2PjcPOWoglWO3fLmJ4Lol2a3c5XDSMwMxjcJXq+c8Ue1v0=" + aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY + aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID + - model_name: gemini-vision + litellm_params: + model: vertex_ai/gemini-1.0-pro-vision-001 + api_base: https://exampleopenaiendpoint-production.up.railway.app/v1/projects/adroit-crow-413218/locations/us-central1/publishers/google/models/gemini-1.0-pro-vision-001 + vertex_project: "adroit-crow-413218" + vertex_location: "us-central1" + + - model_name: fake-openai-endpoint litellm_params: model: gpt-3.5-turbo - -router_settings: - model_group_alias: {"gpt-4": {"model": "gpt-3.5-turbo", "hidden": false}} \ No newline at end of file + api_base: https://exampleopenaiendpoint-production.up.railway.app + \ No newline at end of file diff --git a/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py b/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py index a272af46b3ec..6051af37c8be 100644 --- a/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py +++ b/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py @@ -80,7 +80,13 @@ async def gemini_proxy_route( updated_url = base_url.copy_with(path=encoded_endpoint) # Add or update query parameters - gemini_api_key = litellm.utils.get_secret(secret_name="GEMINI_API_KEY") + gemini_api_key: Optional[str] = litellm.utils.get_secret( # type: ignore + secret_name="GEMINI_API_KEY" + ) + if gemini_api_key is None: + raise Exception( + "Required 'GEMINI_API_KEY' in environment to make pass-through calls to Google AI Studio." + ) # Merge query parameters, giving precedence to those in updated_url merged_params = dict(request.query_params) merged_params.update({"key": gemini_api_key}) @@ -99,8 +105,8 @@ async def gemini_proxy_route( request, fastapi_response, user_api_key_dict, - query_params=merged_params, - stream=is_streaming_request, + query_params=merged_params, # type: ignore + stream=is_streaming_request, # type: ignore ) return received_value @@ -142,7 +148,7 @@ async def cohere_proxy_route( request, fastapi_response, user_api_key_dict, - stream=is_streaming_request, + stream=is_streaming_request, # type: ignore ) return received_value @@ -208,15 +214,15 @@ async def bedrock_proxy_route( endpoint_func = create_pass_through_route( endpoint=endpoint, target=str(prepped.url), - custom_headers=prepped.headers, + custom_headers=prepped.headers, # type: ignore ) # dynamically construct pass-through endpoint based on incoming path received_value = await endpoint_func( request, fastapi_response, user_api_key_dict, - stream=is_streaming_request, - custom_body=data, - query_params={}, + stream=is_streaming_request, # type: ignore + custom_body=data, # type: ignore + query_params={}, # type: ignore ) return received_value diff --git a/litellm/tests/test_bedrock_completion.py b/litellm/tests/test_bedrock_completion.py index e949d1ee7d2c..e786e4fc8e8c 100644 --- a/litellm/tests/test_bedrock_completion.py +++ b/litellm/tests/test_bedrock_completion.py @@ -27,7 +27,7 @@ completion_cost, embedding, ) -from litellm.llms.bedrock.chat import BedrockLLM, ToolBlock +from litellm.llms.bedrock.chat import BedrockLLM from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.prompt_templates.factory import _bedrock_tools_pt @@ -1287,3 +1287,41 @@ def test_bedrock_converse_translation_tool_message(): ], } ] + + +def test_base_aws_llm_get_credentials(): + import time + + import boto3 + + from litellm.llms.base_aws_llm import BaseAWSLLM + + start_time = time.time() + session = boto3.Session( + aws_access_key_id="test", + aws_secret_access_key="test2", + region_name="test3", + ) + credentials = session.get_credentials().get_frozen_credentials() + end_time = time.time() + + print( + "Total time for credentials - {}. Credentials - {}".format( + end_time - start_time, credentials + ) + ) + + start_time = time.time() + credentials = BaseAWSLLM().get_credentials( + aws_access_key_id="test", + aws_secret_access_key="test2", + aws_region_name="test3", + ) + + end_time = time.time() + + print( + "Total time for credentials - {}. Credentials - {}".format( + end_time - start_time, credentials.get_frozen_credentials() + ) + ) diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 2c4190382bb5..1debf817170b 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -1454,7 +1454,7 @@ async def test_bedrock_httpx_streaming(sync_mode, model, region): has_finish_reason = True break complete_response += chunk - if has_finish_reason == False: + if has_finish_reason is False: raise Exception("finish reason not set") if complete_response.strip() == "": raise Exception("Empty response received") diff --git a/litellm/utils.py b/litellm/utils.py index 45cc91819be0..280691c8a4c1 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -8159,9 +8159,7 @@ def exception_type( exception_mapping_worked = True if hasattr(original_exception, "request"): raise APIConnectionError( - message="{}\n{}".format( - str(original_exception), traceback.format_exc() - ), + message="{} - {}".format(exception_provider, error_str), llm_provider=custom_llm_provider, model=model, request=original_exception.request,