Skip to content

Commit 4f257bd

Browse files
authored
Add token usage metrics to InstrumentedModel (#1898)
1 parent 78e006c commit 4f257bd

File tree

3 files changed

+172
-39
lines changed

3 files changed

+172
-39
lines changed

pydantic_ai_slim/pydantic_ai/models/instrumented.py

Lines changed: 98 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
EventLoggerProvider, # pyright: ignore[reportPrivateImportUsage]
1414
get_event_logger_provider, # pyright: ignore[reportPrivateImportUsage]
1515
)
16+
from opentelemetry.metrics import MeterProvider, get_meter_provider
1617
from opentelemetry.trace import Span, Tracer, TracerProvider, get_tracer_provider
1718
from opentelemetry.util.types import AttributeValue
1819
from pydantic import TypeAdapter
@@ -49,6 +50,10 @@
4950

5051
ANY_ADAPTER = TypeAdapter[Any](Any)
5152

53+
# These are in the spec:
54+
# https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-metrics/#metric-gen_aiclienttokenusage
55+
TOKEN_HISTOGRAM_BOUNDARIES = (1, 4, 16, 64, 256, 1024, 4096, 16384, 65536, 262144, 1048576, 4194304, 16777216, 67108864)
56+
5257

5358
def instrument_model(model: Model, instrument: InstrumentationSettings | bool) -> Model:
5459
"""Instrument a model with OpenTelemetry/logfire."""
@@ -84,6 +89,7 @@ def __init__(
8489
*,
8590
event_mode: Literal['attributes', 'logs'] = 'attributes',
8691
tracer_provider: TracerProvider | None = None,
92+
meter_provider: MeterProvider | None = None,
8793
event_logger_provider: EventLoggerProvider | None = None,
8894
include_binary_content: bool = True,
8995
):
@@ -95,6 +101,9 @@ def __init__(
95101
tracer_provider: The OpenTelemetry tracer provider to use.
96102
If not provided, the global tracer provider is used.
97103
Calling `logfire.configure()` sets the global tracer provider, so most users don't need this.
104+
meter_provider: The OpenTelemetry meter provider to use.
105+
If not provided, the global meter provider is used.
106+
Calling `logfire.configure()` sets the global meter provider, so most users don't need this.
98107
event_logger_provider: The OpenTelemetry event logger provider to use.
99108
If not provided, the global event logger provider is used.
100109
Calling `logfire.configure()` sets the global event logger provider, so most users don't need this.
@@ -104,12 +113,33 @@ def __init__(
104113
from pydantic_ai import __version__
105114

106115
tracer_provider = tracer_provider or get_tracer_provider()
116+
meter_provider = meter_provider or get_meter_provider()
107117
event_logger_provider = event_logger_provider or get_event_logger_provider()
108-
self.tracer = tracer_provider.get_tracer('pydantic-ai', __version__)
109-
self.event_logger = event_logger_provider.get_event_logger('pydantic-ai', __version__)
118+
scope_name = 'pydantic-ai'
119+
self.tracer = tracer_provider.get_tracer(scope_name, __version__)
120+
self.meter = meter_provider.get_meter(scope_name, __version__)
121+
self.event_logger = event_logger_provider.get_event_logger(scope_name, __version__)
110122
self.event_mode = event_mode
111123
self.include_binary_content = include_binary_content
112124

125+
# As specified in the OpenTelemetry GenAI metrics spec:
126+
# https://opentelemetry.io/docs/specs/semconv/gen-ai/gen-ai-metrics/#metric-gen_aiclienttokenusage
127+
tokens_histogram_kwargs = dict(
128+
name='gen_ai.client.token.usage',
129+
unit='{token}',
130+
description='Measures number of input and output tokens used',
131+
)
132+
try:
133+
self.tokens_histogram = self.meter.create_histogram(
134+
**tokens_histogram_kwargs,
135+
explicit_bucket_boundaries_advisory=TOKEN_HISTOGRAM_BOUNDARIES,
136+
)
137+
except TypeError:
138+
# Older OTel/logfire versions don't support explicit_bucket_boundaries_advisory
139+
self.tokens_histogram = self.meter.create_histogram(
140+
**tokens_histogram_kwargs, # pyright: ignore
141+
)
142+
113143
def messages_to_otel_events(self, messages: list[ModelMessage]) -> list[Event]:
114144
"""Convert a list of model messages to OpenTelemetry events.
115145
@@ -224,38 +254,74 @@ def _instrument(
224254
if isinstance(value := model_settings.get(key), (float, int)):
225255
attributes[f'gen_ai.request.{key}'] = value
226256

227-
with self.settings.tracer.start_as_current_span(span_name, attributes=attributes) as span:
228-
229-
def finish(response: ModelResponse):
230-
if not span.is_recording():
231-
return
232-
233-
events = self.settings.messages_to_otel_events(messages)
234-
for event in self.settings.messages_to_otel_events([response]):
235-
events.append(
236-
Event(
237-
'gen_ai.choice',
238-
body={
239-
# TODO finish_reason
240-
'index': 0,
241-
'message': event.body,
242-
},
257+
record_metrics: Callable[[], None] | None = None
258+
try:
259+
with self.settings.tracer.start_as_current_span(span_name, attributes=attributes) as span:
260+
261+
def finish(response: ModelResponse):
262+
# FallbackModel updates these span attributes.
263+
attributes.update(getattr(span, 'attributes', {}))
264+
request_model = attributes[GEN_AI_REQUEST_MODEL_ATTRIBUTE]
265+
system = attributes[GEN_AI_SYSTEM_ATTRIBUTE]
266+
267+
response_model = response.model_name or request_model
268+
269+
def _record_metrics():
270+
metric_attributes = {
271+
GEN_AI_SYSTEM_ATTRIBUTE: system,
272+
'gen_ai.operation.name': operation,
273+
'gen_ai.request.model': request_model,
274+
'gen_ai.response.model': response_model,
275+
}
276+
if response.usage.request_tokens: # pragma: no branch
277+
self.settings.tokens_histogram.record(
278+
response.usage.request_tokens,
279+
{**metric_attributes, 'gen_ai.token.type': 'input'},
280+
)
281+
if response.usage.response_tokens: # pragma: no branch
282+
self.settings.tokens_histogram.record(
283+
response.usage.response_tokens,
284+
{**metric_attributes, 'gen_ai.token.type': 'output'},
285+
)
286+
287+
nonlocal record_metrics
288+
record_metrics = _record_metrics
289+
290+
if not span.is_recording():
291+
return
292+
293+
events = self.settings.messages_to_otel_events(messages)
294+
for event in self.settings.messages_to_otel_events([response]):
295+
events.append(
296+
Event(
297+
'gen_ai.choice',
298+
body={
299+
# TODO finish_reason
300+
'index': 0,
301+
'message': event.body,
302+
},
303+
)
243304
)
305+
span.set_attributes(
306+
{
307+
**response.usage.opentelemetry_attributes(),
308+
'gen_ai.response.model': response_model,
309+
}
244310
)
245-
new_attributes: dict[str, AttributeValue] = response.usage.opentelemetry_attributes() # pyright: ignore[reportAssignmentType]
246-
attributes.update(getattr(span, 'attributes', {}))
247-
request_model = attributes[GEN_AI_REQUEST_MODEL_ATTRIBUTE]
248-
new_attributes['gen_ai.response.model'] = response.model_name or request_model
249-
span.set_attributes(new_attributes)
250-
span.update_name(f'{operation} {request_model}')
251-
for event in events:
252-
event.attributes = {
253-
GEN_AI_SYSTEM_ATTRIBUTE: attributes[GEN_AI_SYSTEM_ATTRIBUTE],
254-
**(event.attributes or {}),
255-
}
256-
self._emit_events(span, events)
257-
258-
yield finish
311+
span.update_name(f'{operation} {request_model}')
312+
for event in events:
313+
event.attributes = {
314+
GEN_AI_SYSTEM_ATTRIBUTE: system,
315+
**(event.attributes or {}),
316+
}
317+
self._emit_events(span, events)
318+
319+
yield finish
320+
finally:
321+
if record_metrics:
322+
# We only want to record metrics after the span is finished,
323+
# to prevent them from being redundantly recorded in the span itself by logfire.
324+
record_metrics()
259325

260326
def _emit_events(self, span: Span, events: list[Event]) -> None:
261327
if self.settings.event_mode == 'logs':

tests/test_logfire.py

Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Any, Callable
55

66
import pytest
7-
from dirty_equals import IsJson
7+
from dirty_equals import IsInt, IsJson, IsList
88
from inline_snapshot import snapshot
99
from typing_extensions import NotRequired, TypedDict
1010

@@ -71,7 +71,11 @@ def get_summary() -> LogfireSummary:
7171
InstrumentationSettings(event_mode='logs'),
7272
],
7373
)
74-
def test_logfire(get_logfire_summary: Callable[[], LogfireSummary], instrument: InstrumentationSettings | bool) -> None:
74+
def test_logfire(
75+
get_logfire_summary: Callable[[], LogfireSummary],
76+
instrument: InstrumentationSettings | bool,
77+
capfire: CaptureLogfire,
78+
) -> None:
7579
my_agent = Agent(model=TestModel(), instrument=instrument)
7680

7781
@my_agent.tool_plain
@@ -167,6 +171,70 @@ async def my_ret(x: int) -> str:
167171
)
168172
chat_span_attributes = summary.attributes[1]
169173
if instrument is True or instrument.event_mode == 'attributes':
174+
if hasattr(capfire, 'get_collected_metrics'):
175+
assert capfire.get_collected_metrics() == snapshot(
176+
[
177+
{
178+
'name': 'gen_ai.client.token.usage',
179+
'description': 'Measures number of input and output tokens used',
180+
'unit': '{token}',
181+
'data': {
182+
'data_points': [
183+
{
184+
'attributes': {
185+
'gen_ai.system': 'test',
186+
'gen_ai.operation.name': 'chat',
187+
'gen_ai.request.model': 'test',
188+
'gen_ai.response.model': 'test',
189+
'gen_ai.token.type': 'input',
190+
},
191+
'start_time_unix_nano': IsInt(),
192+
'time_unix_nano': IsInt(),
193+
'count': 2,
194+
'sum': 103,
195+
'scale': 12,
196+
'zero_count': 0,
197+
'positive': {
198+
'offset': 23234,
199+
'bucket_counts': IsList(length=...), # type: ignore
200+
},
201+
'negative': {'offset': 0, 'bucket_counts': [0]},
202+
'flags': 0,
203+
'min': 51,
204+
'max': 52,
205+
'exemplars': IsList(length=...), # type: ignore
206+
},
207+
{
208+
'attributes': {
209+
'gen_ai.system': 'test',
210+
'gen_ai.operation.name': 'chat',
211+
'gen_ai.request.model': 'test',
212+
'gen_ai.response.model': 'test',
213+
'gen_ai.token.type': 'output',
214+
},
215+
'start_time_unix_nano': IsInt(),
216+
'time_unix_nano': IsInt(),
217+
'count': 2,
218+
'sum': 12,
219+
'scale': 7,
220+
'zero_count': 0,
221+
'positive': {
222+
'offset': 255,
223+
'bucket_counts': IsList(length=...), # type: ignore
224+
},
225+
'negative': {'offset': 0, 'bucket_counts': [0]},
226+
'flags': 0,
227+
'min': 4,
228+
'max': 8,
229+
'exemplars': IsList(length=...), # type: ignore
230+
},
231+
],
232+
'aggregation_temporality': 1,
233+
},
234+
}
235+
]
236+
)
237+
170238
attribute_mode_attributes = {k: chat_span_attributes.pop(k) for k in ['events']}
171239
assert attribute_mode_attributes == snapshot(
172240
{
@@ -450,8 +518,7 @@ async def test_feedback(capfire: CaptureLogfire) -> None:
450518
'factuality': 0.1,
451519
'foo': 'bar',
452520
'logfire.feedback.comment': 'the agent lied',
453-
'logfire.disable_console_log': True,
454-
'logfire.json_schema': '{"type":"object","properties":{"logfire.feedback.name":{},"factuality":{},"foo":{},"logfire.feedback.comment":{},"logfire.span_type":{},"logfire.disable_console_log":{}}}',
521+
'logfire.json_schema': '{"type":"object","properties":{"logfire.feedback.name":{},"factuality":{},"foo":{},"logfire.feedback.comment":{},"logfire.span_type":{}}}',
455522
},
456523
},
457524
]

uv.lock

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)