Skip to content

Commit c8bad12

Browse files
Support field labels for GeminiModel and GoogleModel on Vertex AI (#1056)
1 parent f6ed7c7 commit c8bad12

File tree

8 files changed

+412
-16
lines changed

8 files changed

+412
-16
lines changed

pydantic_ai_slim/pydantic_ai/models/gemini.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ class GeminiModelSettings(ModelSettings, total=False):
8080
"""
8181

8282
gemini_safety_settings: list[GeminiSafetySettings]
83+
"""Safety settings options for Gemini model request."""
8384

8485
gemini_thinking_config: ThinkingConfig
8586
"""Thinking is "on" by default in both the API and AI Studio.
@@ -93,6 +94,12 @@ class GeminiModelSettings(ModelSettings, total=False):
9394
See more about it on <https://ai.google.dev/gemini-api/docs/thinking>.
9495
"""
9596

97+
gemini_labels: dict[str, str]
98+
"""User-defined metadata to break down billed charges. Only supported by the Vertex AI provider.
99+
100+
See the [Gemini API docs](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/add-labels-to-api-calls) for use cases and limitations.
101+
"""
102+
96103

97104
@dataclass(init=False)
98105
class GeminiModel(Model):
@@ -223,25 +230,17 @@ async def _make_request(
223230
if tool_config is not None:
224231
request_data['toolConfig'] = tool_config
225232

226-
generation_config: _GeminiGenerationConfig = {}
227-
if model_settings:
228-
if (max_tokens := model_settings.get('max_tokens')) is not None:
229-
generation_config['max_output_tokens'] = max_tokens
230-
if (temperature := model_settings.get('temperature')) is not None:
231-
generation_config['temperature'] = temperature
232-
if (top_p := model_settings.get('top_p')) is not None:
233-
generation_config['top_p'] = top_p
234-
if (presence_penalty := model_settings.get('presence_penalty')) is not None:
235-
generation_config['presence_penalty'] = presence_penalty
236-
if (frequency_penalty := model_settings.get('frequency_penalty')) is not None:
237-
generation_config['frequency_penalty'] = frequency_penalty
238-
if (thinkingConfig := model_settings.get('gemini_thinking_config')) is not None:
239-
generation_config['thinking_config'] = thinkingConfig # pragma: no cover
240-
if (gemini_safety_settings := model_settings.get('gemini_safety_settings')) is not None:
241-
request_data['safetySettings'] = gemini_safety_settings
233+
generation_config = _settings_to_generation_config(model_settings)
242234
if generation_config:
243235
request_data['generationConfig'] = generation_config
244236

237+
if gemini_safety_settings := model_settings.get('gemini_safety_settings'):
238+
request_data['safetySettings'] = gemini_safety_settings
239+
240+
if gemini_labels := model_settings.get('gemini_labels'):
241+
if self._system == 'google-vertex':
242+
request_data['labels'] = gemini_labels
243+
245244
headers = {'Content-Type': 'application/json', 'User-Agent': get_user_agent()}
246245
url = f'/{self._model_name}:{"streamGenerateContent" if streamed else "generateContent"}'
247246

@@ -362,6 +361,23 @@ async def _map_user_prompt(self, part: UserPromptPart) -> list[_GeminiPartUnion]
362361
return content
363362

364363

364+
def _settings_to_generation_config(model_settings: GeminiModelSettings) -> _GeminiGenerationConfig:
365+
config: _GeminiGenerationConfig = {}
366+
if (max_tokens := model_settings.get('max_tokens')) is not None:
367+
config['max_output_tokens'] = max_tokens
368+
if (temperature := model_settings.get('temperature')) is not None:
369+
config['temperature'] = temperature
370+
if (top_p := model_settings.get('top_p')) is not None:
371+
config['top_p'] = top_p
372+
if (presence_penalty := model_settings.get('presence_penalty')) is not None:
373+
config['presence_penalty'] = presence_penalty
374+
if (frequency_penalty := model_settings.get('frequency_penalty')) is not None:
375+
config['frequency_penalty'] = frequency_penalty
376+
if (thinkingConfig := model_settings.get('gemini_thinking_config')) is not None:
377+
config['thinking_config'] = thinkingConfig # pragma: no cover
378+
return config
379+
380+
365381
class AuthProtocol(Protocol):
366382
"""Abstract definition for Gemini authentication."""
367383

@@ -483,6 +499,7 @@ class _GeminiRequest(TypedDict):
483499
<https://ai.google.dev/gemini-api/docs/system-instructions?lang=rest>
484500
"""
485501
generationConfig: NotRequired[_GeminiGenerationConfig]
502+
labels: NotRequired[dict[str, str]]
486503

487504

488505
class GeminiSafetySettings(TypedDict):

pydantic_ai_slim/pydantic_ai/models/google.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,12 @@ class GoogleModelSettings(ModelSettings, total=False):
115115
See <https://ai.google.dev/gemini-api/docs/thinking> for more information.
116116
"""
117117

118+
google_labels: dict[str, str]
119+
"""User-defined metadata to break down billed charges. Only supported by the Vertex AI API.
120+
121+
See the [Gemini API docs](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/add-labels-to-api-calls) for use cases and limitations.
122+
"""
123+
118124

119125
@dataclass(init=False)
120126
class GoogleModel(Model):
@@ -269,6 +275,7 @@ async def _generate_content(
269275
frequency_penalty=model_settings.get('frequency_penalty'),
270276
safety_settings=model_settings.get('google_safety_settings'),
271277
thinking_config=model_settings.get('google_thinking_config'),
278+
labels=model_settings.get('google_labels'),
272279
tools=cast(ToolListUnionDict, tools),
273280
tool_config=tool_config,
274281
)
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
interactions:
2+
- request:
3+
headers:
4+
accept:
5+
- "*/*"
6+
accept-encoding:
7+
- gzip, deflate
8+
connection:
9+
- keep-alive
10+
content-length:
11+
- "82"
12+
content-type:
13+
- application/json
14+
host:
15+
- generativelanguage.googleapis.com
16+
method: POST
17+
parsed_body:
18+
contents:
19+
- parts:
20+
- text: What is the capital of France?
21+
role: user
22+
uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent
23+
response:
24+
headers:
25+
alt-svc:
26+
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
27+
content-length:
28+
- "637"
29+
content-type:
30+
- application/json; charset=UTF-8
31+
server-timing:
32+
- gfet4t7; dur=426
33+
transfer-encoding:
34+
- chunked
35+
vary:
36+
- Origin
37+
- X-Origin
38+
- Referer
39+
parsed_body:
40+
candidates:
41+
- avgLogprobs: -0.02703852951526642
42+
content:
43+
parts:
44+
- text: |
45+
The capital of France is **Paris**.
46+
role: model
47+
finishReason: STOP
48+
modelVersion: gemini-2.0-flash
49+
usageMetadata:
50+
candidatesTokenCount: 9
51+
candidatesTokensDetails:
52+
- modality: TEXT
53+
tokenCount: 9
54+
promptTokenCount: 7
55+
promptTokensDetails:
56+
- modality: TEXT
57+
tokenCount: 7
58+
totalTokenCount: 16
59+
status:
60+
code: 200
61+
message: OK
62+
version: 1
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
interactions:
2+
- request:
3+
body: grant_type=%5B%27refresh_token%27%5D&client_id=%5B%27764086051850-6qr4p6gpi6hn506pt8ejuq83di341hur.apps.googleusercontent.com%27%5D&client_secret=%5B%27scrubbed%27%5D&refresh_token=%5B%27scrubbed%27%5D
4+
headers:
5+
accept:
6+
- "*/*"
7+
accept-encoding:
8+
- gzip, deflate
9+
connection:
10+
- keep-alive
11+
content-length:
12+
- "268"
13+
content-type:
14+
- application/x-www-form-urlencoded
15+
method: POST
16+
uri: https://oauth2.googleapis.com/token
17+
response:
18+
headers:
19+
alt-svc:
20+
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
21+
cache-control:
22+
- no-cache, no-store, max-age=0, must-revalidate
23+
content-length:
24+
- "1419"
25+
content-type:
26+
- application/json; charset=utf-8
27+
expires:
28+
- Mon, 01 Jan 1990 00:00:00 GMT
29+
pragma:
30+
- no-cache
31+
transfer-encoding:
32+
- chunked
33+
vary:
34+
- Origin
35+
- X-Origin
36+
- Referer
37+
parsed_body:
38+
access_token: scrubbed
39+
expires_in: 3599
40+
id_token: eyJhbGciOiJSUzI1NiIsImtpZCI6IjgyMWYzYmM2NmYwNzUxZjc4NDA2MDY3OTliMWFkZjllOWZiNjBkZmIiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI3NjQwODYwNTE4NTAtNnFyNHA2Z3BpNmhuNTA2cHQ4ZWp1cTgzZGkzNDFodXIuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI3NjQwODYwNTE4NTAtNnFyNHA2Z3BpNmhuNTA2cHQ4ZWp1cTgzZGkzNDFodXIuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMDY1Njg0NzQzMTU3NzkyMTI1NTkiLCJoZCI6InB5ZGFudGljLmRldiIsImVtYWlsIjoibWFyY2Vsb0BweWRhbnRpYy5kZXYiLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6ImlyckNRNE00c0Z0Z2dfS2hRTVNjekEiLCJpYXQiOjE3NDM0MTM3NzcsImV4cCI6MTc0MzQxNzM3N30.BAvb4TlcIoYcQODNLFqwtUQoSNJJbpAR0lk2OyFxXK9rSZ7m1e1_Dp1O4ApxPUS7f_NX34eSCuDJN2IXgh8VBv4k3IhI7CbMydYeqXuwlbgOOp1Z0farGEKneU1M7TvdngigAJ9wT-2LHjKd_GEcGau-CUvzXpcT1IOnNNyXGVqtuGmEfcw5jjPkKJNECUheeNHE3zeImatTstOLuKmI1ZK-etl41l3poSNuQkZkrbQ80Vst8BdT-b1tnJnXP1KGATBIamDy99OOiB9a7a9m_ikXYEyN91yR76DYot3hpDPlOX0H9hF-BOSqoOtlSS2TMBkMvFiiYWjID1e_9VlNUg
41+
scope: https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/userinfo.email openid https://www.googleapis.com/auth/sqlservice.login
42+
token_type: Bearer
43+
status:
44+
code: 200
45+
message: OK
46+
- request:
47+
headers:
48+
accept:
49+
- "*/*"
50+
accept-encoding:
51+
- gzip, deflate
52+
connection:
53+
- keep-alive
54+
content-length:
55+
- "133"
56+
content-type:
57+
- application/json
58+
host:
59+
- us-central1-aiplatform.googleapis.com
60+
method: POST
61+
parsed_body:
62+
contents:
63+
- parts:
64+
- text: What is the capital of France?
65+
role: user
66+
labels:
67+
environment: test
68+
team: analytics
69+
uri: https://us-central1-aiplatform.googleapis.com/v1/projects/pydantic-ai/locations/us-central1/publishers/google/models/gemini-2.0-flash:generateContent
70+
response:
71+
headers:
72+
alt-svc:
73+
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
74+
content-length:
75+
- "759"
76+
content-type:
77+
- application/json; charset=UTF-8
78+
transfer-encoding:
79+
- chunked
80+
vary:
81+
- Origin
82+
- X-Origin
83+
- Referer
84+
parsed_body:
85+
candidates:
86+
- avgLogprobs: -0.02703852951526642
87+
content:
88+
parts:
89+
- text: |
90+
The capital of France is **Paris**.
91+
role: model
92+
finishReason: STOP
93+
createTime: "2025-05-23T07:53:55.494386Z"
94+
modelVersion: gemini-2.0-flash
95+
responseId: kykwaLKWHti5nvgPmN2T8AE
96+
usageMetadata:
97+
candidatesTokenCount: 9
98+
candidatesTokensDetails:
99+
- modality: TEXT
100+
tokenCount: 9
101+
promptTokenCount: 7
102+
promptTokensDetails:
103+
- modality: TEXT
104+
tokenCount: 7
105+
totalTokenCount: 16
106+
trafficType: ON_DEMAND
107+
status:
108+
code: 200
109+
message: OK
110+
version: 1
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
interactions:
2+
- request:
3+
body: grant_type=%5B%27refresh_token%27%5D&client_id=%5B%27764086051850-6qr4p6gpi6hn506pt8ejuq83di341hur.apps.googleusercontent.com%27%5D&client_secret=%5B%27scrubbed%27%5D&refresh_token=%5B%27scrubbed%27%5D
4+
headers:
5+
accept:
6+
- "*/*"
7+
accept-encoding:
8+
- gzip, deflate
9+
connection:
10+
- keep-alive
11+
content-length:
12+
- "268"
13+
content-type:
14+
- application/x-www-form-urlencoded
15+
method: POST
16+
uri: https://oauth2.googleapis.com/token
17+
response:
18+
headers:
19+
alt-svc:
20+
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
21+
cache-control:
22+
- no-cache, no-store, max-age=0, must-revalidate
23+
content-length:
24+
- "1420"
25+
content-type:
26+
- application/json; charset=utf-8
27+
expires:
28+
- Mon, 01 Jan 1990 00:00:00 GMT
29+
pragma:
30+
- no-cache
31+
transfer-encoding:
32+
- chunked
33+
vary:
34+
- Origin
35+
- X-Origin
36+
- Referer
37+
parsed_body:
38+
access_token: scrubbed
39+
expires_in: 3599
40+
id_token: eyJhbGciOiJSUzI1NiIsImtpZCI6IjY2MGVmM2I5Nzg0YmRmNTZlYmU4NTlmNTc3ZjdmYjJlOGMxY2VmZmIiLCJ0eXAiOiJKV1QifQ.eyJpc3MiOiJodHRwczovL2FjY291bnRzLmdvb2dsZS5jb20iLCJhenAiOiI3NjQwODYwNTE4NTAtNnFyNHA2Z3BpNmhuNTA2cHQ4ZWp1cTgzZGkzNDFodXIuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJhdWQiOiI3NjQwODYwNTE4NTAtNnFyNHA2Z3BpNmhuNTA2cHQ4ZWp1cTgzZGkzNDFodXIuYXBwcy5nb29nbGV1c2VyY29udGVudC5jb20iLCJzdWIiOiIxMDY1Njg0NzQzMTU3NzkyMTI1NTkiLCJoZCI6InB5ZGFudGljLmRldiIsImVtYWlsIjoibWFyY2Vsb0BweWRhbnRpYy5kZXYiLCJlbWFpbF92ZXJpZmllZCI6dHJ1ZSwiYXRfaGFzaCI6Ii1CeV9XOWwtRHg1ekg0YTVOV25fV3ciLCJpYXQiOjE3NDc1NzQxOTEsImV4cCI6MTc0NzU3Nzc5MX0.dHg3qRlYoQ8WyIml7-kGqsuefvkl5deuZ0yTQM-RvKuuqtF_t6p8TrWbndEuSbZpRn9JhVPnsoEAYVPexbGy-pon4gu1aHH0dJNq3ghhdim7qp5JWpegLaZqvNvELvEHjj2VNLWXQ70-5wEaI_HCtAWTjlROAHQxvoWHJAdeH0Yf9zoljEBQvx3VLDLEpdCcMd-UGNCBucpQlFHcCJs5Qq8yj8R8f27BCEmRo7z9K3Axuedj_wcJ_tWV1x1tWxojUloJaKsIfztFOPFxzOdNPOlTHXsE47d4v43v87a8LhdDGloD72xN_kLapfIqyTIwRTj4cQvQp5H0u7As49fvMA
41+
scope: https://www.googleapis.com/auth/userinfo.email openid https://www.googleapis.com/auth/sqlservice.login https://www.googleapis.com/auth/cloud-platform
42+
token_type: Bearer
43+
status:
44+
code: 200
45+
message: OK
46+
- request:
47+
headers:
48+
accept:
49+
- "*/*"
50+
accept-encoding:
51+
- gzip, deflate
52+
connection:
53+
- keep-alive
54+
content-length:
55+
- "257"
56+
content-type:
57+
- application/json
58+
host:
59+
- aiplatform.googleapis.com
60+
method: POST
61+
parsed_body:
62+
contents:
63+
- parts:
64+
- text: What is the capital of France?
65+
role: user
66+
generationConfig: {}
67+
labels:
68+
environment: test
69+
team: analytics
70+
systemInstruction:
71+
parts:
72+
- text: You are a helpful chatbot.
73+
role: user
74+
uri: https://aiplatform.googleapis.com/v1beta1/projects/pydantic-ai/locations/global/publishers/google/models/gemini-2.0-flash:generateContent
75+
response:
76+
headers:
77+
alt-svc:
78+
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
79+
content-length:
80+
- "759"
81+
content-type:
82+
- application/json; charset=UTF-8
83+
transfer-encoding:
84+
- chunked
85+
vary:
86+
- Origin
87+
- X-Origin
88+
- Referer
89+
parsed_body:
90+
candidates:
91+
- avgLogprobs: -0.0005532301729544997
92+
content:
93+
parts:
94+
- text: |
95+
The capital of France is Paris.
96+
role: model
97+
finishReason: STOP
98+
createTime: "2025-05-23T07:09:59.524624Z"
99+
modelVersion: gemini-2.0-flash
100+
responseId: sN0paKOZFtmtyOgPqMyL6AE
101+
usageMetadata:
102+
candidatesTokenCount: 8
103+
candidatesTokensDetails:
104+
- modality: TEXT
105+
tokenCount: 8
106+
promptTokenCount: 13
107+
promptTokensDetails:
108+
- modality: TEXT
109+
tokenCount: 13
110+
totalTokenCount: 21
111+
trafficType: ON_DEMAND
112+
status:
113+
code: 200
114+
message: OK
115+
version: 1

0 commit comments

Comments
 (0)