Skip to content

Count tokens #315

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 34 additions & 12 deletions google/generativeai/generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,35 +322,57 @@ async def generate_content_async(
# fmt: off
def count_tokens(
self,
contents: content_types.ContentsType,
contents: content_types.ContentsType = None,
*,
generation_config: generation_types.GenerationConfigType | None = None,
safety_settings: safety_types.SafetySettingOptions | None = None,
tools: content_types.FunctionLibraryType | None = None,
tool_config: content_types.ToolConfigType | None = None,
request_options: dict[str, Any] | None = None,
) -> glm.CountTokensResponse:
if request_options is None:
request_options = {}

if self._client is None:
self._client = client.get_default_generative_client()
contents = content_types.to_contents(contents)
return self._client.count_tokens(
glm.CountTokensRequest(model=self.model_name, contents=contents),
**request_options,
)

request = glm.CountTokensRequest(
model=self.model_name,
generate_content_request=self._prepare_request(
contents=contents,
generation_config=generation_config,
safety_settings=safety_settings,
tools=tools,
tool_config=tool_config,
))
return self._client.count_tokens(request, **request_options)

async def count_tokens_async(
self,
contents: content_types.ContentsType,
contents: content_types.ContentsType = None,
*,
generation_config: generation_types.GenerationConfigType | None = None,
safety_settings: safety_types.SafetySettingOptions | None = None,
tools: content_types.FunctionLibraryType | None = None,
tool_config: content_types.ToolConfigType | None = None,
request_options: dict[str, Any] | None = None,
) -> glm.CountTokensResponse:
if request_options is None:
request_options = {}

if self._async_client is None:
self._async_client = client.get_default_generative_async_client()
contents = content_types.to_contents(contents)
return await self._async_client.count_tokens(
glm.CountTokensRequest(model=self.model_name, contents=contents),
**request_options,
)

request = glm.CountTokensRequest(
model=self.model_name,
generate_content_request=self._prepare_request(
contents=contents,
generation_config=generation_config,
safety_settings=safety_settings,
tools=tools,
tool_config=tool_config,
))
return await self._async_client.count_tokens(request, **request_options)

# fmt: on

Expand Down
16 changes: 14 additions & 2 deletions google/generativeai/types/generation_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@

import collections
import contextlib
import sys
from collections.abc import Iterable, AsyncIterable
import dataclasses
import itertools
import json
import sys
import textwrap
from typing import Union
from typing_extensions import TypedDict
Expand Down Expand Up @@ -250,6 +251,7 @@ def _join_candidates(candidates: Iterable[glm.Candidate]):
finish_reason=candidates[-1].finish_reason,
safety_ratings=_join_safety_ratings_lists([c.safety_ratings for c in candidates]),
citation_metadata=_join_citation_metadatas([c.citation_metadata for c in candidates]),
token_count=candidates[-1].token_count,
)


Expand All @@ -276,9 +278,11 @@ def _join_prompt_feedbacks(


def _join_chunks(chunks: Iterable[glm.GenerateContentResponse]):
chunks = tuple(chunks)
return glm.GenerateContentResponse(
candidates=_join_candidate_lists(c.candidates for c in chunks),
prompt_feedback=_join_prompt_feedbacks(c.prompt_feedback for c in chunks),
usage_metadata=chunks[-1].usage_metadata,
)


Expand Down Expand Up @@ -373,13 +377,21 @@ def text(self):
def prompt_feedback(self):
return self._result.prompt_feedback

@property
def usage_metadata(self):
return self._result.usage_metadata

def __str__(self) -> str:
if self._done:
_iterator = "None"
else:
_iterator = f"<{self._iterator.__class__.__name__}>"

_result = f"glm.GenerateContentResponse({type(self._result).to_dict(self._result)})"
as_dict = type(self._result).to_dict(self._result)
json_str = json.dumps(as_dict, indent=2)

_result = f"glm.GenerateContentResponse({json_str})"
_result = _result.replace("\n", "\n ")

if self._error:
_error = f",\nerror=<{self._error.__class__.__name__}> {self._error}"
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_version():
release_status = "Development Status :: 5 - Production/Stable"

dependencies = [
"google-ai-generativelanguage==0.6.2",
"google-ai-generativelanguage@https://storage.googleapis.com/generativeai-downloads/preview/ai-generativelanguage-v1beta-py-2.tar.gz",
"google-api-core",
"google-api-python-client",
"google-auth>=2.15.0", # 2.15 adds API key auth support
Expand Down
38 changes: 36 additions & 2 deletions tests/test_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,24 @@ def test_repr_for_generate_content_response_from_response(self):
GenerateContentResponse(
done=True,
iterator=None,
result=glm.GenerateContentResponse({'candidates': [{'content': {'parts': [{'text': 'Hello world!'}], 'role': ''}, 'finish_reason': 0, 'safety_ratings': [], 'token_count': 0, 'grounding_attributions': []}]}),
result=glm.GenerateContentResponse({
"candidates": [
{
"content": {
"parts": [
{
"text": "Hello world!"
}
],
"role": ""
},
"finish_reason": 0,
"safety_ratings": [],
"token_count": 0,
"grounding_attributions": []
}
]
}),
)"""
)
self.assertEqual(expected, result)
Expand All @@ -522,7 +539,24 @@ def test_repr_for_generate_content_response_from_iterator(self):
GenerateContentResponse(
done=False,
iterator=<list_iterator>,
result=glm.GenerateContentResponse({'candidates': [{'content': {'parts': [{'text': 'a'}], 'role': ''}, 'finish_reason': 0, 'safety_ratings': [], 'token_count': 0, 'grounding_attributions': []}]}),
result=glm.GenerateContentResponse({
"candidates": [
{
"content": {
"parts": [
{
"text": "a"
}
],
"role": ""
},
"finish_reason": 0,
"safety_ratings": [],
"token_count": 0,
"grounding_attributions": []
}
]
}),
)"""
)
self.assertEqual(expected, result)
Expand Down
Loading
Loading