diff --git a/llama-index-integrations/llms/llama-index-llms-bedrock-converse/llama_index/llms/bedrock_converse/base.py b/llama-index-integrations/llms/llama-index-llms-bedrock-converse/llama_index/llms/bedrock_converse/base.py index 2daa04611c14a..e61e9b4e9bc7b 100644 --- a/llama-index-integrations/llms/llama-index-llms-bedrock-converse/llama_index/llms/bedrock_converse/base.py +++ b/llama-index-integrations/llms/llama-index-llms-bedrock-converse/llama_index/llms/bedrock_converse/base.py @@ -325,6 +325,10 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: response ) + dict_response = dict(response) + # Add Bedrock's token count to usage dict to match OpenAI's format + dict_response["usage"] = self._get_response_token_counts(dict_response) + return ChatResponse( message=ChatMessage( role=MessageRole.ASSISTANT, @@ -335,7 +339,7 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: "status": status, }, ), - raw=dict(response), + raw=dict_response, additional_kwargs=self._get_response_token_counts(dict(response)), ) @@ -367,6 +371,10 @@ def stream_chat( **all_kwargs, ) + dict_response = dict(response) + # Add Bedrock's token count to usage dict to match OpenAI's format + dict_response["usage"] = self._get_response_token_counts(dict_response) + def gen() -> ChatResponseGen: content = {} role = MessageRole.ASSISTANT @@ -392,7 +400,7 @@ def gen() -> ChatResponseGen: }, ), delta=content_delta.get("text", ""), - raw=response, + raw=dict_response, additional_kwargs=self._get_response_token_counts( dict(response) ), @@ -417,7 +425,7 @@ def gen() -> ChatResponseGen: "status": status, }, ), - raw=response, + raw=dict_response, additional_kwargs=self._get_response_token_counts( dict(response) ), @@ -458,6 +466,10 @@ async def achat( response ) + dict_response = dict(response) + # Add Bedrock's token count to usage dict to match OpenAI's format + dict_response["usage"] = self._get_response_token_counts(dict_response) + return ChatResponse( message=ChatMessage( role=MessageRole.ASSISTANT, @@ -468,7 +480,7 @@ async def achat( "status": status, }, ), - raw=dict(response), + raw=dict_response, additional_kwargs=self._get_response_token_counts(dict(response)), )