Skip to content

Commit e5bd9a9

Browse files
authored
Properly handle uncaught child/activity cancel during workflow cancel (#71)
Fixes #70
1 parent a079a5e commit e5bd9a9

File tree

2 files changed

+88
-8
lines changed

2 files changed

+88
-8
lines changed

temporalio/worker/workflow_instance.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def __init__(self, det: WorkflowInstanceDetails) -> None:
149149
self._info = det.info
150150
self._primary_task: Optional[asyncio.Task[None]] = None
151151
self._time = 0.0
152+
self._cancel_requested = False
152153
self._current_history_length = 0
153154
# Handles which are ready to run on the next event loop iteration
154155
self._ready: Deque[asyncio.Handle] = collections.deque()
@@ -345,6 +346,7 @@ def _apply(
345346
def _apply_cancel_workflow(
346347
self, job: temporalio.bridge.proto.workflow_activation.CancelWorkflow
347348
) -> None:
349+
self._cancel_requested = True
348350
# TODO(cretz): Details or cancel message or whatever?
349351
if self._primary_task:
350352
self._primary_task.cancel()
@@ -1126,6 +1128,20 @@ async def _run_top_level_workflow_function(self, coro: Awaitable[None]) -> None:
11261128
f"Workflow raised failure with run ID {self._info.run_id}",
11271129
exc_info=True,
11281130
)
1131+
# If a cancel was requested, and the failure is from an activity or
1132+
# child, and its cause was a cancellation, we want to use that cause
1133+
# instead because it means a cancel bubbled up while waiting on an
1134+
# activity or child.
1135+
if (
1136+
self._cancel_requested
1137+
and (
1138+
isinstance(err, temporalio.exceptions.ActivityError)
1139+
or isinstance(err, temporalio.exceptions.ChildWorkflowError)
1140+
)
1141+
and isinstance(err.cause, temporalio.exceptions.CancelledError)
1142+
):
1143+
err = err.cause
1144+
11291145
command = self._add_command()
11301146
command.fail_workflow_execution.failure.SetInParent()
11311147
try:

tests/worker/test_workflow.py

Lines changed: 72 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,13 @@
4545
)
4646
from temporalio.bridge.proto.workflow_activation import WorkflowActivation
4747
from temporalio.bridge.proto.workflow_completion import WorkflowActivationCompletion
48-
from temporalio.client import Client, WorkflowFailureError, WorkflowHandle
48+
from temporalio.client import (
49+
Client,
50+
RPCError,
51+
RPCStatusCode,
52+
WorkflowFailureError,
53+
WorkflowHandle,
54+
)
4955
from temporalio.common import RetryPolicy, SearchAttributes
5056
from temporalio.converter import DataConverter, PayloadCodec, decode_search_attributes
5157
from temporalio.exceptions import (
@@ -690,6 +696,58 @@ async def started() -> bool:
690696
assert isinstance(err.value.cause, CancelledError)
691697

692698

699+
@activity.defn
700+
async def wait_forever() -> NoReturn:
701+
await asyncio.Future()
702+
raise RuntimeError("Unreachable")
703+
704+
705+
@workflow.defn
706+
class UncaughtCancelWorkflow:
707+
@workflow.run
708+
async def run(self, activity: bool) -> NoReturn:
709+
self._started = True
710+
# Wait forever on activity or child workflow
711+
if activity:
712+
await workflow.execute_activity(
713+
wait_forever, start_to_close_timeout=timedelta(seconds=1000)
714+
)
715+
else:
716+
await workflow.execute_child_workflow(
717+
UncaughtCancelWorkflow.run,
718+
True,
719+
id=f"{workflow.info().workflow_id}_child",
720+
)
721+
722+
@workflow.query
723+
def started(self) -> bool:
724+
return self._started
725+
726+
727+
@pytest.mark.parametrize("activity", [True, False])
728+
async def test_workflow_uncaught_cancel(client: Client, activity: bool):
729+
async with new_worker(
730+
client, UncaughtCancelWorkflow, activities=[wait_forever]
731+
) as worker:
732+
# Start workflow waiting on activity or child workflow, cancel it, and
733+
# confirm the workflow is shown as cancelled
734+
handle = await client.start_workflow(
735+
UncaughtCancelWorkflow.run,
736+
activity,
737+
id=f"workflow-{uuid.uuid4()}",
738+
task_queue=worker.task_queue,
739+
)
740+
741+
async def started() -> bool:
742+
return await handle.query(UncaughtCancelWorkflow.started)
743+
744+
await assert_eq_eventually(True, started)
745+
await handle.cancel()
746+
with pytest.raises(WorkflowFailureError) as err:
747+
await handle.result()
748+
assert isinstance(err.value.cause, CancelledError)
749+
750+
693751
@workflow.defn
694752
class CancelChildWorkflow:
695753
def __init__(self) -> None:
@@ -732,13 +790,19 @@ async def test_workflow_cancel_child_started(client: Client, use_execute: bool):
732790
)
733791
# Wait until child started
734792
async def child_started() -> bool:
735-
return await handle.query(
736-
CancelChildWorkflow.ready
737-
) and await client.get_workflow_handle_for(
738-
LongSleepWorkflow.run, workflow_id=f"{handle.id}_child"
739-
).query(
740-
LongSleepWorkflow.started
741-
)
793+
try:
794+
return await handle.query(
795+
CancelChildWorkflow.ready
796+
) and await client.get_workflow_handle_for(
797+
LongSleepWorkflow.run, workflow_id=f"{handle.id}_child"
798+
).query(
799+
LongSleepWorkflow.started
800+
)
801+
except RPCError as err:
802+
# Ignore not-found because child may not have started yet
803+
if err.status == RPCStatusCode.NOT_FOUND:
804+
return False
805+
raise
742806

743807
await assert_eq_eventually(True, child_started)
744808
# Send cancel signal and wait on the handle

0 commit comments

Comments
 (0)