7
7
from agents .tracing import (
8
8
get_trace_provider ,
9
9
)
10
- from agents .tracing .provider import DefaultTraceProvider
10
+ from agents .tracing .provider import (
11
+ DefaultTraceProvider ,
12
+ SynchronousMultiTracingProcessor ,
13
+ )
11
14
from agents .tracing .spans import Span
12
15
13
16
from temporalio import workflow
@@ -72,22 +75,35 @@ def activity_span(
72
75
)
73
76
74
77
75
- class _TemporalTracingProcessor (TracingProcessor ):
76
- def __init__ (self , impl : TracingProcessor ):
78
+ class _TemporalTracingProcessor (SynchronousMultiTracingProcessor ):
79
+ def __init__ (
80
+ self , impl : SynchronousMultiTracingProcessor , auto_close_in_workflows : bool
81
+ ):
77
82
super ().__init__ ()
78
83
self ._impl = impl
84
+ self ._auto_close_in_workflows = auto_close_in_workflows
85
+
86
+ def add_tracing_processor (self , tracing_processor : TracingProcessor ):
87
+ self ._impl .add_tracing_processor (tracing_processor )
88
+
89
+ def set_processors (self , processors : list [TracingProcessor ]):
90
+ self ._impl .set_processors (processors )
79
91
80
92
def on_trace_start (self , trace : Trace ) -> None :
81
93
if workflow .in_workflow () and workflow .unsafe .is_replaying ():
82
94
# In replay mode, don't report
83
95
return
84
96
85
97
self ._impl .on_trace_start (trace )
98
+ if self ._auto_close_in_workflows and workflow .in_workflow ():
99
+ self ._impl .on_trace_end (trace )
86
100
87
101
def on_trace_end (self , trace : Trace ) -> None :
88
102
if workflow .in_workflow () and workflow .unsafe .is_replaying ():
89
103
# In replay mode, don't report
90
104
return
105
+ if self ._auto_close_in_workflows and workflow .in_workflow ():
106
+ return
91
107
92
108
self ._impl .on_trace_end (trace )
93
109
@@ -97,11 +113,16 @@ def on_span_start(self, span: Span[Any]) -> None:
97
113
return
98
114
99
115
self ._impl .on_span_start (span )
116
+ if self ._auto_close_in_workflows and workflow .in_workflow ():
117
+ self ._impl .on_span_end (span )
100
118
101
119
def on_span_end (self , span : Span [Any ]) -> None :
102
120
if workflow .in_workflow () and workflow .unsafe .is_replaying ():
103
121
# In replay mode, don't report
104
122
return
123
+ if self ._auto_close_in_workflows and workflow .in_workflow ():
124
+ return
125
+
105
126
self ._impl .on_span_end (span )
106
127
107
128
def shutdown (self ) -> None :
@@ -114,12 +135,13 @@ def force_flush(self) -> None:
114
135
class TemporalTraceProvider (DefaultTraceProvider ):
115
136
"""A trace provider that integrates with Temporal workflows."""
116
137
117
- def __init__ (self ):
138
+ def __init__ (self , auto_close_in_workflows : bool = False ):
118
139
"""Initialize the TemporalTraceProvider."""
119
140
super ().__init__ ()
120
141
self ._original_provider = cast (DefaultTraceProvider , get_trace_provider ())
121
- self ._multi_processor = _TemporalTracingProcessor ( # type: ignore[assignment]
122
- self ._original_provider ._multi_processor
142
+ self ._multi_processor = _TemporalTracingProcessor (
143
+ self ._original_provider ._multi_processor ,
144
+ auto_close_in_workflows ,
123
145
)
124
146
125
147
def time_iso (self ) -> str :
0 commit comments