Skip to content

Commit c665665

Browse files
committed
Infer result type from result_type arg under string name overload
1 parent 83d2ae4 commit c665665

File tree

2 files changed

+20
-11
lines changed

2 files changed

+20
-11
lines changed

temporalio/client.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ async def start_workflow(
406406
args: Sequence[Any] = [],
407407
id: str,
408408
task_queue: str,
409-
result_type: Optional[Type] = None,
409+
result_type: Optional[Type[ReturnType]] = None,
410410
execution_timeout: Optional[timedelta] = None,
411411
run_timeout: Optional[timedelta] = None,
412412
task_timeout: Optional[timedelta] = None,
@@ -431,7 +431,7 @@ async def start_workflow(
431431
request_eager_start: bool = False,
432432
priority: temporalio.common.Priority = temporalio.common.Priority.default,
433433
versioning_override: Optional[temporalio.common.VersioningOverride] = None,
434-
) -> WorkflowHandle[Any, Any]: ...
434+
) -> WorkflowHandle[Any, ReturnType]: ...
435435

436436
async def start_workflow(
437437
self,
@@ -676,7 +676,7 @@ async def execute_workflow(
676676
args: Sequence[Any] = [],
677677
id: str,
678678
task_queue: str,
679-
result_type: Optional[Type] = None,
679+
result_type: Optional[Type[ReturnType]] = None,
680680
execution_timeout: Optional[timedelta] = None,
681681
run_timeout: Optional[timedelta] = None,
682682
task_timeout: Optional[timedelta] = None,
@@ -701,7 +701,7 @@ async def execute_workflow(
701701
request_eager_start: bool = False,
702702
priority: temporalio.common.Priority = temporalio.common.Priority.default,
703703
versioning_override: Optional[temporalio.common.VersioningOverride] = None,
704-
) -> Any: ...
704+
) -> ReturnType: ...
705705

706706
async def execute_workflow(
707707
self,
@@ -889,10 +889,10 @@ async def execute_update_with_start_workflow(
889889
start_workflow_operation: WithStartWorkflowOperation[Any, Any],
890890
args: Sequence[Any] = [],
891891
id: Optional[str] = None,
892-
result_type: Optional[Type] = None,
892+
result_type: Optional[Type[LocalReturnType]] = None,
893893
rpc_metadata: Mapping[str, str] = {},
894894
rpc_timeout: Optional[timedelta] = None,
895-
) -> Any: ...
895+
) -> LocalReturnType: ...
896896

897897
async def execute_update_with_start_workflow(
898898
self,
@@ -1013,10 +1013,10 @@ async def start_update_with_start_workflow(
10131013
wait_for_stage: WorkflowUpdateStage,
10141014
args: Sequence[Any] = [],
10151015
id: Optional[str] = None,
1016-
result_type: Optional[Type] = None,
1016+
result_type: Optional[Type[LocalReturnType]] = None,
10171017
rpc_metadata: Mapping[str, str] = {},
10181018
rpc_timeout: Optional[timedelta] = None,
1019-
) -> WorkflowUpdateHandle[Any]: ...
1019+
) -> WorkflowUpdateHandle[LocalReturnType]: ...
10201020

10211021
async def start_update_with_start_workflow(
10221022
self,

tests/worker/test_workflow.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ async def run(self, param1: int, param2: str) -> str:
194194

195195
async def test_workflow_multi_param(client: Client):
196196
# This test is mostly just here to confirm MyPy type checks the multi-param
197-
# overload approach properly
197+
# overload approach properly, and infers result type from result_type.
198198
async with new_worker(
199199
client, MultiParamWorkflow, activities=[multi_param_activity]
200200
) as worker:
@@ -206,6 +206,15 @@ async def test_workflow_multi_param(client: Client):
206206
)
207207
assert result == "param1: 123, param2: val1"
208208

209+
result_via_name_overload = await client.execute_workflow(
210+
"MultiParamWorkflow",
211+
args=[123, "val1"],
212+
id=f"workflow-{uuid.uuid4()}",
213+
task_queue=worker.task_queue,
214+
result_type=str,
215+
)
216+
assert result_via_name_overload == "param1: 123, param2: val1"
217+
209218

210219
@workflow.defn
211220
class InfoWorkflow:
@@ -8042,7 +8051,7 @@ async def test_workflow_logging_trace_identifier(client: Client):
80428051
) as worker:
80438052
await client.execute_workflow(
80448053
TaskFailOnceWorkflow.run,
8045-
id=f"workflow_failure_trace_identifier",
8054+
id="workflow_failure_trace_identifier",
80468055
task_queue=worker.task_queue,
80478056
)
80488057

@@ -8082,7 +8091,7 @@ async def test_in_workflow_sync(client: Client):
80828091
) as worker:
80838092
res = await client.execute_workflow(
80848093
UseInWorkflow.run,
8085-
id=f"test_in_workflow_sync",
8094+
id="test_in_workflow_sync",
80868095
task_queue=worker.task_queue,
80878096
execution_timeout=timedelta(minutes=1),
80888097
)

0 commit comments

Comments
 (0)