Skip to content

Commit bf7745f

Browse files
seanzhougooglecopybara-github
authored andcommitted
fix: Create correct object for image and video content in litellm
PiperOrigin-RevId: 783478779
1 parent c2058f3 commit bf7745f

File tree

2 files changed

+61
-39
lines changed

2 files changed

+61
-39
lines changed

src/google/adk/models/lite_llm.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,13 @@
3636
from litellm import ChatCompletionAssistantToolCall
3737
from litellm import ChatCompletionDeveloperMessage
3838
from litellm import ChatCompletionFileObject
39+
from litellm import ChatCompletionImageObject
3940
from litellm import ChatCompletionImageUrlObject
4041
from litellm import ChatCompletionMessageToolCall
4142
from litellm import ChatCompletionTextObject
4243
from litellm import ChatCompletionToolMessage
4344
from litellm import ChatCompletionUserMessage
45+
from litellm import ChatCompletionVideoObject
4446
from litellm import ChatCompletionVideoUrlObject
4547
from litellm import completion
4648
from litellm import CustomStreamWrapper
@@ -250,17 +252,25 @@ def _get_content(
250252
data_uri = f"data:{part.inline_data.mime_type};base64,{base64_string}"
251253

252254
if part.inline_data.mime_type.startswith("image"):
255+
# Extract format from mime type (e.g., "image/png" -> "png")
256+
format_type = part.inline_data.mime_type.split("/")[-1]
253257
content_objects.append(
254-
ChatCompletionImageUrlObject(
258+
ChatCompletionImageObject(
255259
type="image_url",
256-
image_url=data_uri,
260+
image_url=ChatCompletionImageUrlObject(
261+
url=data_uri, format=format_type
262+
),
257263
)
258264
)
259265
elif part.inline_data.mime_type.startswith("video"):
266+
# Extract format from mime type (e.g., "video/mp4" -> "mp4")
267+
format_type = part.inline_data.mime_type.split("/")[-1]
260268
content_objects.append(
261-
ChatCompletionVideoUrlObject(
269+
ChatCompletionVideoObject(
262270
type="video_url",
263-
video_url=data_uri,
271+
video_url=ChatCompletionVideoUrlObject(
272+
url=data_uri, format=format_type
273+
),
264274
)
265275
)
266276
elif part.inline_data.mime_type == "application/pdf":

tests/unittests/models/test_litellm.py

Lines changed: 47 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -780,39 +780,6 @@ async def test_generate_content_async_with_tool_response(
780780
assert kwargs["messages"][2]["content"] == '{"result": "test_result"}'
781781

782782

783-
@pytest.mark.asyncio
784-
async def test_generate_content_async(mock_acompletion, lite_llm_instance):
785-
786-
async for response in lite_llm_instance.generate_content_async(
787-
LLM_REQUEST_WITH_FUNCTION_DECLARATION
788-
):
789-
assert response.content.role == "model"
790-
assert response.content.parts[0].text == "Test response"
791-
assert response.content.parts[1].function_call.name == "test_function"
792-
assert response.content.parts[1].function_call.args == {
793-
"test_arg": "test_value"
794-
}
795-
assert response.content.parts[1].function_call.id == "test_tool_call_id"
796-
797-
mock_acompletion.assert_called_once()
798-
799-
_, kwargs = mock_acompletion.call_args
800-
assert kwargs["model"] == "test_model"
801-
assert kwargs["messages"][0]["role"] == "user"
802-
assert kwargs["messages"][0]["content"] == "Test prompt"
803-
assert kwargs["tools"][0]["function"]["name"] == "test_function"
804-
assert (
805-
kwargs["tools"][0]["function"]["description"]
806-
== "Test function description"
807-
)
808-
assert (
809-
kwargs["tools"][0]["function"]["parameters"]["properties"]["test_arg"][
810-
"type"
811-
]
812-
== "string"
813-
)
814-
815-
816783
@pytest.mark.asyncio
817784
async def test_generate_content_async_with_usage_metadata(
818785
lite_llm_instance, mock_acompletion
@@ -924,6 +891,43 @@ def test_content_to_message_param_function_call():
924891
assert tool_call["function"]["arguments"] == '{"test_arg": "test_value"}'
925892

926893

894+
def test_content_to_message_param_multipart_content():
895+
"""Test handling of multipart content where final_content is a list with text objects."""
896+
content = types.Content(
897+
role="assistant",
898+
parts=[
899+
types.Part.from_text(text="text part"),
900+
types.Part.from_bytes(data=b"test_image_data", mime_type="image/png"),
901+
],
902+
)
903+
message = _content_to_message_param(content)
904+
assert message["role"] == "assistant"
905+
# When content is a list and the first element is a text object with type "text",
906+
# it should extract the text (for providers like ollama_chat that don't handle lists well)
907+
# This is the behavior implemented in the fix
908+
assert message["content"] == "text part"
909+
assert message["tool_calls"] is None
910+
911+
912+
def test_content_to_message_param_single_text_object_in_list():
913+
"""Test extraction of text from single text object in list (for ollama_chat compatibility)."""
914+
from unittest.mock import patch
915+
916+
# Mock _get_content to return a list with single text object
917+
with patch("google.adk.models.lite_llm._get_content") as mock_get_content:
918+
mock_get_content.return_value = [{"type": "text", "text": "single text"}]
919+
920+
content = types.Content(
921+
role="assistant",
922+
parts=[types.Part.from_text(text="single text")],
923+
)
924+
message = _content_to_message_param(content)
925+
assert message["role"] == "assistant"
926+
# Should extract the text from the single text object
927+
assert message["content"] == "single text"
928+
assert message["tool_calls"] is None
929+
930+
927931
def test_message_to_generate_content_response_text():
928932
message = ChatCompletionAssistantMessage(
929933
role="assistant",
@@ -971,7 +975,11 @@ def test_get_content_image():
971975
]
972976
content = _get_content(parts)
973977
assert content[0]["type"] == "image_url"
974-
assert content[0]["image_url"] == "data:image/png;base64,dGVzdF9pbWFnZV9kYXRh"
978+
assert (
979+
content[0]["image_url"]["url"]
980+
== "data:image/png;base64,dGVzdF9pbWFnZV9kYXRh"
981+
)
982+
assert content[0]["image_url"]["format"] == "png"
975983

976984

977985
def test_get_content_video():
@@ -980,7 +988,11 @@ def test_get_content_video():
980988
]
981989
content = _get_content(parts)
982990
assert content[0]["type"] == "video_url"
983-
assert content[0]["video_url"] == "data:video/mp4;base64,dGVzdF92aWRlb19kYXRh"
991+
assert (
992+
content[0]["video_url"]["url"]
993+
== "data:video/mp4;base64,dGVzdF92aWRlb19kYXRh"
994+
)
995+
assert content[0]["video_url"]["format"] == "mp4"
984996

985997

986998
def test_to_litellm_role():

0 commit comments

Comments
 (0)