From 3f0838259194ea2f93e71b5551f4fe2f91673edb Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 14 Sep 2024 17:41:09 -0700 Subject: [PATCH 1/7] feat(aws_base_llm.py): prevents recreating boto3 credentials during high traffic Leads to 100ms perf boost in local testing --- litellm/llms/base_aws_llm.py | 66 ++- litellm/llms/bedrock/chat.py | 226 +-------- litellm/llms/bedrock/chat/converse.py | 228 +++++++++ .../bedrock/chat/converse_transformation.py | 431 ++++++++++++++++++ litellm/llms/bedrock/common_utils.py | 62 +-- litellm/llms/bedrock/embed/embedding.py | 192 +------- litellm/main.py | 1 + litellm/proxy/_new_secret_config.yaml | 21 +- 8 files changed, 801 insertions(+), 426 deletions(-) create mode 100644 litellm/llms/bedrock/chat/converse.py create mode 100644 litellm/llms/bedrock/chat/converse_transformation.py diff --git a/litellm/llms/base_aws_llm.py b/litellm/llms/base_aws_llm.py index 70f333eb6a0e..25429c4db6f8 100644 --- a/litellm/llms/base_aws_llm.py +++ b/litellm/llms/base_aws_llm.py @@ -1,5 +1,6 @@ +import hashlib import json -from typing import List, Optional +from typing import Dict, List, Optional, Tuple import httpx @@ -28,6 +29,14 @@ def __init__(self) -> None: self.iam_cache = DualCache() super().__init__() + def get_cache_key(self, credential_args: Dict[str, Optional[str]]) -> str: + """ + Generate a unique cache key based on the credential arguments. + """ + # Convert credential arguments to a JSON string and hash it to create a unique key + credential_str = json.dumps(credential_args, sort_keys=True) + return hashlib.sha256(credential_str.encode()).hexdigest() + def get_credentials( self, aws_access_key_id: Optional[str] = None, @@ -43,7 +52,12 @@ def get_credentials( """ Return a boto3.Credentials object """ + args = locals() + args.pop("self") + cache_key = self.get_cache_key(args) + import boto3 + from botocore.credentials import Credentials ## CHECK IS 'os.environ/' passed in params_to_check: List[Optional[str]] = [ @@ -186,7 +200,6 @@ def get_credentials( # Extract the credentials from the response and convert to Session Credentials sts_credentials = sts_response["Credentials"] - from botocore.credentials import Credentials credentials = Credentials( access_key=sts_credentials["AccessKeyId"], @@ -211,12 +224,59 @@ def get_credentials( secret_key=aws_secret_access_key, token=aws_session_token, ) + return credentials else: + + # Check if credentials are already in cache. These credentials have no expiry time. + cached_credentials: Optional[Credentials] = self.iam_cache.get_cache( + cache_key + ) + if cached_credentials: + return cached_credentials + session = boto3.Session( aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, region_name=aws_region_name, ) - return session.get_credentials() + credentials = session.get_credentials() + + self.iam_cache.set_cache(cache_key, credentials, ttl=3600 - 60) + + return credentials + + def get_runtime_endpoint( + self, + api_base: Optional[str], + aws_bedrock_runtime_endpoint: Optional[str], + aws_region_name: str, + ) -> Tuple[str, str]: + env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT") + if api_base is not None: + endpoint_url = api_base + elif aws_bedrock_runtime_endpoint is not None and isinstance( + aws_bedrock_runtime_endpoint, str + ): + endpoint_url = aws_bedrock_runtime_endpoint + elif env_aws_bedrock_runtime_endpoint and isinstance( + env_aws_bedrock_runtime_endpoint, str + ): + endpoint_url = env_aws_bedrock_runtime_endpoint + else: + endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com" + + # Determine proxy_endpoint_url + if env_aws_bedrock_runtime_endpoint and isinstance( + env_aws_bedrock_runtime_endpoint, str + ): + proxy_endpoint_url = env_aws_bedrock_runtime_endpoint + elif aws_bedrock_runtime_endpoint is not None and isinstance( + aws_bedrock_runtime_endpoint, str + ): + proxy_endpoint_url = aws_bedrock_runtime_endpoint + else: + proxy_endpoint_url = endpoint_url + + return endpoint_url, proxy_endpoint_url diff --git a/litellm/llms/bedrock/chat.py b/litellm/llms/bedrock/chat.py index 8d6d98ba62da..8068e5b03d39 100644 --- a/litellm/llms/bedrock/chat.py +++ b/litellm/llms/bedrock/chat.py @@ -64,7 +64,8 @@ parse_xml_params, prompt_factory, ) -from .common_utils import BedrockError, ModelResponseIterator, get_runtime_endpoint +from .common_utils import BedrockError, ModelResponseIterator, get_bedrock_tool_name +from .converse_chat.converse_transformation import AmazonConverseConfig BEDROCK_CONVERSE_MODELS = [ "anthropic.claude-3-5-sonnet-20240620-v1:0", @@ -416,6 +417,7 @@ def process_response( except: raise BedrockError(message=response.text, status_code=422) + outputText: Optional[str] = None try: if provider == "cohere": if "text" in completion_response: @@ -565,23 +567,27 @@ def process_response( try: if ( - len(outputText) > 0 + outputText is not None + and len(outputText) > 0 and hasattr(model_response.choices[0], "message") - and getattr(model_response.choices[0].message, "tool_calls", None) + and getattr(model_response.choices[0].message, "tool_calls", None) # type: ignore is None ): - model_response.choices[0].message.content = outputText + model_response.choices[0].message.content = outputText # type: ignore elif ( hasattr(model_response.choices[0], "message") - and getattr(model_response.choices[0].message, "tool_calls", None) + and getattr(model_response.choices[0].message, "tool_calls", None) # type: ignore is not None ): pass else: raise Exception() - except: + except Exception as e: raise BedrockError( - message=json.dumps(outputText), status_code=response.status_code + message="Error parsing received text={}.\nError-{}".format( + outputText, str(e) + ), + status_code=response.status_code, ) if stream and provider == "ai21": @@ -593,8 +599,8 @@ def process_response( streaming_choice = litellm.utils.StreamingChoices() streaming_choice.index = model_response.choices[0].index delta_obj = litellm.utils.Delta( - content=getattr(model_response.choices[0].message, "content", None), - role=model_response.choices[0].message.role, + content=getattr(model_response.choices[0].message, "content", None), # type: ignore + role=model_response.choices[0].message.role, # type: ignore ) streaming_choice.delta = delta_obj streaming_model_response.choices = [streaming_choice] @@ -730,7 +736,7 @@ def completion( ) ### SET RUNTIME ENDPOINT ### - endpoint_url, proxy_endpoint_url = get_runtime_endpoint( + endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint( api_base=api_base, aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, aws_region_name=aws_region_name, @@ -1001,7 +1007,7 @@ def completion( response.raise_for_status() except httpx.HTTPStatusError as err: error_code = err.response.status_code - raise BedrockError(status_code=error_code, message=response.text) + raise BedrockError(status_code=error_code, message=err.response.text) except httpx.TimeoutException as e: raise BedrockError(status_code=408, message="Timeout error occurred.") @@ -1112,182 +1118,6 @@ def embedding(self, *args, **kwargs): return super().embedding(*args, **kwargs) -class AmazonConverseConfig: - """ - Reference - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html - #2 - https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#conversation-inference-supported-models-features - """ - - maxTokens: Optional[int] - stopSequences: Optional[List[str]] - temperature: Optional[int] - topP: Optional[int] - - def __init__( - self, - maxTokens: Optional[int] = None, - stopSequences: Optional[List[str]] = None, - temperature: Optional[int] = None, - topP: Optional[int] = None, - ) -> None: - locals_ = locals() - for key, value in locals_.items(): - if key != "self" and value is not None: - setattr(self.__class__, key, value) - - @classmethod - def get_config(cls): - return { - k: v - for k, v in cls.__dict__.items() - if not k.startswith("__") - and not isinstance( - v, - ( - types.FunctionType, - types.BuiltinFunctionType, - classmethod, - staticmethod, - ), - ) - and v is not None - } - - def get_supported_openai_params(self, model: str) -> List[str]: - supported_params = [ - "max_tokens", - "stream", - "stream_options", - "stop", - "temperature", - "top_p", - "extra_headers", - "response_format", - ] - - if ( - model.startswith("anthropic") - or model.startswith("mistral") - or model.startswith("cohere") - or model.startswith("meta.llama3-1") - ): - supported_params.append("tools") - - if model.startswith("anthropic") or model.startswith("mistral"): - # only anthropic and mistral support tool choice config. otherwise (E.g. cohere) will fail the call - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html - supported_params.append("tool_choice") - - return supported_params - - def map_tool_choice_values( - self, model: str, tool_choice: Union[str, dict], drop_params: bool - ) -> Optional[ToolChoiceValuesBlock]: - if tool_choice == "none": - if litellm.drop_params is True or drop_params is True: - return None - else: - raise litellm.utils.UnsupportedParamsError( - message="Bedrock doesn't support tool_choice={}. To drop it from the call, set `litellm.drop_params = True.".format( - tool_choice - ), - status_code=400, - ) - elif tool_choice == "required": - return ToolChoiceValuesBlock(any={}) - elif tool_choice == "auto": - return ToolChoiceValuesBlock(auto={}) - elif isinstance(tool_choice, dict): - # only supported for anthropic + mistral models - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html - specific_tool = SpecificToolChoiceBlock( - name=tool_choice.get("function", {}).get("name", "") - ) - return ToolChoiceValuesBlock(tool=specific_tool) - else: - raise litellm.utils.UnsupportedParamsError( - message="Bedrock doesn't support tool_choice={}. Supported tool_choice values=['auto', 'required', json object]. To drop it from the call, set `litellm.drop_params = True.".format( - tool_choice - ), - status_code=400, - ) - - def get_supported_image_types(self) -> List[str]: - return ["png", "jpeg", "gif", "webp"] - - def map_openai_params( - self, - model: str, - non_default_params: dict, - optional_params: dict, - drop_params: bool, - ) -> dict: - for param, value in non_default_params.items(): - if param == "response_format": - json_schema: Optional[dict] = None - schema_name: str = "" - if "response_schema" in value: - json_schema = value["response_schema"] - schema_name = "json_tool_call" - elif "json_schema" in value: - json_schema = value["json_schema"]["schema"] - schema_name = value["json_schema"]["name"] - """ - Follow similar approach to anthropic - translate to a single tool call. - - When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode - - You usually want to provide a single tool - - You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool - - Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective. - """ - if json_schema is not None: - _tool_choice = self.map_tool_choice_values( - model=model, tool_choice="required", drop_params=drop_params # type: ignore - ) - - _tool = ChatCompletionToolParam( - type="function", - function=ChatCompletionToolParamFunctionChunk( - name=schema_name, parameters=json_schema - ), - ) - - optional_params["tools"] = [_tool] - optional_params["tool_choice"] = _tool_choice - optional_params["json_mode"] = True - else: - if litellm.drop_params is True or drop_params is True: - pass - else: - raise litellm.utils.UnsupportedParamsError( - message="Bedrock doesn't support response_format={}. To drop it from the call, set `litellm.drop_params = True.".format( - value - ), - status_code=400, - ) - if param == "max_tokens": - optional_params["maxTokens"] = value - if param == "stream": - optional_params["stream"] = value - if param == "stop": - if isinstance(value, str): - if len(value) == 0: # converse raises error for empty strings - continue - value = [value] - optional_params["stopSequences"] = value - if param == "temperature": - optional_params["temperature"] = value - if param == "top_p": - optional_params["topP"] = value - if param == "tools": - optional_params["tools"] = value - if param == "tool_choice": - _tool_choice_value = self.map_tool_choice_values( - model=model, tool_choice=value, drop_params=drop_params # type: ignore - ) - if _tool_choice_value is not None: - optional_params["tool_choice"] = _tool_choice_value - return optional_params - - class BedrockConverseLLM(BaseAWSLLM): def __init__(self) -> None: super().__init__() @@ -1612,7 +1442,7 @@ def completion( ) ### SET RUNTIME ENDPOINT ### - endpoint_url, proxy_endpoint_url = get_runtime_endpoint( + endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint( api_base=api_base, aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, aws_region_name=aws_region_name, @@ -1811,7 +1641,7 @@ def completion( response.raise_for_status() except httpx.HTTPStatusError as err: error_code = err.response.status_code - raise BedrockError(status_code=error_code, message=response.text) + raise BedrockError(status_code=error_code, message=err.response.text) except httpx.TimeoutException: raise BedrockError(status_code=408, message="Timeout error occurred.") @@ -1845,24 +1675,6 @@ def get_response_stream_shape(): return _response_stream_shape_cache -def get_bedrock_tool_name(response_tool_name: str) -> str: - """ - If litellm formatted the input tool name, we need to convert it back to the original name. - - Args: - response_tool_name (str): The name of the tool as received from the response. - - Returns: - str: The original name of the tool. - """ - - if response_tool_name in litellm.bedrock_tool_name_mappings.cache_dict: - response_tool_name = litellm.bedrock_tool_name_mappings.cache_dict[ - response_tool_name - ] - return response_tool_name - - class AWSEventStreamDecoder: def __init__(self, model: str) -> None: from botocore.parsers import EventStreamJSONParser diff --git a/litellm/llms/bedrock/chat/converse.py b/litellm/llms/bedrock/chat/converse.py new file mode 100644 index 000000000000..23f1932277cc --- /dev/null +++ b/litellm/llms/bedrock/chat/converse.py @@ -0,0 +1,228 @@ +import json +from typing import Callable, Optional, Union + +import httpx + +import litellm +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + HTTPHandler, + get_async_httpx_client, +) +from litellm.types.utils import ModelResponse +from litellm.utils import CustomStreamWrapper, get_secret + +from ...base_aws_llm import BaseAWSLLM +from ..common_utils import BedrockError + +# class FakeBedrockConverseLLM(BaseAWSLLM): +# def __init__(self) -> None: +# super().__init__() + +# async def async_completion( +# self, +# model: str, +# messages: list, +# api_base: str, +# model_response: ModelResponse, +# print_verbose: Callable, +# data: str, +# timeout: Optional[Union[float, httpx.Timeout]], +# encoding, +# logging_obj, +# stream, +# optional_params: dict, +# litellm_params=None, +# logger_fn=None, +# headers={}, +# client: Optional[AsyncHTTPHandler] = None, +# ) -> Union[ModelResponse, CustomStreamWrapper]: +# if client is None or not isinstance(client, AsyncHTTPHandler): +# _params = {} +# if timeout is not None: +# if isinstance(timeout, float) or isinstance(timeout, int): +# timeout = httpx.Timeout(timeout) +# _params["timeout"] = timeout +# client = get_async_httpx_client( +# params=_params, llm_provider=litellm.LlmProviders.BEDROCK +# ) +# else: +# client = client # type: ignore + +# try: +# response = await client.post(url=api_base, headers=headers, data=data) # type: ignore +# response.raise_for_status() +# except httpx.HTTPStatusError as err: +# error_code = err.response.status_code +# raise BedrockError(status_code=error_code, message=err.response.text) +# except httpx.TimeoutException as e: +# raise BedrockError(status_code=408, message="Timeout error occurred.") + +# return litellm.AmazonConverseConfig()._transform_response( +# model=model, +# response=response, +# model_response=model_response, +# stream=stream if isinstance(stream, bool) else False, +# logging_obj=logging_obj, +# api_key="", +# data=data, +# messages=messages, +# print_verbose=print_verbose, +# optional_params=optional_params, +# encoding=encoding, +# ) + +# def completion( +# self, +# model: str, +# messages: list, +# api_base: Optional[str], +# custom_prompt_dict: dict, +# model_response: ModelResponse, +# print_verbose: Callable, +# encoding, +# logging_obj, +# optional_params: dict, +# acompletion: bool, +# timeout: Optional[Union[float, httpx.Timeout]], +# litellm_params: dict, +# logger_fn=None, +# extra_headers: Optional[dict] = None, +# client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None, +# ): +# try: +# import boto3 +# from botocore.auth import SigV4Auth +# from botocore.awsrequest import AWSRequest +# from botocore.credentials import Credentials +# except ImportError: +# raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") + +# ## SETUP ## +# stream = optional_params.pop("stream", None) +# modelId = optional_params.pop("model_id", None) +# if modelId is not None: +# modelId = self.encode_model_id(model_id=modelId) +# else: +# modelId = model + +# provider = model.split(".")[0] + +# # CREDENTIALS +# # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them +# aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) +# aws_access_key_id = optional_params.pop("aws_access_key_id", None) +# aws_session_token = optional_params.pop("aws_session_token", None) +# aws_region_name = optional_params.pop("aws_region_name", None) +# aws_role_name = optional_params.pop("aws_role_name", None) +# aws_session_name = optional_params.pop("aws_session_name", None) +# aws_profile_name = optional_params.pop("aws_profile_name", None) +# aws_bedrock_runtime_endpoint = optional_params.pop( +# "aws_bedrock_runtime_endpoint", None +# ) # https://bedrock-runtime.{region_name}.amazonaws.com +# aws_web_identity_token = optional_params.pop("aws_web_identity_token", None) +# aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None) + +# # if aws_region_name is None: +# # # check env # +# # litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) + +# # if litellm_aws_region_name is not None and isinstance( +# # litellm_aws_region_name, str +# # ): +# # aws_region_name = litellm_aws_region_name + +# # standard_aws_region_name = get_secret("AWS_REGION", None) +# # if standard_aws_region_name is not None and isinstance( +# # standard_aws_region_name, str +# # ): +# # aws_region_name = standard_aws_region_name + +# # if aws_region_name is None: +# # aws_region_name = "us-west-2" + +# if aws_region_name is None: +# aws_region_name = "us-west-2" + +# credentials = super().get_credentials( +# aws_access_key_id=aws_access_key_id, +# aws_secret_access_key=aws_secret_access_key, +# aws_session_token=aws_session_token, +# aws_region_name=aws_region_name, +# aws_session_name=aws_session_name, +# aws_profile_name=aws_profile_name, +# aws_role_name=aws_role_name, +# aws_web_identity_token=aws_web_identity_token, +# aws_sts_endpoint=aws_sts_endpoint, +# ) +# # sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name) + +# ### SET RUNTIME ENDPOINT ### +# endpoint_url, proxy_endpoint_url = super().get_runtime_endpoint( +# api_base=api_base, +# aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, +# aws_region_name=aws_region_name, +# ) + +# if (stream is not None and stream is True) and provider != "ai21": +# endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream" +# proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse-stream" +# else: +# endpoint_url = f"{endpoint_url}/model/{modelId}/converse" +# proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse" + +# _data = litellm.AmazonConverseConfig()._transform_request( +# model=model, +# messages=messages, +# optional_params=optional_params, +# litellm_params=litellm_params, +# ) + +# data = json.dumps(_data) + +# ## COMPLETION CALL + +# # headers = {"Content-Type": "application/json"} +# # if extra_headers is not None: +# # headers = {"Content-Type": "application/json", **extra_headers} +# # print(f"endpoint_url: {endpoint_url}, data: {data}, headers: {headers}") +# # request = AWSRequest( +# # method="POST", url=endpoint_url, data=data, headers=headers +# # ) +# # sigv4.add_auth(request) + +# # if ( +# # extra_headers is not None and "Authorization" in extra_headers +# # ): # prevent sigv4 from overwriting the auth header +# # request.headers["Authorization"] = extra_headers["Authorization"] +# # prepped = request.prepare() + +# # ## LOGGING +# # logging_obj.pre_call( +# # input=messages, +# # api_key="", +# # additional_args={ +# # "complete_input_dict": data, +# # "api_base": proxy_endpoint_url, +# # "headers": prepped.headers, +# # }, +# # ) + +# return self.async_completion( +# model=model, +# messages=messages, +# data=data, +# api_base=proxy_endpoint_url, +# model_response=model_response, +# print_verbose=print_verbose, +# encoding=encoding, +# logging_obj=logging_obj, +# optional_params=optional_params, +# stream=False, +# litellm_params=litellm_params, +# logger_fn=logger_fn, +# # headers=prepped.headers, +# headers={"Authorization": "my-fake-key"}, +# timeout=timeout, +# client=client, +# ) # type: ignore diff --git a/litellm/llms/bedrock/chat/converse_transformation.py b/litellm/llms/bedrock/chat/converse_transformation.py new file mode 100644 index 000000000000..5f6141cb4e62 --- /dev/null +++ b/litellm/llms/bedrock/chat/converse_transformation.py @@ -0,0 +1,431 @@ +""" +Translating between OpenAI's `/chat/completion` format and Amazon's `/converse` format +""" + +import copy +import time +import types +from typing import List, Optional, Union + +import httpx + +import litellm +from litellm.litellm_core_utils.core_helpers import map_finish_reason +from litellm.litellm_core_utils.litellm_logging import Logging +from litellm.types.llms.bedrock import * +from litellm.types.llms.openai import ( + AllMessageValues, + ChatCompletionResponseMessage, + ChatCompletionToolCallChunk, + ChatCompletionToolCallFunctionChunk, + ChatCompletionToolParam, + ChatCompletionToolParamFunctionChunk, +) +from litellm.types.utils import ModelResponse, Usage +from litellm.utils import CustomStreamWrapper + +from ...prompt_templates.factory import _bedrock_converse_messages_pt, _bedrock_tools_pt +from ..common_utils import BedrockError, get_bedrock_tool_name + + +class AmazonConverseConfig: + """ + Reference - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html + #2 - https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#conversation-inference-supported-models-features + """ + + maxTokens: Optional[int] + stopSequences: Optional[List[str]] + temperature: Optional[int] + topP: Optional[int] + + def __init__( + self, + maxTokens: Optional[int] = None, + stopSequences: Optional[List[str]] = None, + temperature: Optional[int] = None, + topP: Optional[int] = None, + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def get_supported_openai_params(self, model: str) -> List[str]: + supported_params = [ + "max_tokens", + "stream", + "stream_options", + "stop", + "temperature", + "top_p", + "extra_headers", + "response_format", + ] + + if ( + model.startswith("anthropic") + or model.startswith("mistral") + or model.startswith("cohere") + or model.startswith("meta.llama3-1") + ): + supported_params.append("tools") + + if model.startswith("anthropic") or model.startswith("mistral"): + # only anthropic and mistral support tool choice config. otherwise (E.g. cohere) will fail the call - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html + supported_params.append("tool_choice") + + return supported_params + + def map_tool_choice_values( + self, model: str, tool_choice: Union[str, dict], drop_params: bool + ) -> Optional[ToolChoiceValuesBlock]: + if tool_choice == "none": + if litellm.drop_params is True or drop_params is True: + return None + else: + raise litellm.utils.UnsupportedParamsError( + message="Bedrock doesn't support tool_choice={}. To drop it from the call, set `litellm.drop_params = True.".format( + tool_choice + ), + status_code=400, + ) + elif tool_choice == "required": + return ToolChoiceValuesBlock(any={}) + elif tool_choice == "auto": + return ToolChoiceValuesBlock(auto={}) + elif isinstance(tool_choice, dict): + # only supported for anthropic + mistral models - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html + specific_tool = SpecificToolChoiceBlock( + name=tool_choice.get("function", {}).get("name", "") + ) + return ToolChoiceValuesBlock(tool=specific_tool) + else: + raise litellm.utils.UnsupportedParamsError( + message="Bedrock doesn't support tool_choice={}. Supported tool_choice values=['auto', 'required', json object]. To drop it from the call, set `litellm.drop_params = True.".format( + tool_choice + ), + status_code=400, + ) + + def get_supported_image_types(self) -> List[str]: + return ["png", "jpeg", "gif", "webp"] + + def map_openai_params( + self, + model: str, + non_default_params: dict, + optional_params: dict, + drop_params: bool, + ) -> dict: + for param, value in non_default_params.items(): + if param == "response_format": + json_schema: Optional[dict] = None + schema_name: str = "" + if "response_schema" in value: + json_schema = value["response_schema"] + schema_name = "json_tool_call" + elif "json_schema" in value: + json_schema = value["json_schema"]["schema"] + schema_name = value["json_schema"]["name"] + """ + Follow similar approach to anthropic - translate to a single tool call. + + When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode + - You usually want to provide a single tool + - You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool + - Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective. + """ + if json_schema is not None: + _tool_choice = self.map_tool_choice_values( + model=model, tool_choice="required", drop_params=drop_params # type: ignore + ) + + _tool = ChatCompletionToolParam( + type="function", + function=ChatCompletionToolParamFunctionChunk( + name=schema_name, parameters=json_schema + ), + ) + + optional_params["tools"] = [_tool] + optional_params["tool_choice"] = _tool_choice + optional_params["json_mode"] = True + else: + if litellm.drop_params is True or drop_params is True: + pass + else: + raise litellm.utils.UnsupportedParamsError( + message="Bedrock doesn't support response_format={}. To drop it from the call, set `litellm.drop_params = True.".format( + value + ), + status_code=400, + ) + if param == "max_tokens": + optional_params["maxTokens"] = value + if param == "stream": + optional_params["stream"] = value + if param == "stop": + if isinstance(value, str): + if len(value) == 0: # converse raises error for empty strings + continue + value = [value] + optional_params["stopSequences"] = value + if param == "temperature": + optional_params["temperature"] = value + if param == "top_p": + optional_params["topP"] = value + if param == "tools": + optional_params["tools"] = value + if param == "tool_choice": + _tool_choice_value = self.map_tool_choice_values( + model=model, tool_choice=value, drop_params=drop_params # type: ignore + ) + if _tool_choice_value is not None: + optional_params["tool_choice"] = _tool_choice_value + return optional_params + + def _transform_request( + self, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + ) -> RequestObject: + system_prompt_indices = [] + system_content_blocks: List[SystemContentBlock] = [] + for idx, message in enumerate(messages): + if message["role"] == "system": + _system_content_block: Optional[SystemContentBlock] = None + if isinstance(message["content"], str) and len(message["content"]) > 0: + _system_content_block = SystemContentBlock(text=message["content"]) + elif isinstance(message["content"], list): + for m in message["content"]: + if m.get("type", "") == "text" and len(m["text"]) > 0: + _system_content_block = SystemContentBlock(text=m["text"]) + if _system_content_block is not None: + system_content_blocks.append(_system_content_block) + system_prompt_indices.append(idx) + if len(system_prompt_indices) > 0: + for idx in reversed(system_prompt_indices): + messages.pop(idx) + + inference_params = copy.deepcopy(optional_params) + additional_request_keys = [] + additional_request_params = {} + supported_converse_params = AmazonConverseConfig.__annotations__.keys() + supported_tool_call_params = ["tools", "tool_choice"] + supported_guardrail_params = ["guardrailConfig"] + json_mode: Optional[bool] = inference_params.pop( + "json_mode", None + ) # used for handling json_schema + ## TRANSFORMATION ## + + bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt( + messages=messages, + model=model, + llm_provider="bedrock_converse", + user_continue_message=litellm_params.pop("user_continue_message", None), + ) + + # send all model-specific params in 'additional_request_params' + for k, v in inference_params.items(): + if ( + k not in supported_converse_params + and k not in supported_tool_call_params + and k not in supported_guardrail_params + ): + additional_request_params[k] = v + additional_request_keys.append(k) + for key in additional_request_keys: + inference_params.pop(key, None) + + bedrock_tools: List[ToolBlock] = _bedrock_tools_pt( + inference_params.pop("tools", []) + ) + bedrock_tool_config: Optional[ToolConfigBlock] = None + if len(bedrock_tools) > 0: + tool_choice_values: ToolChoiceValuesBlock = inference_params.pop( + "tool_choice", None + ) + bedrock_tool_config = ToolConfigBlock( + tools=bedrock_tools, + ) + if tool_choice_values is not None: + bedrock_tool_config["toolChoice"] = tool_choice_values + + _data: RequestObject = { + "messages": bedrock_messages, + "additionalModelRequestFields": additional_request_params, + "system": system_content_blocks, + "inferenceConfig": InferenceConfig(**inference_params), + } + + # Guardrail Config + guardrail_config: Optional[GuardrailConfigBlock] = None + request_guardrails_config = inference_params.pop("guardrailConfig", None) + if request_guardrails_config is not None: + guardrail_config = GuardrailConfigBlock(**request_guardrails_config) + _data["guardrailConfig"] = guardrail_config + + # Tool Config + if bedrock_tool_config is not None: + _data["toolConfig"] = bedrock_tool_config + + return _data + + def _transform_response( + self, + model: str, + response: httpx.Response, + model_response: ModelResponse, + stream: bool, + logging_obj: Optional[Logging], + optional_params: dict, + api_key: str, + data: Union[dict, str], + messages: List, + print_verbose, + encoding, + ) -> Union[ModelResponse, CustomStreamWrapper]: + + ## LOGGING + if logging_obj is not None: + logging_obj.post_call( + input=messages, + api_key=api_key, + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) + print_verbose(f"raw model_response: {response.text}") + json_mode: Optional[bool] = optional_params.pop("json_mode", None) + ## RESPONSE OBJECT + try: + completion_response = ConverseResponseBlock(**response.json()) # type: ignore + except Exception as e: + raise BedrockError( + message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format( + response.text, str(e) + ), + status_code=422, + ) + + """ + Bedrock Response Object has optional message block + + completion_response["output"].get("message", None) + + A message block looks like this (Example 1): + "output": { + "message": { + "role": "assistant", + "content": [ + { + "text": "Is there anything else you'd like to talk about? Perhaps I can help with some economic questions or provide some information about economic concepts?" + } + ] + } + }, + (Example 2): + "output": { + "message": { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "tooluse_hbTgdi0CSLq_hM4P8csZJA", + "name": "top_song", + "input": { + "sign": "WZPZ" + } + } + } + ] + } + } + + """ + message: Optional[MessageBlock] = completion_response["output"]["message"] + chat_completion_message: ChatCompletionResponseMessage = {"role": "assistant"} + content_str = "" + tools: List[ChatCompletionToolCallChunk] = [] + if message is not None: + for idx, content in enumerate(message["content"]): + """ + - Content is either a tool response or text + """ + if "text" in content: + content_str += content["text"] + if "toolUse" in content: + + ## check tool name was formatted by litellm + _response_tool_name = content["toolUse"]["name"] + response_tool_name = get_bedrock_tool_name( + response_tool_name=_response_tool_name + ) + _function_chunk = ChatCompletionToolCallFunctionChunk( + name=response_tool_name, + arguments=json.dumps(content["toolUse"]["input"]), + ) + + _tool_response_chunk = ChatCompletionToolCallChunk( + id=content["toolUse"]["toolUseId"], + type="function", + function=_function_chunk, + index=idx, + ) + tools.append(_tool_response_chunk) + chat_completion_message["content"] = content_str + + if json_mode is True and tools is not None and len(tools) == 1: + # to support 'json_schema' logic on bedrock models + json_mode_content_str: Optional[str] = tools[0]["function"].get("arguments") + if json_mode_content_str is not None: + chat_completion_message["content"] = json_mode_content_str + else: + chat_completion_message["tool_calls"] = tools + + ## CALCULATING USAGE - bedrock returns usage in the headers + input_tokens = completion_response["usage"]["inputTokens"] + output_tokens = completion_response["usage"]["outputTokens"] + total_tokens = completion_response["usage"]["totalTokens"] + + model_response.choices = [ + litellm.Choices( + finish_reason=map_finish_reason(completion_response["stopReason"]), + index=0, + message=litellm.Message(**chat_completion_message), + ) + ] + model_response.created = int(time.time()) + model_response.model = model + usage = Usage( + prompt_tokens=input_tokens, + completion_tokens=output_tokens, + total_tokens=total_tokens, + ) + setattr(model_response, "usage", usage) + + # Add "trace" from Bedrock guardrails - if user has opted in to returning it + if "trace" in completion_response: + setattr(model_response, "trace", completion_response["trace"]) + + return model_response diff --git a/litellm/llms/bedrock/common_utils.py b/litellm/llms/bedrock/common_utils.py index 25379474ef50..37ff8cbb7798 100644 --- a/litellm/llms/bedrock/common_utils.py +++ b/litellm/llms/bedrock/common_utils.py @@ -5,7 +5,7 @@ import os import types from enum import Enum -from typing import List, Optional, Union, Tuple +from typing import List, Optional, Tuple, Union import httpx @@ -575,7 +575,7 @@ def init_bedrock_client( # Iterate over parameters and update if needed for i, param in enumerate(params_to_check): if param and param.startswith("os.environ/"): - params_to_check[i] = get_secret(param) + params_to_check[i] = get_secret(param) # type: ignore # Assign updated values back to parameters ( aws_access_key_id, @@ -618,13 +618,13 @@ def init_bedrock_client( import boto3 if isinstance(timeout, float): - config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout) + config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout) # type: ignore elif isinstance(timeout, httpx.Timeout): - config = boto3.session.Config( + config = boto3.session.Config( # type: ignore connect_timeout=timeout.connect, read_timeout=timeout.read ) else: - config = boto3.session.Config() + config = boto3.session.Config() # type: ignore ### CHECK STS ### if ( @@ -725,40 +725,6 @@ def init_bedrock_client( return client -def get_runtime_endpoint( - api_base: Optional[str], - aws_bedrock_runtime_endpoint: Optional[str], - aws_region_name: str, -) -> Tuple[str, str]: - env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT") - if api_base is not None: - endpoint_url = api_base - elif aws_bedrock_runtime_endpoint is not None and isinstance( - aws_bedrock_runtime_endpoint, str - ): - endpoint_url = aws_bedrock_runtime_endpoint - elif env_aws_bedrock_runtime_endpoint and isinstance( - env_aws_bedrock_runtime_endpoint, str - ): - endpoint_url = env_aws_bedrock_runtime_endpoint - else: - endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com" - - # Determine proxy_endpoint_url - if env_aws_bedrock_runtime_endpoint and isinstance( - env_aws_bedrock_runtime_endpoint, str - ): - proxy_endpoint_url = env_aws_bedrock_runtime_endpoint - elif aws_bedrock_runtime_endpoint is not None and isinstance( - aws_bedrock_runtime_endpoint, str - ): - proxy_endpoint_url = aws_bedrock_runtime_endpoint - else: - proxy_endpoint_url = endpoint_url - - return endpoint_url, proxy_endpoint_url - - class ModelResponseIterator: def __init__(self, model_response): self.model_response = model_response @@ -783,3 +749,21 @@ async def __anext__(self): raise StopAsyncIteration self.is_done = True return self.model_response + + +def get_bedrock_tool_name(response_tool_name: str) -> str: + """ + If litellm formatted the input tool name, we need to convert it back to the original name. + + Args: + response_tool_name (str): The name of the tool as received from the response. + + Returns: + str: The original name of the tool. + """ + + if response_tool_name in litellm.bedrock_tool_name_mappings.cache_dict: + response_tool_name = litellm.bedrock_tool_name_mappings.cache_dict[ + response_tool_name + ] + return response_tool_name diff --git a/litellm/llms/bedrock/embed/embedding.py b/litellm/llms/bedrock/embed/embedding.py index 7d2e441da146..e25fa0dc0841 100644 --- a/litellm/llms/bedrock/embed/embedding.py +++ b/litellm/llms/bedrock/embed/embedding.py @@ -23,7 +23,7 @@ from litellm.types.utils import Embedding, EmbeddingResponse, Usage from ...base_aws_llm import BaseAWSLLM -from ..common_utils import BedrockError, get_runtime_endpoint +from ..common_utils import BedrockError from .amazon_titan_g1_transformation import AmazonTitanG1Config from .amazon_titan_multimodal_transformation import ( AmazonTitanMultimodalEmbeddingG1Config, @@ -141,7 +141,7 @@ async def _make_async_call( response.raise_for_status() except httpx.HTTPStatusError as err: error_code = err.response.status_code - raise BedrockError(status_code=error_code, message=response.text) + raise BedrockError(status_code=error_code, message=err.response.text) except httpx.TimeoutException: raise BedrockError(status_code=408, message="Timeout error occurred.") @@ -197,7 +197,7 @@ def _single_func_embeddings( client=client, timeout=timeout, api_base=prepped.url, - headers=prepped.headers, + headers=dict(prepped.headers), data=data, ) @@ -288,7 +288,7 @@ async def _async_single_func_embeddings( client=client, timeout=timeout, api_base=prepped.url, - headers=prepped.headers, + headers=dict(prepped.headers), data=data, ) @@ -342,8 +342,8 @@ def embeddings( timeout: Optional[Union[float, httpx.Timeout]], aembedding: Optional[bool], extra_headers: Optional[dict], - optional_params=None, - litellm_params=None, + optional_params: dict, + litellm_params: dict, ) -> EmbeddingResponse: try: import boto3 @@ -392,10 +392,21 @@ def embeddings( transformed_request = AmazonTitanV2Config()._transform_request( input=i, inference_params=inference_params ) + else: + raise Exception( + "Unmapped model. Received={}. Expected={}".format( + model, + [ + "amazon.titan-embed-image-v1", + "amazon.titan-embed-text-v1", + "amazon.titan-embed-text-v2:0", + ], + ) + ) batch_data.append(transformed_request) ### SET RUNTIME ENDPOINT ### - endpoint_url, proxy_endpoint_url = get_runtime_endpoint( + endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint( api_base=api_base, aws_bedrock_runtime_endpoint=optional_params.pop( "aws_bedrock_runtime_endpoint", None @@ -467,170 +478,5 @@ def embeddings( aembedding=aembedding, timeout=timeout, client=client, - headers=prepped.headers, + headers=dict(prepped.headers), ) - - # def _embedding_func_single( - # model: str, - # input: str, - # client: Any, - # optional_params=None, - # encoding=None, - # logging_obj=None, - # ): - # if isinstance(input, str) is False: - # raise BedrockError( - # message="Bedrock Embedding API input must be type str | List[str]", - # status_code=400, - # ) - # # logic for parsing in - calling - parsing out model embedding calls - # ## FORMAT EMBEDDING INPUT ## - # provider = model.split(".")[0] - # inference_params = copy.deepcopy(optional_params) - # inference_params.pop( - # "user", None - # ) # make sure user is not passed in for bedrock call - # modelId = ( - # optional_params.pop("model_id", None) or model - # ) # default to model if not passed - # if provider == "amazon": - # input = input.replace(os.linesep, " ") - # data = {"inputText": input, **inference_params} - # # data = json.dumps(data) - # elif provider == "cohere": - # inference_params["input_type"] = inference_params.get( - # "input_type", "search_document" - # ) # aws bedrock example default - https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/providers?model=cohere.embed-english-v3 - # data = {"texts": [input], **inference_params} # type: ignore - # body = json.dumps(data).encode("utf-8") # type: ignore - # ## LOGGING - # request_str = f""" - # response = client.invoke_model( - # body={body}, - # modelId={modelId}, - # accept="*/*", - # contentType="application/json", - # )""" # type: ignore - # logging_obj.pre_call( - # input=input, - # api_key="", # boto3 is used for init. - # additional_args={ - # "complete_input_dict": {"model": modelId, "texts": input}, - # "request_str": request_str, - # }, - # ) - # try: - # response = client.invoke_model( - # body=body, - # modelId=modelId, - # accept="*/*", - # contentType="application/json", - # ) - # response_body = json.loads(response.get("body").read()) - # ## LOGGING - # logging_obj.post_call( - # input=input, - # api_key="", - # additional_args={"complete_input_dict": data}, - # original_response=json.dumps(response_body), - # ) - # if provider == "cohere": - # response = response_body.get("embeddings") - # # flatten list - # response = [item for sublist in response for item in sublist] - # return response - # elif provider == "amazon": - # return response_body.get("embedding") - # except Exception as e: - # raise BedrockError( - # message=f"Embedding Error with model {model}: {e}", status_code=500 - # ) - - # def embedding( - # model: str, - # input: Union[list, str], - # model_response: litellm.EmbeddingResponse, - # api_key: Optional[str] = None, - # logging_obj=None, - # optional_params=None, - # encoding=None, - # ): - # ### BOTO3 INIT ### - # # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them - # aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) - # aws_access_key_id = optional_params.pop("aws_access_key_id", None) - # aws_region_name = optional_params.pop("aws_region_name", None) - # aws_role_name = optional_params.pop("aws_role_name", None) - # aws_session_name = optional_params.pop("aws_session_name", None) - # aws_bedrock_runtime_endpoint = optional_params.pop( - # "aws_bedrock_runtime_endpoint", None - # ) - # aws_web_identity_token = optional_params.pop("aws_web_identity_token", None) - - # # use passed in BedrockRuntime.Client if provided, otherwise create a new one - # client = init_bedrock_client( - # aws_access_key_id=aws_access_key_id, - # aws_secret_access_key=aws_secret_access_key, - # aws_region_name=aws_region_name, - # aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, - # aws_web_identity_token=aws_web_identity_token, - # aws_role_name=aws_role_name, - # aws_session_name=aws_session_name, - # ) - # if isinstance(input, str): - # ## Embedding Call - # embeddings = [ - # _embedding_func_single( - # model, - # input, - # optional_params=optional_params, - # client=client, - # logging_obj=logging_obj, - # ) - # ] - # elif isinstance(input, list): - # ## Embedding Call - assuming this is a List[str] - # embeddings = [ - # _embedding_func_single( - # model, - # i, - # optional_params=optional_params, - # client=client, - # logging_obj=logging_obj, - # ) - # for i in input - # ] # [TODO]: make these parallel calls - # else: - # # enters this branch if input = int, ex. input=2 - # raise BedrockError( - # message="Bedrock Embedding API input must be type str | List[str]", - # status_code=400, - # ) - - # ## Populate OpenAI compliant dictionary - # embedding_response = [] - # for idx, embedding in enumerate(embeddings): - # embedding_response.append( - # { - # "object": "embedding", - # "index": idx, - # "embedding": embedding, - # } - # ) - # model_response.object = "list" - # model_response.data = embedding_response - # model_response.model = model - # input_tokens = 0 - - # input_str = "".join(input) - - # input_tokens += len(encoding.encode(input_str)) - - # usage = Usage( - # prompt_tokens=input_tokens, - # completion_tokens=0, - # total_tokens=input_tokens + 0, - # ) - # model_response.usage = usage - - # return model_response diff --git a/litellm/main.py b/litellm/main.py index a50c908c673a..3f1d57f6e25c 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2367,6 +2367,7 @@ def completion( ) if model in litellm.BEDROCK_CONVERSE_MODELS: + response = bedrock_converse_chat_completion.completion( model=model, messages=messages, diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index c940e744fdaa..4533bf9114c0 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,7 +1,20 @@ model_list: - - model_name: gpt-3.5-turbo + - model_name: fake-claude-endpoint + litellm_params: + model: anthropic.claude-3-sonnet-20240229-v1:0 + api_base: https://exampleopenaiendpoint-production.up.railway.app + # aws_session_token: "IQoJb3JpZ2luX2VjELj//////////wEaCXVzLXdlc3QtMiJHMEUCIQDatCRVkIZERLcrR6P7Qd1vNfZ8r8xB/LUeaVaTW/lBTwIgAgmHSBe41d65GVRKSkpgVonjsCmOmAS7s/yklM9NsZcq3AEI4P//////////ARABGgw4ODg2MDIyMjM0MjgiDJrio0/CHYEfyt5EqyqwAfyWO4t3bFVWAOIwTyZ1N6lszeJKfMNus2hzVc+r73hia2Anv88uwPxNg2uqnXQNJumEo0DcBt30ZwOw03Isboy0d5l05h8gjb4nl9feyeKmKAnRdcqElrEWtCC1Qcefv78jQv53AbUipH1ssa5NPvptqZZpZYDPMlBEnV3YdvJJiuE23u2yOkCt+EoUJLaOYjOryoRyrSfbWB+JaUsB68R3rNTHzReeN3Nob/9Ic4HrMMmzmLcGOpgBZxclO4w8Z7i6TcVqbCwDOskxuR6bZaiFxKFG+9tDrWS7jaQKpq/YP9HUT0YwYpZplaBEEZR5sbIndg5yb4dRZrSHplblqKz8XLaUf5tuuyRJmwr96PTpw/dyEVk9gicFX6JfLBEv0v5rN2Z0JMFLdfIP4kC1U2PjcPOWoglWO3fLmJ4Lol2a3c5XDSMwMxjcJXq+c8Ue1v0=" + aws_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY + aws_access_key_id: os.environ/AWS_ACCESS_KEY_ID + - model_name: gemini-vision + litellm_params: + model: vertex_ai/gemini-1.0-pro-vision-001 + api_base: https://exampleopenaiendpoint-production.up.railway.app/v1/projects/adroit-crow-413218/locations/us-central1/publishers/google/models/gemini-1.0-pro-vision-001 + vertex_project: "adroit-crow-413218" + vertex_location: "us-central1" + + - model_name: fake-openai-endpoint litellm_params: model: gpt-3.5-turbo - -router_settings: - model_group_alias: {"gpt-4": {"model": "gpt-3.5-turbo", "hidden": false}} \ No newline at end of file + api_base: https://exampleopenaiendpoint-production.up.railway.app + \ No newline at end of file From 13dfe1f2035793364c3feef2aba26d469650dc94 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 14 Sep 2024 17:45:18 -0700 Subject: [PATCH 2/7] fix(base_aws_llm.py): fix credential caching check to see if token is set --- litellm/llms/base_aws_llm.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/litellm/llms/base_aws_llm.py b/litellm/llms/base_aws_llm.py index 25429c4db6f8..c42d847c351c 100644 --- a/litellm/llms/base_aws_llm.py +++ b/litellm/llms/base_aws_llm.py @@ -243,7 +243,10 @@ def get_credentials( credentials = session.get_credentials() - self.iam_cache.set_cache(cache_key, credentials, ttl=3600 - 60) + if ( + credentials.token is None + ): # don't cache if session token exists. The expiry time for that is not known. + self.iam_cache.set_cache(cache_key, credentials, ttl=3600 - 60) return credentials From de9bd9de61131bd1d60b967e3ec60ccaa71dc9e1 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 14 Sep 2024 18:18:21 -0700 Subject: [PATCH 3/7] refactor(bedrock/chat): separate converse api and invoke api + isolate converse api transformation logic Make it easier to see how requests are transformed for /converse --- litellm/__init__.py | 2 +- litellm/llms/bedrock/chat/__init__.py | 2 + litellm/llms/bedrock/chat/converse.py | 228 ------- litellm/llms/bedrock/chat/converse_handler.py | 407 ++++++++++++ .../{chat.py => chat/invoke_handler.py} | 606 +----------------- .../google_ai_studio_endpoints.py | 22 +- litellm/tests/test_bedrock_completion.py | 2 +- litellm/tests/test_secret_manager.py | 2 +- 8 files changed, 431 insertions(+), 840 deletions(-) create mode 100644 litellm/llms/bedrock/chat/__init__.py delete mode 100644 litellm/llms/bedrock/chat/converse.py create mode 100644 litellm/llms/bedrock/chat/converse_handler.py rename litellm/llms/bedrock/{chat.py => chat/invoke_handler.py} (69%) diff --git a/litellm/__init__.py b/litellm/__init__.py index 047927dd9fdc..b884dbc91207 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -911,7 +911,7 @@ class LlmProviders(str, Enum): from .llms.ollama import OllamaConfig from .llms.ollama_chat import OllamaChatConfig from .llms.maritalk import MaritTalkConfig -from .llms.bedrock.chat import ( +from .llms.bedrock.chat.invoke_handler import ( AmazonCohereChatConfig, AmazonConverseConfig, BEDROCK_CONVERSE_MODELS, diff --git a/litellm/llms/bedrock/chat/__init__.py b/litellm/llms/bedrock/chat/__init__.py new file mode 100644 index 000000000000..c3f6aef6d238 --- /dev/null +++ b/litellm/llms/bedrock/chat/__init__.py @@ -0,0 +1,2 @@ +from .converse_handler import BedrockConverseLLM +from .invoke_handler import BedrockLLM diff --git a/litellm/llms/bedrock/chat/converse.py b/litellm/llms/bedrock/chat/converse.py deleted file mode 100644 index 23f1932277cc..000000000000 --- a/litellm/llms/bedrock/chat/converse.py +++ /dev/null @@ -1,228 +0,0 @@ -import json -from typing import Callable, Optional, Union - -import httpx - -import litellm -from litellm.llms.custom_httpx.http_handler import ( - AsyncHTTPHandler, - HTTPHandler, - get_async_httpx_client, -) -from litellm.types.utils import ModelResponse -from litellm.utils import CustomStreamWrapper, get_secret - -from ...base_aws_llm import BaseAWSLLM -from ..common_utils import BedrockError - -# class FakeBedrockConverseLLM(BaseAWSLLM): -# def __init__(self) -> None: -# super().__init__() - -# async def async_completion( -# self, -# model: str, -# messages: list, -# api_base: str, -# model_response: ModelResponse, -# print_verbose: Callable, -# data: str, -# timeout: Optional[Union[float, httpx.Timeout]], -# encoding, -# logging_obj, -# stream, -# optional_params: dict, -# litellm_params=None, -# logger_fn=None, -# headers={}, -# client: Optional[AsyncHTTPHandler] = None, -# ) -> Union[ModelResponse, CustomStreamWrapper]: -# if client is None or not isinstance(client, AsyncHTTPHandler): -# _params = {} -# if timeout is not None: -# if isinstance(timeout, float) or isinstance(timeout, int): -# timeout = httpx.Timeout(timeout) -# _params["timeout"] = timeout -# client = get_async_httpx_client( -# params=_params, llm_provider=litellm.LlmProviders.BEDROCK -# ) -# else: -# client = client # type: ignore - -# try: -# response = await client.post(url=api_base, headers=headers, data=data) # type: ignore -# response.raise_for_status() -# except httpx.HTTPStatusError as err: -# error_code = err.response.status_code -# raise BedrockError(status_code=error_code, message=err.response.text) -# except httpx.TimeoutException as e: -# raise BedrockError(status_code=408, message="Timeout error occurred.") - -# return litellm.AmazonConverseConfig()._transform_response( -# model=model, -# response=response, -# model_response=model_response, -# stream=stream if isinstance(stream, bool) else False, -# logging_obj=logging_obj, -# api_key="", -# data=data, -# messages=messages, -# print_verbose=print_verbose, -# optional_params=optional_params, -# encoding=encoding, -# ) - -# def completion( -# self, -# model: str, -# messages: list, -# api_base: Optional[str], -# custom_prompt_dict: dict, -# model_response: ModelResponse, -# print_verbose: Callable, -# encoding, -# logging_obj, -# optional_params: dict, -# acompletion: bool, -# timeout: Optional[Union[float, httpx.Timeout]], -# litellm_params: dict, -# logger_fn=None, -# extra_headers: Optional[dict] = None, -# client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None, -# ): -# try: -# import boto3 -# from botocore.auth import SigV4Auth -# from botocore.awsrequest import AWSRequest -# from botocore.credentials import Credentials -# except ImportError: -# raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") - -# ## SETUP ## -# stream = optional_params.pop("stream", None) -# modelId = optional_params.pop("model_id", None) -# if modelId is not None: -# modelId = self.encode_model_id(model_id=modelId) -# else: -# modelId = model - -# provider = model.split(".")[0] - -# # CREDENTIALS -# # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them -# aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) -# aws_access_key_id = optional_params.pop("aws_access_key_id", None) -# aws_session_token = optional_params.pop("aws_session_token", None) -# aws_region_name = optional_params.pop("aws_region_name", None) -# aws_role_name = optional_params.pop("aws_role_name", None) -# aws_session_name = optional_params.pop("aws_session_name", None) -# aws_profile_name = optional_params.pop("aws_profile_name", None) -# aws_bedrock_runtime_endpoint = optional_params.pop( -# "aws_bedrock_runtime_endpoint", None -# ) # https://bedrock-runtime.{region_name}.amazonaws.com -# aws_web_identity_token = optional_params.pop("aws_web_identity_token", None) -# aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None) - -# # if aws_region_name is None: -# # # check env # -# # litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) - -# # if litellm_aws_region_name is not None and isinstance( -# # litellm_aws_region_name, str -# # ): -# # aws_region_name = litellm_aws_region_name - -# # standard_aws_region_name = get_secret("AWS_REGION", None) -# # if standard_aws_region_name is not None and isinstance( -# # standard_aws_region_name, str -# # ): -# # aws_region_name = standard_aws_region_name - -# # if aws_region_name is None: -# # aws_region_name = "us-west-2" - -# if aws_region_name is None: -# aws_region_name = "us-west-2" - -# credentials = super().get_credentials( -# aws_access_key_id=aws_access_key_id, -# aws_secret_access_key=aws_secret_access_key, -# aws_session_token=aws_session_token, -# aws_region_name=aws_region_name, -# aws_session_name=aws_session_name, -# aws_profile_name=aws_profile_name, -# aws_role_name=aws_role_name, -# aws_web_identity_token=aws_web_identity_token, -# aws_sts_endpoint=aws_sts_endpoint, -# ) -# # sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name) - -# ### SET RUNTIME ENDPOINT ### -# endpoint_url, proxy_endpoint_url = super().get_runtime_endpoint( -# api_base=api_base, -# aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, -# aws_region_name=aws_region_name, -# ) - -# if (stream is not None and stream is True) and provider != "ai21": -# endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream" -# proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse-stream" -# else: -# endpoint_url = f"{endpoint_url}/model/{modelId}/converse" -# proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse" - -# _data = litellm.AmazonConverseConfig()._transform_request( -# model=model, -# messages=messages, -# optional_params=optional_params, -# litellm_params=litellm_params, -# ) - -# data = json.dumps(_data) - -# ## COMPLETION CALL - -# # headers = {"Content-Type": "application/json"} -# # if extra_headers is not None: -# # headers = {"Content-Type": "application/json", **extra_headers} -# # print(f"endpoint_url: {endpoint_url}, data: {data}, headers: {headers}") -# # request = AWSRequest( -# # method="POST", url=endpoint_url, data=data, headers=headers -# # ) -# # sigv4.add_auth(request) - -# # if ( -# # extra_headers is not None and "Authorization" in extra_headers -# # ): # prevent sigv4 from overwriting the auth header -# # request.headers["Authorization"] = extra_headers["Authorization"] -# # prepped = request.prepare() - -# # ## LOGGING -# # logging_obj.pre_call( -# # input=messages, -# # api_key="", -# # additional_args={ -# # "complete_input_dict": data, -# # "api_base": proxy_endpoint_url, -# # "headers": prepped.headers, -# # }, -# # ) - -# return self.async_completion( -# model=model, -# messages=messages, -# data=data, -# api_base=proxy_endpoint_url, -# model_response=model_response, -# print_verbose=print_verbose, -# encoding=encoding, -# logging_obj=logging_obj, -# optional_params=optional_params, -# stream=False, -# litellm_params=litellm_params, -# logger_fn=logger_fn, -# # headers=prepped.headers, -# headers={"Authorization": "my-fake-key"}, -# timeout=timeout, -# client=client, -# ) # type: ignore diff --git a/litellm/llms/bedrock/chat/converse_handler.py b/litellm/llms/bedrock/chat/converse_handler.py new file mode 100644 index 000000000000..4b14f1016b5c --- /dev/null +++ b/litellm/llms/bedrock/chat/converse_handler.py @@ -0,0 +1,407 @@ +import json +import urllib +from typing import Any, Callable, Optional, Union + +import httpx + +import litellm +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + HTTPHandler, + _get_httpx_client, + get_async_httpx_client, +) +from litellm.types.utils import ModelResponse +from litellm.utils import CustomStreamWrapper, get_secret + +from ...base_aws_llm import BaseAWSLLM +from ..common_utils import BedrockError +from .invoke_handler import AWSEventStreamDecoder, MockResponseIterator, make_call + + +def make_sync_call( + client: Optional[HTTPHandler], + api_base: str, + headers: dict, + data: str, + model: str, + messages: list, + logging_obj, +): + if client is None: + client = _get_httpx_client() # Create a new client if none provided + + response = client.post( + api_base, + headers=headers, + data=data, + stream=True if "ai21" not in api_base else False, + ) + + if response.status_code != 200: + raise BedrockError(status_code=response.status_code, message=response.read()) + + if "ai21" in api_base: + aws_bedrock_process_response = BedrockConverseLLM() + model_response: ModelResponse = aws_bedrock_process_response.process_response( + model=model, + response=response, + model_response=litellm.ModelResponse(), + stream=True, + logging_obj=logging_obj, + optional_params={}, + api_key="", + data=data, + messages=messages, + print_verbose=litellm.print_verbose, + encoding=litellm.encoding, + ) # type: ignore + completion_stream: Any = MockResponseIterator(model_response=model_response) + else: + decoder = AWSEventStreamDecoder(model=model) + completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024)) + + # LOGGING + logging_obj.post_call( + input=messages, + api_key="", + original_response="first stream response received", + additional_args={"complete_input_dict": data}, + ) + + return completion_stream + + +class BedrockConverseLLM(BaseAWSLLM): + def __init__(self) -> None: + super().__init__() + + def encode_model_id(self, model_id: str) -> str: + """ + Double encode the model ID to ensure it matches the expected double-encoded format. + Args: + model_id (str): The model ID to encode. + Returns: + str: The double-encoded model ID. + """ + return urllib.parse.quote(model_id, safe="") # type: ignore + + async def async_streaming( + self, + model: str, + messages: list, + api_base: str, + model_response: ModelResponse, + print_verbose: Callable, + data: str, + timeout: Optional[Union[float, httpx.Timeout]], + encoding, + logging_obj, + stream, + optional_params: dict, + litellm_params=None, + logger_fn=None, + headers={}, + client: Optional[AsyncHTTPHandler] = None, + ) -> CustomStreamWrapper: + + completion_stream = await make_call( + client=client, + api_base=api_base, + headers=headers, + data=data, + model=model, + messages=messages, + logging_obj=logging_obj, + ) + streaming_response = CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider="bedrock", + logging_obj=logging_obj, + ) + return streaming_response + + async def async_completion( + self, + model: str, + messages: list, + api_base: str, + model_response: ModelResponse, + print_verbose: Callable, + data: str, + timeout: Optional[Union[float, httpx.Timeout]], + encoding, + logging_obj, + stream, + optional_params: dict, + litellm_params=None, + logger_fn=None, + headers={}, + client: Optional[AsyncHTTPHandler] = None, + ) -> Union[ModelResponse, CustomStreamWrapper]: + if client is None or not isinstance(client, AsyncHTTPHandler): + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + timeout = httpx.Timeout(timeout) + _params["timeout"] = timeout + client = get_async_httpx_client( + params=_params, llm_provider=litellm.LlmProviders.BEDROCK + ) + else: + client = client # type: ignore + + try: + response = await client.post(url=api_base, headers=headers, data=data) # type: ignore + response.raise_for_status() + except httpx.HTTPStatusError as err: + error_code = err.response.status_code + raise BedrockError(status_code=error_code, message=err.response.text) + except httpx.TimeoutException as e: + raise BedrockError(status_code=408, message="Timeout error occurred.") + + return litellm.AmazonConverseConfig()._transform_response( + model=model, + response=response, + model_response=model_response, + stream=stream if isinstance(stream, bool) else False, + logging_obj=logging_obj, + api_key="", + data=data, + messages=messages, + print_verbose=print_verbose, + optional_params=optional_params, + encoding=encoding, + ) + + def completion( + self, + model: str, + messages: list, + api_base: Optional[str], + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + logging_obj, + optional_params: dict, + acompletion: bool, + timeout: Optional[Union[float, httpx.Timeout]], + litellm_params: dict, + logger_fn=None, + extra_headers: Optional[dict] = None, + client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None, + ): + try: + import boto3 + from botocore.auth import SigV4Auth + from botocore.awsrequest import AWSRequest + from botocore.credentials import Credentials + except ImportError: + raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") + + ## SETUP ## + stream = optional_params.pop("stream", None) + modelId = optional_params.pop("model_id", None) + if modelId is not None: + modelId = self.encode_model_id(model_id=modelId) + else: + modelId = model + + provider = model.split(".")[0] + + ## CREDENTIALS ## + # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them + aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) + aws_access_key_id = optional_params.pop("aws_access_key_id", None) + aws_session_token = optional_params.pop("aws_session_token", None) + aws_region_name = optional_params.pop("aws_region_name", None) + aws_role_name = optional_params.pop("aws_role_name", None) + aws_session_name = optional_params.pop("aws_session_name", None) + aws_profile_name = optional_params.pop("aws_profile_name", None) + aws_bedrock_runtime_endpoint = optional_params.pop( + "aws_bedrock_runtime_endpoint", None + ) # https://bedrock-runtime.{region_name}.amazonaws.com + aws_web_identity_token = optional_params.pop("aws_web_identity_token", None) + aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None) + + ### SET REGION NAME ### + if aws_region_name is None: + # check env # + litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) + + if litellm_aws_region_name is not None and isinstance( + litellm_aws_region_name, str + ): + aws_region_name = litellm_aws_region_name + + standard_aws_region_name = get_secret("AWS_REGION", None) + if standard_aws_region_name is not None and isinstance( + standard_aws_region_name, str + ): + aws_region_name = standard_aws_region_name + + if aws_region_name is None: + aws_region_name = "us-west-2" + + credentials: Credentials = self.get_credentials( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + aws_region_name=aws_region_name, + aws_session_name=aws_session_name, + aws_profile_name=aws_profile_name, + aws_role_name=aws_role_name, + aws_web_identity_token=aws_web_identity_token, + aws_sts_endpoint=aws_sts_endpoint, + ) + + ### SET RUNTIME ENDPOINT ### + endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint( + api_base=api_base, + aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, + aws_region_name=aws_region_name, + ) + if (stream is not None and stream is True) and provider != "ai21": + endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream" + proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse-stream" + else: + endpoint_url = f"{endpoint_url}/model/{modelId}/converse" + proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse" + + sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name) + + ## TRANSFORMATION ## + + _data = litellm.AmazonConverseConfig()._transform_request( + model=model, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + ) + data = json.dumps(_data) + ## COMPLETION CALL + + headers = {"Content-Type": "application/json"} + if extra_headers is not None: + headers = {"Content-Type": "application/json", **extra_headers} + request = AWSRequest( + method="POST", url=endpoint_url, data=data, headers=headers + ) + sigv4.add_auth(request) + if ( + extra_headers is not None and "Authorization" in extra_headers + ): # prevent sigv4 from overwriting the auth header + request.headers["Authorization"] = extra_headers["Authorization"] + prepped = request.prepare() + + ## LOGGING + logging_obj.pre_call( + input=messages, + api_key="", + additional_args={ + "complete_input_dict": data, + "api_base": proxy_endpoint_url, + "headers": prepped.headers, + }, + ) + + ### ROUTING (ASYNC, STREAMING, SYNC) + if acompletion: + if isinstance(client, HTTPHandler): + client = None + if stream is True: + return self.async_streaming( + model=model, + messages=messages, + data=data, + api_base=proxy_endpoint_url, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + logging_obj=logging_obj, + optional_params=optional_params, + stream=True, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=prepped.headers, + timeout=timeout, + client=client, + ) # type: ignore + ### ASYNC COMPLETION + return self.async_completion( + model=model, + messages=messages, + data=data, + api_base=proxy_endpoint_url, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + logging_obj=logging_obj, + optional_params=optional_params, + stream=stream, # type: ignore + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=prepped.headers, + timeout=timeout, + client=client, + ) # type: ignore + + if client is None or isinstance(client, AsyncHTTPHandler): + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + timeout = httpx.Timeout(timeout) + _params["timeout"] = timeout + client = _get_httpx_client(_params) # type: ignore + else: + client = client + + if stream is not None and stream is True: + completion_stream = make_sync_call( + client=( + client + if client is not None and isinstance(client, HTTPHandler) + else None + ), + api_base=proxy_endpoint_url, + headers=prepped.headers, # type: ignore + data=data, + model=model, + messages=messages, + logging_obj=logging_obj, + ) + streaming_response = CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider="bedrock", + logging_obj=logging_obj, + ) + + return streaming_response + + ### COMPLETION + + try: + response = client.post(url=proxy_endpoint_url, headers=prepped.headers, data=data) # type: ignore + response.raise_for_status() + except httpx.HTTPStatusError as err: + error_code = err.response.status_code + raise BedrockError(status_code=error_code, message=err.response.text) + except httpx.TimeoutException: + raise BedrockError(status_code=408, message="Timeout error occurred.") + + return litellm.AmazonConverseConfig()._transform_response( + model=model, + response=response, + model_response=model_response, + stream=stream if isinstance(stream, bool) else False, + logging_obj=logging_obj, + api_key="", + data=data, + messages=messages, + print_verbose=print_verbose, + optional_params=optional_params, + encoding=encoding, + ) diff --git a/litellm/llms/bedrock/chat.py b/litellm/llms/bedrock/chat/invoke_handler.py similarity index 69% rename from litellm/llms/bedrock/chat.py rename to litellm/llms/bedrock/chat/invoke_handler.py index 8068e5b03d39..fbad0940ce82 100644 --- a/litellm/llms/bedrock/chat.py +++ b/litellm/llms/bedrock/chat/invoke_handler.py @@ -52,8 +52,8 @@ from litellm.types.utils import GenericStreamingChunk as GChunk from litellm.utils import CustomStreamWrapper, ModelResponse, Usage, get_secret -from ..base_aws_llm import BaseAWSLLM -from ..prompt_templates.factory import ( +from ...base_aws_llm import BaseAWSLLM +from ...prompt_templates.factory import ( _bedrock_converse_messages_pt, _bedrock_tools_pt, cohere_message_pt, @@ -64,8 +64,8 @@ parse_xml_params, prompt_factory, ) -from .common_utils import BedrockError, ModelResponseIterator, get_bedrock_tool_name -from .converse_chat.converse_transformation import AmazonConverseConfig +from ..common_utils import BedrockError, ModelResponseIterator, get_bedrock_tool_name +from .converse_transformation import AmazonConverseConfig BEDROCK_CONVERSE_MODELS = [ "anthropic.claude-3-5-sonnet-20240620-v1:0", @@ -225,10 +225,9 @@ async def make_call( raise BedrockError(status_code=response.status_code, message=response.text) if "ai21" in api_base: - aws_bedrock_process_response = BedrockConverseLLM() model_response: ( ModelResponse - ) = aws_bedrock_process_response.process_response( + ) = litellm.AmazonConverseConfig()._transform_response( model=model, response=response, model_response=litellm.ModelResponse(), @@ -266,59 +265,6 @@ async def make_call( raise BedrockError(status_code=500, message=str(e)) -def make_sync_call( - client: Optional[HTTPHandler], - api_base: str, - headers: dict, - data: str, - model: str, - messages: list, - logging_obj, -): - if client is None: - client = _get_httpx_client() # Create a new client if none provided - - response = client.post( - api_base, - headers=headers, - data=data, - stream=True if "ai21" not in api_base else False, - ) - - if response.status_code != 200: - raise BedrockError(status_code=response.status_code, message=response.read()) - - if "ai21" in api_base: - aws_bedrock_process_response = BedrockConverseLLM() - model_response: ModelResponse = aws_bedrock_process_response.process_response( - model=model, - response=response, - model_response=litellm.ModelResponse(), - stream=True, - logging_obj=logging_obj, - optional_params={}, - api_key="", - data=data, - messages=messages, - print_verbose=litellm.print_verbose, - encoding=litellm.encoding, - ) # type: ignore - completion_stream: Any = MockResponseIterator(model_response=model_response) - else: - decoder = AWSEventStreamDecoder(model=model) - completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024)) - - # LOGGING - logging_obj.post_call( - input=messages, - api_key="", - original_response="first stream response received", - additional_args={"complete_input_dict": data}, - ) - - return completion_stream - - class BedrockLLM(BaseAWSLLM): """ Example call @@ -1118,548 +1064,6 @@ def embedding(self, *args, **kwargs): return super().embedding(*args, **kwargs) -class BedrockConverseLLM(BaseAWSLLM): - def __init__(self) -> None: - super().__init__() - - def process_response( - self, - model: str, - response: Union[requests.Response, httpx.Response], - model_response: ModelResponse, - stream: bool, - logging_obj: Optional[Logging], - optional_params: dict, - api_key: str, - data: Union[dict, str], - messages: List, - print_verbose, - encoding, - ) -> Union[ModelResponse, CustomStreamWrapper]: - - ## LOGGING - if logging_obj is not None: - logging_obj.post_call( - input=messages, - api_key=api_key, - original_response=response.text, - additional_args={"complete_input_dict": data}, - ) - print_verbose(f"raw model_response: {response.text}") - json_mode: Optional[bool] = optional_params.pop("json_mode", None) - ## RESPONSE OBJECT - try: - completion_response = ConverseResponseBlock(**response.json()) # type: ignore - except Exception as e: - raise BedrockError( - message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format( - response.text, str(e) - ), - status_code=422, - ) - - """ - Bedrock Response Object has optional message block - - completion_response["output"].get("message", None) - - A message block looks like this (Example 1): - "output": { - "message": { - "role": "assistant", - "content": [ - { - "text": "Is there anything else you'd like to talk about? Perhaps I can help with some economic questions or provide some information about economic concepts?" - } - ] - } - }, - (Example 2): - "output": { - "message": { - "role": "assistant", - "content": [ - { - "toolUse": { - "toolUseId": "tooluse_hbTgdi0CSLq_hM4P8csZJA", - "name": "top_song", - "input": { - "sign": "WZPZ" - } - } - } - ] - } - } - - """ - message: Optional[MessageBlock] = completion_response["output"]["message"] - chat_completion_message: ChatCompletionResponseMessage = {"role": "assistant"} - content_str = "" - tools: List[ChatCompletionToolCallChunk] = [] - if message is not None: - for idx, content in enumerate(message["content"]): - """ - - Content is either a tool response or text - """ - if "text" in content: - content_str += content["text"] - if "toolUse" in content: - - ## check tool name was formatted by litellm - _response_tool_name = content["toolUse"]["name"] - response_tool_name = get_bedrock_tool_name( - response_tool_name=_response_tool_name - ) - _function_chunk = ChatCompletionToolCallFunctionChunk( - name=response_tool_name, - arguments=json.dumps(content["toolUse"]["input"]), - ) - - _tool_response_chunk = ChatCompletionToolCallChunk( - id=content["toolUse"]["toolUseId"], - type="function", - function=_function_chunk, - index=idx, - ) - tools.append(_tool_response_chunk) - chat_completion_message["content"] = content_str - - if json_mode is True and tools is not None and len(tools) == 1: - # to support 'json_schema' logic on bedrock models - json_mode_content_str: Optional[str] = tools[0]["function"].get("arguments") - if json_mode_content_str is not None: - chat_completion_message["content"] = json_mode_content_str - else: - chat_completion_message["tool_calls"] = tools - - ## CALCULATING USAGE - bedrock returns usage in the headers - input_tokens = completion_response["usage"]["inputTokens"] - output_tokens = completion_response["usage"]["outputTokens"] - total_tokens = completion_response["usage"]["totalTokens"] - - model_response.choices = [ - litellm.Choices( - finish_reason=map_finish_reason(completion_response["stopReason"]), - index=0, - message=litellm.Message(**chat_completion_message), - ) - ] - model_response.created = int(time.time()) - model_response.model = model - usage = Usage( - prompt_tokens=input_tokens, - completion_tokens=output_tokens, - total_tokens=total_tokens, - ) - setattr(model_response, "usage", usage) - - # Add "trace" from Bedrock guardrails - if user has opted in to returning it - if "trace" in completion_response: - setattr(model_response, "trace", completion_response["trace"]) - - return model_response - - def encode_model_id(self, model_id: str) -> str: - """ - Double encode the model ID to ensure it matches the expected double-encoded format. - Args: - model_id (str): The model ID to encode. - Returns: - str: The double-encoded model ID. - """ - return urllib.parse.quote(model_id, safe="") - - async def async_streaming( - self, - model: str, - messages: list, - api_base: str, - model_response: ModelResponse, - print_verbose: Callable, - data: str, - timeout: Optional[Union[float, httpx.Timeout]], - encoding, - logging_obj, - stream, - optional_params: dict, - litellm_params=None, - logger_fn=None, - headers={}, - client: Optional[AsyncHTTPHandler] = None, - ) -> CustomStreamWrapper: - streaming_response = CustomStreamWrapper( - completion_stream=None, - make_call=partial( - make_call, - client=client, - api_base=api_base, - headers=headers, - data=data, - model=model, - messages=messages, - logging_obj=logging_obj, - ), - model=model, - custom_llm_provider="bedrock", - logging_obj=logging_obj, - ) - return streaming_response - - async def async_completion( - self, - model: str, - messages: list, - api_base: str, - model_response: ModelResponse, - print_verbose: Callable, - data: str, - timeout: Optional[Union[float, httpx.Timeout]], - encoding, - logging_obj, - stream, - optional_params: dict, - litellm_params=None, - logger_fn=None, - headers={}, - client: Optional[AsyncHTTPHandler] = None, - ) -> Union[ModelResponse, CustomStreamWrapper]: - if client is None or not isinstance(client, AsyncHTTPHandler): - _params = {} - if timeout is not None: - if isinstance(timeout, float) or isinstance(timeout, int): - timeout = httpx.Timeout(timeout) - _params["timeout"] = timeout - client = get_async_httpx_client( - params=_params, llm_provider=litellm.LlmProviders.BEDROCK - ) - else: - client = client # type: ignore - - try: - response = await client.post(url=api_base, headers=headers, data=data) # type: ignore - response.raise_for_status() - except httpx.HTTPStatusError as err: - error_code = err.response.status_code - raise BedrockError(status_code=error_code, message=err.response.text) - except httpx.TimeoutException as e: - raise BedrockError(status_code=408, message="Timeout error occurred.") - - return self.process_response( - model=model, - response=response, - model_response=model_response, - stream=stream if isinstance(stream, bool) else False, - logging_obj=logging_obj, - api_key="", - data=data, - messages=messages, - print_verbose=print_verbose, - optional_params=optional_params, - encoding=encoding, - ) - - def completion( - self, - model: str, - messages: list, - api_base: Optional[str], - custom_prompt_dict: dict, - model_response: ModelResponse, - print_verbose: Callable, - encoding, - logging_obj, - optional_params: dict, - acompletion: bool, - timeout: Optional[Union[float, httpx.Timeout]], - litellm_params: dict, - logger_fn=None, - extra_headers: Optional[dict] = None, - client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None, - ): - try: - import boto3 - from botocore.auth import SigV4Auth - from botocore.awsrequest import AWSRequest - from botocore.credentials import Credentials - except ImportError: - raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") - - ## SETUP ## - stream = optional_params.pop("stream", None) - modelId = optional_params.pop("model_id", None) - if modelId is not None: - modelId = self.encode_model_id(model_id=modelId) - else: - modelId = model - - provider = model.split(".")[0] - - ## CREDENTIALS ## - # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them - aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) - aws_access_key_id = optional_params.pop("aws_access_key_id", None) - aws_session_token = optional_params.pop("aws_session_token", None) - aws_region_name = optional_params.pop("aws_region_name", None) - aws_role_name = optional_params.pop("aws_role_name", None) - aws_session_name = optional_params.pop("aws_session_name", None) - aws_profile_name = optional_params.pop("aws_profile_name", None) - aws_bedrock_runtime_endpoint = optional_params.pop( - "aws_bedrock_runtime_endpoint", None - ) # https://bedrock-runtime.{region_name}.amazonaws.com - aws_web_identity_token = optional_params.pop("aws_web_identity_token", None) - aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None) - - ### SET REGION NAME ### - if aws_region_name is None: - # check env # - litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) - - if litellm_aws_region_name is not None and isinstance( - litellm_aws_region_name, str - ): - aws_region_name = litellm_aws_region_name - - standard_aws_region_name = get_secret("AWS_REGION", None) - if standard_aws_region_name is not None and isinstance( - standard_aws_region_name, str - ): - aws_region_name = standard_aws_region_name - - if aws_region_name is None: - aws_region_name = "us-west-2" - - credentials: Credentials = self.get_credentials( - aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key, - aws_session_token=aws_session_token, - aws_region_name=aws_region_name, - aws_session_name=aws_session_name, - aws_profile_name=aws_profile_name, - aws_role_name=aws_role_name, - aws_web_identity_token=aws_web_identity_token, - aws_sts_endpoint=aws_sts_endpoint, - ) - - ### SET RUNTIME ENDPOINT ### - endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint( - api_base=api_base, - aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, - aws_region_name=aws_region_name, - ) - if (stream is not None and stream is True) and provider != "ai21": - endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream" - proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse-stream" - else: - endpoint_url = f"{endpoint_url}/model/{modelId}/converse" - proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse" - - sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name) - - # Separate system prompt from rest of message - system_prompt_indices = [] - system_content_blocks: List[SystemContentBlock] = [] - for idx, message in enumerate(messages): - if message["role"] == "system": - _system_content_block: Optional[SystemContentBlock] = None - if isinstance(message["content"], str) and len(message["content"]) > 0: - _system_content_block = SystemContentBlock(text=message["content"]) - elif isinstance(message["content"], list): - for m in message["content"]: - if m.get("type", "") == "text" and len(m["text"]) > 0: - _system_content_block = SystemContentBlock(text=m["text"]) - if _system_content_block is not None: - system_content_blocks.append(_system_content_block) - system_prompt_indices.append(idx) - if len(system_prompt_indices) > 0: - for idx in reversed(system_prompt_indices): - messages.pop(idx) - - inference_params = copy.deepcopy(optional_params) - additional_request_keys = [] - additional_request_params = {} - supported_converse_params = AmazonConverseConfig.__annotations__.keys() - supported_tool_call_params = ["tools", "tool_choice"] - supported_guardrail_params = ["guardrailConfig"] - json_mode: Optional[bool] = inference_params.pop( - "json_mode", None - ) # used for handling json_schema - ## TRANSFORMATION ## - - bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt( - messages=messages, - model=model, - llm_provider="bedrock_converse", - user_continue_message=litellm_params.pop("user_continue_message", None), - ) - - # send all model-specific params in 'additional_request_params' - for k, v in inference_params.items(): - if ( - k not in supported_converse_params - and k not in supported_tool_call_params - and k not in supported_guardrail_params - ): - additional_request_params[k] = v - additional_request_keys.append(k) - for key in additional_request_keys: - inference_params.pop(key, None) - - bedrock_tools: List[ToolBlock] = _bedrock_tools_pt( - inference_params.pop("tools", []) - ) - bedrock_tool_config: Optional[ToolConfigBlock] = None - if len(bedrock_tools) > 0: - tool_choice_values: ToolChoiceValuesBlock = inference_params.pop( - "tool_choice", None - ) - bedrock_tool_config = ToolConfigBlock( - tools=bedrock_tools, - ) - if tool_choice_values is not None: - bedrock_tool_config["toolChoice"] = tool_choice_values - - _data: RequestObject = { - "messages": bedrock_messages, - "additionalModelRequestFields": additional_request_params, - "system": system_content_blocks, - "inferenceConfig": InferenceConfig(**inference_params), - } - - # Guardrail Config - guardrail_config: Optional[GuardrailConfigBlock] = None - request_guardrails_config = inference_params.pop("guardrailConfig", None) - if request_guardrails_config is not None: - guardrail_config = GuardrailConfigBlock(**request_guardrails_config) - _data["guardrailConfig"] = guardrail_config - - # Tool Config - if bedrock_tool_config is not None: - _data["toolConfig"] = bedrock_tool_config - - data = json.dumps(_data) - ## COMPLETION CALL - - headers = {"Content-Type": "application/json"} - if extra_headers is not None: - headers = {"Content-Type": "application/json", **extra_headers} - request = AWSRequest( - method="POST", url=endpoint_url, data=data, headers=headers - ) - sigv4.add_auth(request) - if ( - extra_headers is not None and "Authorization" in extra_headers - ): # prevent sigv4 from overwriting the auth header - request.headers["Authorization"] = extra_headers["Authorization"] - prepped = request.prepare() - - ## LOGGING - logging_obj.pre_call( - input=messages, - api_key="", - additional_args={ - "complete_input_dict": data, - "api_base": proxy_endpoint_url, - "headers": prepped.headers, - }, - ) - - ### ROUTING (ASYNC, STREAMING, SYNC) - if acompletion: - if isinstance(client, HTTPHandler): - client = None - if stream is True: - return self.async_streaming( - model=model, - messages=messages, - data=data, - api_base=proxy_endpoint_url, - model_response=model_response, - print_verbose=print_verbose, - encoding=encoding, - logging_obj=logging_obj, - optional_params=optional_params, - stream=True, - litellm_params=litellm_params, - logger_fn=logger_fn, - headers=prepped.headers, - timeout=timeout, - client=client, - ) # type: ignore - ### ASYNC COMPLETION - return self.async_completion( - model=model, - messages=messages, - data=data, - api_base=proxy_endpoint_url, - model_response=model_response, - print_verbose=print_verbose, - encoding=encoding, - logging_obj=logging_obj, - optional_params=optional_params, - stream=stream, # type: ignore - litellm_params=litellm_params, - logger_fn=logger_fn, - headers=prepped.headers, - timeout=timeout, - client=client, - ) # type: ignore - - if stream is not None and stream is True: - - streaming_response = CustomStreamWrapper( - completion_stream=None, - make_call=partial( - make_sync_call, - client=None, - api_base=proxy_endpoint_url, - headers=prepped.headers, # type: ignore - data=data, - model=model, - messages=messages, - logging_obj=logging_obj, - ), - model=model, - custom_llm_provider="bedrock", - logging_obj=logging_obj, - ) - - return streaming_response - ### COMPLETION - - if client is None or isinstance(client, AsyncHTTPHandler): - _params = {} - if timeout is not None: - if isinstance(timeout, float) or isinstance(timeout, int): - timeout = httpx.Timeout(timeout) - _params["timeout"] = timeout - client = _get_httpx_client(_params) # type: ignore - else: - client = client - try: - response = client.post(url=proxy_endpoint_url, headers=prepped.headers, data=data) # type: ignore - response.raise_for_status() - except httpx.HTTPStatusError as err: - error_code = err.response.status_code - raise BedrockError(status_code=error_code, message=err.response.text) - except httpx.TimeoutException: - raise BedrockError(status_code=408, message="Timeout error occurred.") - - return self.process_response( - model=model, - response=response, - model_response=model_response, - stream=stream if isinstance(stream, bool) else False, - logging_obj=logging_obj, - optional_params=optional_params, - api_key="", - data=data, - messages=messages, - print_verbose=print_verbose, - encoding=encoding, - ) - - def get_response_stream_shape(): global _response_stream_shape_cache if _response_stream_shape_cache is None: diff --git a/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py b/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py index a272af46b3ec..c4c71820b931 100644 --- a/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py +++ b/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py @@ -80,7 +80,13 @@ async def gemini_proxy_route( updated_url = base_url.copy_with(path=encoded_endpoint) # Add or update query parameters - gemini_api_key = litellm.utils.get_secret(secret_name="GEMINI_API_KEY") + gemini_api_key: Optional[str] = litellm.utils.get_secret( # type: ignore + secret_name="GEMINI_API_KEY" + ) + if gemini_api_key is None: + raise Exception( + "Required 'GEMINI_API_KEY' in environment to make pass-through calls to Google AI Studio." + ) # Merge query parameters, giving precedence to those in updated_url merged_params = dict(request.query_params) merged_params.update({"key": gemini_api_key}) @@ -99,8 +105,8 @@ async def gemini_proxy_route( request, fastapi_response, user_api_key_dict, - query_params=merged_params, - stream=is_streaming_request, + query_params=merged_params, # type: ignore + stream=is_streaming_request, # type: ignore ) return received_value @@ -142,7 +148,7 @@ async def cohere_proxy_route( request, fastapi_response, user_api_key_dict, - stream=is_streaming_request, + stream=is_streaming_request, # type: ignore ) return received_value @@ -208,15 +214,15 @@ async def bedrock_proxy_route( endpoint_func = create_pass_through_route( endpoint=endpoint, target=str(prepped.url), - custom_headers=prepped.headers, + custom_headers=dict(prepped.headers), ) # dynamically construct pass-through endpoint based on incoming path received_value = await endpoint_func( request, fastapi_response, user_api_key_dict, - stream=is_streaming_request, - custom_body=data, - query_params={}, + stream=is_streaming_request, # type: ignore + custom_body=data, # type: ignore + query_params={}, # type: ignore ) return received_value diff --git a/litellm/tests/test_bedrock_completion.py b/litellm/tests/test_bedrock_completion.py index e949d1ee7d2c..f80ace9fa4f5 100644 --- a/litellm/tests/test_bedrock_completion.py +++ b/litellm/tests/test_bedrock_completion.py @@ -27,7 +27,7 @@ completion_cost, embedding, ) -from litellm.llms.bedrock.chat import BedrockLLM, ToolBlock +from litellm.llms.bedrock.chat.invoke_handler import BedrockLLM, ToolBlock from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.prompt_templates.factory import _bedrock_tools_pt diff --git a/litellm/tests/test_secret_manager.py b/litellm/tests/test_secret_manager.py index 397128ecb060..0ca3a6cff22d 100644 --- a/litellm/tests/test_secret_manager.py +++ b/litellm/tests/test_secret_manager.py @@ -17,7 +17,7 @@ import pytest from litellm.llms.AzureOpenAI.azure import get_azure_ad_token_from_oidc -from litellm.llms.bedrock.chat import BedrockConverseLLM, BedrockLLM +from litellm.llms.bedrock.chat.invoke_handler import BedrockConverseLLM, BedrockLLM from litellm.secret_managers.aws_secret_manager import load_aws_secret_manager from litellm.secret_managers.main import get_secret From edcda3170c9d5466c971c54f19a5b6c3c24bc6df Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 14 Sep 2024 18:21:42 -0700 Subject: [PATCH 4/7] fix: fix imports --- litellm/main.py | 2 +- litellm/tests/test_bedrock_completion.py | 2 +- litellm/tests/test_secret_manager.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/litellm/main.py b/litellm/main.py index 3f1d57f6e25c..53c60213b9cd 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -3552,7 +3552,7 @@ def embedding( client=client, timeout=timeout, aembedding=aembedding, - litellm_params=litellm_params, + litellm_params={}, api_base=api_base, print_verbose=print_verbose, extra_headers=extra_headers, diff --git a/litellm/tests/test_bedrock_completion.py b/litellm/tests/test_bedrock_completion.py index f80ace9fa4f5..d14363983e14 100644 --- a/litellm/tests/test_bedrock_completion.py +++ b/litellm/tests/test_bedrock_completion.py @@ -27,7 +27,7 @@ completion_cost, embedding, ) -from litellm.llms.bedrock.chat.invoke_handler import BedrockLLM, ToolBlock +from litellm.llms.bedrock.chat import BedrockLLM from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.prompt_templates.factory import _bedrock_tools_pt diff --git a/litellm/tests/test_secret_manager.py b/litellm/tests/test_secret_manager.py index 0ca3a6cff22d..397128ecb060 100644 --- a/litellm/tests/test_secret_manager.py +++ b/litellm/tests/test_secret_manager.py @@ -17,7 +17,7 @@ import pytest from litellm.llms.AzureOpenAI.azure import get_azure_ad_token_from_oidc -from litellm.llms.bedrock.chat.invoke_handler import BedrockConverseLLM, BedrockLLM +from litellm.llms.bedrock.chat import BedrockConverseLLM, BedrockLLM from litellm.secret_managers.aws_secret_manager import load_aws_secret_manager from litellm.secret_managers.main import get_secret From 3507c2698c5674cc1e5c99fa9290e043f4438382 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 14 Sep 2024 18:49:33 -0700 Subject: [PATCH 5/7] fix(bedrock/embed): fix reordering of headers --- litellm/llms/bedrock/embed/embedding.py | 7 ++++--- .../vertex_ai_endpoints/google_ai_studio_endpoints.py | 2 +- litellm/utils.py | 4 +--- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/litellm/llms/bedrock/embed/embedding.py b/litellm/llms/bedrock/embed/embedding.py index e25fa0dc0841..8203eb6e6789 100644 --- a/litellm/llms/bedrock/embed/embedding.py +++ b/litellm/llms/bedrock/embed/embedding.py @@ -197,7 +197,7 @@ def _single_func_embeddings( client=client, timeout=timeout, api_base=prepped.url, - headers=dict(prepped.headers), + headers=prepped.headers, # type: ignore data=data, ) @@ -288,7 +288,7 @@ async def _async_single_func_embeddings( client=client, timeout=timeout, api_base=prepped.url, - headers=dict(prepped.headers), + headers=prepped.headers, # type: ignore data=data, ) @@ -454,6 +454,7 @@ def embeddings( headers = {"Content-Type": "application/json"} if extra_headers is not None: headers = {"Content-Type": "application/json", **extra_headers} + request = AWSRequest( method="POST", url=endpoint_url, data=json.dumps(data), headers=headers ) @@ -478,5 +479,5 @@ def embeddings( aembedding=aembedding, timeout=timeout, client=client, - headers=dict(prepped.headers), + headers=prepped.headers, # type: ignore ) diff --git a/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py b/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py index c4c71820b931..6051af37c8be 100644 --- a/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py +++ b/litellm/proxy/vertex_ai_endpoints/google_ai_studio_endpoints.py @@ -214,7 +214,7 @@ async def bedrock_proxy_route( endpoint_func = create_pass_through_route( endpoint=endpoint, target=str(prepped.url), - custom_headers=dict(prepped.headers), + custom_headers=prepped.headers, # type: ignore ) # dynamically construct pass-through endpoint based on incoming path received_value = await endpoint_func( request, diff --git a/litellm/utils.py b/litellm/utils.py index d3e757ae814e..907839d51ed8 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -8153,9 +8153,7 @@ def exception_type( exception_mapping_worked = True if hasattr(original_exception, "request"): raise APIConnectionError( - message="{}\n{}".format( - str(original_exception), traceback.format_exc() - ), + message="{} - {}".format(exception_provider, error_str), llm_provider=custom_llm_provider, model=model, request=original_exception.request, From 1a17a8ddaa057ad8d94c67395e09be0e9232bc62 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 14 Sep 2024 19:52:26 -0700 Subject: [PATCH 6/7] fix(base_aws_llm.py): fix get credential logic --- litellm/llms/base_aws_llm.py | 38 ++++++++++++++++++++---- litellm/tests/test_bedrock_completion.py | 38 ++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 5 deletions(-) diff --git a/litellm/llms/base_aws_llm.py b/litellm/llms/base_aws_llm.py index c42d847c351c..348a84180059 100644 --- a/litellm/llms/base_aws_llm.py +++ b/litellm/llms/base_aws_llm.py @@ -1,5 +1,6 @@ import hashlib import json +import os from typing import Dict, List, Optional, Tuple import httpx @@ -52,14 +53,22 @@ def get_credentials( """ Return a boto3.Credentials object """ - args = locals() - args.pop("self") - cache_key = self.get_cache_key(args) import boto3 from botocore.credentials import Credentials ## CHECK IS 'os.environ/' passed in + param_names = [ + "aws_access_key_id", + "aws_secret_access_key", + "aws_session_token", + "aws_region_name", + "aws_session_name", + "aws_profile_name", + "aws_role_name", + "aws_web_identity_token", + "aws_sts_endpoint", + ] params_to_check: List[Optional[str]] = [ aws_access_key_id, aws_secret_access_key, @@ -78,6 +87,11 @@ def get_credentials( _v = get_secret(param) if _v is not None and isinstance(_v, str): params_to_check[i] = _v + elif param is None: # check if uppercase value in env + key = param_names[i] + if key.upper() in os.environ: + params_to_check[i] = os.getenv(key) + # Assign updated values back to parameters ( aws_access_key_id, @@ -91,6 +105,10 @@ def get_credentials( aws_sts_endpoint, ) = params_to_check + # create cache key for non-expiring auth flows + args = {k: v for k, v in locals().items() if k.startswith("aws_")} + cache_key = self.get_cache_key(args) + verbose_logger.debug( "in get credentials\n" "aws_access_key_id=%s\n" @@ -226,8 +244,11 @@ def get_credentials( ) return credentials - else: - + elif ( + aws_access_key_id is not None + and aws_secret_access_key is not None + and aws_region_name is not None + ): # Check if credentials are already in cache. These credentials have no expiry time. cached_credentials: Optional[Credentials] = self.iam_cache.get_cache( cache_key @@ -249,6 +270,13 @@ def get_credentials( self.iam_cache.set_cache(cache_key, credentials, ttl=3600 - 60) return credentials + else: + # check env var. Do not cache the response from this. + session = boto3.Session() + + credentials = session.get_credentials() + + return credentials def get_runtime_endpoint( self, diff --git a/litellm/tests/test_bedrock_completion.py b/litellm/tests/test_bedrock_completion.py index d14363983e14..e786e4fc8e8c 100644 --- a/litellm/tests/test_bedrock_completion.py +++ b/litellm/tests/test_bedrock_completion.py @@ -1287,3 +1287,41 @@ def test_bedrock_converse_translation_tool_message(): ], } ] + + +def test_base_aws_llm_get_credentials(): + import time + + import boto3 + + from litellm.llms.base_aws_llm import BaseAWSLLM + + start_time = time.time() + session = boto3.Session( + aws_access_key_id="test", + aws_secret_access_key="test2", + region_name="test3", + ) + credentials = session.get_credentials().get_frozen_credentials() + end_time = time.time() + + print( + "Total time for credentials - {}. Credentials - {}".format( + end_time - start_time, credentials + ) + ) + + start_time = time.time() + credentials = BaseAWSLLM().get_credentials( + aws_access_key_id="test", + aws_secret_access_key="test2", + aws_region_name="test3", + ) + + end_time = time.time() + + print( + "Total time for credentials - {}. Credentials - {}".format( + end_time - start_time, credentials.get_frozen_credentials() + ) + ) From 9af8f7c90c6507c6498b9324fe466c50f3e7940c Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 14 Sep 2024 20:17:16 -0700 Subject: [PATCH 7/7] fix(converse_handler.py): fix ai21 streaming response --- litellm/llms/bedrock/chat/converse_handler.py | 5 +++-- litellm/tests/test_streaming.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/litellm/llms/bedrock/chat/converse_handler.py b/litellm/llms/bedrock/chat/converse_handler.py index 4b14f1016b5c..caf6113ac4a1 100644 --- a/litellm/llms/bedrock/chat/converse_handler.py +++ b/litellm/llms/bedrock/chat/converse_handler.py @@ -42,8 +42,9 @@ def make_sync_call( raise BedrockError(status_code=response.status_code, message=response.read()) if "ai21" in api_base: - aws_bedrock_process_response = BedrockConverseLLM() - model_response: ModelResponse = aws_bedrock_process_response.process_response( + model_response: ( + ModelResponse + ) = litellm.AmazonConverseConfig()._transform_response( model=model, response=response, model_response=litellm.ModelResponse(), diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 2c4190382bb5..1debf817170b 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -1454,7 +1454,7 @@ async def test_bedrock_httpx_streaming(sync_mode, model, region): has_finish_reason = True break complete_response += chunk - if has_finish_reason == False: + if has_finish_reason is False: raise Exception("finish reason not set") if complete_response.strip() == "": raise Exception("Empty response received")