diff --git a/temporalio/client.py b/temporalio/client.py index f46297eb9..a149bbfb8 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -406,7 +406,7 @@ async def start_workflow( args: Sequence[Any] = [], id: str, task_queue: str, - result_type: Optional[Type] = None, + result_type: Optional[Type[ReturnType]] = None, execution_timeout: Optional[timedelta] = None, run_timeout: Optional[timedelta] = None, task_timeout: Optional[timedelta] = None, @@ -431,7 +431,7 @@ async def start_workflow( request_eager_start: bool = False, priority: temporalio.common.Priority = temporalio.common.Priority.default, versioning_override: Optional[temporalio.common.VersioningOverride] = None, - ) -> WorkflowHandle[Any, Any]: ... + ) -> WorkflowHandle[Any, ReturnType]: ... async def start_workflow( self, @@ -676,7 +676,7 @@ async def execute_workflow( args: Sequence[Any] = [], id: str, task_queue: str, - result_type: Optional[Type] = None, + result_type: Optional[Type[ReturnType]] = None, execution_timeout: Optional[timedelta] = None, run_timeout: Optional[timedelta] = None, task_timeout: Optional[timedelta] = None, @@ -701,7 +701,7 @@ async def execute_workflow( request_eager_start: bool = False, priority: temporalio.common.Priority = temporalio.common.Priority.default, versioning_override: Optional[temporalio.common.VersioningOverride] = None, - ) -> Any: ... + ) -> ReturnType: ... async def execute_workflow( self, @@ -889,10 +889,10 @@ async def execute_update_with_start_workflow( start_workflow_operation: WithStartWorkflowOperation[Any, Any], args: Sequence[Any] = [], id: Optional[str] = None, - result_type: Optional[Type] = None, + result_type: Optional[Type[LocalReturnType]] = None, rpc_metadata: Mapping[str, str] = {}, rpc_timeout: Optional[timedelta] = None, - ) -> Any: ... + ) -> LocalReturnType: ... async def execute_update_with_start_workflow( self, @@ -1013,10 +1013,10 @@ async def start_update_with_start_workflow( wait_for_stage: WorkflowUpdateStage, args: Sequence[Any] = [], id: Optional[str] = None, - result_type: Optional[Type] = None, + result_type: Optional[Type[LocalReturnType]] = None, rpc_metadata: Mapping[str, str] = {}, rpc_timeout: Optional[timedelta] = None, - ) -> WorkflowUpdateHandle[Any]: ... + ) -> WorkflowUpdateHandle[LocalReturnType]: ... async def start_update_with_start_workflow( self, diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index fcf06fa7a..f676b9b0e 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -194,7 +194,7 @@ async def run(self, param1: int, param2: str) -> str: async def test_workflow_multi_param(client: Client): # This test is mostly just here to confirm MyPy type checks the multi-param - # overload approach properly + # overload approach properly, and infers result type from result_type. async with new_worker( client, MultiParamWorkflow, activities=[multi_param_activity] ) as worker: @@ -206,6 +206,15 @@ async def test_workflow_multi_param(client: Client): ) assert result == "param1: 123, param2: val1" + result_via_name_overload = await client.execute_workflow( + "MultiParamWorkflow", + args=[123, "val1"], + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + result_type=str, + ) + assert result_via_name_overload == "param1: 123, param2: val1" + @workflow.defn class InfoWorkflow: @@ -8038,7 +8047,7 @@ async def test_workflow_logging_trace_identifier(client: Client): ) as worker: await client.execute_workflow( TaskFailOnceWorkflow.run, - id=f"workflow_failure_trace_identifier", + id="workflow_failure_trace_identifier", task_queue=worker.task_queue, ) @@ -8078,7 +8087,7 @@ async def test_in_workflow_sync(client: Client): ) as worker: res = await client.execute_workflow( UseInWorkflow.run, - id=f"test_in_workflow_sync", + id="test_in_workflow_sync", task_queue=worker.task_queue, execution_timeout=timedelta(minutes=1), )