Skip to content

Commit d0c1eca

Browse files
committed
Type-level cleanup/evolution in workflow caller
1 parent fd10067 commit d0c1eca

File tree

7 files changed

+56
-66
lines changed

7 files changed

+56
-66
lines changed

temporalio/nexus/_decorators.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,17 @@
44
Awaitable,
55
Callable,
66
Optional,
7+
TypeVar,
78
Union,
89
overload,
910
)
1011

1112
import nexusrpc
13+
from nexusrpc import InputT, OutputT
1214
from nexusrpc.handler import (
1315
OperationHandler,
1416
StartOperationContext,
1517
)
16-
from nexusrpc.types import InputT, OutputT, ServiceHandlerT
1718

1819
from temporalio.nexus._operation_context import WorkflowRunOperationContext
1920
from temporalio.nexus._operation_handlers import (
@@ -27,6 +28,8 @@
2728
get_workflow_run_start_method_input_and_output_type_annotations,
2829
)
2930

31+
ServiceHandlerT = TypeVar("ServiceHandlerT")
32+
3033

3134
@overload
3235
def workflow_run_operation(

temporalio/nexus/_operation_handlers.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88
Type,
99
)
1010

11-
from nexusrpc import OperationInfo
11+
from nexusrpc import (
12+
InputT,
13+
OperationInfo,
14+
OutputT,
15+
)
1216
from nexusrpc.handler import (
1317
CancelOperationContext,
1418
FetchOperationInfoContext,
@@ -19,10 +23,6 @@
1923
StartOperationContext,
2024
StartOperationResultAsync,
2125
)
22-
from nexusrpc.types import (
23-
InputT,
24-
OutputT,
25-
)
2626

2727
from temporalio import client
2828
from temporalio.nexus._operation_context import (

temporalio/nexus/_token.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from dataclasses import dataclass
66
from typing import Any, Generic, Literal, Optional, Type
77

8-
from nexusrpc.types import OutputT
8+
from nexusrpc import OutputT
99

1010
from temporalio import client
1111

temporalio/nexus/_util.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010
Callable,
1111
Optional,
1212
Type,
13+
TypeVar,
1314
Union,
1415
)
1516

16-
from nexusrpc.types import (
17+
from nexusrpc import (
1718
InputT,
1819
OutputT,
19-
ServiceHandlerT,
2020
)
2121

2222
from temporalio.nexus._operation_context import WorkflowRunOperationContext
@@ -25,6 +25,8 @@
2525
WorkflowHandle as WorkflowHandle,
2626
)
2727

28+
ServiceHandlerT = TypeVar("ServiceHandlerT")
29+
2830

2931
def get_workflow_run_start_method_input_and_output_type_annotations(
3032
start: Callable[

temporalio/worker/_interceptor.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,7 @@
1919
)
2020

2121
import nexusrpc.handler
22-
from nexusrpc.types import (
23-
InputT,
24-
OutputT,
25-
)
22+
from nexusrpc import InputT, OutputT
2623

2724
import temporalio.activity
2825
import temporalio.api.common.v1
@@ -464,7 +461,7 @@ def start_local_activity(
464461
return self.next.start_local_activity(input)
465462

466463
async def start_nexus_operation(
467-
self, input: StartNexusOperationInput
468-
) -> temporalio.workflow.NexusOperationHandle[Any]:
464+
self, input: StartNexusOperationInput[InputT, OutputT]
465+
) -> temporalio.workflow.NexusOperationHandle[OutputT]:
469466
"""Called for every :py:func:`temporalio.workflow.start_nexus_operation` call."""
470467
return await self.next.start_nexus_operation(input)

temporalio/worker/_workflow_instance.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
)
4545

4646
import nexusrpc.handler
47+
from nexusrpc import InputT, OutputT
4748
from typing_extensions import Self, TypeAlias, TypedDict
4849

4950
import temporalio.activity
@@ -1498,12 +1499,12 @@ async def workflow_start_nexus_operation(
14981499
self,
14991500
endpoint: str,
15001501
service: str,
1501-
operation: Union[nexusrpc.Operation[I, O], str, Callable[..., Any]],
1502+
operation: Union[nexusrpc.Operation[InputT, OutputT], str, Callable[..., Any]],
15021503
input: Any,
1503-
output_type: Optional[Type[O]] = None,
1504+
output_type: Optional[Type[OutputT]] = None,
15041505
schedule_to_close_timeout: Optional[timedelta] = None,
15051506
headers: Optional[Mapping[str, str]] = None,
1506-
) -> temporalio.workflow.NexusOperationHandle[Any]:
1507+
) -> temporalio.workflow.NexusOperationHandle[OutputT]:
15071508
# start_nexus_operation
15081509
return await self._outbound.start_nexus_operation(
15091510
StartNexusOperationInput(
@@ -1824,8 +1825,8 @@ async def run_child() -> Any:
18241825
apply_child_cancel_error()
18251826

18261827
async def _outbound_start_nexus_operation(
1827-
self, input: StartNexusOperationInput
1828-
) -> _NexusOperationHandle[Any]:
1828+
self, input: StartNexusOperationInput[Any, OutputT]
1829+
) -> _NexusOperationHandle[OutputT]:
18291830
# A Nexus operation handle contains two futures: self._start_fut is resolved as a
18301831
# result of the Nexus operation starting (activation job:
18311832
# resolve_nexus_operation_start), and self._result_fut is resolved as a result of
@@ -1840,9 +1841,9 @@ async def _outbound_start_nexus_operation(
18401841
# and start will be resolved with an operation token). See comments in
18411842
# tests/worker/test_nexus.py for worked examples of the evolution of the resulting
18421843
# handle state machine in the sync and async Nexus response cases.
1843-
handle: _NexusOperationHandle
1844+
handle: _NexusOperationHandle[OutputT]
18441845

1845-
async def operation_handle_fn() -> Any:
1846+
async def operation_handle_fn() -> OutputT:
18461847
while True:
18471848
try:
18481849
return await asyncio.shield(handle._result_fut)
@@ -2601,8 +2602,8 @@ async def start_child_workflow(
26012602
return await self._instance._outbound_start_child_workflow(input)
26022603

26032604
async def start_nexus_operation(
2604-
self, input: StartNexusOperationInput
2605-
) -> temporalio.workflow.NexusOperationHandle[Any]:
2605+
self, input: StartNexusOperationInput[Any, OutputT]
2606+
) -> _NexusOperationHandle[OutputT]:
26062607
return await self._instance._outbound_start_nexus_operation(input)
26072608

26082609
def start_local_activity(
@@ -2991,27 +2992,23 @@ async def cancel(self) -> None:
29912992
await self._instance._cancel_external_workflow(command)
29922993

29932994

2994-
I = TypeVar("I")
2995-
O = TypeVar("O")
2996-
2997-
29982995
# TODO(dan): are we sure we don't want to inherit from asyncio.Task as ActivityHandle and
29992996
# ChildWorkflowHandle do? I worry that we should provide .done(), .result(), .exception()
30002997
# etc for consistency.
3001-
class _NexusOperationHandle(temporalio.workflow.NexusOperationHandle[O]):
2998+
class _NexusOperationHandle(temporalio.workflow.NexusOperationHandle[OutputT]):
30022999
def __init__(
30033000
self,
30043001
instance: _WorkflowInstanceImpl,
30053002
seq: int,
3006-
input: StartNexusOperationInput,
3007-
fn: Coroutine[Any, Any, O],
3003+
input: StartNexusOperationInput[Any, OutputT],
3004+
fn: Coroutine[Any, Any, OutputT],
30083005
):
30093006
self._instance = instance
30103007
self._seq = seq
30113008
self._input = input
30123009
self._task = asyncio.Task(fn)
30133010
self._start_fut: asyncio.Future[Optional[str]] = instance.create_future()
3014-
self._result_fut: asyncio.Future[Optional[O]] = instance.create_future()
3011+
self._result_fut: asyncio.Future[Optional[OutputT]] = instance.create_future()
30153012

30163013
@property
30173014
def operation_token(self) -> Optional[str]:
@@ -3025,10 +3022,10 @@ def operation_token(self) -> Optional[str]:
30253022
except BaseException:
30263023
return None
30273024

3028-
async def result(self) -> O:
3025+
async def result(self) -> OutputT:
30293026
return await self._task
30303027

3031-
def __await__(self) -> Generator[Any, Any, O]:
3028+
def __await__(self) -> Generator[Any, Any, OutputT]:
30323029
return self._task.__await__()
30333030

30343031
def __repr__(self) -> str:
@@ -3045,7 +3042,7 @@ def _resolve_start_success(self, operation_token: Optional[str]) -> None:
30453042
# We intentionally let this error if already done
30463043
self._start_fut.set_result(operation_token)
30473044

3048-
def _resolve_success(self, result: Any) -> None:
3045+
def _resolve_success(self, result: OutputT) -> None:
30493046
# We intentionally let this error if already done
30503047
self._result_fut.set_result(result)
30513048

temporalio/workflow.py

Lines changed: 22 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242

4343
import nexusrpc
4444
import nexusrpc.handler
45+
from nexusrpc import InputT, OutputT
4546
from typing_extensions import (
4647
Concatenate,
4748
Literal,
@@ -854,12 +855,12 @@ async def workflow_start_nexus_operation(
854855
self,
855856
endpoint: str,
856857
service: str,
857-
operation: Union[nexusrpc.Operation[I, O], str, Callable[..., Any]],
858+
operation: Union[nexusrpc.Operation[InputT, OutputT], str, Callable[..., Any]],
858859
input: Any,
859-
output_type: Optional[Type[O]] = None,
860+
output_type: Optional[Type[OutputT]] = None,
860861
schedule_to_close_timeout: Optional[timedelta] = None,
861862
headers: Optional[Mapping[str, str]] = None,
862-
) -> NexusOperationHandle[Any]: ...
863+
) -> NexusOperationHandle[OutputT]: ...
863864

864865
@abstractmethod
865866
def workflow_time_ns(self) -> int: ...
@@ -4383,14 +4384,8 @@ async def execute_child_workflow(
43834384
return await handle
43844385

43854386

4386-
# TODO(nexus-prerelease): use types from nexusrpc
4387-
I = TypeVar("I")
4388-
O = TypeVar("O")
4389-
S = TypeVar("S")
4390-
4391-
43924387
# TODO(nexus-prerelease): ABC / inherit from asyncio.Task?
4393-
class NexusOperationHandle(Generic[O]):
4388+
class NexusOperationHandle(Generic[OutputT]):
43944389
def cancel(self) -> bool:
43954390
# TODO(nexus-prerelease): docstring
43964391
"""
@@ -4404,7 +4399,7 @@ def cancel(self) -> bool:
44044399
"""
44054400
raise NotImplementedError
44064401

4407-
def __await__(self) -> Generator[Any, Any, O]:
4402+
def __await__(self) -> Generator[Any, Any, OutputT]:
44084403
raise NotImplementedError
44094404

44104405
# TODO(nexus-prerelease): check SDK-wide consistency for @property vs nullary accessor methods.
@@ -4416,13 +4411,13 @@ def operation_token(self) -> Optional[str]:
44164411
async def start_nexus_operation(
44174412
endpoint: str,
44184413
service: str,
4419-
operation: Union[nexusrpc.Operation[I, O], str, Callable[..., Any]],
4414+
operation: Union[nexusrpc.Operation[InputT, OutputT], str, Callable[..., Any]],
44204415
input: Any,
44214416
*,
4422-
output_type: Optional[Type[O]] = None,
4417+
output_type: Optional[Type[OutputT]] = None,
44234418
schedule_to_close_timeout: Optional[timedelta] = None,
44244419
headers: Optional[Mapping[str, str]] = None,
4425-
) -> NexusOperationHandle[Any]:
4420+
) -> NexusOperationHandle[OutputT]:
44264421
"""Start a Nexus operation and return its handle.
44274422
44284423
Args:
@@ -5161,17 +5156,13 @@ def _to_proto(self) -> temporalio.bridge.proto.common.VersioningIntent.ValueType
51615156

51625157
# Nexus
51635158

5159+
ServiceT = TypeVar("ServiceT")
51645160

5165-
class NexusClient(Generic[S]):
5161+
5162+
class NexusClient(Generic[ServiceT]):
51665163
def __init__(
51675164
self,
5168-
service: Union[
5169-
# TODO(nexus-prerelease): Type[S] is modeling the interface case as well the impl case, but
5170-
# the typevar S is used below only in the impl case. I think this is OK, but
5171-
# think about it again before deleting this TODO.
5172-
Type[S],
5173-
str,
5174-
],
5165+
service: Union[Type[ServiceT], str],
51755166
*,
51765167
endpoint: str,
51775168
) -> None:
@@ -5194,13 +5185,13 @@ def __init__(
51945185
# TODO(nexus-prerelease): should it be an error to use a reference to a method on a class other than that supplied?
51955186
async def start_operation(
51965187
self,
5197-
operation: Union[nexusrpc.Operation[I, O], str, Callable[..., Any]],
5198-
input: I,
5188+
operation: Union[nexusrpc.Operation[InputT, OutputT], str, Callable[..., Any]],
5189+
input: InputT,
51995190
*,
5200-
output_type: Optional[Type[O]] = None,
5191+
output_type: Optional[Type[OutputT]] = None,
52015192
schedule_to_close_timeout: Optional[timedelta] = None,
52025193
headers: Optional[Mapping[str, str]] = None,
5203-
) -> NexusOperationHandle[O]:
5194+
) -> NexusOperationHandle[OutputT]:
52045195
return await temporalio.workflow.start_nexus_operation(
52055196
endpoint=self._endpoint,
52065197
service=self._service_name,
@@ -5214,14 +5205,14 @@ async def start_operation(
52145205
# TODO(nexus-prerelease): overloads: no-input, ret type
52155206
async def execute_operation(
52165207
self,
5217-
operation: Union[nexusrpc.Operation[I, O], str, Callable[..., Any]],
5218-
input: I,
5208+
operation: Union[nexusrpc.Operation[InputT, OutputT], str, Callable[..., Any]],
5209+
input: InputT,
52195210
*,
5220-
output_type: Optional[Type[O]] = None,
5211+
output_type: Optional[Type[OutputT]] = None,
52215212
schedule_to_close_timeout: Optional[timedelta] = None,
52225213
headers: Optional[Mapping[str, str]] = None,
5223-
) -> O:
5224-
handle: NexusOperationHandle[O] = await self.start_operation(
5214+
) -> OutputT:
5215+
handle: NexusOperationHandle[OutputT] = await self.start_operation(
52255216
operation,
52265217
input,
52275218
output_type=output_type,

0 commit comments

Comments
 (0)