From 6f8673522fe6e2f6f7eae7b41ac003c506577bb1 Mon Sep 17 00:00:00 2001 From: kauabh <56749351+kauabh@users.noreply.github.com> Date: Sun, 6 Jul 2025 04:27:35 +0530 Subject: [PATCH 1/9] Adding CountToken to Gemini Gemini Provides an endpoint to count token before sending an response https://ai.google.dev/api/tokens#method:-models.counttokens --- pydantic_ai_slim/pydantic_ai/models/gemini.py | 122 ++++++++++++++++++ 1 file changed, 122 insertions(+) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 64008622b..9ca1b9dcc 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -397,6 +397,104 @@ def _map_response_schema(self, o: OutputObjectDefinition) -> dict[str, Any]: return response_schema + async def count_tokens( + self, + messages: list[ModelMessage], + model_settings: GeminiModelSettings | None, + model_request_parameters: ModelRequestParameters, + ) -> usage.Usage: + check_allow_model_requests() + async with self._make_count_request(messages, model_settings or {}, model_request_parameters) as http_response: + data = await http_response.aread() + response = _gemini_count_tokens_response_ta.validate_json(data) + return self._process_count_tokens_response(response) + + @asynccontextmanager + async def _make_count_request( + self, + messages: list[ModelMessage], + model_settings: GeminiModelSettings, + model_request_parameters: ModelRequestParameters, + ) -> AsyncIterator[HTTPResponse]: + tools = self._get_tools(model_request_parameters) + tool_config = self._get_tool_config(model_request_parameters, tools) + sys_prompt_parts, contents = await self._message_to_gemini_content(messages) + + request_data = _GeminiCountTokensRequest(contents=contents) + if sys_prompt_parts: + request_data['systemInstruction'] = _GeminiTextContent(role='user', parts=sys_prompt_parts) + if tools is not None: + request_data['tools'] = tools + if tool_config is not None: + request_data['toolConfig'] = tool_config + + generation_config = _settings_to_generation_config(model_settings) + if model_request_parameters.output_mode == 'native': + if tools: + raise UserError('Gemini does not support structured output and tools at the same time.') + generation_config['response_mime_type'] = 'application/json' + output_object = model_request_parameters.output_object + assert output_object is not None + generation_config['response_schema'] = self._map_response_schema(output_object) + elif model_request_parameters.output_mode == 'prompted' and not tools: + generation_config['response_mime_type'] = 'application/json' + + if generation_config: + request_data['generateContentRequest'] = { + 'contents': contents, + 'generationConfig': generation_config, + } + if sys_prompt_parts: + request_data['generateContentRequest']['systemInstruction'] = _GeminiTextContent(role='user', parts=sys_prompt_parts) + if tools is not None: + request_data['generateContentRequest']['tools'] = tools + if tool_config is not None: + request_data['generateContentRequest']['toolConfig'] = tool_config + + if gemini_safety_settings := model_settings.get('gemini_safety_settings'): + request_data['safetySettings'] = gemini_safety_settings + + if gemini_labels := model_settings.get('gemini_labels'): + if self._system == 'google-vertex': + request_data['labels'] = gemini_labels + + headers = {'Content-Type': 'application/json', 'User-Agent': get_user_agent()} + url = f'/models/{self._model_name}:countTokens' + + request_json = _gemini_count_tokens_request_ta.dump_json(request_data, by_alias=True) + async with self.client.stream( + 'POST', + url, + content=request_json, + headers=headers, + timeout=model_settings.get('timeout', USE_CLIENT_DEFAULT), + ) as r: + if (status_code := r.status_code) != 200: + await r.aread() + if status_code >= 400: + raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=r.text) + raise UnexpectedModelBehavior( # pragma: no cover + f'Unexpected response from gemini {status_code}', r.text) + yield r + + def _process_count_tokens_response(self, response: _GeminiCountTokensResponse) -> usage.Usage: + details: dict[str, int] = {} + if cached_content_token_count := response.get('cachedContentTokenCount'): + details['cached_content_tokens'] = cached_content_token_count + + for key, metadata_details in response.items(): + if key.endswith('TokensDetails') and metadata_details: + metadata_details = cast(list[_GeminiModalityTokenCount], metadata_details) + suffix = key.removesuffix('TokensDetails').lower() + for detail in metadata_details: + details[f'{detail["modality"].lower()}_{suffix}'] = detail['token_count'] + + return usage.Usage( + request_tokens=response.get('totalTokens', 0), + response_tokens=0, # countTokens does not provide response tokens + total_tokens=response.get('totalTokens', 0), + details=details, + ) def _settings_to_generation_config(model_settings: GeminiModelSettings) -> _GeminiGenerationConfig: config: _GeminiGenerationConfig = {} @@ -809,6 +907,30 @@ class _GeminiResponse(TypedDict): vendor_id: NotRequired[Annotated[str, pydantic.Field(alias='responseId')]] +@pydantic.with_config(pydantic.ConfigDict(defer_build=True)) +class _GeminiCountTokensRequest(TypedDict): + """Schema for a countTokens API request to the Gemini API. + + See for API docs. + """ + + contents: NotRequired[list[_GeminiContent]] + generateContentRequest: NotRequired[_GeminiRequest] + + +@pydantic.with_config(pydantic.ConfigDict(defer_build=True)) +class _GeminiCountTokensResponse(TypedDict): + """Schema for the response from the Gemini countTokens API. + + See for API docs. + """ + + totalTokens: int + cachedContentTokenCount: NotRequired[int] + promptTokensDetails: NotRequired[list[_GeminiModalityTokenCount]] + cacheTokensDetails: NotRequired[list[_GeminiModalityTokenCount]] + + class _GeminiCandidates(TypedDict): """See .""" From 5cd88e0e77c975427726188f6480448e6aef2e5d Mon Sep 17 00:00:00 2001 From: kauabh <56749351+kauabh@users.noreply.github.com> Date: Sun, 6 Jul 2025 04:52:53 +0530 Subject: [PATCH 2/9] Update gemini.py added type adaptor --- pydantic_ai_slim/pydantic_ai/models/gemini.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 9ca1b9dcc..14a5eed1e 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -929,7 +929,11 @@ class _GeminiCountTokensResponse(TypedDict): cachedContentTokenCount: NotRequired[int] promptTokensDetails: NotRequired[list[_GeminiModalityTokenCount]] cacheTokensDetails: NotRequired[list[_GeminiModalityTokenCount]] - + + +_gemini_count_tokens_request_ta = pydantic.TypeAdapter(_GeminiCountTokensRequest) +_gemini_count_tokens_response_ta = pydantic.TypeAdapter(_GeminiCountTokensResponse) + class _GeminiCandidates(TypedDict): """See .""" From a30234527a8dbe0b51a5c8e9752aed31648f6921 Mon Sep 17 00:00:00 2001 From: kauabh <56749351+kauabh@users.noreply.github.com> Date: Sun, 6 Jul 2025 05:06:24 +0530 Subject: [PATCH 3/9] Update gemini.py Removed extra assignment --- pydantic_ai_slim/pydantic_ai/models/gemini.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 14a5eed1e..f2199af09 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -421,12 +421,6 @@ async def _make_count_request( sys_prompt_parts, contents = await self._message_to_gemini_content(messages) request_data = _GeminiCountTokensRequest(contents=contents) - if sys_prompt_parts: - request_data['systemInstruction'] = _GeminiTextContent(role='user', parts=sys_prompt_parts) - if tools is not None: - request_data['tools'] = tools - if tool_config is not None: - request_data['toolConfig'] = tool_config generation_config = _settings_to_generation_config(model_settings) if model_request_parameters.output_mode == 'native': @@ -451,13 +445,6 @@ async def _make_count_request( if tool_config is not None: request_data['generateContentRequest']['toolConfig'] = tool_config - if gemini_safety_settings := model_settings.get('gemini_safety_settings'): - request_data['safetySettings'] = gemini_safety_settings - - if gemini_labels := model_settings.get('gemini_labels'): - if self._system == 'google-vertex': - request_data['labels'] = gemini_labels - headers = {'Content-Type': 'application/json', 'User-Agent': get_user_agent()} url = f'/models/{self._model_name}:countTokens' From 3b2e26ab82901c3af0e0dc23d4c00518d285ff95 Mon Sep 17 00:00:00 2001 From: kauabh <56749351+kauabh@users.noreply.github.com> Date: Sun, 6 Jul 2025 05:15:41 +0530 Subject: [PATCH 4/9] Update gemini.py Linting --- pydantic_ai_slim/pydantic_ai/models/gemini.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index f2199af09..917fd698e 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -401,8 +401,7 @@ async def count_tokens( self, messages: list[ModelMessage], model_settings: GeminiModelSettings | None, - model_request_parameters: ModelRequestParameters, - ) -> usage.Usage: + model_request_parameters: ModelRequestParameters,) -> usage.Usage: check_allow_model_requests() async with self._make_count_request(messages, model_settings or {}, model_request_parameters) as http_response: data = await http_response.aread() @@ -415,7 +414,7 @@ async def _make_count_request( messages: list[ModelMessage], model_settings: GeminiModelSettings, model_request_parameters: ModelRequestParameters, - ) -> AsyncIterator[HTTPResponse]: + ) -> AsyncIterator[HTTPResponse]: tools = self._get_tools(model_request_parameters) tool_config = self._get_tool_config(model_request_parameters, tools) sys_prompt_parts, contents = await self._message_to_gemini_content(messages) @@ -439,7 +438,9 @@ async def _make_count_request( 'generationConfig': generation_config, } if sys_prompt_parts: - request_data['generateContentRequest']['systemInstruction'] = _GeminiTextContent(role='user', parts=sys_prompt_parts) + request_data['generateContentRequest']['systemInstruction'] = _GeminiTextContent( + role='user', parts=sys_prompt_parts + ) if tools is not None: request_data['generateContentRequest']['tools'] = tools if tool_config is not None: @@ -460,7 +461,7 @@ async def _make_count_request( await r.aread() if status_code >= 400: raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=r.text) - raise UnexpectedModelBehavior( # pragma: no cover + raise UnexpectedModelBehavior( # pragma: no cover f'Unexpected response from gemini {status_code}', r.text) yield r From dc4d29b91e219a9d668bab21644c4fe68cfd0ebc Mon Sep 17 00:00:00 2001 From: kauabh <56749351+kauabh@users.noreply.github.com> Date: Sun, 6 Jul 2025 05:21:12 +0530 Subject: [PATCH 5/9] Update gemini.py Linting --- pydantic_ai_slim/pydantic_ai/models/gemini.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 917fd698e..dadc34feb 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -401,7 +401,8 @@ async def count_tokens( self, messages: list[ModelMessage], model_settings: GeminiModelSettings | None, - model_request_parameters: ModelRequestParameters,) -> usage.Usage: + model_request_parameters: ModelRequestParameters, + ) -> usage.Usage: check_allow_model_requests() async with self._make_count_request(messages, model_settings or {}, model_request_parameters) as http_response: data = await http_response.aread() @@ -462,7 +463,8 @@ async def _make_count_request( if status_code >= 400: raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=r.text) raise UnexpectedModelBehavior( # pragma: no cover - f'Unexpected response from gemini {status_code}', r.text) + f'Unexpected response from gemini {status_code}', r.text + ) yield r def _process_count_tokens_response(self, response: _GeminiCountTokensResponse) -> usage.Usage: From 16f18dcb699c919991f075a152a84661f36c13ef Mon Sep 17 00:00:00 2001 From: kauabh <56749351+kauabh@users.noreply.github.com> Date: Sun, 6 Jul 2025 05:32:00 +0530 Subject: [PATCH 6/9] Update gemini.py --- pydantic_ai_slim/pydantic_ai/models/gemini.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index dadc34feb..4b3f1b8ca 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -485,6 +485,7 @@ def _process_count_tokens_response(self, response: _GeminiCountTokensResponse) - total_tokens=response.get('totalTokens', 0), details=details, ) + def _settings_to_generation_config(model_settings: GeminiModelSettings) -> _GeminiGenerationConfig: config: _GeminiGenerationConfig = {} @@ -911,7 +912,6 @@ class _GeminiCountTokensRequest(TypedDict): @pydantic.with_config(pydantic.ConfigDict(defer_build=True)) class _GeminiCountTokensResponse(TypedDict): """Schema for the response from the Gemini countTokens API. - See for API docs. """ From 24d6c25a446916ed1cf2dbe6844930061e52b5cf Mon Sep 17 00:00:00 2001 From: kauabh <56749351+kauabh@users.noreply.github.com> Date: Sun, 6 Jul 2025 05:42:23 +0530 Subject: [PATCH 7/9] Update gemini.py --- pydantic_ai_slim/pydantic_ai/models/gemini.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 4b3f1b8ca..b23e80548 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -466,7 +466,7 @@ async def _make_count_request( f'Unexpected response from gemini {status_code}', r.text ) yield r - + def _process_count_tokens_response(self, response: _GeminiCountTokensResponse) -> usage.Usage: details: dict[str, int] = {} if cached_content_token_count := response.get('cachedContentTokenCount'): @@ -485,7 +485,7 @@ def _process_count_tokens_response(self, response: _GeminiCountTokensResponse) - total_tokens=response.get('totalTokens', 0), details=details, ) - + def _settings_to_generation_config(model_settings: GeminiModelSettings) -> _GeminiGenerationConfig: config: _GeminiGenerationConfig = {} @@ -907,7 +907,7 @@ class _GeminiCountTokensRequest(TypedDict): contents: NotRequired[list[_GeminiContent]] generateContentRequest: NotRequired[_GeminiRequest] - + @pydantic.with_config(pydantic.ConfigDict(defer_build=True)) class _GeminiCountTokensResponse(TypedDict): From 90fc8bbaca3172cecc4218d1605b9a0596887907 Mon Sep 17 00:00:00 2001 From: kauabh <56749351+kauabh@users.noreply.github.com> Date: Sun, 6 Jul 2025 05:46:24 +0530 Subject: [PATCH 8/9] Update gemini.py Linting --- pydantic_ai_slim/pydantic_ai/models/gemini.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index b23e80548..041a01f75 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -912,6 +912,7 @@ class _GeminiCountTokensRequest(TypedDict): @pydantic.with_config(pydantic.ConfigDict(defer_build=True)) class _GeminiCountTokensResponse(TypedDict): """Schema for the response from the Gemini countTokens API. + See for API docs. """ From 2bfc8d083c2ef1e4e7edfbbdd5dfe8813a0a2be5 Mon Sep 17 00:00:00 2001 From: kauabh <56749351+kauabh@users.noreply.github.com> Date: Sun, 6 Jul 2025 05:49:18 +0530 Subject: [PATCH 9/9] Update gemini.py Removed White Space --- pydantic_ai_slim/pydantic_ai/models/gemini.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 041a01f75..ee6a659b8 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -912,7 +912,7 @@ class _GeminiCountTokensRequest(TypedDict): @pydantic.with_config(pydantic.ConfigDict(defer_build=True)) class _GeminiCountTokensResponse(TypedDict): """Schema for the response from the Gemini countTokens API. - + See for API docs. """