Skip to content

Commit 7487ab4

Browse files
dmontaguDouweM
andauthored
Make agent and graph runs serializable again (#1545)
Co-authored-by: Douwe Maan <hi@douwe.me>
1 parent 8679aa5 commit 7487ab4

File tree

5 files changed

+96
-51
lines changed

5 files changed

+96
-51
lines changed

pydantic_ai_slim/pydantic_ai/_utils.py

+1-10
Original file line numberDiff line numberDiff line change
@@ -291,13 +291,4 @@ async def __anext__(self) -> T:
291291

292292

293293
def get_traceparent(x: AgentRun | AgentRunResult | GraphRun | GraphRunResult) -> str:
294-
import logfire
295-
import logfire_api
296-
from logfire.experimental.annotations import get_traceparent
297-
298-
span: AbstractSpan | None = x._span(required=False) # type: ignore[reportPrivateUsage]
299-
if not span: # pragma: no cover
300-
return ''
301-
if isinstance(span, logfire_api.LogfireSpan): # pragma: no cover
302-
assert isinstance(span, logfire.LogfireSpan)
303-
return get_traceparent(span)
294+
return x._traceparent(required=False) or '' # type: ignore[reportPrivateUsage]

pydantic_ai_slim/pydantic_ai/agent.py

+15-16
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
result,
2828
usage as _usage,
2929
)
30-
from ._utils import AbstractSpan
3130
from .models.instrumented import InstrumentationSettings, InstrumentedModel
3231
from .result import FinalResult, OutputDataT, StreamedRunResult, ToolOutput
3332
from .settings import ModelSettings, merge_model_settings
@@ -1683,14 +1682,14 @@ async def main():
16831682
]
16841683

16851684
@overload
1686-
def _span(self, *, required: Literal[False]) -> AbstractSpan | None: ...
1685+
def _traceparent(self, *, required: Literal[False]) -> str | None: ...
16871686
@overload
1688-
def _span(self) -> AbstractSpan: ...
1689-
def _span(self, *, required: bool = True) -> AbstractSpan | None:
1690-
span = self._graph_run._span(required=False) # type: ignore[reportPrivateUsage]
1691-
if span is None and required: # pragma: no cover
1692-
raise AttributeError('Span is not available for this agent run')
1693-
return span
1687+
def _traceparent(self) -> str: ...
1688+
def _traceparent(self, *, required: bool = True) -> str | None:
1689+
traceparent = self._graph_run._traceparent(required=False) # type: ignore[reportPrivateUsage]
1690+
if traceparent is None and required: # pragma: no cover
1691+
raise AttributeError('No span was created for this agent run')
1692+
return traceparent
16941693

16951694
@property
16961695
def ctx(self) -> GraphRunContext[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any]]:
@@ -1729,7 +1728,7 @@ def result(self) -> AgentRunResult[OutputDataT] | None:
17291728
graph_run_result.output.tool_name,
17301729
graph_run_result.state,
17311730
self._graph_run.deps.new_message_index,
1732-
self._graph_run._span(required=False), # type: ignore[reportPrivateUsage]
1731+
self._traceparent(required=False),
17331732
)
17341733

17351734
def __aiter__(
@@ -1847,16 +1846,16 @@ class AgentRunResult(Generic[OutputDataT]):
18471846
_output_tool_name: str | None = dataclasses.field(repr=False)
18481847
_state: _agent_graph.GraphAgentState = dataclasses.field(repr=False)
18491848
_new_message_index: int = dataclasses.field(repr=False)
1850-
_span_value: AbstractSpan | None = dataclasses.field(repr=False)
1849+
_traceparent_value: str | None = dataclasses.field(repr=False)
18511850

18521851
@overload
1853-
def _span(self, *, required: Literal[False]) -> AbstractSpan | None: ...
1852+
def _traceparent(self, *, required: Literal[False]) -> str | None: ...
18541853
@overload
1855-
def _span(self) -> AbstractSpan: ...
1856-
def _span(self, *, required: bool = True) -> AbstractSpan | None:
1857-
if self._span_value is None and required: # pragma: no cover
1858-
raise AttributeError('Span is not available for this agent run')
1859-
return self._span_value
1854+
def _traceparent(self) -> str: ...
1855+
def _traceparent(self, *, required: bool = True) -> str | None:
1856+
if self._traceparent_value is None and required: # pragma: no cover
1857+
raise AttributeError('No span was created for this agent run')
1858+
return self._traceparent_value
18601859

18611860
@property
18621861
@deprecated('`result.data` is deprecated, use `result.output` instead.')

pydantic_graph/pydantic_graph/_utils.py

+33-2
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,47 @@
33
import asyncio
44
import types
55
from functools import partial
6-
from typing import Any, Callable, TypeVar
6+
from typing import TYPE_CHECKING, Any, Callable, TypeVar
77

88
from logfire_api import LogfireSpan
9-
from opentelemetry.trace import Span
109
from typing_extensions import ParamSpec, TypeAlias, TypeIs, get_args, get_origin
1110
from typing_inspection import typing_objects
1211
from typing_inspection.introspection import is_union_origin
1312

13+
if TYPE_CHECKING:
14+
from opentelemetry.trace import Span
15+
16+
1417
AbstractSpan: TypeAlias = 'LogfireSpan | Span'
1518

19+
try:
20+
from opentelemetry.trace import Span, set_span_in_context
21+
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
22+
23+
TRACEPARENT_PROPAGATOR = TraceContextTextMapPropagator()
24+
TRACEPARENT_NAME = 'traceparent'
25+
assert TRACEPARENT_NAME in TRACEPARENT_PROPAGATOR.fields
26+
27+
# Logic taken from logfire.experimental.annotations
28+
def get_traceparent(span: AbstractSpan) -> str | None:
29+
"""Get a string representing the span context to use for annotating spans."""
30+
real_span: Span
31+
if isinstance(span, Span):
32+
real_span = span
33+
else:
34+
real_span = span._span
35+
assert real_span
36+
context = set_span_in_context(real_span)
37+
carrier: dict[str, Any] = {}
38+
TRACEPARENT_PROPAGATOR.inject(carrier, context)
39+
return carrier.get(TRACEPARENT_NAME, '')
40+
41+
except ImportError: # pragma: no cover
42+
43+
def get_traceparent(span: AbstractSpan) -> str | None:
44+
# Opentelemetry wasn't installed, so we can't get the traceparent
45+
return None
46+
1647

1748
def get_event_loop():
1849
try:

pydantic_graph/pydantic_graph/graph.py

+32-22
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from typing_inspection import typing_objects
1616

1717
from . import _utils, exceptions, mermaid
18-
from ._utils import AbstractSpan
18+
from ._utils import AbstractSpan, get_traceparent
1919
from .nodes import BaseNode, DepsT, End, GraphRunContext, NodeDef, RunEndT, StateT
2020
from .persistence import BaseStatePersistence
2121
from .persistence.in_mem import SimpleStatePersistence
@@ -257,8 +257,14 @@ async def iter(
257257
entered_span = stack.enter_context(logfire_api.span('run graph {graph.name}', graph=self))
258258
else:
259259
entered_span = stack.enter_context(span)
260+
traceparent = None if entered_span is None else get_traceparent(entered_span)
260261
yield GraphRun[StateT, DepsT, RunEndT](
261-
graph=self, start_node=start_node, persistence=persistence, state=state, deps=deps, span=entered_span
262+
graph=self,
263+
start_node=start_node,
264+
persistence=persistence,
265+
state=state,
266+
deps=deps,
267+
traceparent=traceparent,
262268
)
263269

264270
@asynccontextmanager
@@ -301,14 +307,15 @@ async def iter_from_persistence(
301307

302308
with ExitStack() as stack:
303309
entered_span = None if span is None else stack.enter_context(span)
310+
traceparent = None if entered_span is None else get_traceparent(entered_span)
304311
yield GraphRun[StateT, DepsT, RunEndT](
305312
graph=self,
306313
start_node=snapshot.node,
307314
persistence=persistence,
308315
state=snapshot.state,
309316
deps=deps,
310317
snapshot_id=snapshot.id,
311-
span=entered_span,
318+
traceparent=traceparent,
312319
)
313320

314321
async def initialize(
@@ -369,7 +376,7 @@ async def next(
369376
persistence=persistence,
370377
state=state,
371378
deps=deps,
372-
span=None,
379+
traceparent=None,
373380
)
374381
return await run.next(node)
375382

@@ -644,7 +651,7 @@ def __init__(
644651
persistence: BaseStatePersistence[StateT, RunEndT],
645652
state: StateT,
646653
deps: DepsT,
647-
span: AbstractSpan | None,
654+
traceparent: str | None,
648655
snapshot_id: str | None = None,
649656
):
650657
"""Create a new run for a given graph, starting at the specified node.
@@ -659,7 +666,7 @@ def __init__(
659666
to all nodes via `ctx.state`.
660667
deps: Optional dependencies that each node can access via `ctx.deps`, e.g. database connections,
661668
configuration, or logging clients.
662-
span: The span used for the graph run.
669+
traceparent: The traceparent for the span used for the graph run.
663670
snapshot_id: The ID of the snapshot the node came from.
664671
"""
665672
self.graph = graph
@@ -668,18 +675,18 @@ def __init__(
668675
self.state = state
669676
self.deps = deps
670677

671-
self.__span = span
678+
self.__traceparent = traceparent
672679
self._next_node: BaseNode[StateT, DepsT, RunEndT] | End[RunEndT] = start_node
673680
self._is_started: bool = False
674681

675682
@overload
676-
def _span(self, *, required: typing_extensions.Literal[False]) -> AbstractSpan | None: ...
683+
def _traceparent(self, *, required: typing_extensions.Literal[False]) -> str | None: ...
677684
@overload
678-
def _span(self) -> AbstractSpan: ...
679-
def _span(self, *, required: bool = True) -> AbstractSpan | None:
680-
if self.__span is None and required: # pragma: no cover
681-
raise exceptions.GraphRuntimeError('No span available for this graph run.')
682-
return self.__span
685+
def _traceparent(self) -> str: ...
686+
def _traceparent(self, *, required: bool = True) -> str | None:
687+
if self.__traceparent is None and required: # pragma: no cover
688+
raise exceptions.GraphRuntimeError('No span was created for this graph run')
689+
return self.__traceparent
683690

684691
@property
685692
def next_node(self) -> BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]:
@@ -695,7 +702,10 @@ def result(self) -> GraphRunResult[StateT, RunEndT] | None:
695702
if not isinstance(self._next_node, End):
696703
return None # The GraphRun has not finished running
697704
return GraphRunResult[StateT, RunEndT](
698-
self._next_node.data, state=self.state, persistence=self.persistence, span=self._span(required=False)
705+
self._next_node.data,
706+
state=self.state,
707+
persistence=self.persistence,
708+
traceparent=self._traceparent(required=False),
699709
)
700710

701711
async def next(
@@ -816,18 +826,18 @@ def __init__(
816826
output: RunEndT,
817827
state: StateT,
818828
persistence: BaseStatePersistence[StateT, RunEndT],
819-
span: AbstractSpan | None = None,
829+
traceparent: str | None = None,
820830
):
821831
self.output = output
822832
self.state = state
823833
self.persistence = persistence
824-
self.__span = span
834+
self.__traceparent = traceparent
825835

826836
@overload
827-
def _span(self, *, required: typing_extensions.Literal[False]) -> AbstractSpan | None: ...
837+
def _traceparent(self, *, required: typing_extensions.Literal[False]) -> str | None: ...
828838
@overload
829-
def _span(self) -> AbstractSpan: ...
830-
def _span(self, *, required: bool = True) -> AbstractSpan | None: # pragma: no cover
831-
if self.__span is None and required:
832-
raise exceptions.GraphRuntimeError('No span available for this graph run.')
833-
return self.__span
839+
def _traceparent(self) -> str: ...
840+
def _traceparent(self, *, required: bool = True) -> str | None: # pragma: no cover
841+
if self.__traceparent is None and required:
842+
raise exceptions.GraphRuntimeError('No span was created for this graph run.')
843+
return self.__traceparent

tests/test_agent.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88
import pytest
99
from dirty_equals import IsJson
1010
from inline_snapshot import snapshot
11-
from pydantic import BaseModel, field_validator
11+
from pydantic import BaseModel, TypeAdapter, field_validator
1212
from pydantic_core import to_json
1313

1414
from pydantic_ai import Agent, ModelRetry, RunContext, UnexpectedModelBehavior, UserError, capture_run_messages
15+
from pydantic_ai.agent import AgentRunResult
1516
from pydantic_ai.messages import (
1617
BinaryContent,
1718
ModelMessage,
@@ -1869,3 +1870,16 @@ def my_tool(x: int) -> int:
18691870
ModelResponse(parts=[], model_name='function:llm:', timestamp=IsNow(tz=timezone.utc)),
18701871
]
18711872
)
1873+
1874+
1875+
def test_agent_run_result_serialization() -> None:
1876+
agent = Agent('test', output_type=Foo)
1877+
result = agent.run_sync('Hello')
1878+
1879+
# Check that dump_json doesn't raise an error
1880+
adapter = TypeAdapter(AgentRunResult[Foo])
1881+
serialized_data = adapter.dump_json(result)
1882+
1883+
# Check that we can load the data back
1884+
deserialized_result = adapter.validate_json(serialized_data)
1885+
assert deserialized_result == result

0 commit comments

Comments
 (0)