|
45 | 45 | )
|
46 | 46 | from temporalio.bridge.proto.workflow_activation import WorkflowActivation
|
47 | 47 | from temporalio.bridge.proto.workflow_completion import WorkflowActivationCompletion
|
48 |
| -from temporalio.client import Client, WorkflowFailureError, WorkflowHandle |
| 48 | +from temporalio.client import ( |
| 49 | + Client, |
| 50 | + RPCError, |
| 51 | + RPCStatusCode, |
| 52 | + WorkflowFailureError, |
| 53 | + WorkflowHandle, |
| 54 | +) |
49 | 55 | from temporalio.common import RetryPolicy, SearchAttributes
|
50 | 56 | from temporalio.converter import DataConverter, PayloadCodec, decode_search_attributes
|
51 | 57 | from temporalio.exceptions import (
|
@@ -690,6 +696,58 @@ async def started() -> bool:
|
690 | 696 | assert isinstance(err.value.cause, CancelledError)
|
691 | 697 |
|
692 | 698 |
|
| 699 | +@activity.defn |
| 700 | +async def wait_forever() -> NoReturn: |
| 701 | + await asyncio.Future() |
| 702 | + raise RuntimeError("Unreachable") |
| 703 | + |
| 704 | + |
| 705 | +@workflow.defn |
| 706 | +class UncaughtCancelWorkflow: |
| 707 | + @workflow.run |
| 708 | + async def run(self, activity: bool) -> NoReturn: |
| 709 | + self._started = True |
| 710 | + # Wait forever on activity or child workflow |
| 711 | + if activity: |
| 712 | + await workflow.execute_activity( |
| 713 | + wait_forever, start_to_close_timeout=timedelta(seconds=1000) |
| 714 | + ) |
| 715 | + else: |
| 716 | + await workflow.execute_child_workflow( |
| 717 | + UncaughtCancelWorkflow.run, |
| 718 | + True, |
| 719 | + id=f"{workflow.info().workflow_id}_child", |
| 720 | + ) |
| 721 | + |
| 722 | + @workflow.query |
| 723 | + def started(self) -> bool: |
| 724 | + return self._started |
| 725 | + |
| 726 | + |
| 727 | +@pytest.mark.parametrize("activity", [True, False]) |
| 728 | +async def test_workflow_uncaught_cancel(client: Client, activity: bool): |
| 729 | + async with new_worker( |
| 730 | + client, UncaughtCancelWorkflow, activities=[wait_forever] |
| 731 | + ) as worker: |
| 732 | + # Start workflow waiting on activity or child workflow, cancel it, and |
| 733 | + # confirm the workflow is shown as cancelled |
| 734 | + handle = await client.start_workflow( |
| 735 | + UncaughtCancelWorkflow.run, |
| 736 | + activity, |
| 737 | + id=f"workflow-{uuid.uuid4()}", |
| 738 | + task_queue=worker.task_queue, |
| 739 | + ) |
| 740 | + |
| 741 | + async def started() -> bool: |
| 742 | + return await handle.query(UncaughtCancelWorkflow.started) |
| 743 | + |
| 744 | + await assert_eq_eventually(True, started) |
| 745 | + await handle.cancel() |
| 746 | + with pytest.raises(WorkflowFailureError) as err: |
| 747 | + await handle.result() |
| 748 | + assert isinstance(err.value.cause, CancelledError) |
| 749 | + |
| 750 | + |
693 | 751 | @workflow.defn
|
694 | 752 | class CancelChildWorkflow:
|
695 | 753 | def __init__(self) -> None:
|
@@ -732,13 +790,19 @@ async def test_workflow_cancel_child_started(client: Client, use_execute: bool):
|
732 | 790 | )
|
733 | 791 | # Wait until child started
|
734 | 792 | async def child_started() -> bool:
|
735 |
| - return await handle.query( |
736 |
| - CancelChildWorkflow.ready |
737 |
| - ) and await client.get_workflow_handle_for( |
738 |
| - LongSleepWorkflow.run, workflow_id=f"{handle.id}_child" |
739 |
| - ).query( |
740 |
| - LongSleepWorkflow.started |
741 |
| - ) |
| 793 | + try: |
| 794 | + return await handle.query( |
| 795 | + CancelChildWorkflow.ready |
| 796 | + ) and await client.get_workflow_handle_for( |
| 797 | + LongSleepWorkflow.run, workflow_id=f"{handle.id}_child" |
| 798 | + ).query( |
| 799 | + LongSleepWorkflow.started |
| 800 | + ) |
| 801 | + except RPCError as err: |
| 802 | + # Ignore not-found because child may not have started yet |
| 803 | + if err.status == RPCStatusCode.NOT_FOUND: |
| 804 | + return False |
| 805 | + raise |
742 | 806 |
|
743 | 807 | await assert_eq_eventually(True, child_started)
|
744 | 808 | # Send cancel signal and wait on the handle
|
|
0 commit comments