Skip to content
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions temporalio/contrib/opentelemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@
import opentelemetry.trace.propagation.tracecontext
import opentelemetry.util.types
from opentelemetry.context import Context
from opentelemetry.trace import Span, SpanKind, Status, StatusCode, _Links
from opentelemetry.util import types
from opentelemetry.trace import Status, StatusCode
from typing_extensions import Protocol, TypeAlias, TypedDict

import temporalio.activity
Expand Down Expand Up @@ -473,7 +472,12 @@ async def handle_query(self, input: temporalio.worker.HandleQueryInput) -> Any:
)
return await super().handle_query(input)
finally:
opentelemetry.context.detach(token)
# In some exceptional cases this finally is executed with a
# different contextvars.Context than the one the token was created
# on. As such we do a best effort detach to avoid using a mismatched
# token.
if context is opentelemetry.context.get_current():
opentelemetry.context.detach(token)

def handle_update_validator(
self, input: temporalio.worker.HandleUpdateInput
Expand Down Expand Up @@ -545,6 +549,7 @@ def _top_level_workflow_context(
exception: Optional[Exception] = None
# Run under this context
token = opentelemetry.context.attach(context)

try:
yield None
success = True
Expand All @@ -561,7 +566,13 @@ def _top_level_workflow_context(
exception=exception,
kind=opentelemetry.trace.SpanKind.INTERNAL,
)
opentelemetry.context.detach(token)

# In some exceptional cases this finally is executed with a
# different contextvars.Context than the one the token was created
# on. As such we do a best effort detach to avoid using a mismatched
# token.
if context is opentelemetry.context.get_current():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice

opentelemetry.context.detach(token)

def _context_to_headers(
self, headers: Mapping[str, temporalio.api.common.v1.Payload]
Expand Down
75 changes: 73 additions & 2 deletions tests/contrib/test_opentelemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

import asyncio
import logging
import sys
import uuid
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from datetime import timedelta
from typing import Iterable, List, Optional

import opentelemetry.context
from opentelemetry.sdk.trace import ReadableSpan, TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
Expand All @@ -21,6 +23,12 @@
from temporalio.exceptions import ApplicationError, ApplicationErrorCategory
from temporalio.testing import WorkflowEnvironment
from temporalio.worker import UnsandboxedWorkflowRunner, Worker
from tests.helpers import LogCapturer
from tests.helpers.cache_evitction import (
CacheEvictionTearDownWorkflow,
WaitForeverWorkflow,
wait_forever_activity,
)

# Passing through because Python 3.9 has an import bug at
# https://github.com/python/cpython/issues/91351
Expand Down Expand Up @@ -424,7 +432,10 @@ def dump_spans(
span_links: List[str] = []
for link in span.links:
for link_span in spans:
if link_span.context.span_id == link.context.span_id:
if (
link_span.context is not None
and link_span.context.span_id == link.context.span_id
):
span_links.append(link_span.name)
span_str += f" (links: {', '.join(span_links)})"
# Signals can duplicate in rare situations, so we make sure not to
Expand All @@ -434,7 +445,7 @@ def dump_spans(
ret.append(span_str)
ret += dump_spans(
spans,
parent_id=span.context.span_id,
parent_id=span.context.span_id if span.context else None,
with_attributes=with_attributes,
indent_depth=indent_depth + 1,
)
Expand Down Expand Up @@ -551,3 +562,63 @@ async def test_opentelemetry_benign_exception(client: Client):
# * workflow failure and wft failure
# * signal with start
# * signal failure and wft failure from signal


async def test_opentelemetry_safe_detach(client: Client):
# This test simulates forcing eviction. This purposely raises GeneratorExit on
# GC which triggers the finally which could run on any thread Python
# chooses. When this occurs, we should not detach the token from the context
# b/c the context no longer exists

# Create a tracer that has an in-memory exporter
exporter = InMemorySpanExporter()
provider = TracerProvider()
provider.add_span_processor(SimpleSpanProcessor(exporter))
tracer = get_tracer(__name__, tracer_provider=provider)

async with Worker(
client,
workflows=[CacheEvictionTearDownWorkflow, WaitForeverWorkflow],
activities=[wait_forever_activity],
max_cached_workflows=0,
task_queue=f"task_queue_{uuid.uuid4()}",
disable_safe_workflow_eviction=True,
interceptors=[TracingInterceptor(tracer)],
) as worker:
# Put a hook to catch unraisable exceptions
old_hook = sys.unraisablehook
hook_calls: List[sys.UnraisableHookArgs] = []
sys.unraisablehook = hook_calls.append

with LogCapturer().logs_captured(opentelemetry.context.logger) as capturer:
try:
handle = await client.start_workflow(
CacheEvictionTearDownWorkflow.run,
id=f"wf-{uuid.uuid4()}",
task_queue=worker.task_queue,
)

# CacheEvictionTearDownWorkflow requires 3 signals to be sent
await handle.signal(CacheEvictionTearDownWorkflow.signal)
await handle.signal(CacheEvictionTearDownWorkflow.signal)
await handle.signal(CacheEvictionTearDownWorkflow.signal)

await handle.result()
finally:
sys.unraisablehook = old_hook

# Confirm at least 1 exception
if len(hook_calls) < 1:
logging.warning(
"Expected at least 1 exception. Unable to properly verify context detachment"
)

def otel_context_error(record: logging.LogRecord) -> bool:
return (
record.name == "opentelemetry.context"
and "Failed to detach context" in record.message
)

assert (
capturer.find(otel_context_error) is None
), "Detach from context message should not be logged"
49 changes: 47 additions & 2 deletions tests/helpers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,25 @@
import asyncio
import logging
import logging.handlers
import queue
import socket
import time
import uuid
from contextlib import closing
from contextlib import closing, contextmanager
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Any, Awaitable, Callable, Optional, Sequence, Type, TypeVar, Union
from typing import (
Any,
Awaitable,
Callable,
List,
Optional,
Sequence,
Type,
TypeVar,
Union,
cast,
)

from temporalio.api.common.v1 import WorkflowExecution
from temporalio.api.enums.v1 import EventType as EventType
Expand Down Expand Up @@ -401,3 +415,34 @@ def _format_row(items: list[str], truncate: bool = False) -> str:
padding = len(f" *: {elapsed_ms:>4} ")
summary_row[col_idx] = f"{' ' * padding}[{summary}]"[: col_width - 3]
print(_format_row(summary_row))


class LogCapturer:
def __init__(self) -> None:
self.log_queue: queue.Queue[logging.LogRecord] = queue.Queue()

@contextmanager
def logs_captured(self, *loggers: logging.Logger):
handler = logging.handlers.QueueHandler(self.log_queue)

prev_levels = [l.level for l in loggers]
for l in loggers:
l.setLevel(logging.INFO)
l.addHandler(handler)
try:
yield self
finally:
for i, l in enumerate(loggers):
l.removeHandler(handler)
l.setLevel(prev_levels[i])

def find_log(self, starts_with: str) -> Optional[logging.LogRecord]:
return self.find(lambda l: l.message.startswith(starts_with))

def find(
self, pred: Callable[[logging.LogRecord], bool]
) -> Optional[logging.LogRecord]:
for record in cast(List[logging.LogRecord], self.log_queue.queue):
if pred(record):
return record
return None
68 changes: 68 additions & 0 deletions tests/helpers/cache_evitction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import asyncio
from datetime import timedelta

from temporalio import activity, workflow


@activity.defn
async def wait_forever_activity() -> None:
await asyncio.Future()


@workflow.defn
class WaitForeverWorkflow:
@workflow.run
async def run(self) -> None:
await asyncio.Future()


@workflow.defn
class CacheEvictionTearDownWorkflow:
def __init__(self) -> None:
self._signal_count = 0

@workflow.run
async def run(self) -> None:
# Start several things in background. This is just to show that eviction
# can work even with these things running.
tasks = [
asyncio.create_task(
workflow.execute_activity(
wait_forever_activity, start_to_close_timeout=timedelta(hours=1)
)
),
asyncio.create_task(
workflow.execute_child_workflow(WaitForeverWorkflow.run)
),
asyncio.create_task(asyncio.sleep(1000)),
asyncio.shield(
workflow.execute_activity(
wait_forever_activity, start_to_close_timeout=timedelta(hours=1)
)
),
asyncio.create_task(workflow.wait_condition(lambda: False)),
]
gather_fut = asyncio.gather(*tasks, return_exceptions=True)
# Let's also start something in the background that we never wait on
asyncio.create_task(asyncio.sleep(1000))
try:
# Wait for signal count to reach 2
await asyncio.sleep(0.01)
await workflow.wait_condition(lambda: self._signal_count > 1)
finally:
# This finally, on eviction, is actually called but the command
# should be ignored
await asyncio.sleep(0.01)
await workflow.wait_condition(lambda: self._signal_count > 2)
# Cancel gather tasks and wait on them, but ignore the errors
for task in tasks:
task.cancel()
await gather_fut

@workflow.signal
async def signal(self) -> None:
self._signal_count += 1

@workflow.query
def signal_count(self) -> int:
return self._signal_count
Loading
Loading