Skip to content

Commit 83d2ae4

Browse files
authored
OpenAI trace fixes (#934)
* Don't create fake parent span, use actual processor base class * Remove print
1 parent 47027fc commit 83d2ae4

File tree

3 files changed

+35
-20
lines changed

3 files changed

+35
-20
lines changed

temporalio/contrib/openai_agents/_temporal_trace_provider.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
from agents.tracing import (
88
get_trace_provider,
99
)
10-
from agents.tracing.provider import DefaultTraceProvider
10+
from agents.tracing.provider import (
11+
DefaultTraceProvider,
12+
SynchronousMultiTracingProcessor,
13+
)
1114
from agents.tracing.spans import Span
1215

1316
from temporalio import workflow
@@ -72,22 +75,35 @@ def activity_span(
7275
)
7376

7477

75-
class _TemporalTracingProcessor(TracingProcessor):
76-
def __init__(self, impl: TracingProcessor):
78+
class _TemporalTracingProcessor(SynchronousMultiTracingProcessor):
79+
def __init__(
80+
self, impl: SynchronousMultiTracingProcessor, auto_close_in_workflows: bool
81+
):
7782
super().__init__()
7883
self._impl = impl
84+
self._auto_close_in_workflows = auto_close_in_workflows
85+
86+
def add_tracing_processor(self, tracing_processor: TracingProcessor):
87+
self._impl.add_tracing_processor(tracing_processor)
88+
89+
def set_processors(self, processors: list[TracingProcessor]):
90+
self._impl.set_processors(processors)
7991

8092
def on_trace_start(self, trace: Trace) -> None:
8193
if workflow.in_workflow() and workflow.unsafe.is_replaying():
8294
# In replay mode, don't report
8395
return
8496

8597
self._impl.on_trace_start(trace)
98+
if self._auto_close_in_workflows and workflow.in_workflow():
99+
self._impl.on_trace_end(trace)
86100

87101
def on_trace_end(self, trace: Trace) -> None:
88102
if workflow.in_workflow() and workflow.unsafe.is_replaying():
89103
# In replay mode, don't report
90104
return
105+
if self._auto_close_in_workflows and workflow.in_workflow():
106+
return
91107

92108
self._impl.on_trace_end(trace)
93109

@@ -97,11 +113,16 @@ def on_span_start(self, span: Span[Any]) -> None:
97113
return
98114

99115
self._impl.on_span_start(span)
116+
if self._auto_close_in_workflows and workflow.in_workflow():
117+
self._impl.on_span_end(span)
100118

101119
def on_span_end(self, span: Span[Any]) -> None:
102120
if workflow.in_workflow() and workflow.unsafe.is_replaying():
103121
# In replay mode, don't report
104122
return
123+
if self._auto_close_in_workflows and workflow.in_workflow():
124+
return
125+
105126
self._impl.on_span_end(span)
106127

107128
def shutdown(self) -> None:
@@ -114,12 +135,13 @@ def force_flush(self) -> None:
114135
class TemporalTraceProvider(DefaultTraceProvider):
115136
"""A trace provider that integrates with Temporal workflows."""
116137

117-
def __init__(self):
138+
def __init__(self, auto_close_in_workflows: bool = False):
118139
"""Initialize the TemporalTraceProvider."""
119140
super().__init__()
120141
self._original_provider = cast(DefaultTraceProvider, get_trace_provider())
121-
self._multi_processor = _TemporalTracingProcessor( # type: ignore[assignment]
122-
self._original_provider._multi_processor
142+
self._multi_processor = _TemporalTracingProcessor(
143+
self._original_provider._multi_processor,
144+
auto_close_in_workflows,
123145
)
124146

125147
def time_iso(self) -> str:

temporalio/contrib/openai_agents/temporal_openai_agents.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
@contextmanager
2222
def set_open_ai_agent_temporal_overrides(
2323
model_params: ModelActivityParameters,
24+
auto_close_tracing_in_workflows: bool = False,
2425
):
2526
"""Configure Temporal-specific overrides for OpenAI agents.
2627
@@ -55,7 +56,9 @@ def set_open_ai_agent_temporal_overrides(
5556

5657
previous_runner = get_default_agent_runner()
5758
previous_trace_provider = get_trace_provider()
58-
provider = TemporalTraceProvider()
59+
provider = TemporalTraceProvider(
60+
auto_close_in_workflows=auto_close_tracing_in_workflows
61+
)
5962

6063
try:
6164
set_default_agent_runner(TemporalOpenAIRunner(model_params))

temporalio/contrib/openai_agents/trace_interceptor.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -60,16 +60,6 @@ def context_from_header(
6060
if span_info is None:
6161
yield
6262
else:
63-
span = SpanImpl(
64-
trace_id=str(span_info["traceId"]),
65-
span_id=span_info["spanId"],
66-
parent_id=None,
67-
span_data=CustomSpanData(
68-
name="Parent Temporal Span",
69-
data={},
70-
),
71-
processor=cast(DefaultTraceProvider, get_trace_provider())._multi_processor,
72-
)
7363
workflow_type = (
7464
activity.info().workflow_type
7565
if activity.in_activity()
@@ -94,11 +84,11 @@ def context_from_header(
9484
span_info["traceName"],
9585
trace_id=span_info["traceId"],
9686
metadata=metadata,
97-
):
98-
with custom_span(name=span_name, parent=span, data=data):
87+
) as t:
88+
with custom_span(name=span_name, parent=t, data=data):
9989
yield
10090
else:
101-
with custom_span(name=span_name, parent=span, data=data):
91+
with custom_span(name=span_name, parent=None, data=data):
10292
yield
10393

10494

0 commit comments

Comments
 (0)