Skip to content

Commit 4e3769a

Browse files
davide-andreoliKludex
authored andcommitted
Add vendor_id and finish_reason to Gemini/Google model responses (#1800)
1 parent cb4e539 commit 4e3769a

File tree

5 files changed

+73
-8
lines changed

5 files changed

+73
-8
lines changed

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,7 @@ class ModelResponse:
620620
kind: Literal['response'] = 'response'
621621
"""Message type identifier, this is available on all parts as a discriminator."""
622622

623-
vendor_details: dict[str, Any] | None = field(default=None, repr=False)
623+
vendor_details: dict[str, Any] | None = field(default=None)
624624
"""Additional vendor-specific details in a serializable format.
625625
626626
This allows storing selected vendor-specific data that isn't mapped to standard ModelResponse fields.

pydantic_ai_slim/pydantic_ai/models/gemini.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,8 @@ async def _make_request(
260260
yield r
261261

262262
def _process_response(self, response: _GeminiResponse) -> ModelResponse:
263+
vendor_details: dict[str, Any] | None = None
264+
263265
if len(response['candidates']) != 1:
264266
raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response') # pragma: no cover
265267
if 'content' not in response['candidates'][0]:
@@ -270,9 +272,19 @@ def _process_response(self, response: _GeminiResponse) -> ModelResponse:
270272
'Content field missing from Gemini response', str(response)
271273
)
272274
parts = response['candidates'][0]['content']['parts']
275+
vendor_id = response.get('vendor_id', None)
276+
finish_reason = response['candidates'][0].get('finish_reason')
277+
if finish_reason:
278+
vendor_details = {'finish_reason': finish_reason}
273279
usage = _metadata_as_usage(response)
274280
usage.requests = 1
275-
return _process_response_from_parts(parts, response.get('model_version', self._model_name), usage)
281+
return _process_response_from_parts(
282+
parts,
283+
response.get('model_version', self._model_name),
284+
usage,
285+
vendor_id=vendor_id,
286+
vendor_details=vendor_details,
287+
)
276288

277289
async def _process_streamed_response(self, http_response: HTTPResponse) -> StreamedResponse:
278290
"""Process a streamed response, and prepare a streaming response to return."""
@@ -612,7 +624,11 @@ def _function_call_part_from_call(tool: ToolCallPart) -> _GeminiFunctionCallPart
612624

613625

614626
def _process_response_from_parts(
615-
parts: Sequence[_GeminiPartUnion], model_name: GeminiModelName, usage: usage.Usage
627+
parts: Sequence[_GeminiPartUnion],
628+
model_name: GeminiModelName,
629+
usage: usage.Usage,
630+
vendor_id: str | None,
631+
vendor_details: dict[str, Any] | None = None,
616632
) -> ModelResponse:
617633
items: list[ModelResponsePart] = []
618634
for part in parts:
@@ -624,7 +640,9 @@ def _process_response_from_parts(
624640
raise UnexpectedModelBehavior(
625641
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
626642
)
627-
return ModelResponse(parts=items, usage=usage, model_name=model_name)
643+
return ModelResponse(
644+
parts=items, usage=usage, model_name=model_name, vendor_id=vendor_id, vendor_details=vendor_details
645+
)
628646

629647

630648
class _GeminiFunctionCall(TypedDict):
@@ -736,6 +754,7 @@ class _GeminiResponse(TypedDict):
736754
usage_metadata: NotRequired[Annotated[_GeminiUsageMetaData, pydantic.Field(alias='usageMetadata')]]
737755
prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, pydantic.Field(alias='promptFeedback')]]
738756
model_version: NotRequired[Annotated[str, pydantic.Field(alias='modelVersion')]]
757+
vendor_id: NotRequired[Annotated[str, pydantic.Field(alias='responseId')]]
739758

740759

741760
class _GeminiCandidates(TypedDict):

pydantic_ai_slim/pydantic_ai/models/google.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from contextlib import asynccontextmanager
77
from dataclasses import dataclass, field, replace
88
from datetime import datetime
9-
from typing import Literal, Union, cast, overload
9+
from typing import Any, Literal, Union, cast, overload
1010
from uuid import uuid4
1111

1212
from typing_extensions import assert_never
@@ -302,9 +302,16 @@ def _process_response(self, response: GenerateContentResponse) -> ModelResponse:
302302
'Content field missing from Gemini response', str(response)
303303
) # pragma: no cover
304304
parts = response.candidates[0].content.parts or []
305+
vendor_id = response.response_id or None
306+
vendor_details: dict[str, Any] | None = None
307+
finish_reason = response.candidates[0].finish_reason
308+
if finish_reason: # pragma: no branch
309+
vendor_details = {'finish_reason': finish_reason.value}
305310
usage = _metadata_as_usage(response)
306311
usage.requests = 1
307-
return _process_response_from_parts(parts, response.model_version or self._model_name, usage)
312+
return _process_response_from_parts(
313+
parts, response.model_version or self._model_name, usage, vendor_id=vendor_id, vendor_details=vendor_details
314+
)
308315

309316
async def _process_streamed_response(self, response: AsyncIterator[GenerateContentResponse]) -> StreamedResponse:
310317
"""Process a streamed response, and prepare a streaming response to return."""
@@ -450,7 +457,13 @@ def _content_model_response(m: ModelResponse) -> ContentDict:
450457
return ContentDict(role='model', parts=parts)
451458

452459

453-
def _process_response_from_parts(parts: list[Part], model_name: GoogleModelName, usage: usage.Usage) -> ModelResponse:
460+
def _process_response_from_parts(
461+
parts: list[Part],
462+
model_name: GoogleModelName,
463+
usage: usage.Usage,
464+
vendor_id: str | None,
465+
vendor_details: dict[str, Any] | None = None,
466+
) -> ModelResponse:
454467
items: list[ModelResponsePart] = []
455468
for part in parts:
456469
if part.text:
@@ -465,7 +478,9 @@ def _process_response_from_parts(parts: list[Part], model_name: GoogleModelName,
465478
raise UnexpectedModelBehavior(
466479
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
467480
)
468-
return ModelResponse(parts=items, model_name=model_name, usage=usage)
481+
return ModelResponse(
482+
parts=items, model_name=model_name, usage=usage, vendor_id=vendor_id, vendor_details=vendor_details
483+
)
469484

470485

471486
def _function_declaration_from_tool(tool: ToolDefinition) -> FunctionDeclarationDict:

tests/models/test_gemini.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,7 @@ async def test_text_success(get_gemini_client: GetGeminiClient):
540540
usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}),
541541
model_name='gemini-1.5-flash-123',
542542
timestamp=IsNow(tz=timezone.utc),
543+
vendor_details={'finish_reason': 'STOP'},
543544
),
544545
]
545546
)
@@ -555,13 +556,15 @@ async def test_text_success(get_gemini_client: GetGeminiClient):
555556
usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}),
556557
model_name='gemini-1.5-flash-123',
557558
timestamp=IsNow(tz=timezone.utc),
559+
vendor_details={'finish_reason': 'STOP'},
558560
),
559561
ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
560562
ModelResponse(
561563
parts=[TextPart(content='Hello world')],
562564
usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}),
563565
model_name='gemini-1.5-flash-123',
564566
timestamp=IsNow(tz=timezone.utc),
567+
vendor_details={'finish_reason': 'STOP'},
565568
),
566569
]
567570
)
@@ -585,6 +588,7 @@ async def test_request_structured_response(get_gemini_client: GetGeminiClient):
585588
usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}),
586589
model_name='gemini-1.5-flash-123',
587590
timestamp=IsNow(tz=timezone.utc),
591+
vendor_details={'finish_reason': 'STOP'},
588592
),
589593
ModelRequest(
590594
parts=[
@@ -647,6 +651,7 @@ async def get_location(loc_name: str) -> str:
647651
usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}),
648652
model_name='gemini-1.5-flash-123',
649653
timestamp=IsNow(tz=timezone.utc),
654+
vendor_details={'finish_reason': 'STOP'},
650655
),
651656
ModelRequest(
652657
parts=[
@@ -666,6 +671,7 @@ async def get_location(loc_name: str) -> str:
666671
usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}),
667672
model_name='gemini-1.5-flash-123',
668673
timestamp=IsNow(tz=timezone.utc),
674+
vendor_details={'finish_reason': 'STOP'},
669675
),
670676
ModelRequest(
671677
parts=[
@@ -688,6 +694,7 @@ async def get_location(loc_name: str) -> str:
688694
usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}),
689695
model_name='gemini-1.5-flash-123',
690696
timestamp=IsNow(tz=timezone.utc),
697+
vendor_details={'finish_reason': 'STOP'},
691698
),
692699
]
693700
)
@@ -1099,6 +1106,7 @@ async def get_image() -> BinaryContent:
10991106
usage=Usage(requests=1, request_tokens=38, response_tokens=28, total_tokens=427, details={}),
11001107
model_name='gemini-2.5-pro-preview-03-25',
11011108
timestamp=IsDatetime(),
1109+
vendor_details={'finish_reason': 'STOP'},
11021110
),
11031111
ModelRequest(
11041112
parts=[
@@ -1122,6 +1130,7 @@ async def get_image() -> BinaryContent:
11221130
usage=Usage(requests=1, request_tokens=360, response_tokens=11, total_tokens=572, details={}),
11231131
model_name='gemini-2.5-pro-preview-03-25',
11241132
timestamp=IsDatetime(),
1133+
vendor_details={'finish_reason': 'STOP'},
11251134
),
11261135
]
11271136
)
@@ -1244,6 +1253,7 @@ async def test_gemini_model_instructions(allow_model_requests: None, gemini_api_
12441253
usage=Usage(requests=1, request_tokens=13, response_tokens=8, total_tokens=21, details={}),
12451254
model_name='gemini-1.5-flash',
12461255
timestamp=IsDatetime(),
1256+
vendor_details={'finish_reason': 'STOP'},
12471257
),
12481258
]
12491259
)
@@ -1284,3 +1294,18 @@ async def get_temperature(location: dict[str, CurrentLocation]) -> float: # pra
12841294
assert result.output == snapshot(
12851295
'I need a location dictionary to use the `get_temperature` function. I cannot provide the temperature in Tokyo without more information.\n'
12861296
)
1297+
1298+
1299+
async def test_gemini_no_finish_reason(get_gemini_client: GetGeminiClient):
1300+
response = gemini_response(
1301+
_content_model_response(ModelResponse(parts=[TextPart('Hello world')])), finish_reason=None
1302+
)
1303+
gemini_client = get_gemini_client(response)
1304+
m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client))
1305+
agent = Agent(m)
1306+
1307+
result = await agent.run('Hello World')
1308+
1309+
for message in result.all_messages():
1310+
if isinstance(message, ModelResponse):
1311+
assert message.vendor_details is None

tests/models/test_google.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ async def test_google_model(allow_model_requests: None, google_provider: GoogleP
8686
usage=Usage(requests=1, request_tokens=7, response_tokens=11, total_tokens=18, details={}),
8787
model_name='gemini-1.5-flash',
8888
timestamp=IsDatetime(),
89+
vendor_details={'finish_reason': 'STOP'},
8990
),
9091
]
9192
)
@@ -139,6 +140,7 @@ async def temperature(city: str, date: datetime.date) -> str:
139140
usage=Usage(requests=1, request_tokens=101, response_tokens=14, total_tokens=115, details={}),
140141
model_name='gemini-1.5-flash',
141142
timestamp=IsDatetime(),
143+
vendor_details={'finish_reason': 'STOP'},
142144
),
143145
ModelRequest(
144146
parts=[
@@ -158,6 +160,7 @@ async def temperature(city: str, date: datetime.date) -> str:
158160
usage=Usage(requests=1, request_tokens=123, response_tokens=21, total_tokens=144, details={}),
159161
model_name='gemini-1.5-flash',
160162
timestamp=IsDatetime(),
163+
vendor_details={'finish_reason': 'STOP'},
161164
),
162165
ModelRequest(
163166
parts=[
@@ -216,6 +219,7 @@ async def get_capital(country: str) -> str:
216219
),
217220
model_name='models/gemini-2.5-pro-preview-05-06',
218221
timestamp=IsDatetime(),
222+
vendor_details={'finish_reason': 'STOP'},
219223
),
220224
ModelRequest(
221225
parts=[
@@ -236,6 +240,7 @@ async def get_capital(country: str) -> str:
236240
usage=Usage(requests=1, request_tokens=104, response_tokens=18, total_tokens=122, details={}),
237241
model_name='models/gemini-2.5-pro-preview-05-06',
238242
timestamp=IsDatetime(),
243+
vendor_details={'finish_reason': 'STOP'},
239244
),
240245
]
241246
)
@@ -492,6 +497,7 @@ def instructions() -> str:
492497
usage=Usage(requests=1, request_tokens=13, response_tokens=8, total_tokens=21, details={}),
493498
model_name='gemini-2.0-flash',
494499
timestamp=IsDatetime(),
500+
vendor_details={'finish_reason': 'STOP'},
495501
),
496502
]
497503
)

0 commit comments

Comments
 (0)