Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions camel/models/stub_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,19 @@


class StubTokenCounter(BaseTokenCounter):
def extract_usage_from_response(
self, response: Any
) -> Optional[Dict[str, int]]:
r"""Stub implementation - always returns fixed usage
data for testing.

"""
return {
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 15,
}

def count_tokens_from_messages(self, messages: List[OpenAIMessage]) -> int:
r"""Token counting for STUB models, directly returning a constant.

Expand Down
259 changes: 257 additions & 2 deletions camel/utils/token_counting.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,16 @@
from abc import ABC, abstractmethod
from io import BytesIO
from math import ceil
from typing import TYPE_CHECKING, List, Optional
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Dict,
Iterator,
List,
Optional,
Union,
)

from PIL import Image

Expand Down Expand Up @@ -77,10 +86,137 @@ def get_model_encoding(value_for_tiktoken: str):
class BaseTokenCounter(ABC):
r"""Base class for token counters of different kinds of models."""

@abstractmethod
def extract_usage_from_response(
self, response: Any
) -> Optional[Dict[str, int]]:
r"""Extract native usage data from model response.

Args:
response: The response object from the model API call.

Returns:
Dict with keys: prompt_tokens, completion_tokens, total_tokens,
and optionally cached_tokens or cache-related fields if supported
by the provider.
None if usage data not available
"""
pass

def extract_usage_from_streaming_response(
self, stream: Union[Iterator[Any], AsyncIterator[Any]]
) -> Optional[Dict[str, int]]:
r"""Extract native usage data from streaming response.

This method processes a streaming response to find usage data,
typically available in the final chunk when stream_options
include_usage is enabled.

Args:
stream: Iterator or AsyncIterator of streaming response chunks

Returns:
Dict with keys: prompt_tokens, completion_tokens, total_tokens
None if usage data not available
"""
try:
# For sync streams
if hasattr(stream, '__iter__') and not hasattr(
stream, '__aiter__'
):
return self._extract_usage_from_sync_stream(stream)
# For async streams
elif hasattr(stream, '__aiter__'):
logger.warning(
"Async stream detected but sync method called. "
"Use extract_usage_from_async_streaming_response instead."
)
return None
else:
logger.debug("Unsupported stream type for usage extraction")
return None
except Exception as e:
logger.debug(
f"Failed to extract usage from streaming response: {e}"
)
return None

async def extract_usage_from_async_streaming_response(
self, stream: AsyncIterator[Any]
) -> Optional[Dict[str, int]]:
r"""Extract native usage data from async streaming response.

Args:
stream: AsyncIterator of streaming response chunks

Returns:
Dict with keys: prompt_tokens, completion_tokens, total_tokens
None if usage data not available
"""
try:
return await self._extract_usage_from_async_stream(stream)
except Exception as e:
logger.debug(
f"Failed to extract usage from async streaming response: {e}"
)
return None

def _extract_usage_from_sync_stream(
self, stream: Iterator[Any]
) -> Optional[Dict[str, int]]:
r"""Extract usage from a synchronous streaming response.

Args:
stream (Iterator[Any]): Provider-specific synchronous stream iterator.
Returns:
Optional[Dict[str, int]]: Usage with `prompt_tokens`, `completion_tokens`,
`total_tokens`, or None if unavailable.
"""
final_chunk = None
try:
for chunk in stream:
final_chunk = chunk
usage = self.extract_usage_from_response(chunk)
if usage:
return usage

if final_chunk:
return self.extract_usage_from_response(final_chunk)

except Exception as e:
logger.debug(f"Error processing sync stream: {e}")

return None

async def _extract_usage_from_async_stream(
self, stream: AsyncIterator[Any]
) -> Optional[Dict[str, int]]:
r"""Extract usage from asynchronous stream by consuming all chunks."""
final_chunk = None
try:
async for chunk in stream:
final_chunk = chunk
usage = self.extract_usage_from_response(chunk)
if usage:
return usage

if final_chunk:
return self.extract_usage_from_response(final_chunk)

except Exception as e:
logger.debug(f"Error processing async stream: {e}")

return None

@abstractmethod
def count_tokens_from_messages(self, messages: List[OpenAIMessage]) -> int:
r"""Count number of tokens in the provided message list.

.. note::
This method provides estimation-based token counting.
For more accurate token counts from actual API responses,
use :meth:`extract_usage_from_response` when possible.

Args:
messages (List[OpenAIMessage]): Message list with the chat history
in OpenAI API format.
Expand Down Expand Up @@ -161,6 +297,43 @@ def __init__(self, model: UnifiedModelType):

self.encoding = get_model_encoding(self.model)

def extract_usage_from_response(
self, response: Any
) -> Optional[Dict[str, int]]:
r"""Extract native usage data from OpenAI response.

Args:
response: OpenAI response object (ChatCompletion or similar)

Returns:
Dict with keys: prompt_tokens, completion_tokens, total_tokens,
cached_tokens (if available)
None if usage data not available
"""
try:
if hasattr(response, 'usage') and response.usage is not None:
usage = response.usage
result = {
'prompt_tokens': getattr(usage, 'prompt_tokens', 0),
'completion_tokens': getattr(
usage, 'completion_tokens', 0
),
'total_tokens': getattr(usage, 'total_tokens', 0),
Comment on lines +317 to +321
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be better also record cached token

}
# Include cached_tokens if available (for prompt caching)
if hasattr(usage, 'prompt_tokens_details'):
details = usage.prompt_tokens_details
if hasattr(details, 'cached_tokens'):
result['cached_tokens'] = getattr(
details, 'cached_tokens', 0
)
return result

except Exception as e:
logger.debug(f"Failed to extract usage from OpenAI response: {e}")

return None

def count_tokens_from_messages(self, messages: List[OpenAIMessage]) -> int:
r"""Count number of tokens in the provided message list with the
help of package tiktoken.
Expand Down Expand Up @@ -314,6 +487,47 @@ def __init__(
self.client = Anthropic(api_key=api_key, base_url=base_url)
self.model = model

def extract_usage_from_response(
self, response: Any
) -> Optional[Dict[str, int]]:
r"""Extract native usage data from Anthropic response.

Args:
response: Anthropic response object (Message or similar)

Returns:
Dict with keys: prompt_tokens, completion_tokens, total_tokens,
cache_creation_input_tokens, cache_read_input_tokens (if available)
None if usage data not available
"""
try:
if hasattr(response, 'usage') and response.usage is not None:
usage = response.usage
input_tokens = getattr(usage, 'input_tokens', 0)
output_tokens = getattr(usage, 'output_tokens', 0)
result = {
'prompt_tokens': input_tokens,
'completion_tokens': output_tokens,
'total_tokens': input_tokens + output_tokens,
}
# Include Anthropic prompt caching fields if available
cache_creation = getattr(
usage, 'cache_creation_input_tokens', None
)
cache_read = getattr(usage, 'cache_read_input_tokens', None)
if cache_creation is not None:
result['cache_creation_input_tokens'] = cache_creation
if cache_read is not None:
result['cache_read_input_tokens'] = cache_read
return result

except Exception as e:
logger.debug(
f"Failed to extract usage from Anthropic response: {e}"
)

return None

@dependencies_required('anthropic')
def count_tokens_from_messages(self, messages: List[OpenAIMessage]) -> int:
r"""Count number of tokens in the provided message list using
Expand Down Expand Up @@ -367,7 +581,7 @@ def decode(self, token_ids: List[int]) -> str:
)


class LiteLLMTokenCounter(BaseTokenCounter):
class LiteLLMTokenCounter(OpenAITokenCounter):
def __init__(self, model_type: UnifiedModelType):
r"""Constructor for the token counter for LiteLLM models.

Expand Down Expand Up @@ -395,6 +609,9 @@ def completion_cost(self):
self._completion_cost = completion_cost
return self._completion_cost

# Inherit extract_usage_from_response from OpenAITokenCounter since
# LiteLLM standardizes usage format to OpenAI-compatible schema.

def count_tokens_from_messages(self, messages: List[OpenAIMessage]) -> int:
r"""Count number of tokens in the provided message list using
the tokenizer specific to this type of model.
Expand Down Expand Up @@ -473,6 +690,44 @@ def __init__(self, model_type: ModelType):

self.tokenizer = MistralTokenizer.from_model(model_name)

def extract_usage_from_response(
self, response: Any
) -> Optional[Dict[str, int]]:
Comment on lines +693 to +695
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems duplicated, as it's already defined in base class

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same reasoning as above

r"""Extract native usage data from Mistral response.

Args:
response: Mistral response object

Returns:
Dict with keys: prompt_tokens, completion_tokens, total_tokens
None if usage data not available
"""
try:
if hasattr(response, 'usage') and response.usage is not None:
usage = response.usage
prompt_tokens = getattr(usage, 'prompt_tokens', 0)
completion_tokens = getattr(usage, 'completion_tokens', 0)
total_tokens = getattr(
usage, 'total_tokens', prompt_tokens + completion_tokens
)
result = {
'prompt_tokens': prompt_tokens,
'completion_tokens': completion_tokens,
'total_tokens': total_tokens,
}
# Include cached tokens if available (for prompt caching)
if hasattr(usage, 'prompt_tokens_details'):
details = usage.prompt_tokens_details
cached = getattr(details, 'cached_tokens', None)
if cached is not None:
result['cached_tokens'] = cached
return result

except Exception as e:
logger.debug(f"Failed to extract usage from Mistral response: {e}")

return None

def count_tokens_from_messages(self, messages: List[OpenAIMessage]) -> int:
r"""Count number of tokens in the provided message list using
loaded tokenizer specific for this type of model.
Expand Down
58 changes: 58 additions & 0 deletions examples/token_counting/token_counting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========

from camel.agents import ChatAgent
from camel.messages import BaseMessage
from camel.models import ModelFactory
from camel.types import ModelPlatformType, ModelType

model = ModelFactory.create(
model_platform=ModelPlatformType.OPENAI,
model_type=ModelType.GPT_4O_MINI,
)

agent = ChatAgent(
system_message="You are a helpful assistant.",
model=model,
)

user_msg = BaseMessage.make_user_message(
role_name="User",
content="What is 2+2?",
)

response = agent.step(user_msg)

print(f"User: {user_msg.content}")
print(f"Assistant: {response.msg.content}")

# Extract token usage from response
if response.info and 'usage' in response.info:
usage = response.info['usage']
print("\nToken Usage:")
print(f" Prompt tokens: {usage.get('prompt_tokens', 0)}")
print(f" Completion tokens: {usage.get('completion_tokens', 0)}")
print(f" Total tokens: {usage.get('total_tokens', 0)}")

'''
===============================================================================
User: What is 2+2?
Assistant: 2 + 2 equals 4.

Token Usage:
Prompt tokens: 24
Completion tokens: 8
Total tokens: 32
===============================================================================
'''
Loading