Skip to content

Add gemini audio input support + handle special tokens in sagemaker response #9640

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions litellm/llms/sagemaker/completion/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse, Usage
from litellm.utils import token_counter

from ..common_utils import SagemakerError

Expand Down Expand Up @@ -238,9 +239,12 @@ def transform_response(
)

## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here.
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
prompt_tokens = token_counter(
text=prompt, count_response_tokens=True
) # doesn't apply any default token count from openai's chat template
completion_tokens = token_counter(
text=model_response["choices"][0]["message"].get("content", ""),
count_response_tokens=True,
)

model_response.created = int(time.time())
Expand Down
100 changes: 83 additions & 17 deletions litellm/llms/vertex_ai/gemini/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionAssistantMessage,
ChatCompletionAudioObject,
ChatCompletionFileObject,
ChatCompletionImageObject,
ChatCompletionTextObject,
)
Expand Down Expand Up @@ -103,24 +105,53 @@ def _get_image_mime_type_from_url(url: str) -> Optional[str]:
See gemini mime types: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/image-understanding#image-requirements

Supported by Gemini:
- PNG (`image/png`)
- JPEG (`image/jpeg`)
- WebP (`image/webp`)
Example:
url = https://example.com/image.jpg
Returns: image/jpeg
application/pdf
audio/mpeg
audio/mp3
audio/wav
image/png
image/jpeg
image/webp
text/plain
video/mov
video/mpeg
video/mp4
video/mpg
video/avi
video/wmv
video/mpegps
video/flv
"""
url = url.lower()
if url.endswith((".jpg", ".jpeg")):
return "image/jpeg"
elif url.endswith(".png"):
return "image/png"
elif url.endswith(".webp"):
return "image/webp"
elif url.endswith(".mp4"):
return "video/mp4"
elif url.endswith(".pdf"):
return "application/pdf"

# Map file extensions to mime types
mime_types = {
# Images
(".jpg", ".jpeg"): "image/jpeg",
(".png",): "image/png",
(".webp",): "image/webp",
# Videos
(".mp4",): "video/mp4",
(".mov",): "video/mov",
(".mpeg", ".mpg"): "video/mpeg",
(".avi",): "video/avi",
(".wmv",): "video/wmv",
(".mpegps",): "video/mpegps",
(".flv",): "video/flv",
# Audio
(".mp3",): "audio/mp3",
(".wav",): "audio/wav",
(".mpeg",): "audio/mpeg",
# Documents
(".pdf",): "application/pdf",
(".txt",): "text/plain",
}

# Check each extension group against the URL
for extensions, mime_type in mime_types.items():
if any(url.endswith(ext) for ext in extensions):
return mime_type

return None


Expand Down Expand Up @@ -152,7 +183,7 @@ def _gemini_convert_messages_with_history( # noqa: PLR0915
_message_content = messages[msg_i].get("content")
if _message_content is not None and isinstance(_message_content, list):
_parts: List[PartType] = []
for element in _message_content:
for element_idx, element in enumerate(_message_content):
if (
element["type"] == "text"
and "text" in element
Expand All @@ -174,6 +205,41 @@ def _gemini_convert_messages_with_history( # noqa: PLR0915
image_url=image_url, format=format
)
_parts.append(_part)
elif element["type"] == "input_audio":
audio_element = cast(ChatCompletionAudioObject, element)
if audio_element["input_audio"].get("data") is not None:
_part = PartType(
inline_data=BlobType(
data=audio_element["input_audio"]["data"],
mime_type="audio/{}".format(
audio_element["input_audio"]["format"]
),
)
)
_parts.append(_part)
elif element["type"] == "file":
file_element = cast(ChatCompletionFileObject, element)
file_id = file_element["file"].get("file_id")
format = file_element["file"].get("format")

if not file_id:
continue
mime_type = format or _get_image_mime_type_from_url(file_id)

if mime_type is not None:
_part = PartType(
file_data=FileDataType(
file_uri=file_id,
mime_type=mime_type,
)
)
_parts.append(_part)
else:
raise Exception(
"Unable to determine mime type for file_id: {}, set this explicitly using message[{}].content[{}].file.format".format(
file_id, msg_i, element_idx
)
)
user_content.extend(_parts)
elif (
_message_content is not None
Expand Down
2 changes: 2 additions & 0 deletions litellm/model_prices_and_context_window_backup.json
Original file line number Diff line number Diff line change
Expand Up @@ -4696,6 +4696,8 @@
"supports_vision": true,
"supports_response_schema": true,
"supports_audio_output": true,
"supports_audio_input": true,
"supported_modalities": ["text", "image", "audio", "video"],
"supports_tool_choice": true,
"source": "https://ai.google.dev/pricing#2_0flash"
},
Expand Down
16 changes: 10 additions & 6 deletions litellm/proxy/_new_secret_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,18 @@ model_list:
- model_name: "bedrock-nova"
litellm_params:
model: us.amazon.nova-pro-v1:0
- model_name: "gemini-2.0-flash"
litellm_params:
model: gemini/gemini-2.0-flash
api_key: os.environ/GEMINI_API_KEY

litellm_settings:
num_retries: 0
callbacks: ["prometheus"]
json_logs: true
# json_logs: true

router_settings:
routing_strategy: usage-based-routing-v2 # 👈 KEY CHANGE
redis_host: os.environ/REDIS_HOST
redis_password: os.environ/REDIS_PASSWORD
redis_port: os.environ/REDIS_PORT
# router_settings:
# routing_strategy: usage-based-routing-v2 # 👈 KEY CHANGE
# redis_host: os.environ/REDIS_HOST
# redis_password: os.environ/REDIS_PASSWORD
# redis_port: os.environ/REDIS_PORT
24 changes: 2 additions & 22 deletions litellm/proxy/management_endpoints/internal_user_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1370,22 +1370,6 @@ async def get_user_daily_activity(
default=None,
description="End date in YYYY-MM-DD format",
),
group_by: List[GroupByDimension] = fastapi.Query(
default=[GroupByDimension.DATE],
description="Dimensions to group by. Can combine multiple (e.g. date,team)",
),
view_by: Literal["team", "organization", "user"] = fastapi.Query(
default="user",
description="View spend at team/org/user level",
),
team_id: Optional[str] = fastapi.Query(
default=None,
description="Filter by specific team",
),
org_id: Optional[str] = fastapi.Query(
default=None,
description="Filter by specific organization",
),
model: Optional[str] = fastapi.Query(
default=None,
description="Filter by specific model",
Expand All @@ -1408,13 +1392,13 @@ async def get_user_daily_activity(
Meant to optimize querying spend data for analytics for a user.

Returns:
(by date/team/org/user/model/api_key/model_group/provider)
(by date)
- spend
- prompt_tokens
- completion_tokens
- total_tokens
- api_requests
- breakdown by team, organization, user, model, api_key, model_group, provider
- breakdown by model, api_key, provider
"""
from litellm.proxy.proxy_server import prisma_client

Expand All @@ -1439,10 +1423,6 @@ async def get_user_daily_activity(
}
}

if team_id:
where_conditions["team_id"] = team_id
if org_id:
where_conditions["organization_id"] = org_id
if model:
where_conditions["model"] = model
if api_key:
Expand Down
9 changes: 5 additions & 4 deletions litellm/types/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,10 +505,11 @@ class ChatCompletionDocumentObject(TypedDict):
citations: Optional[CitationsObject]


class ChatCompletionFileObjectFile(TypedDict):
file_data: Optional[str]
file_id: Optional[str]
filename: Optional[str]
class ChatCompletionFileObjectFile(TypedDict, total=False):
file_data: str
file_id: str
filename: str
format: str


class ChatCompletionFileObject(TypedDict):
Expand Down
2 changes: 2 additions & 0 deletions model_prices_and_context_window.json
Original file line number Diff line number Diff line change
Expand Up @@ -4696,6 +4696,8 @@
"supports_vision": true,
"supports_response_schema": true,
"supports_audio_output": true,
"supports_audio_input": true,
"supported_modalities": ["text", "image", "audio", "video"],
"supports_tool_choice": true,
"source": "https://ai.google.dev/pricing#2_0flash"
},
Expand Down
71 changes: 71 additions & 0 deletions tests/llm_translation/base_llm_unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
import uuid
import time
import base64

sys.path.insert(
0, os.path.abspath("../..")
Expand Down Expand Up @@ -889,6 +890,74 @@

assert cost > 0

@pytest.mark.parametrize("input_type", ["input_audio", "audio_url"])
def test_supports_audio_input(self, input_type):
from litellm.utils import return_raw_request, supports_audio_input
from litellm.types.utils import CallTypes
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
litellm.model_cost = litellm.get_model_cost_map(url="")


litellm.drop_params = True
base_completion_call_args = self.get_base_completion_call_args()
if not supports_audio_input(base_completion_call_args["model"], None):
print("Model does not support audio input")
pytest.skip("Model does not support audio input")

url = "https://openaiassets.blob.core.windows.net/$web/API/docs/audio/alloy.wav"
response = httpx.get(url)
response.raise_for_status()
wav_data = response.content
audio_format = "wav"
encoded_string = base64.b64encode(wav_data).decode("utf-8")

audio_content = [
{
"type": "text",
"text": "What is in this recording?"
}
]

test_file_id = "gs://bucket/file.wav"

if input_type == "input_audio":
audio_content.append({
"type": "input_audio",
"input_audio": {"data": encoded_string, "format": audio_format},
})
elif input_type == "audio_url":
audio_content.append(
{
"type": "file",
"file": {
"file_id": test_file_id,
"filename": "my-sample-audio-file",
}
}
)



raw_request = return_raw_request(
endpoint=CallTypes.completion,
kwargs={
**base_completion_call_args,
"modalities": ["text", "audio"],
"audio": {"voice": "alloy", "format": audio_format},
"messages": [
{
"role": "user",
"content": audio_content,
},
]
}
)
print("raw_request: ", raw_request)

Check failure

Code scanning / CodeQL

Clear-text logging of sensitive information High test

This expression logs
sensitive data (secret)
as clear text.
This expression logs
sensitive data (password)
as clear text.
This expression logs
sensitive data (secret)
as clear text.
This expression logs
sensitive data (password)
as clear text.

Copilot Autofix

AI 3 months ago

To fix the problem, we need to ensure that sensitive information is not logged. This can be achieved by scrubbing the raw_request object of any sensitive data before logging it. We can use a utility function to remove or mask sensitive information from the raw_request object before printing it.

  • Add a utility function to scrub sensitive information from the raw_request object.
  • Use this utility function to clean the raw_request object before logging it.
  • Ensure that the changes are made in the tests/llm_translation/base_llm_unit_tests.py file.
Suggested changeset 1
tests/llm_translation/base_llm_unit_tests.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/tests/llm_translation/base_llm_unit_tests.py b/tests/llm_translation/base_llm_unit_tests.py
--- a/tests/llm_translation/base_llm_unit_tests.py
+++ b/tests/llm_translation/base_llm_unit_tests.py
@@ -11,2 +11,9 @@
 
+def scrub_sensitive_data(data):
+    if isinstance(data, dict):
+        for key in ["client_secret", "api_key", "azure_ad_token", "azure_username", "azure_password"]:
+            if key in data:
+                data[key] = "REDACTED"
+    return data
+
 sys.path.insert(
@@ -954,2 +961,3 @@
         )
+        raw_request = scrub_sensitive_data(raw_request)
         print("raw_request: ", raw_request)
EOF
@@ -11,2 +11,9 @@

def scrub_sensitive_data(data):
if isinstance(data, dict):
for key in ["client_secret", "api_key", "azure_ad_token", "azure_username", "azure_password"]:
if key in data:
data[key] = "REDACTED"
return data

sys.path.insert(
@@ -954,2 +961,3 @@
)
raw_request = scrub_sensitive_data(raw_request)
print("raw_request: ", raw_request)
Copilot is powered by AI and may make mistakes. Always verify output.

if input_type == "input_audio":
assert encoded_string in json.dumps(raw_request), "Audio data not sent to gemini"
elif input_type == "audio_url":
assert test_file_id in json.dumps(raw_request), "Audio URL not sent to gemini"

class BaseOSeriesModelsTest(ABC): # test across azure/openai
@abstractmethod
Expand Down Expand Up @@ -1089,3 +1158,5 @@
)

print(response)


2 changes: 1 addition & 1 deletion tests/llm_translation/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

class TestGoogleAIStudioGemini(BaseLLMChatTest):
def get_base_completion_call_args(self) -> dict:
return {"model": "gemini/gemini-1.5-flash-002"}
return {"model": "gemini/gemini-2.0-flash"}

def test_tool_call_no_arguments(self, tool_call_no_arguments):
"""Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
Expand Down
7 changes: 5 additions & 2 deletions tests/llm_translation/test_gpt4o_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,20 +84,22 @@ async def test_audio_output_from_model(stream):

@pytest.mark.asyncio
@pytest.mark.parametrize("stream", [True, False])
async def test_audio_input_to_model(stream):
@pytest.mark.parametrize("model", ["gpt-4o-audio-preview"]) # "gpt-4o-audio-preview",
async def test_audio_input_to_model(stream, model):
# Fetch the audio file and convert it to a base64 encoded string
audio_format = "pcm16"
if stream is False:
audio_format = "wav"
litellm._turn_on_debug()
litellm.drop_params = True
url = "https://openaiassets.blob.core.windows.net/$web/API/docs/audio/alloy.wav"
response = requests.get(url)
response.raise_for_status()
wav_data = response.content
encoded_string = base64.b64encode(wav_data).decode("utf-8")
try:
completion = await litellm.acompletion(
model="gpt-4o-audio-preview",
model=model,
modalities=["text", "audio"],
audio={"voice": "alloy", "format": audio_format},
stream=stream,
Expand All @@ -120,6 +122,7 @@ async def test_audio_input_to_model(stream):
except Exception as e:
if "openai-internal" in str(e):
pytest.skip("Skipping test due to openai-internal error")
raise e
if stream is True:
await check_streaming_response(completion)
else:
Expand Down
Loading
Loading