|
| 1 | +"""Functions that can be called inside of activities. |
| 2 | +
|
| 3 | +Most of these functions use :py:mod:`contextvars` to obtain the current activity |
| 4 | +in context. This is already set before the start of the activity. Activities |
| 5 | +that make calls that do not automatically propagate the context, such as calls |
| 6 | +in another thread, should not use the calls herein unless the context is |
| 7 | +explicitly propagated. |
| 8 | +""" |
| 9 | + |
| 10 | +from __future__ import annotations |
| 11 | + |
| 12 | +import asyncio |
| 13 | +import contextvars |
| 14 | +import logging |
| 15 | +import threading |
| 16 | +from dataclasses import dataclass |
| 17 | +from datetime import datetime, timedelta |
| 18 | +from typing import ( |
| 19 | + Any, |
| 20 | + Callable, |
| 21 | + Iterable, |
| 22 | + Mapping, |
| 23 | + MutableMapping, |
| 24 | + NoReturn, |
| 25 | + Optional, |
| 26 | + Tuple, |
| 27 | +) |
| 28 | + |
| 29 | +import temporalio.api.common.v1 |
| 30 | +import temporalio.common |
| 31 | +import temporalio.exceptions |
| 32 | + |
| 33 | + |
| 34 | +@dataclass(frozen=True) |
| 35 | +class Info: |
| 36 | + """Information about the running activity. |
| 37 | +
|
| 38 | + Retrieved inside an activity via :py:func:`info`. |
| 39 | + """ |
| 40 | + |
| 41 | + activity_id: str |
| 42 | + activity_type: str |
| 43 | + attempt: int |
| 44 | + current_attempt_scheduled_time: datetime |
| 45 | + header: Mapping[str, temporalio.api.common.v1.Payload] |
| 46 | + heartbeat_details: Iterable[Any] |
| 47 | + heartbeat_timeout: Optional[timedelta] |
| 48 | + is_local: bool |
| 49 | + retry_policy: Optional[temporalio.common.RetryPolicy] |
| 50 | + schedule_to_close_timeout: Optional[timedelta] |
| 51 | + scheduled_time: datetime |
| 52 | + start_to_close_timeout: Optional[timedelta] |
| 53 | + started_time: datetime |
| 54 | + task_queue: str |
| 55 | + task_token: bytes |
| 56 | + workflow_id: str |
| 57 | + workflow_namespace: str |
| 58 | + workflow_run_id: str |
| 59 | + workflow_type: str |
| 60 | + # TODO(cretz): Consider putting identity on here for "worker_id" for logger? |
| 61 | + |
| 62 | + def _logger_details(self) -> Mapping[str, Any]: |
| 63 | + return { |
| 64 | + "activity_id": self.activity_id, |
| 65 | + "activity_type": self.activity_type, |
| 66 | + "attempt": self.attempt, |
| 67 | + "namespace": self.workflow_namespace, |
| 68 | + "task_queue": self.task_queue, |
| 69 | + "workflow_id": self.workflow_id, |
| 70 | + "workflow_run_id": self.workflow_run_id, |
| 71 | + "workflow_type": self.workflow_type, |
| 72 | + } |
| 73 | + |
| 74 | + |
| 75 | +_current_context: contextvars.ContextVar[_Context] = contextvars.ContextVar("activity") |
| 76 | + |
| 77 | + |
| 78 | +@dataclass |
| 79 | +class _Context: |
| 80 | + info: Callable[[], Info] |
| 81 | + # This is optional because during interceptor init it is not present |
| 82 | + heartbeat: Optional[Callable[..., None]] |
| 83 | + cancelled_event: _CompositeEvent |
| 84 | + worker_shutdown_event: _CompositeEvent |
| 85 | + _logger_details: Optional[Mapping[str, Any]] = None |
| 86 | + |
| 87 | + @staticmethod |
| 88 | + def current() -> _Context: |
| 89 | + context = _current_context.get(None) |
| 90 | + if not context: |
| 91 | + raise RuntimeError("Not in activity context") |
| 92 | + return context |
| 93 | + |
| 94 | + @staticmethod |
| 95 | + def set(context: _Context) -> None: |
| 96 | + _current_context.set(context) |
| 97 | + |
| 98 | + @property |
| 99 | + def logger_details(self) -> Mapping[str, Any]: |
| 100 | + if self._logger_details is None: |
| 101 | + self._logger_details = self.info()._logger_details() |
| 102 | + return self._logger_details |
| 103 | + |
| 104 | + |
| 105 | +@dataclass |
| 106 | +class _CompositeEvent: |
| 107 | + # This should always be present, but is sometimes lazily set internally |
| 108 | + thread_event: Optional[threading.Event] |
| 109 | + # Async event only for async activities |
| 110 | + async_event: Optional[asyncio.Event] |
| 111 | + |
| 112 | + def set(self) -> None: |
| 113 | + if not self.thread_event: |
| 114 | + raise RuntimeError("Missing event") |
| 115 | + self.thread_event.set() |
| 116 | + if self.async_event: |
| 117 | + self.async_event.set() |
| 118 | + |
| 119 | + def is_set(self) -> bool: |
| 120 | + if not self.thread_event: |
| 121 | + raise RuntimeError("Missing event") |
| 122 | + return self.thread_event.is_set() |
| 123 | + |
| 124 | + async def wait(self) -> None: |
| 125 | + if not self.async_event: |
| 126 | + raise RuntimeError("not in async activity") |
| 127 | + await self.async_event.wait() |
| 128 | + |
| 129 | + def wait_sync(self, timeout: Optional[float] = None) -> None: |
| 130 | + if not self.thread_event: |
| 131 | + raise RuntimeError("Missing event") |
| 132 | + self.thread_event.wait(timeout) |
| 133 | + |
| 134 | + |
| 135 | +def in_activity() -> bool: |
| 136 | + """Whether the current code is inside an activity. |
| 137 | +
|
| 138 | + Returns: |
| 139 | + True if in an activity, False otherwise. |
| 140 | + """ |
| 141 | + return not _current_context.get(None) is None |
| 142 | + |
| 143 | + |
| 144 | +def info() -> Info: |
| 145 | + """Current activity's info. |
| 146 | +
|
| 147 | + Returns: |
| 148 | + Info for the currently running activity. |
| 149 | +
|
| 150 | + Raises: |
| 151 | + RuntimeError: When not in an activity. |
| 152 | + """ |
| 153 | + return _Context.current().info() |
| 154 | + |
| 155 | + |
| 156 | +def heartbeat(*details: Any) -> None: |
| 157 | + """Send a heartbeat for the current activity. |
| 158 | +
|
| 159 | + Raises: |
| 160 | + RuntimeError: When not in an activity. |
| 161 | + """ |
| 162 | + heartbeat_fn = _Context.current().heartbeat |
| 163 | + if not heartbeat_fn: |
| 164 | + raise RuntimeError("Can only execute heartbeat after interceptor init") |
| 165 | + heartbeat_fn(*details) |
| 166 | + |
| 167 | + |
| 168 | +def is_cancelled() -> bool: |
| 169 | + """Whether a cancellation was ever requested on this activity. |
| 170 | +
|
| 171 | + Returns: |
| 172 | + True if the activity has had a cancellation request, False otherwise. |
| 173 | +
|
| 174 | + Raises: |
| 175 | + RuntimeError: When not in an activity. |
| 176 | + """ |
| 177 | + return _Context.current().cancelled_event.is_set() |
| 178 | + |
| 179 | + |
| 180 | +async def wait_for_cancelled() -> None: |
| 181 | + """Asynchronously wait for this activity to get a cancellation request. |
| 182 | +
|
| 183 | + Raises: |
| 184 | + RuntimeError: When not in an async activity. |
| 185 | + """ |
| 186 | + await _Context.current().cancelled_event.wait() |
| 187 | + |
| 188 | + |
| 189 | +def wait_for_cancelled_sync(timeout: Optional[float] = None) -> None: |
| 190 | + """Synchronously block while waiting for a cancellation request on this |
| 191 | + activity. |
| 192 | +
|
| 193 | + This is essentially a wrapper around :py:meth:`threading.Event.wait`. |
| 194 | +
|
| 195 | + Args: |
| 196 | + timeout: Max amount of time to wait for cancellation. |
| 197 | +
|
| 198 | + Raises: |
| 199 | + RuntimeError: When not in an activity. |
| 200 | + """ |
| 201 | + _Context.current().cancelled_event.wait_sync(timeout) |
| 202 | + |
| 203 | + |
| 204 | +def is_worker_shutdown() -> bool: |
| 205 | + """Whether shutdown has been invoked on the worker. |
| 206 | +
|
| 207 | + Returns: |
| 208 | + True if shutdown has been called on the worker, False otherwise. |
| 209 | +
|
| 210 | + Raises: |
| 211 | + RuntimeError: When not in an activity. |
| 212 | + """ |
| 213 | + return _Context.current().worker_shutdown_event.is_set() |
| 214 | + |
| 215 | + |
| 216 | +async def wait_for_worker_shutdown() -> None: |
| 217 | + """Asynchronously wait for shutdown to be called on the worker. |
| 218 | +
|
| 219 | + Raises: |
| 220 | + RuntimeError: When not in an async activity. |
| 221 | + """ |
| 222 | + await _Context.current().worker_shutdown_event.wait() |
| 223 | + |
| 224 | + |
| 225 | +def wait_for_worker_shutdown_sync(timeout: Optional[float] = None) -> None: |
| 226 | + """Synchronously block while waiting for shutdown to be called on the |
| 227 | + worker. |
| 228 | +
|
| 229 | + This is essentially a wrapper around :py:meth:`threading.Event.wait`. |
| 230 | +
|
| 231 | + Args: |
| 232 | + timeout: Max amount of time to wait for shutdown to be called on the |
| 233 | + worker. |
| 234 | +
|
| 235 | + Raises: |
| 236 | + RuntimeError: When not in an activity. |
| 237 | + """ |
| 238 | + _Context.current().worker_shutdown_event.wait_sync(timeout) |
| 239 | + |
| 240 | + |
| 241 | +def raise_complete_async() -> NoReturn: |
| 242 | + """Raise an error that says the activity will be completed |
| 243 | + asynchronously. |
| 244 | + """ |
| 245 | + raise _CompleteAsyncError() |
| 246 | + |
| 247 | + |
| 248 | +class _CompleteAsyncError(temporalio.exceptions.TemporalError): |
| 249 | + pass |
| 250 | + |
| 251 | + |
| 252 | +class LoggerAdapter(logging.LoggerAdapter): |
| 253 | + """Adapter that adds details to the log about the running activity. |
| 254 | +
|
| 255 | + Attributes: |
| 256 | + activity_info_on_message: Boolean for whether a string representation of |
| 257 | + a dict of some activity info will be appended to each message. |
| 258 | + Default is True. |
| 259 | + activity_info_on_extra: Boolean for whether an ``activity_info`` value |
| 260 | + will be added to the ``extra`` dictionary, making it present on the |
| 261 | + ``LogRecord.__dict__`` for use by others. |
| 262 | + """ |
| 263 | + |
| 264 | + def __init__( |
| 265 | + self, logger: logging.Logger, extra: Optional[Mapping[str, Any]] |
| 266 | + ) -> None: |
| 267 | + """Create the logger adapter.""" |
| 268 | + super().__init__(logger, extra or {}) |
| 269 | + self.activity_info_on_message = True |
| 270 | + self.activity_info_on_extra = True |
| 271 | + |
| 272 | + def process( |
| 273 | + self, msg: Any, kwargs: MutableMapping[str, Any] |
| 274 | + ) -> Tuple[Any, MutableMapping[str, Any]]: |
| 275 | + """Override to add activity details.""" |
| 276 | + msg, kwargs = super().process(msg, kwargs) |
| 277 | + if self.activity_info_on_extra or self.activity_info_on_extra: |
| 278 | + context = _current_context.get(None) |
| 279 | + if context: |
| 280 | + if self.activity_info_on_message: |
| 281 | + msg = f"{msg} ({context.logger_details})" |
| 282 | + if self.activity_info_on_extra: |
| 283 | + # Extra can be absent or None, this handles both |
| 284 | + extra = kwargs.get("extra", None) or {} |
| 285 | + extra["activity_info"] = context.info() |
| 286 | + kwargs["extra"] = extra |
| 287 | + return (msg, kwargs) |
| 288 | + |
| 289 | + @property |
| 290 | + def base_logger(self) -> logging.Logger: |
| 291 | + """Underlying logger usable for actions such as adding |
| 292 | + handlers/formatters. |
| 293 | + """ |
| 294 | + return self.logger |
| 295 | + |
| 296 | + |
| 297 | +#: Logger that will have contextual activity details embedded. |
| 298 | +logger = LoggerAdapter(logging.getLogger(__name__), None) |
0 commit comments