Skip to content

Adding CountToken to Gemini #2137

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,95 @@ def _map_response_schema(self, o: OutputObjectDefinition) -> dict[str, Any]:

return response_schema

async def count_tokens(
self,
messages: list[ModelMessage],
model_settings: GeminiModelSettings | None,
model_request_parameters: ModelRequestParameters,
) -> usage.Usage:
check_allow_model_requests()
async with self._make_count_request(messages, model_settings or {}, model_request_parameters) as http_response:
data = await http_response.aread()
response = _gemini_count_tokens_response_ta.validate_json(data)
return self._process_count_tokens_response(response)

@asynccontextmanager
async def _make_count_request(
self,
messages: list[ModelMessage],
model_settings: GeminiModelSettings,
model_request_parameters: ModelRequestParameters,
) -> AsyncIterator[HTTPResponse]:
tools = self._get_tools(model_request_parameters)
tool_config = self._get_tool_config(model_request_parameters, tools)
sys_prompt_parts, contents = await self._message_to_gemini_content(messages)

request_data = _GeminiCountTokensRequest(contents=contents)

generation_config = _settings_to_generation_config(model_settings)
if model_request_parameters.output_mode == 'native':
if tools:
raise UserError('Gemini does not support structured output and tools at the same time.')
generation_config['response_mime_type'] = 'application/json'
output_object = model_request_parameters.output_object
assert output_object is not None
generation_config['response_schema'] = self._map_response_schema(output_object)
elif model_request_parameters.output_mode == 'prompted' and not tools:
generation_config['response_mime_type'] = 'application/json'

if generation_config:
request_data['generateContentRequest'] = {
'contents': contents,
'generationConfig': generation_config,
}
if sys_prompt_parts:
request_data['generateContentRequest']['systemInstruction'] = _GeminiTextContent(
role='user', parts=sys_prompt_parts
)
if tools is not None:
request_data['generateContentRequest']['tools'] = tools
if tool_config is not None:
request_data['generateContentRequest']['toolConfig'] = tool_config

headers = {'Content-Type': 'application/json', 'User-Agent': get_user_agent()}
url = f'/models/{self._model_name}:countTokens'

request_json = _gemini_count_tokens_request_ta.dump_json(request_data, by_alias=True)
async with self.client.stream(
'POST',
url,
content=request_json,
headers=headers,
timeout=model_settings.get('timeout', USE_CLIENT_DEFAULT),
) as r:
if (status_code := r.status_code) != 200:
await r.aread()
if status_code >= 400:
raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=r.text)
raise UnexpectedModelBehavior( # pragma: no cover
f'Unexpected response from gemini {status_code}', r.text
)
yield r

def _process_count_tokens_response(self, response: _GeminiCountTokensResponse) -> usage.Usage:
details: dict[str, int] = {}
if cached_content_token_count := response.get('cachedContentTokenCount'):
details['cached_content_tokens'] = cached_content_token_count

for key, metadata_details in response.items():
if key.endswith('TokensDetails') and metadata_details:
metadata_details = cast(list[_GeminiModalityTokenCount], metadata_details)
suffix = key.removesuffix('TokensDetails').lower()
for detail in metadata_details:
details[f'{detail["modality"].lower()}_{suffix}'] = detail['token_count']

return usage.Usage(
request_tokens=response.get('totalTokens', 0),
response_tokens=0, # countTokens does not provide response tokens
total_tokens=response.get('totalTokens', 0),
details=details,
)


def _settings_to_generation_config(model_settings: GeminiModelSettings) -> _GeminiGenerationConfig:
config: _GeminiGenerationConfig = {}
Expand Down Expand Up @@ -809,6 +898,34 @@ class _GeminiResponse(TypedDict):
vendor_id: NotRequired[Annotated[str, pydantic.Field(alias='responseId')]]


@pydantic.with_config(pydantic.ConfigDict(defer_build=True))
class _GeminiCountTokensRequest(TypedDict):
"""Schema for a countTokens API request to the Gemini API.

See <https://ai.google.dev/api/tokens#endpoint> for API docs.
"""

contents: NotRequired[list[_GeminiContent]]
generateContentRequest: NotRequired[_GeminiRequest]


@pydantic.with_config(pydantic.ConfigDict(defer_build=True))
class _GeminiCountTokensResponse(TypedDict):
"""Schema for the response from the Gemini countTokens API.

See <https://ai.google.dev/api/tokens#endpoint> for API docs.
"""

totalTokens: int
cachedContentTokenCount: NotRequired[int]
promptTokensDetails: NotRequired[list[_GeminiModalityTokenCount]]
cacheTokensDetails: NotRequired[list[_GeminiModalityTokenCount]]


_gemini_count_tokens_request_ta = pydantic.TypeAdapter(_GeminiCountTokensRequest)
_gemini_count_tokens_response_ta = pydantic.TypeAdapter(_GeminiCountTokensResponse)


class _GeminiCandidates(TypedDict):
"""See <https://ai.google.dev/api/generate-content#v1beta.Candidate>."""

Expand Down
Loading