Skip to content

Commit 27cc67f

Browse files
authored
Interrupt heartbeating activity on pause (#854)
* Init commit - waiting for core changes * add activity paused error usage * working for async activities * working for sync activities * linting and cleanup * add cancellation details arg to testing ActivityEnvironment cancel * remove .vscode * formatting * use object reference instead of function, picklable * nits * use holder * docstrings * add test for pause/unpause * linting, reduce heartbeat timeouts for faster test * make cancellation details non-optional for testing activity env * address pr suggestion * rebase conflict fixes * include is_worker_shutdown as reason for requested cancellation, test fix * skip if time-skipping server (does not support pause/unpause yet) * remove sleep calls from tests, add cancellation details to async cancellation errors from external activities * replace racy heartbeat check with heartbeat details check
1 parent b24326c commit 27cc67f

File tree

9 files changed

+501
-16
lines changed

9 files changed

+501
-16
lines changed

temporalio/activity.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
overload,
3535
)
3636

37+
import temporalio.bridge
38+
import temporalio.bridge.proto
39+
import temporalio.bridge.proto.activity_task
3740
import temporalio.common
3841
import temporalio.converter
3942

@@ -135,6 +138,34 @@ def _logger_details(self) -> Mapping[str, Any]:
135138
_current_context: contextvars.ContextVar[_Context] = contextvars.ContextVar("activity")
136139

137140

141+
@dataclass
142+
class _ActivityCancellationDetailsHolder:
143+
details: Optional[ActivityCancellationDetails] = None
144+
145+
146+
@dataclass(frozen=True)
147+
class ActivityCancellationDetails:
148+
"""Provides the reasons for the activity's cancellation. Cancellation details are set once and do not change once set."""
149+
150+
not_found: bool = False
151+
cancel_requested: bool = False
152+
paused: bool = False
153+
timed_out: bool = False
154+
worker_shutdown: bool = False
155+
156+
@staticmethod
157+
def _from_proto(
158+
proto: temporalio.bridge.proto.activity_task.ActivityCancellationDetails,
159+
) -> ActivityCancellationDetails:
160+
return ActivityCancellationDetails(
161+
not_found=proto.is_not_found,
162+
cancel_requested=proto.is_cancelled,
163+
paused=proto.is_paused,
164+
timed_out=proto.is_timed_out,
165+
worker_shutdown=proto.is_worker_shutdown,
166+
)
167+
168+
138169
@dataclass
139170
class _Context:
140171
info: Callable[[], Info]
@@ -148,6 +179,7 @@ class _Context:
148179
temporalio.converter.PayloadConverter,
149180
]
150181
runtime_metric_meter: Optional[temporalio.common.MetricMeter]
182+
cancellation_details: _ActivityCancellationDetailsHolder
151183
_logger_details: Optional[Mapping[str, Any]] = None
152184
_payload_converter: Optional[temporalio.converter.PayloadConverter] = None
153185
_metric_meter: Optional[temporalio.common.MetricMeter] = None
@@ -260,6 +292,11 @@ def info() -> Info:
260292
return _Context.current().info()
261293

262294

295+
def cancellation_details() -> Optional[ActivityCancellationDetails]:
296+
"""Cancellation details of the current activity, if any. Once set, cancellation details do not change."""
297+
return _Context.current().cancellation_details.details
298+
299+
263300
def heartbeat(*details: Any) -> None:
264301
"""Send a heartbeat for the current activity.
265302

temporalio/bridge/src/client.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,9 @@ impl ClientRef {
235235
"patch_schedule" => {
236236
rpc_call!(retry_client, call, patch_schedule)
237237
}
238+
"pause_activity" => {
239+
rpc_call!(retry_client, call, pause_activity)
240+
}
238241
"poll_activity_task_queue" => {
239242
rpc_call!(retry_client, call, poll_activity_task_queue)
240243
}
@@ -325,6 +328,9 @@ impl ClientRef {
325328
"trigger_workflow_rule" => {
326329
rpc_call!(retry_client, call, trigger_workflow_rule)
327330
}
331+
"unpause_activity" => {
332+
rpc_call!(retry_client, call, unpause_activity)
333+
}
328334
"update_namespace" => {
329335
rpc_call_on_trait!(retry_client, call, WorkflowService, update_namespace)
330336
}

temporalio/client.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
import temporalio.runtime
5757
import temporalio.service
5858
import temporalio.workflow
59+
from temporalio.activity import ActivityCancellationDetails
5960
from temporalio.service import (
6061
HttpConnectProxyConfig,
6162
KeepAliveConfig,
@@ -5145,9 +5146,10 @@ def __init__(self) -> None:
51455146
class AsyncActivityCancelledError(temporalio.exceptions.TemporalError):
51465147
"""Error that occurs when async activity attempted heartbeat but was cancelled."""
51475148

5148-
def __init__(self) -> None:
5149+
def __init__(self, details: Optional[ActivityCancellationDetails] = None) -> None:
51495150
"""Create async activity cancelled error."""
51505151
super().__init__("Activity cancelled")
5152+
self.details = details
51515153

51525154

51535155
class ScheduleAlreadyRunningError(temporalio.exceptions.TemporalError):
@@ -6287,8 +6289,14 @@ async def heartbeat_async_activity(
62876289
metadata=input.rpc_metadata,
62886290
timeout=input.rpc_timeout,
62896291
)
6290-
if resp_by_id.cancel_requested:
6291-
raise AsyncActivityCancelledError()
6292+
if resp_by_id.cancel_requested or resp_by_id.activity_paused:
6293+
raise AsyncActivityCancelledError(
6294+
details=ActivityCancellationDetails(
6295+
cancel_requested=resp_by_id.cancel_requested,
6296+
paused=resp_by_id.activity_paused,
6297+
)
6298+
)
6299+
62926300
else:
62936301
resp = await self._client.workflow_service.record_activity_task_heartbeat(
62946302
temporalio.api.workflowservice.v1.RecordActivityTaskHeartbeatRequest(
@@ -6301,8 +6309,13 @@ async def heartbeat_async_activity(
63016309
metadata=input.rpc_metadata,
63026310
timeout=input.rpc_timeout,
63036311
)
6304-
if resp.cancel_requested:
6305-
raise AsyncActivityCancelledError()
6312+
if resp.cancel_requested or resp.activity_paused:
6313+
raise AsyncActivityCancelledError(
6314+
details=ActivityCancellationDetails(
6315+
cancel_requested=resp.cancel_requested,
6316+
paused=resp.activity_paused,
6317+
)
6318+
)
63066319

63076320
async def complete_async_activity(self, input: CompleteAsyncActivityInput) -> None:
63086321
result = (

temporalio/testing/_activity.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,29 @@ def __init__(self) -> None:
7474
self._cancelled = False
7575
self._worker_shutdown = False
7676
self._activities: Set[_Activity] = set()
77+
self._cancellation_details = (
78+
temporalio.activity._ActivityCancellationDetailsHolder()
79+
)
7780

78-
def cancel(self) -> None:
81+
def cancel(
82+
self,
83+
cancellation_details: temporalio.activity.ActivityCancellationDetails = temporalio.activity.ActivityCancellationDetails(
84+
cancel_requested=True
85+
),
86+
) -> None:
7987
"""Cancel the activity.
8088
89+
Args:
90+
cancellation_details: details about the cancellation. These will
91+
be accessible through temporalio.activity.cancellation_details()
92+
in the activity after cancellation.
93+
8194
This only has an effect on the first call.
8295
"""
8396
if self._cancelled:
8497
return
8598
self._cancelled = True
99+
self._cancellation_details.details = cancellation_details
86100
for act in self._activities:
87101
act.cancel()
88102

@@ -154,6 +168,7 @@ def __init__(
154168
else self.cancel_thread_raiser.shielded,
155169
payload_converter_class_or_instance=env.payload_converter,
156170
runtime_metric_meter=env.metric_meter,
171+
cancellation_details=env._cancellation_details,
157172
)
158173
self.task: Optional[asyncio.Task] = None
159174

temporalio/worker/_activity.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import warnings
1616
from abc import ABC, abstractmethod
1717
from contextlib import contextmanager
18-
from dataclasses import dataclass
18+
from dataclasses import dataclass, field
1919
from datetime import datetime, timedelta, timezone
2020
from typing import (
2121
Any,
@@ -216,7 +216,13 @@ def _cancel(
216216
warnings.warn(f"Cannot find activity to cancel for token {task_token!r}")
217217
return
218218
logger.debug("Cancelling activity %s, reason: %s", task_token, cancel.reason)
219-
activity.cancel(cancelled_by_request=True)
219+
activity.cancellation_details.details = (
220+
temporalio.activity.ActivityCancellationDetails._from_proto(cancel.details)
221+
)
222+
activity.cancel(
223+
cancelled_by_request=cancel.details.is_cancelled
224+
or cancel.details.is_worker_shutdown
225+
)
220226

221227
def _heartbeat(self, task_token: bytes, *details: Any) -> None:
222228
# We intentionally make heartbeating non-async, but since the data
@@ -303,6 +309,24 @@ async def _run_activity(
303309
await self._data_converter.encode_failure(
304310
err, completion.result.failed.failure
305311
)
312+
elif (
313+
isinstance(
314+
err,
315+
(asyncio.CancelledError, temporalio.exceptions.CancelledError),
316+
)
317+
and running_activity.cancellation_details.details
318+
and running_activity.cancellation_details.details.paused
319+
):
320+
temporalio.activity.logger.warning(
321+
f"Completing as failure due to unhandled cancel error produced by activity pause",
322+
)
323+
await self._data_converter.encode_failure(
324+
temporalio.exceptions.ApplicationError(
325+
type="ActivityPause",
326+
message="Unhandled activity cancel error produced by activity pause",
327+
),
328+
completion.result.failed.failure,
329+
)
306330
elif (
307331
isinstance(
308332
err,
@@ -336,7 +360,6 @@ async def _run_activity(
336360
await self._data_converter.encode_failure(
337361
err, completion.result.failed.failure
338362
)
339-
340363
# For broken executors, we have to fail the entire worker
341364
if isinstance(err, concurrent.futures.BrokenExecutor):
342365
self._fail_worker_exception_queue.put_nowait(err)
@@ -524,6 +547,7 @@ async def _execute_activity(
524547
else running_activity.cancel_thread_raiser.shielded,
525548
payload_converter_class_or_instance=self._data_converter.payload_converter,
526549
runtime_metric_meter=None if sync_non_threaded else self._metric_meter,
550+
cancellation_details=running_activity.cancellation_details,
527551
)
528552
)
529553
temporalio.activity.logger.debug("Starting activity")
@@ -570,6 +594,9 @@ class _RunningActivity:
570594
done: bool = False
571595
cancelled_by_request: bool = False
572596
cancelled_due_to_heartbeat_error: Optional[Exception] = None
597+
cancellation_details: temporalio.activity._ActivityCancellationDetailsHolder = (
598+
field(default_factory=temporalio.activity._ActivityCancellationDetailsHolder)
599+
)
573600

574601
def cancel(
575602
self,
@@ -659,6 +686,7 @@ async def execute_activity(self, input: ExecuteActivityInput) -> Any:
659686
# can set the initializer on the executor).
660687
ctx = temporalio.activity._Context.current()
661688
info = ctx.info()
689+
cancellation_details = ctx.cancellation_details
662690

663691
# Heartbeat calls internally use a data converter which is async so
664692
# they need to be called on the event loop
@@ -717,6 +745,7 @@ async def heartbeat_with_context(*details: Any) -> None:
717745
worker_shutdown_event.thread_event,
718746
payload_converter_class_or_instance,
719747
ctx.runtime_metric_meter,
748+
cancellation_details,
720749
input.fn,
721750
*input.args,
722751
]
@@ -732,7 +761,6 @@ async def heartbeat_with_context(*details: Any) -> None:
732761
finally:
733762
if shared_manager:
734763
await shared_manager.unregister_heartbeater(info.task_token)
735-
736764
# Otherwise for async activity, just run
737765
return await input.fn(*input.args)
738766

@@ -764,6 +792,7 @@ def _execute_sync_activity(
764792
temporalio.converter.PayloadConverter,
765793
],
766794
runtime_metric_meter: Optional[temporalio.common.MetricMeter],
795+
cancellation_details: temporalio.activity._ActivityCancellationDetailsHolder,
767796
fn: Callable[..., Any],
768797
*args: Any,
769798
) -> Any:
@@ -795,6 +824,7 @@ def _execute_sync_activity(
795824
else cancel_thread_raiser.shielded,
796825
payload_converter_class_or_instance=payload_converter_class_or_instance,
797826
runtime_metric_meter=runtime_metric_meter,
827+
cancellation_details=cancellation_details,
798828
)
799829
)
800830
return fn(*args)

tests/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ async def env(env_type: str) -> AsyncGenerator[WorkflowEnvironment, None]:
115115
"frontend.workerVersioningDataAPIs=true",
116116
"--dynamic-config-value",
117117
"system.enableDeploymentVersions=true",
118+
"--dynamic-config-value",
119+
"frontend.activityAPIsEnabled=true",
118120
],
119121
dev_server_download_version=DEV_SERVER_DOWNLOAD_VERSION,
120122
)

tests/helpers/__init__.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313
ListSearchAttributesRequest,
1414
)
1515
from temporalio.api.update.v1 import UpdateRef
16-
from temporalio.api.workflowservice.v1 import PollWorkflowExecutionUpdateRequest
16+
from temporalio.api.workflow.v1 import PendingActivityInfo
17+
from temporalio.api.workflowservice.v1 import (
18+
PauseActivityRequest,
19+
PollWorkflowExecutionUpdateRequest,
20+
UnpauseActivityRequest,
21+
)
1722
from temporalio.client import BuildIdOpAddNewDefault, Client, WorkflowHandle
1823
from temporalio.common import SearchAttributeKey
1924
from temporalio.service import RPCError, RPCStatusCode
@@ -210,3 +215,75 @@ async def check_workflow_exists() -> bool:
210215
await assert_eq_eventually(True, check_workflow_exists)
211216
assert handle is not None
212217
return handle
218+
219+
220+
async def assert_pending_activity_exists_eventually(
221+
handle: WorkflowHandle,
222+
activity_id: str,
223+
timeout: timedelta = timedelta(seconds=5),
224+
) -> PendingActivityInfo:
225+
"""Wait until a pending activity with the given ID exists and return it."""
226+
227+
async def check() -> PendingActivityInfo:
228+
act_info = await get_pending_activity_info(handle, activity_id)
229+
if act_info is not None:
230+
return act_info
231+
raise AssertionError(
232+
f"Activity with ID {activity_id} not found in pending activities"
233+
)
234+
235+
return await assert_eventually(check, timeout=timeout)
236+
237+
238+
async def get_pending_activity_info(
239+
handle: WorkflowHandle,
240+
activity_id: str,
241+
) -> Optional[PendingActivityInfo]:
242+
"""Get pending activity info by ID, or None if not found."""
243+
desc = await handle.describe()
244+
for act in desc.raw_description.pending_activities:
245+
if act.activity_id == activity_id:
246+
return act
247+
return None
248+
249+
250+
async def pause_and_assert(client: Client, handle: WorkflowHandle, activity_id: str):
251+
"""Pause the given activity and assert it becomes paused."""
252+
desc = await handle.describe()
253+
req = PauseActivityRequest(
254+
namespace=client.namespace,
255+
execution=WorkflowExecution(
256+
workflow_id=desc.raw_description.workflow_execution_info.execution.workflow_id,
257+
run_id=desc.raw_description.workflow_execution_info.execution.run_id,
258+
),
259+
id=activity_id,
260+
)
261+
await client.workflow_service.pause_activity(req)
262+
263+
# Assert eventually paused
264+
async def check_paused() -> bool:
265+
info = await assert_pending_activity_exists_eventually(handle, activity_id)
266+
return info.paused
267+
268+
await assert_eventually(check_paused)
269+
270+
271+
async def unpause_and_assert(client: Client, handle: WorkflowHandle, activity_id: str):
272+
"""Unpause the given activity and assert it is not paused."""
273+
desc = await handle.describe()
274+
req = UnpauseActivityRequest(
275+
namespace=client.namespace,
276+
execution=WorkflowExecution(
277+
workflow_id=desc.raw_description.workflow_execution_info.execution.workflow_id,
278+
run_id=desc.raw_description.workflow_execution_info.execution.run_id,
279+
),
280+
id=activity_id,
281+
)
282+
await client.workflow_service.unpause_activity(req)
283+
284+
# Assert eventually not paused
285+
async def check_unpaused() -> bool:
286+
info = await assert_pending_activity_exists_eventually(handle, activity_id)
287+
return not info.paused
288+
289+
await assert_eventually(check_unpaused)

0 commit comments

Comments
 (0)