Skip to content

Commit ef07d77

Browse files
authored
Fix Anthropic streaming (#1686)
1 parent 88ad258 commit ef07d77

File tree

3 files changed

+98
-28
lines changed

3 files changed

+98
-28
lines changed

pydantic_ai_slim/pydantic_ai/models/anthropic.py

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -393,36 +393,31 @@ def _map_tool_definition(f: ToolDefinition) -> ToolParam:
393393
def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> usage.Usage:
394394
if isinstance(message, AnthropicMessage):
395395
response_usage = message.usage
396+
elif isinstance(message, RawMessageStartEvent):
397+
response_usage = message.message.usage
398+
elif isinstance(message, RawMessageDeltaEvent):
399+
response_usage = message.usage
396400
else:
397-
if isinstance(message, RawMessageStartEvent):
398-
response_usage = message.message.usage
399-
elif isinstance(message, RawMessageDeltaEvent):
400-
response_usage = message.usage
401-
else:
402-
# No usage information provided in:
403-
# - RawMessageStopEvent
404-
# - RawContentBlockStartEvent
405-
# - RawContentBlockDeltaEvent
406-
# - RawContentBlockStopEvent
407-
response_usage = None
408-
409-
if response_usage is None:
401+
# No usage information provided in:
402+
# - RawMessageStopEvent
403+
# - RawContentBlockStartEvent
404+
# - RawContentBlockDeltaEvent
405+
# - RawContentBlockStopEvent
410406
return usage.Usage()
411407

412-
# Store all integer-typed usage values in the details dict
413-
response_usage_dict = response_usage.model_dump()
414-
details: dict[str, int] = {}
415-
for key, value in response_usage_dict.items():
416-
if isinstance(value, int):
417-
details[key] = value
408+
# Store all integer-typed usage values in the details, except 'output_tokens' which is represented exactly by
409+
# `response_tokens`
410+
details: dict[str, int] = {
411+
key: value for key, value in response_usage.model_dump().items() if isinstance(value, int)
412+
}
418413

419-
# Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence the getattr call
414+
# Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence using `get`
420415
# Tokens are only counted once between input_tokens, cache_creation_input_tokens, and cache_read_input_tokens
421416
# This approach maintains request_tokens as the count of all input tokens, with cached counts as details
422417
request_tokens = (
423-
getattr(response_usage, 'input_tokens', 0)
424-
+ (getattr(response_usage, 'cache_creation_input_tokens', 0) or 0) # These can be missing, None, or int
425-
+ (getattr(response_usage, 'cache_read_input_tokens', 0) or 0)
418+
details.get('input_tokens', 0)
419+
+ details.get('cache_creation_input_tokens', 0)
420+
+ details.get('cache_read_input_tokens', 0)
426421
)
427422

428423
return usage.Usage(

tests/models/test_anthropic.py

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from dataclasses import dataclass, field
77
from datetime import timezone
88
from functools import cached_property
9-
from typing import Any, TypeVar, Union, cast
9+
from typing import Any, Callable, TypeVar, Union, cast
1010

1111
import httpx
1212
import pytest
@@ -52,7 +52,11 @@
5252
)
5353
from anthropic.types.raw_message_delta_event import Delta
5454

55-
from pydantic_ai.models.anthropic import AnthropicModel, AnthropicModelSettings
55+
from pydantic_ai.models.anthropic import (
56+
AnthropicModel,
57+
AnthropicModelSettings,
58+
_map_usage, # pyright: ignore[reportPrivateUsage]
59+
)
5660
from pydantic_ai.providers.anthropic import AnthropicProvider
5761

5862
# note: we use Union here so that casting works with Python 3.9
@@ -921,3 +925,74 @@ def simple_instructions():
921925
),
922926
]
923927
)
928+
929+
930+
def anth_msg(usage: AnthropicUsage) -> AnthropicMessage:
931+
return AnthropicMessage(
932+
id='x',
933+
content=[],
934+
model='claude-3-7-sonnet-latest',
935+
role='assistant',
936+
type='message',
937+
usage=usage,
938+
)
939+
940+
941+
@pytest.mark.parametrize(
942+
'message_callback,usage',
943+
[
944+
pytest.param(
945+
lambda: anth_msg(AnthropicUsage(input_tokens=1, output_tokens=1)),
946+
snapshot(
947+
Usage(
948+
request_tokens=1, response_tokens=1, total_tokens=2, details={'input_tokens': 1, 'output_tokens': 1}
949+
)
950+
),
951+
id='AnthropicMessage',
952+
),
953+
pytest.param(
954+
lambda: anth_msg(
955+
AnthropicUsage(
956+
input_tokens=1, output_tokens=1, cache_creation_input_tokens=2, cache_read_input_tokens=3
957+
)
958+
),
959+
snapshot(
960+
Usage(
961+
request_tokens=6,
962+
response_tokens=1,
963+
total_tokens=7,
964+
details={
965+
'cache_creation_input_tokens': 2,
966+
'cache_read_input_tokens': 3,
967+
'input_tokens': 1,
968+
'output_tokens': 1,
969+
},
970+
)
971+
),
972+
id='AnthropicMessage-cached',
973+
),
974+
pytest.param(
975+
lambda: RawMessageStartEvent(
976+
message=anth_msg(AnthropicUsage(input_tokens=1, output_tokens=1)), type='message_start'
977+
),
978+
snapshot(
979+
Usage(
980+
request_tokens=1, response_tokens=1, total_tokens=2, details={'input_tokens': 1, 'output_tokens': 1}
981+
)
982+
),
983+
id='RawMessageStartEvent',
984+
),
985+
pytest.param(
986+
lambda: RawMessageDeltaEvent(
987+
delta=Delta(),
988+
usage=MessageDeltaUsage(output_tokens=5),
989+
type='message_delta',
990+
),
991+
snapshot(Usage(response_tokens=5, total_tokens=5, details={'output_tokens': 5})),
992+
id='RawMessageDeltaEvent',
993+
),
994+
pytest.param(lambda: RawMessageStopEvent(type='message_stop'), snapshot(Usage()), id='RawMessageStopEvent'),
995+
],
996+
)
997+
def test_usage(message_callback: Callable[[], AnthropicMessage | RawMessageStreamEvent], usage: Usage):
998+
assert _map_usage(message_callback()) == usage

uv.lock

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)