diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 64008622b..ee6a659b8 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -397,6 +397,95 @@ def _map_response_schema(self, o: OutputObjectDefinition) -> dict[str, Any]: return response_schema + async def count_tokens( + self, + messages: list[ModelMessage], + model_settings: GeminiModelSettings | None, + model_request_parameters: ModelRequestParameters, + ) -> usage.Usage: + check_allow_model_requests() + async with self._make_count_request(messages, model_settings or {}, model_request_parameters) as http_response: + data = await http_response.aread() + response = _gemini_count_tokens_response_ta.validate_json(data) + return self._process_count_tokens_response(response) + + @asynccontextmanager + async def _make_count_request( + self, + messages: list[ModelMessage], + model_settings: GeminiModelSettings, + model_request_parameters: ModelRequestParameters, + ) -> AsyncIterator[HTTPResponse]: + tools = self._get_tools(model_request_parameters) + tool_config = self._get_tool_config(model_request_parameters, tools) + sys_prompt_parts, contents = await self._message_to_gemini_content(messages) + + request_data = _GeminiCountTokensRequest(contents=contents) + + generation_config = _settings_to_generation_config(model_settings) + if model_request_parameters.output_mode == 'native': + if tools: + raise UserError('Gemini does not support structured output and tools at the same time.') + generation_config['response_mime_type'] = 'application/json' + output_object = model_request_parameters.output_object + assert output_object is not None + generation_config['response_schema'] = self._map_response_schema(output_object) + elif model_request_parameters.output_mode == 'prompted' and not tools: + generation_config['response_mime_type'] = 'application/json' + + if generation_config: + request_data['generateContentRequest'] = { + 'contents': contents, + 'generationConfig': generation_config, + } + if sys_prompt_parts: + request_data['generateContentRequest']['systemInstruction'] = _GeminiTextContent( + role='user', parts=sys_prompt_parts + ) + if tools is not None: + request_data['generateContentRequest']['tools'] = tools + if tool_config is not None: + request_data['generateContentRequest']['toolConfig'] = tool_config + + headers = {'Content-Type': 'application/json', 'User-Agent': get_user_agent()} + url = f'/models/{self._model_name}:countTokens' + + request_json = _gemini_count_tokens_request_ta.dump_json(request_data, by_alias=True) + async with self.client.stream( + 'POST', + url, + content=request_json, + headers=headers, + timeout=model_settings.get('timeout', USE_CLIENT_DEFAULT), + ) as r: + if (status_code := r.status_code) != 200: + await r.aread() + if status_code >= 400: + raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=r.text) + raise UnexpectedModelBehavior( # pragma: no cover + f'Unexpected response from gemini {status_code}', r.text + ) + yield r + + def _process_count_tokens_response(self, response: _GeminiCountTokensResponse) -> usage.Usage: + details: dict[str, int] = {} + if cached_content_token_count := response.get('cachedContentTokenCount'): + details['cached_content_tokens'] = cached_content_token_count + + for key, metadata_details in response.items(): + if key.endswith('TokensDetails') and metadata_details: + metadata_details = cast(list[_GeminiModalityTokenCount], metadata_details) + suffix = key.removesuffix('TokensDetails').lower() + for detail in metadata_details: + details[f'{detail["modality"].lower()}_{suffix}'] = detail['token_count'] + + return usage.Usage( + request_tokens=response.get('totalTokens', 0), + response_tokens=0, # countTokens does not provide response tokens + total_tokens=response.get('totalTokens', 0), + details=details, + ) + def _settings_to_generation_config(model_settings: GeminiModelSettings) -> _GeminiGenerationConfig: config: _GeminiGenerationConfig = {} @@ -809,6 +898,34 @@ class _GeminiResponse(TypedDict): vendor_id: NotRequired[Annotated[str, pydantic.Field(alias='responseId')]] +@pydantic.with_config(pydantic.ConfigDict(defer_build=True)) +class _GeminiCountTokensRequest(TypedDict): + """Schema for a countTokens API request to the Gemini API. + + See for API docs. + """ + + contents: NotRequired[list[_GeminiContent]] + generateContentRequest: NotRequired[_GeminiRequest] + + +@pydantic.with_config(pydantic.ConfigDict(defer_build=True)) +class _GeminiCountTokensResponse(TypedDict): + """Schema for the response from the Gemini countTokens API. + + See for API docs. + """ + + totalTokens: int + cachedContentTokenCount: NotRequired[int] + promptTokensDetails: NotRequired[list[_GeminiModalityTokenCount]] + cacheTokensDetails: NotRequired[list[_GeminiModalityTokenCount]] + + +_gemini_count_tokens_request_ta = pydantic.TypeAdapter(_GeminiCountTokensRequest) +_gemini_count_tokens_response_ta = pydantic.TypeAdapter(_GeminiCountTokensResponse) + + class _GeminiCandidates(TypedDict): """See ."""