15
15
from typing_inspection import typing_objects
16
16
17
17
from . import _utils , exceptions , mermaid
18
- from ._utils import AbstractSpan
18
+ from ._utils import AbstractSpan , get_traceparent
19
19
from .nodes import BaseNode , DepsT , End , GraphRunContext , NodeDef , RunEndT , StateT
20
20
from .persistence import BaseStatePersistence
21
21
from .persistence .in_mem import SimpleStatePersistence
@@ -257,8 +257,14 @@ async def iter(
257
257
entered_span = stack .enter_context (logfire_api .span ('run graph {graph.name}' , graph = self ))
258
258
else :
259
259
entered_span = stack .enter_context (span )
260
+ traceparent = None if entered_span is None else get_traceparent (entered_span )
260
261
yield GraphRun [StateT , DepsT , RunEndT ](
261
- graph = self , start_node = start_node , persistence = persistence , state = state , deps = deps , span = entered_span
262
+ graph = self ,
263
+ start_node = start_node ,
264
+ persistence = persistence ,
265
+ state = state ,
266
+ deps = deps ,
267
+ traceparent = traceparent ,
262
268
)
263
269
264
270
@asynccontextmanager
@@ -301,14 +307,15 @@ async def iter_from_persistence(
301
307
302
308
with ExitStack () as stack :
303
309
entered_span = None if span is None else stack .enter_context (span )
310
+ traceparent = None if entered_span is None else get_traceparent (entered_span )
304
311
yield GraphRun [StateT , DepsT , RunEndT ](
305
312
graph = self ,
306
313
start_node = snapshot .node ,
307
314
persistence = persistence ,
308
315
state = snapshot .state ,
309
316
deps = deps ,
310
317
snapshot_id = snapshot .id ,
311
- span = entered_span ,
318
+ traceparent = traceparent ,
312
319
)
313
320
314
321
async def initialize (
@@ -369,7 +376,7 @@ async def next(
369
376
persistence = persistence ,
370
377
state = state ,
371
378
deps = deps ,
372
- span = None ,
379
+ traceparent = None ,
373
380
)
374
381
return await run .next (node )
375
382
@@ -644,7 +651,7 @@ def __init__(
644
651
persistence : BaseStatePersistence [StateT , RunEndT ],
645
652
state : StateT ,
646
653
deps : DepsT ,
647
- span : AbstractSpan | None ,
654
+ traceparent : str | None ,
648
655
snapshot_id : str | None = None ,
649
656
):
650
657
"""Create a new run for a given graph, starting at the specified node.
@@ -659,7 +666,7 @@ def __init__(
659
666
to all nodes via `ctx.state`.
660
667
deps: Optional dependencies that each node can access via `ctx.deps`, e.g. database connections,
661
668
configuration, or logging clients.
662
- span : The span used for the graph run.
669
+ traceparent : The traceparent for the span used for the graph run.
663
670
snapshot_id: The ID of the snapshot the node came from.
664
671
"""
665
672
self .graph = graph
@@ -668,18 +675,18 @@ def __init__(
668
675
self .state = state
669
676
self .deps = deps
670
677
671
- self .__span = span
678
+ self .__traceparent = traceparent
672
679
self ._next_node : BaseNode [StateT , DepsT , RunEndT ] | End [RunEndT ] = start_node
673
680
self ._is_started : bool = False
674
681
675
682
@overload
676
- def _span (self , * , required : typing_extensions .Literal [False ]) -> AbstractSpan | None : ...
683
+ def _traceparent (self , * , required : typing_extensions .Literal [False ]) -> str | None : ...
677
684
@overload
678
- def _span (self ) -> AbstractSpan : ...
679
- def _span (self , * , required : bool = True ) -> AbstractSpan | None :
680
- if self .__span is None and required : # pragma: no cover
681
- raise exceptions .GraphRuntimeError ('No span available for this graph run. ' )
682
- return self .__span
685
+ def _traceparent (self ) -> str : ...
686
+ def _traceparent (self , * , required : bool = True ) -> str | None :
687
+ if self .__traceparent is None and required : # pragma: no cover
688
+ raise exceptions .GraphRuntimeError ('No span was created for this graph run' )
689
+ return self .__traceparent
683
690
684
691
@property
685
692
def next_node (self ) -> BaseNode [StateT , DepsT , RunEndT ] | End [RunEndT ]:
@@ -695,7 +702,10 @@ def result(self) -> GraphRunResult[StateT, RunEndT] | None:
695
702
if not isinstance (self ._next_node , End ):
696
703
return None # The GraphRun has not finished running
697
704
return GraphRunResult [StateT , RunEndT ](
698
- self ._next_node .data , state = self .state , persistence = self .persistence , span = self ._span (required = False )
705
+ self ._next_node .data ,
706
+ state = self .state ,
707
+ persistence = self .persistence ,
708
+ traceparent = self ._traceparent (required = False ),
699
709
)
700
710
701
711
async def next (
@@ -816,18 +826,18 @@ def __init__(
816
826
output : RunEndT ,
817
827
state : StateT ,
818
828
persistence : BaseStatePersistence [StateT , RunEndT ],
819
- span : AbstractSpan | None = None ,
829
+ traceparent : str | None = None ,
820
830
):
821
831
self .output = output
822
832
self .state = state
823
833
self .persistence = persistence
824
- self .__span = span
834
+ self .__traceparent = traceparent
825
835
826
836
@overload
827
- def _span (self , * , required : typing_extensions .Literal [False ]) -> AbstractSpan | None : ...
837
+ def _traceparent (self , * , required : typing_extensions .Literal [False ]) -> str | None : ...
828
838
@overload
829
- def _span (self ) -> AbstractSpan : ...
830
- def _span (self , * , required : bool = True ) -> AbstractSpan | None : # pragma: no cover
831
- if self .__span is None and required :
832
- raise exceptions .GraphRuntimeError ('No span available for this graph run.' )
833
- return self .__span
839
+ def _traceparent (self ) -> str : ...
840
+ def _traceparent (self , * , required : bool = True ) -> str | None : # pragma: no cover
841
+ if self .__traceparent is None and required :
842
+ raise exceptions .GraphRuntimeError ('No span was created for this graph run.' )
843
+ return self .__traceparent
0 commit comments