Skip to content

Commit 036428c

Browse files
SumeredDouweM
andauthored
Added support for google specific arguments for video analysis (#2110)
Co-authored-by: Douwe Maan <douwe@pydantic.dev>
1 parent f49197c commit 036428c

File tree

9 files changed

+270
-9
lines changed

9 files changed

+270
-9
lines changed

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,13 @@ class FileUrl(ABC):
9999
* If False, the URL is sent directly to the model and no download is performed.
100100
"""
101101

102+
vendor_metadata: dict[str, Any] | None = None
103+
"""Vendor-specific metadata for the file.
104+
105+
Supported by:
106+
- `GoogleModel`: `VideoUrl.vendor_metadata` is used as `video_metadata`: https://ai.google.dev/gemini-api/docs/video-understanding#customize-video-processing
107+
"""
108+
102109
@property
103110
@abstractmethod
104111
def media_type(self) -> str:
@@ -263,6 +270,13 @@ class BinaryContent:
263270
media_type: AudioMediaType | ImageMediaType | DocumentMediaType | str
264271
"""The media type of the binary data."""
265272

273+
vendor_metadata: dict[str, Any] | None = None
274+
"""Vendor-specific metadata for the file.
275+
276+
Supported by:
277+
- `GoogleModel`: `BinaryContent.vendor_metadata` is used as `video_metadata`: https://ai.google.dev/gemini-api/docs/video-understanding#customize-video-processing
278+
"""
279+
266280
kind: Literal['binary'] = 'binary'
267281
"""Type identifier, this is available on all parts as a discriminator."""
268282

pydantic_ai_slim/pydantic_ai/models/google.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
GenerateContentConfigDict,
5656
GenerateContentResponse,
5757
HttpOptionsDict,
58+
MediaResolution,
5859
Part,
5960
PartDict,
6061
SafetySettingDict,
@@ -120,6 +121,12 @@ class GoogleModelSettings(ModelSettings, total=False):
120121
See the [Gemini API docs](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/add-labels-to-api-calls) for use cases and limitations.
121122
"""
122123

124+
google_video_resolution: MediaResolution
125+
"""The video resolution to use for the model.
126+
127+
See <https://ai.google.dev/api/generate-content#MediaResolution> for more information.
128+
"""
129+
123130

124131
@dataclass(init=False)
125132
class GoogleModel(Model):
@@ -291,6 +298,7 @@ async def _generate_content(
291298
safety_settings=model_settings.get('google_safety_settings'),
292299
thinking_config=model_settings.get('google_thinking_config'),
293300
labels=model_settings.get('google_labels'),
301+
media_resolution=model_settings.get('google_video_resolution'),
294302
tools=cast(ToolListUnionDict, tools),
295303
tool_config=tool_config,
296304
response_mime_type=response_mime_type,
@@ -398,9 +406,15 @@ async def _map_user_prompt(self, part: UserPromptPart) -> list[PartDict]:
398406
elif isinstance(item, BinaryContent):
399407
# NOTE: The type from Google GenAI is incorrect, it should be `str`, not `bytes`.
400408
base64_encoded = base64.b64encode(item.data).decode('utf-8')
401-
content.append({'inline_data': {'data': base64_encoded, 'mime_type': item.media_type}}) # type: ignore
409+
inline_data_dict = {'inline_data': {'data': base64_encoded, 'mime_type': item.media_type}}
410+
if item.vendor_metadata:
411+
inline_data_dict['video_metadata'] = item.vendor_metadata
412+
content.append(inline_data_dict) # type: ignore
402413
elif isinstance(item, VideoUrl) and item.is_youtube:
403-
content.append({'file_data': {'file_uri': item.url, 'mime_type': item.media_type}})
414+
file_data_dict = {'file_data': {'file_uri': item.url, 'mime_type': item.media_type}}
415+
if item.vendor_metadata:
416+
file_data_dict['video_metadata'] = item.vendor_metadata
417+
content.append(file_data_dict) # type: ignore
404418
elif isinstance(item, FileUrl):
405419
if self.system == 'google-gla' or item.force_download:
406420
downloaded_item = await download_item(item, data_format='base64')

pydantic_ai_slim/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ logfire = ["logfire>=3.11.0"]
6464
openai = ["openai>=1.76.0"]
6565
cohere = ["cohere>=5.13.11; platform_system != 'Emscripten'"]
6666
vertexai = ["google-auth>=2.36.0", "requests>=2.32.2"]
67-
google = ["google-genai>=1.15.0"]
67+
google = ["google-genai>=1.24.0"]
6868
anthropic = ["anthropic>=0.52.0"]
6969
groq = ["groq>=0.19.0"]
7070
mistral = ["mistralai>=1.2.5"]

tests/conftest.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from contextlib import contextmanager
1111
from dataclasses import dataclass
1212
from datetime import datetime
13+
from functools import cached_property
1314
from pathlib import Path
1415
from types import ModuleType
1516
from typing import TYPE_CHECKING, Any, Callable
@@ -19,7 +20,7 @@
1920
from _pytest.assertion.rewrite import AssertionRewritingHook
2021
from pytest_mock import MockerFixture
2122
from typing_extensions import TypeAlias
22-
from vcr import VCR
23+
from vcr import VCR, request as vcr_request
2324

2425
import pydantic_ai.models
2526
from pydantic_ai.messages import BinaryContent
@@ -194,6 +195,29 @@ def pytest_recording_configure(config: Any, vcr: VCR):
194195

195196
vcr.register_serializer('yaml', json_body_serializer)
196197

198+
def method_matcher(r1: vcr_request.Request, r2: vcr_request.Request) -> None:
199+
if r1.method.upper() != r2.method.upper():
200+
raise AssertionError(f'{r1.method} != {r2.method}')
201+
202+
vcr.register_matcher('method', method_matcher)
203+
204+
205+
@pytest.fixture(autouse=True)
206+
def mock_vcr_aiohttp_content(mocker: MockerFixture):
207+
try:
208+
from vcr.stubs import aiohttp_stubs
209+
except ImportError:
210+
return
211+
212+
# google-genai calls `self.response_stream.content.readline()` where `self.response_stream` is a `MockClientResponse`,
213+
# which creates a new `MockStream` each time instead of returning the same one, resulting in the readline cursor not being respected.
214+
# So we turn `content` into a cached property to return the same one each time.
215+
# VCR issue: https://github.com/kevin1024/vcrpy/issues/927. Once that's is resolved, we can remove this patch.
216+
cached_content = cached_property(aiohttp_stubs.MockClientResponse.content.fget) # type: ignore
217+
cached_content.__set_name__(aiohttp_stubs.MockClientResponse, 'content')
218+
mocker.patch('vcr.stubs.aiohttp_stubs.MockClientResponse.content', new=cached_content)
219+
mocker.patch('vcr.stubs.aiohttp_stubs.MockStream.set_exception', return_value=None)
220+
197221

198222
@pytest.fixture(scope='module')
199223
def vcr_config():

tests/models/cassettes/test_google/test_google_model_video_as_binary_content_input_with_vendor_metadata.yaml

Lines changed: 80 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
interactions:
2+
- request:
3+
headers:
4+
content-type:
5+
- application/json
6+
method: post
7+
parsed_body:
8+
contents:
9+
- parts:
10+
- text: Explain me this video
11+
- fileData:
12+
fileUri: https://youtu.be/lCdaVNyHtjU
13+
mimeType: video/mp4
14+
videoMetadata:
15+
fps: 0.2
16+
role: user
17+
generationConfig: {}
18+
systemInstruction:
19+
parts:
20+
- text: You are a helpful chatbot.
21+
role: user
22+
uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent
23+
response:
24+
headers:
25+
alt-svc:
26+
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
27+
content-length:
28+
- '2759'
29+
content-type:
30+
- application/json; charset=UTF-8
31+
server-timing:
32+
- gfet4t7; dur=11467
33+
transfer-encoding:
34+
- chunked
35+
vary:
36+
- Origin
37+
- X-Origin
38+
- Referer
39+
parsed_body:
40+
candidates:
41+
- avgLogprobs: -0.4793745385795377
42+
content:
43+
parts:
44+
- text: |-
45+
Okay, based on the image, here's what I can infer:
46+
47+
* **A camera monitor is mounted on top of a camera.**
48+
* **The monitor's screen is on, displaying a view of the rocky mountains.**
49+
* **This setting suggests a professional video shoot.**
50+
51+
If you'd like a more detailed explanation, please provide additional information about the video.
52+
role: model
53+
finishReason: STOP
54+
modelVersion: gemini-2.0-flash
55+
responseId: ldpraPqBM6HshMIPgsi60QI
56+
usageMetadata:
57+
candidatesTokenCount: 459
58+
candidatesTokensDetails:
59+
- modality: TEXT
60+
tokenCount: 459
61+
promptTokenCount: 4605
62+
promptTokensDetails:
63+
- modality: TEXT
64+
tokenCount: 10
65+
- modality: AUDIO
66+
tokenCount: 1475
67+
- modality: VIDEO
68+
tokenCount: 3120
69+
totalTokenCount: 5064
70+
status:
71+
code: 200
72+
message: OK
73+
version: 1

tests/models/test_google.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,23 @@ async def test_google_model_video_as_binary_content_input(
420420
""")
421421

422422

423+
async def test_google_model_video_as_binary_content_input_with_vendor_metadata(
424+
allow_model_requests: None, video_content: BinaryContent, google_provider: GoogleProvider
425+
):
426+
m = GoogleModel('gemini-2.0-flash', provider=google_provider)
427+
agent = Agent(m, system_prompt='You are a helpful chatbot.')
428+
video_content.vendor_metadata = {'start_offset': '2s', 'end_offset': '10s'}
429+
430+
result = await agent.run(['Explain me this video', video_content])
431+
assert result.output == snapshot("""\
432+
Okay, I can describe what is visible in the image.
433+
434+
The image shows a camera setup in an outdoor setting. The camera is mounted on a tripod and has an external monitor attached to it. The monitor is displaying a scene that appears to be a desert landscape with rocky formations and mountains in the background. The foreground and background of the overall image, outside of the camera monitor, is also a blurry, desert landscape. The colors in the background are warm and suggest either sunrise, sunset, or reflected light off the rock formations.
435+
436+
It looks like someone is either reviewing footage on the monitor, or using it as an aid for framing the shot.\
437+
""")
438+
439+
423440
async def test_google_model_image_url_input(allow_model_requests: None, google_provider: GoogleProvider):
424441
m = GoogleModel('gemini-2.0-flash', provider=google_provider)
425442
agent = Agent(m, system_prompt='You are a helpful chatbot.')
@@ -454,6 +471,32 @@ async def test_google_model_video_url_input(allow_model_requests: None, google_p
454471
""")
455472

456473

474+
async def test_google_model_youtube_video_url_input_with_vendor_metadata(
475+
allow_model_requests: None, google_provider: GoogleProvider
476+
):
477+
m = GoogleModel('gemini-2.0-flash', provider=google_provider)
478+
agent = Agent(m, system_prompt='You are a helpful chatbot.')
479+
480+
result = await agent.run(
481+
[
482+
'Explain me this video',
483+
VideoUrl(
484+
url='https://youtu.be/lCdaVNyHtjU',
485+
vendor_metadata={'fps': 0.2},
486+
),
487+
]
488+
)
489+
assert result.output == snapshot("""\
490+
Okay, based on the image, here's what I can infer:
491+
492+
* **A camera monitor is mounted on top of a camera.**
493+
* **The monitor's screen is on, displaying a view of the rocky mountains.**
494+
* **This setting suggests a professional video shoot.**
495+
496+
If you'd like a more detailed explanation, please provide additional information about the video.\
497+
""")
498+
499+
457500
async def test_google_model_document_url_input(allow_model_requests: None, google_provider: GoogleProvider):
458501
m = GoogleModel('gemini-2.0-flash', provider=google_provider)
459502
agent = Agent(m, system_prompt='You are a helpful chatbot.')

tests/test_agent.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2829,7 +2829,10 @@ def test_binary_content_all_messages_json():
28292829
{
28302830
'parts': [
28312831
{
2832-
'content': ['Hello', {'data': 'SGVsbG8=', 'media_type': 'text/plain', 'kind': 'binary'}],
2832+
'content': [
2833+
'Hello',
2834+
{'data': 'SGVsbG8=', 'media_type': 'text/plain', 'vendor_metadata': None, 'kind': 'binary'},
2835+
],
28332836
'timestamp': IsStr(),
28342837
'part_kind': 'user-prompt',
28352838
}

uv.lock

Lines changed: 14 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)