Skip to content

Commit 6a43919

Browse files
authored
Support raising cancellation in sync multithreaded activities (#217)
1 parent 89b6e66 commit 6a43919

File tree

13 files changed

+339
-50
lines changed

13 files changed

+339
-50
lines changed

README.md

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -850,8 +850,8 @@ activities no special worker parameters are needed.
850850

851851
Cancellation for asynchronous activities is done via
852852
[`asyncio.Task.cancel`](https://docs.python.org/3/library/asyncio-task.html#asyncio.Task.cancel). This means that
853-
`asyncio.CancelledError` will be raised (and can be caught, but it is not recommended). An activity must heartbeat to
854-
receive cancellation and there are other ways to be notified about cancellation (see "Activity Context" and
853+
`asyncio.CancelledError` will be raised (and can be caught, but it is not recommended). A non-local activity must
854+
heartbeat to receive cancellation and there are other ways to be notified about cancellation (see "Activity Context" and
855855
"Heartbeating and Cancellation" later).
856856

857857
##### Synchronous Activities
@@ -860,10 +860,10 @@ Synchronous activities, i.e. functions that do not have `async def`, can be used
860860
`activity_executor` worker parameter must be set with a `concurrent.futures.Executor` instance to use for executing the
861861
activities.
862862

863-
Cancellation for synchronous activities is done in the background and the activity must choose to listen for it and
864-
react appropriately. If after cancellation is obtained an unwrapped `temporalio.exceptions.CancelledError` is raised,
865-
the activity will be marked cancelled. An activity must heartbeat to receive cancellation and there are other ways to be
866-
notified about cancellation (see "Activity Context" and "Heartbeating and Cancellation" later).
863+
All long running, non-local activities should heartbeat so they can be cancelled. Cancellation in threaded activities
864+
throws but multiprocess/other activities does not. The sections below on each synchronous type explain further. There
865+
are also calls on the context that can check for cancellation. For more information, see "Activity Context" and
866+
"Heartbeating and Cancellation" sections later.
867867

868868
Note, all calls from an activity to functions in the `temporalio.activity` package are powered by
869869
[contextvars](https://docs.python.org/3/library/contextvars.html). Therefore, new threads starting _inside_ of
@@ -876,6 +876,15 @@ If `activity_executor` is set to an instance of `concurrent.futures.ThreadPoolEx
876876
are considered multithreaded activities. Besides `activity_executor`, no other worker parameters are required for
877877
synchronous multithreaded activities.
878878

879+
By default, cancellation of a synchronous multithreaded activity is done via a `temporalio.exceptions.CancelledError`
880+
thrown into the activity thread. Activities that do not wish to have cancellation thrown can set
881+
`no_thread_cancel_exception=True` in the `@activity.defn` decorator.
882+
883+
Code that wishes to be temporarily shielded from the cancellation exception can run inside
884+
`with activity.shield_thread_cancel_exception():`. But once the last nested form of that block is finished, even if
885+
there is a return statement within, it will throw the cancellation if there was one. A `try` +
886+
`except temporalio.exceptions.CancelledError` would have to surround the `with` to handle the cancellation explicitly.
887+
879888
###### Synchronous Multiprocess/Other Activities
880889

881890
If `activity_executor` is set to an instance of `concurrent.futures.Executor` that is _not_
@@ -901,6 +910,8 @@ calls in the `temporalio.activity` package make use of it. Specifically:
901910
* `is_cancelled()` - Whether a cancellation has been requested on this activity
902911
* `wait_for_cancelled()` - `async` call to wait for cancellation request
903912
* `wait_for_cancelled_sync(timeout)` - Synchronous blocking call to wait for cancellation request
913+
* `shield_thread_cancel_exception()` - Context manager for use in `with` clauses by synchronous multithreaded activities
914+
to prevent cancel exception from being thrown during the block of code
904915
* `is_worker_shutdown()` - Whether the worker has started graceful shutdown
905916
* `wait_for_worker_shutdown()` - `async` call to wait for start of graceful worker shutdown
906917
* `wait_for_worker_shutdown_sync(timeout)` - Synchronous blocking call to wait for start of graceful worker shutdown
@@ -912,15 +923,17 @@ occurs. Synchronous activities cannot call any of the `async` functions.
912923

913924
##### Heartbeating and Cancellation
914925

915-
In order for an activity to be notified of cancellation requests, they must invoke `temporalio.activity.heartbeat()`.
916-
It is strongly recommended that all but the fastest executing activities call this function regularly. "Types of
917-
Activities" has specifics on cancellation for asynchronous and synchronous activities.
926+
In order for a non-local activity to be notified of cancellation requests, it must invoke
927+
`temporalio.activity.heartbeat()`. It is strongly recommended that all but the fastest executing activities call this
928+
function regularly. "Types of Activities" has specifics on cancellation for asynchronous and synchronous activities.
918929

919930
In addition to obtaining cancellation information, heartbeats also support detail data that is persisted on the server
920931
for retrieval during activity retry. If an activity calls `temporalio.activity.heartbeat(123, 456)` and then fails and
921932
is retried, `temporalio.activity.info().heartbeat_details` will return an iterable containing `123` and `456` on the
922933
next run.
923934

935+
Heartbeating has no effect on local activities.
936+
924937
##### Worker Shutdown
925938

926939
An activity can react to a worker shutdown. Using `is_worker_shutdown` or one of the `wait_for_worker_shutdown`

temporalio/activity.py

Lines changed: 58 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@
1515
import inspect
1616
import logging
1717
import threading
18+
from contextlib import AbstractContextManager, contextmanager
1819
from dataclasses import dataclass
1920
from datetime import datetime, timedelta
20-
from functools import partial
2121
from typing import (
2222
Any,
2323
Callable,
24+
Iterator,
2425
List,
2526
Mapping,
2627
MutableMapping,
@@ -34,7 +35,6 @@
3435
)
3536

3637
import temporalio.common
37-
import temporalio.exceptions
3838

3939
from .types import CallableType
4040

@@ -45,32 +45,41 @@ def defn(fn: CallableType) -> CallableType:
4545

4646

4747
@overload
48-
def defn(*, name: str) -> Callable[[CallableType], CallableType]:
48+
def defn(
49+
*, name: Optional[str] = None, no_thread_cancel_exception: bool = False
50+
) -> Callable[[CallableType], CallableType]:
4951
...
5052

5153

52-
def defn(fn: Optional[CallableType] = None, *, name: Optional[str] = None):
54+
def defn(
55+
fn: Optional[CallableType] = None,
56+
*,
57+
name: Optional[str] = None,
58+
no_thread_cancel_exception: bool = False,
59+
):
5360
"""Decorator for activity functions.
5461
5562
Activities can be async or non-async.
5663
5764
Args:
5865
fn: The function to decorate.
5966
name: Name to use for the activity. Defaults to function ``__name__``.
67+
no_thread_cancel_exception: If set to true, an exception will not be
68+
raised in synchronous, threaded activities upon cancellation.
6069
"""
6170

62-
def with_name(name: str, fn: CallableType) -> CallableType:
71+
def decorator(fn: CallableType) -> CallableType:
6372
# This performs validation
64-
_Definition._apply_to_callable(fn, name)
73+
_Definition._apply_to_callable(
74+
fn,
75+
activity_name=name or fn.__name__,
76+
no_thread_cancel_exception=no_thread_cancel_exception,
77+
)
6578
return fn
6679

67-
# If name option is available, return decorator function
68-
if name is not None:
69-
return partial(with_name, name)
70-
if fn is None:
71-
raise RuntimeError("Cannot invoke defn without function or name")
72-
# Otherwise just run decorator function
73-
return with_name(fn.__name__, fn)
80+
if fn is not None:
81+
return decorator(fn)
82+
return decorator
7483

7584

7685
@dataclass(frozen=True)
@@ -122,6 +131,7 @@ class _Context:
122131
heartbeat: Optional[Callable[..., None]]
123132
cancelled_event: _CompositeEvent
124133
worker_shutdown_event: _CompositeEvent
134+
shield_thread_cancel_exception: Optional[Callable[[], AbstractContextManager]]
125135
_logger_details: Optional[Mapping[str, Any]] = None
126136

127137
@staticmethod
@@ -221,6 +231,36 @@ def is_cancelled() -> bool:
221231
return _Context.current().cancelled_event.is_set()
222232

223233

234+
@contextmanager
235+
def shield_thread_cancel_exception() -> Iterator[None]:
236+
"""Context manager for synchronous multithreaded activities to delay
237+
cancellation exceptions.
238+
239+
By default, synchronous multithreaded activities have an exception thrown
240+
inside when cancellation occurs. Code within a "with" block of this context
241+
manager will delay that throwing until the end. Even if the block returns a
242+
value or throws its own exception, if a cancellation exception is pending,
243+
it is thrown instead. Therefore users are encouraged to not throw out of
244+
this block and can surround this with a try/except if they wish to catch a
245+
cancellation.
246+
247+
This properly supports nested calls and will only throw after the last one.
248+
249+
This just runs the blocks with no extra effects for async activities or
250+
synchronous multiprocess/other activities.
251+
252+
Raises:
253+
temporalio.exceptions.CancelledError: If a cancellation occurs anytime
254+
during this block and this is not nested in another shield block.
255+
"""
256+
shield_context = _Context.current().shield_thread_cancel_exception
257+
if not shield_context:
258+
yield None
259+
else:
260+
with shield_context():
261+
yield None
262+
263+
224264
async def wait_for_cancelled() -> None:
225265
"""Asynchronously wait for this activity to get a cancellation request.
226266
@@ -353,6 +393,7 @@ class _Definition:
353393
name: str
354394
fn: Callable
355395
is_async: bool
396+
no_thread_cancel_exception: bool
356397
# Types loaded on post init if both are None
357398
arg_types: Optional[List[Type]] = None
358399
ret_type: Optional[Type] = None
@@ -379,7 +420,9 @@ def must_from_callable(fn: Callable) -> _Definition:
379420
)
380421

381422
@staticmethod
382-
def _apply_to_callable(fn: Callable, activity_name: str) -> None:
423+
def _apply_to_callable(
424+
fn: Callable, *, activity_name: str, no_thread_cancel_exception: bool = False
425+
) -> None:
383426
# Validate the activity
384427
if hasattr(fn, "__temporal_activity_definition"):
385428
raise ValueError("Function already contains activity definition")
@@ -399,6 +442,7 @@ def _apply_to_callable(fn: Callable, activity_name: str) -> None:
399442
# iscoroutinefunction does not return true for async __call__
400443
# TODO(cretz): Why can't MyPy handle this?
401444
is_async=inspect.iscoroutinefunction(fn) or inspect.iscoroutinefunction(fn.__call__), # type: ignore
445+
no_thread_cancel_exception=no_thread_cancel_exception,
402446
),
403447
)
404448

temporalio/bridge/runtime.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from __future__ import annotations
77

88
from dataclasses import dataclass
9-
from typing import ClassVar, Mapping, Optional
9+
from typing import ClassVar, Mapping, Optional, Type
1010

1111
import temporalio.bridge.temporal_sdk_bridge
1212

@@ -54,6 +54,13 @@ def set_default(runtime: Runtime, *, error_if_already_set: bool = True) -> None:
5454
raise RuntimeError("Runtime default already set")
5555
_default_runtime = runtime
5656

57+
@staticmethod
58+
def _raise_in_thread(thread_id: int, exc_type: Type[BaseException]) -> bool:
59+
"""Internal helper for raising an exception in thread."""
60+
return temporalio.bridge.temporal_sdk_bridge.raise_in_thread(
61+
thread_id, exc_type
62+
)
63+
5764
def __init__(self, *, telemetry: TelemetryConfig) -> None:
5865
"""Create a default runtime with the given telemetry config.
5966

temporalio/bridge/src/lib.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ fn temporal_sdk_bridge(py: Python, m: &PyModule) -> PyResult<()> {
1616
// Runtime stuff
1717
m.add_class::<runtime::RuntimeRef>()?;
1818
m.add_function(wrap_pyfunction!(init_runtime, m)?)?;
19+
m.add_function(wrap_pyfunction!(raise_in_thread, m)?)?;
1920

2021
// Testing stuff
2122
m.add_class::<testing::EphemeralServerRef>()?;
@@ -48,6 +49,11 @@ fn init_runtime(telemetry_config: runtime::TelemetryConfig) -> PyResult<runtime:
4849
runtime::init_runtime(telemetry_config)
4950
}
5051

52+
#[pyfunction]
53+
fn raise_in_thread<'a>(py: Python<'a>, thread_id: std::os::raw::c_long, exc: &PyAny) -> bool {
54+
runtime::raise_in_thread(py, thread_id, exc)
55+
}
56+
5157
#[pyfunction]
5258
fn start_temporalite<'a>(
5359
py: Python<'a>,

temporalio/bridge/src/runtime.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use pyo3::exceptions::{PyRuntimeError, PyValueError};
22
use pyo3::prelude::*;
3+
use pyo3::AsPyPointer;
34
use std::collections::HashMap;
45
use std::future::Future;
56
use std::net::SocketAddr;
@@ -75,6 +76,10 @@ pub fn init_runtime(telemetry_config: TelemetryConfig) -> PyResult<RuntimeRef> {
7576
})
7677
}
7778

79+
pub fn raise_in_thread<'a>(_py: Python<'a>, thread_id: std::os::raw::c_long, exc: &PyAny) -> bool {
80+
unsafe { pyo3::ffi::PyThreadState_SetAsyncExc(thread_id, exc.as_ptr()) == 1 }
81+
}
82+
7883
impl Runtime {
7984
pub fn future_into_py<'a, F, T>(&self, py: Python<'a>, fut: F) -> PyResult<&'a PyAny>
8085
where

temporalio/exceptions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def non_retryable(self) -> bool:
9797
class CancelledError(FailureError):
9898
"""Error raised on workflow/activity cancellation."""
9999

100-
def __init__(self, message: str, *details: Any) -> None:
100+
def __init__(self, message: str = "Cancelled", *details: Any) -> None:
101101
"""Initialize a cancelled error."""
102102
super().__init__(message)
103103
self._details = details

temporalio/testing/_activity.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from typing_extensions import ParamSpec
1313

1414
import temporalio.activity
15+
import temporalio.exceptions
16+
import temporalio.worker._activity
1517

1618
_Params = ParamSpec("_Params")
1719
_Return = TypeVar("_Return")
@@ -111,6 +113,17 @@ def __init__(
111113
self.env = env
112114
self.fn = fn
113115
self.is_async = inspect.iscoroutinefunction(fn)
116+
self.cancel_thread_raiser: Optional[
117+
temporalio.worker._activity._ThreadExceptionRaiser
118+
] = None
119+
if not self.is_async:
120+
# If there is a definition and they disable thread raising, don't
121+
# set
122+
defn = temporalio.activity._Definition.from_callable(fn)
123+
if not defn or not defn.no_thread_cancel_exception:
124+
self.cancel_thread_raiser = (
125+
temporalio.worker._activity._ThreadExceptionRaiser()
126+
)
114127
# Create context
115128
self.context = temporalio.activity._Context(
116129
info=lambda: env.info,
@@ -123,10 +136,18 @@ def __init__(
123136
thread_event=threading.Event(),
124137
async_event=asyncio.Event() if self.is_async else None,
125138
),
139+
shield_thread_cancel_exception=None
140+
if not self.cancel_thread_raiser
141+
else self.cancel_thread_raiser.shielded,
126142
)
127143
self.task: Optional[asyncio.Task] = None
128144

129145
def run(self, *args, **kwargs) -> Any:
146+
if self.cancel_thread_raiser:
147+
thread_id = threading.current_thread().ident
148+
if thread_id is not None:
149+
self.cancel_thread_raiser.set_thread_id(thread_id)
150+
130151
@contextmanager
131152
def activity_context():
132153
# Set cancelled and shutdown if already so in environment
@@ -163,6 +184,10 @@ async def run_async():
163184
def cancel(self) -> None:
164185
if not self.context.cancelled_event.is_set():
165186
self.context.cancelled_event.set()
187+
if self.cancel_thread_raiser:
188+
self.cancel_thread_raiser.raise_in_thread(
189+
temporalio.exceptions.CancelledError
190+
)
166191
if self.task and not self.task.done():
167192
self.task.cancel()
168193

0 commit comments

Comments
 (0)