Skip to content

Commit 6bcc1a8

Browse files
kiqapsKludex
andcommitted
Enhance Gemini usage tracking to collect comprehensive token data (#1752)
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
1 parent c8bb611 commit 6bcc1a8

File tree

4 files changed

+143
-32
lines changed

4 files changed

+143
-32
lines changed

pydantic_ai_slim/pydantic_ai/models/gemini.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -464,13 +464,12 @@ async def _get_gemini_responses(self) -> AsyncIterator[_GeminiResponse]:
464464
responses_to_yield = gemini_responses[:-1]
465465
for r in responses_to_yield[current_gemini_response_index:]:
466466
current_gemini_response_index += 1
467-
self._usage += _metadata_as_usage(r)
468467
yield r
469468

470469
# Now yield the final response, which should be complete
471470
if gemini_responses: # pragma: no branch
472471
r = gemini_responses[-1]
473-
self._usage += _metadata_as_usage(r)
472+
self._usage = _metadata_as_usage(r)
474473
yield r
475474

476475
@property
@@ -771,8 +770,17 @@ class _GeminiCandidates(TypedDict):
771770
safety_ratings: NotRequired[Annotated[list[_GeminiSafetyRating], pydantic.Field(alias='safetyRatings')]]
772771

773772

773+
class _GeminiModalityTokenCount(TypedDict):
774+
"""See <https://ai.google.dev/api/generate-content#modalitytokencount>."""
775+
776+
modality: Annotated[
777+
Literal['MODALITY_UNSPECIFIED', 'TEXT', 'IMAGE', 'VIDEO', 'AUDIO', 'DOCUMENT'], pydantic.Field(alias='modality')
778+
]
779+
token_count: Annotated[int, pydantic.Field(alias='tokenCount', default=0)]
780+
781+
774782
class _GeminiUsageMetaData(TypedDict, total=False):
775-
"""See <https://ai.google.dev/api/generate-content#FinishReason>.
783+
"""See <https://ai.google.dev/api/generate-content#UsageMetadata>.
776784
777785
The docs suggest all fields are required, but some are actually not required, so we assume they are all optional.
778786
"""
@@ -781,6 +789,20 @@ class _GeminiUsageMetaData(TypedDict, total=False):
781789
candidates_token_count: NotRequired[Annotated[int, pydantic.Field(alias='candidatesTokenCount')]]
782790
total_token_count: Annotated[int, pydantic.Field(alias='totalTokenCount')]
783791
cached_content_token_count: NotRequired[Annotated[int, pydantic.Field(alias='cachedContentTokenCount')]]
792+
thoughts_token_count: NotRequired[Annotated[int, pydantic.Field(alias='thoughtsTokenCount')]]
793+
tool_use_prompt_token_count: NotRequired[Annotated[int, pydantic.Field(alias='toolUsePromptTokenCount')]]
794+
prompt_tokens_details: NotRequired[
795+
Annotated[list[_GeminiModalityTokenCount], pydantic.Field(alias='promptTokensDetails')]
796+
]
797+
cache_tokens_details: NotRequired[
798+
Annotated[list[_GeminiModalityTokenCount], pydantic.Field(alias='cacheTokensDetails')]
799+
]
800+
candidates_tokens_details: NotRequired[
801+
Annotated[list[_GeminiModalityTokenCount], pydantic.Field(alias='candidatesTokensDetails')]
802+
]
803+
tool_use_prompt_tokens_details: NotRequired[
804+
Annotated[list[_GeminiModalityTokenCount], pydantic.Field(alias='toolUsePromptTokensDetails')]
805+
]
784806

785807

786808
def _metadata_as_usage(response: _GeminiResponse) -> usage.Usage:
@@ -789,7 +811,21 @@ def _metadata_as_usage(response: _GeminiResponse) -> usage.Usage:
789811
return usage.Usage() # pragma: no cover
790812
details: dict[str, int] = {}
791813
if cached_content_token_count := metadata.get('cached_content_token_count'):
792-
details['cached_content_token_count'] = cached_content_token_count # pragma: no cover
814+
details['cached_content_tokens'] = cached_content_token_count # pragma: no cover
815+
816+
if thoughts_token_count := metadata.get('thoughts_token_count'):
817+
details['thoughts_tokens'] = thoughts_token_count
818+
819+
if tool_use_prompt_token_count := metadata.get('tool_use_prompt_token_count'):
820+
details['tool_use_prompt_tokens'] = tool_use_prompt_token_count # pragma: no cover
821+
822+
for key, metadata_details in metadata.items():
823+
if key.endswith('_details') and metadata_details:
824+
metadata_details = cast(list[_GeminiModalityTokenCount], metadata_details)
825+
suffix = key.removesuffix('_details')
826+
for detail in metadata_details:
827+
details[f'{detail["modality"].lower()}_{suffix}'] = detail['token_count']
828+
793829
return usage.Usage(
794830
request_tokens=metadata.get('prompt_token_count', 0),
795831
response_tokens=metadata.get('candidates_token_count', 0),

pydantic_ai_slim/pydantic_ai/models/google.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ class GeminiStreamedResponse(StreamedResponse):
410410

411411
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
412412
async for chunk in self._response:
413-
self._usage += _metadata_as_usage(chunk)
413+
self._usage = _metadata_as_usage(chunk)
414414

415415
assert chunk.candidates is not None
416416
candidate = chunk.candidates[0]
@@ -501,17 +501,28 @@ def _metadata_as_usage(response: GenerateContentResponse) -> usage.Usage:
501501
metadata = response.usage_metadata
502502
if metadata is None:
503503
return usage.Usage() # pragma: no cover
504-
# TODO(Marcelo): We exclude the `prompt_tokens_details` and `candidate_token_details` fields because on
505-
# `usage.Usage.incr``, it will try to sum non-integer values with integers, which will fail. We should probably
506-
# handle this in the `Usage` class.
507-
details = metadata.model_dump(
508-
exclude={'prompt_tokens_details', 'candidates_tokens_details', 'traffic_type'},
509-
exclude_defaults=True,
510-
)
504+
metadata = metadata.model_dump(exclude_defaults=True)
505+
506+
details: dict[str, int] = {}
507+
if cached_content_token_count := metadata.get('cached_content_token_count'):
508+
details['cached_content_tokens'] = cached_content_token_count # pragma: no cover
509+
510+
if thoughts_token_count := metadata.get('thoughts_token_count'):
511+
details['thoughts_tokens'] = thoughts_token_count
512+
513+
if tool_use_prompt_token_count := metadata.get('tool_use_prompt_token_count'):
514+
details['tool_use_prompt_tokens'] = tool_use_prompt_token_count # pragma: no cover
515+
516+
for key, metadata_details in metadata.items():
517+
if key.endswith('_details') and metadata_details:
518+
suffix = key.removesuffix('_details')
519+
for detail in metadata_details:
520+
details[f'{detail["modality"].lower()}_{suffix}'] = detail['token_count']
521+
511522
return usage.Usage(
512-
request_tokens=details.pop('prompt_token_count', 0),
513-
response_tokens=details.pop('candidates_token_count', 0),
514-
total_tokens=details.pop('total_token_count', 0),
523+
request_tokens=metadata.get('prompt_token_count', 0),
524+
response_tokens=metadata.get('candidates_token_count', 0),
525+
total_tokens=metadata.get('total_token_count', 0),
515526
details=details,
516527
)
517528

tests/models/test_gemini.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -739,12 +739,12 @@ async def test_stream_text(get_gemini_client: GetGeminiClient):
739739
'Hello world',
740740
]
741741
)
742-
assert result.usage() == snapshot(Usage(requests=1, request_tokens=2, response_tokens=4, total_tokens=6))
742+
assert result.usage() == snapshot(Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3))
743743

744744
async with agent.run_stream('Hello') as result:
745745
chunks = [chunk async for chunk in result.stream_text(delta=True, debounce_by=None)]
746746
assert chunks == snapshot(['Hello ', 'world'])
747-
assert result.usage() == snapshot(Usage(requests=1, request_tokens=2, response_tokens=4, total_tokens=6))
747+
assert result.usage() == snapshot(Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3))
748748

749749

750750
async def test_stream_invalid_unicode_text(get_gemini_client: GetGeminiClient):
@@ -776,7 +776,7 @@ async def test_stream_invalid_unicode_text(get_gemini_client: GetGeminiClient):
776776
async with agent.run_stream('Hello') as result:
777777
chunks = [chunk async for chunk in result.stream(debounce_by=None)]
778778
assert chunks == snapshot(['abc', 'abc€def', 'abc€def'])
779-
assert result.usage() == snapshot(Usage(requests=1, request_tokens=2, response_tokens=4, total_tokens=6))
779+
assert result.usage() == snapshot(Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3))
780780

781781

782782
async def test_stream_text_no_data(get_gemini_client: GetGeminiClient):
@@ -847,7 +847,7 @@ async def bar(y: str) -> str:
847847
async with agent.run_stream('Hello') as result:
848848
response = await result.get_output()
849849
assert response == snapshot((1, 2))
850-
assert result.usage() == snapshot(Usage(requests=2, request_tokens=3, response_tokens=6, total_tokens=9))
850+
assert result.usage() == snapshot(Usage(requests=2, request_tokens=2, response_tokens=4, total_tokens=6))
851851
assert result.all_messages() == snapshot(
852852
[
853853
ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
@@ -856,7 +856,7 @@ async def bar(y: str) -> str:
856856
ToolCallPart(tool_name='foo', args={'x': 'a'}, tool_call_id=IsStr()),
857857
ToolCallPart(tool_name='bar', args={'y': 'b'}, tool_call_id=IsStr()),
858858
],
859-
usage=Usage(request_tokens=2, response_tokens=4, total_tokens=6),
859+
usage=Usage(request_tokens=1, response_tokens=2, total_tokens=3, details={}),
860860
model_name='gemini-1.5-flash',
861861
timestamp=IsNow(tz=timezone.utc),
862862
),
@@ -872,7 +872,7 @@ async def bar(y: str) -> str:
872872
),
873873
ModelResponse(
874874
parts=[ToolCallPart(tool_name='final_result', args={'response': [1, 2]}, tool_call_id=IsStr())],
875-
usage=Usage(request_tokens=1, response_tokens=2, total_tokens=3),
875+
usage=Usage(request_tokens=1, response_tokens=2, total_tokens=3, details={}),
876876
model_name='gemini-1.5-flash',
877877
timestamp=IsNow(tz=timezone.utc),
878878
),
@@ -1103,7 +1103,13 @@ async def get_image() -> BinaryContent:
11031103
),
11041104
ToolCallPart(tool_name='get_image', args={}, tool_call_id=IsStr()),
11051105
],
1106-
usage=Usage(requests=1, request_tokens=38, response_tokens=28, total_tokens=427, details={}),
1106+
usage=Usage(
1107+
requests=1,
1108+
request_tokens=38,
1109+
response_tokens=28,
1110+
total_tokens=427,
1111+
details={'thoughts_tokens': 361, 'text_prompt_tokens': 38},
1112+
),
11071113
model_name='gemini-2.5-pro-preview-03-25',
11081114
timestamp=IsDatetime(),
11091115
vendor_details={'finish_reason': 'STOP'},
@@ -1127,7 +1133,13 @@ async def get_image() -> BinaryContent:
11271133
),
11281134
ModelResponse(
11291135
parts=[TextPart(content='The image shows a kiwi fruit, sliced in half.')],
1130-
usage=Usage(requests=1, request_tokens=360, response_tokens=11, total_tokens=572, details={}),
1136+
usage=Usage(
1137+
requests=1,
1138+
request_tokens=360,
1139+
response_tokens=11,
1140+
total_tokens=572,
1141+
details={'thoughts_tokens': 201, 'text_prompt_tokens': 102, 'image_prompt_tokens': 258},
1142+
),
11311143
model_name='gemini-2.5-pro-preview-03-25',
11321144
timestamp=IsDatetime(),
11331145
vendor_details={'finish_reason': 'STOP'},
@@ -1250,7 +1262,13 @@ async def test_gemini_model_instructions(allow_model_requests: None, gemini_api_
12501262
),
12511263
ModelResponse(
12521264
parts=[TextPart(content='The capital of France is Paris.\n')],
1253-
usage=Usage(requests=1, request_tokens=13, response_tokens=8, total_tokens=21, details={}),
1265+
usage=Usage(
1266+
requests=1,
1267+
request_tokens=13,
1268+
response_tokens=8,
1269+
total_tokens=21,
1270+
details={'text_prompt_tokens': 13, 'text_candidates_tokens': 8},
1271+
),
12541272
model_name='gemini-1.5-flash',
12551273
timestamp=IsDatetime(),
12561274
vendor_details={'finish_reason': 'STOP'},

tests/models/test_google.py

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,15 @@ async def test_google_model(allow_model_requests: None, google_provider: GoogleP
6666

6767
result = await agent.run('Hello!')
6868
assert result.output == snapshot('Hello there! How can I help you today?\n')
69-
assert result.usage() == snapshot(Usage(requests=1, request_tokens=7, response_tokens=11, total_tokens=18))
69+
assert result.usage() == snapshot(
70+
Usage(
71+
requests=1,
72+
request_tokens=7,
73+
response_tokens=11,
74+
total_tokens=18,
75+
details={'text_prompt_tokens': 7, 'text_candidates_tokens': 11},
76+
)
77+
)
7078
assert result.all_messages() == snapshot(
7179
[
7280
ModelRequest(
@@ -83,7 +91,13 @@ async def test_google_model(allow_model_requests: None, google_provider: GoogleP
8391
),
8492
ModelResponse(
8593
parts=[TextPart(content='Hello there! How can I help you today?\n')],
86-
usage=Usage(requests=1, request_tokens=7, response_tokens=11, total_tokens=18, details={}),
94+
usage=Usage(
95+
requests=1,
96+
request_tokens=7,
97+
response_tokens=11,
98+
total_tokens=18,
99+
details={'text_prompt_tokens': 7, 'text_candidates_tokens': 11},
100+
),
87101
model_name='gemini-1.5-flash',
88102
timestamp=IsDatetime(),
89103
vendor_details={'finish_reason': 'STOP'},
@@ -116,7 +130,15 @@ async def temperature(city: str, date: datetime.date) -> str:
116130

117131
result = await agent.run('What was the temperature in London 1st January 2022?', output_type=Response)
118132
assert result.output == snapshot({'temperature': '30°C', 'date': datetime.date(2022, 1, 1), 'city': 'London'})
119-
assert result.usage() == snapshot(Usage(requests=2, request_tokens=224, response_tokens=35, total_tokens=259))
133+
assert result.usage() == snapshot(
134+
Usage(
135+
requests=2,
136+
request_tokens=224,
137+
response_tokens=35,
138+
total_tokens=259,
139+
details={'text_prompt_tokens': 224, 'text_candidates_tokens': 35},
140+
)
141+
)
120142
assert result.all_messages() == snapshot(
121143
[
122144
ModelRequest(
@@ -137,7 +159,13 @@ async def temperature(city: str, date: datetime.date) -> str:
137159
tool_name='temperature', args={'date': '2022-01-01', 'city': 'London'}, tool_call_id=IsStr()
138160
)
139161
],
140-
usage=Usage(requests=1, request_tokens=101, response_tokens=14, total_tokens=115, details={}),
162+
usage=Usage(
163+
requests=1,
164+
request_tokens=101,
165+
response_tokens=14,
166+
total_tokens=115,
167+
details={'text_prompt_tokens': 101, 'text_candidates_tokens': 14},
168+
),
141169
model_name='gemini-1.5-flash',
142170
timestamp=IsDatetime(),
143171
vendor_details={'finish_reason': 'STOP'},
@@ -157,7 +185,13 @@ async def temperature(city: str, date: datetime.date) -> str:
157185
tool_call_id=IsStr(),
158186
)
159187
],
160-
usage=Usage(requests=1, request_tokens=123, response_tokens=21, total_tokens=144, details={}),
188+
usage=Usage(
189+
requests=1,
190+
request_tokens=123,
191+
response_tokens=21,
192+
total_tokens=144,
193+
details={'text_prompt_tokens': 123, 'text_candidates_tokens': 21},
194+
),
161195
model_name='gemini-1.5-flash',
162196
timestamp=IsDatetime(),
163197
vendor_details={'finish_reason': 'STOP'},
@@ -215,7 +249,7 @@ async def get_capital(country: str) -> str:
215249
request_tokens=57,
216250
response_tokens=15,
217251
total_tokens=173,
218-
details={'thoughts_token_count': 101},
252+
details={'thoughts_tokens': 101, 'text_prompt_tokens': 57},
219253
),
220254
model_name='models/gemini-2.5-pro-preview-05-06',
221255
timestamp=IsDatetime(),
@@ -237,7 +271,13 @@ async def get_capital(country: str) -> str:
237271
content='I am sorry, I cannot fulfill this request. The country you provided is not supported.'
238272
)
239273
],
240-
usage=Usage(requests=1, request_tokens=104, response_tokens=18, total_tokens=122, details={}),
274+
usage=Usage(
275+
requests=1,
276+
request_tokens=104,
277+
response_tokens=18,
278+
total_tokens=122,
279+
details={'text_prompt_tokens': 104},
280+
),
241281
model_name='models/gemini-2.5-pro-preview-05-06',
242282
timestamp=IsDatetime(),
243283
vendor_details={'finish_reason': 'STOP'},
@@ -494,7 +534,13 @@ def instructions() -> str:
494534
),
495535
ModelResponse(
496536
parts=[TextPart(content='The capital of France is Paris.\n')],
497-
usage=Usage(requests=1, request_tokens=13, response_tokens=8, total_tokens=21, details={}),
537+
usage=Usage(
538+
requests=1,
539+
request_tokens=13,
540+
response_tokens=8,
541+
total_tokens=21,
542+
details={'text_prompt_tokens': 13, 'text_candidates_tokens': 8},
543+
),
498544
model_name='gemini-2.0-flash',
499545
timestamp=IsDatetime(),
500546
vendor_details={'finish_reason': 'STOP'},

0 commit comments

Comments
 (0)