Skip to content

Commit 12071e0

Browse files
committed
Fix type errors
1 parent c00dbab commit 12071e0

11 files changed

+77
-24
lines changed

temporalio/nexus/_operation_handlers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ def __init__(
6262
)
6363
self._start = start
6464
if start.__doc__:
65-
self.start.__func__.__doc__ = start.__doc__
65+
if start_func := getattr(self.start, "__func__", None):
66+
start_func.__doc__ = start.__doc__
6667
self._input_type = input_type
6768
self._output_type = output_type
6869

temporalio/nexus/_util.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
Optional,
1212
Type,
1313
TypeVar,
14-
Union,
1514
)
1615

1716
import nexusrpc
@@ -72,16 +71,12 @@ def get_workflow_run_start_method_input_and_output_type_annotations(
7271
def _get_start_method_input_and_output_type_annotations(
7372
start: Callable[
7473
[ServiceHandlerT, WorkflowRunOperationContext, InputT],
75-
Union[OutputT, Awaitable[OutputT]],
74+
Awaitable[WorkflowHandle[OutputT]],
7675
],
7776
) -> tuple[
7877
Optional[Type[InputT]],
7978
Optional[Type[OutputT]],
8079
]:
81-
"""Return operation input and output types.
82-
83-
`start` must be a type-annotated start method that returns a synchronous result.
84-
"""
8580
try:
8681
type_annotations = typing.get_type_hints(start)
8782
except TypeError:

temporalio/worker/_workflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def __init__(
105105
if interceptor_class:
106106
self._interceptor_classes.append(interceptor_class)
107107
self._extern_functions.update(
108-
**_WorkflowExternFunctions(
108+
**_WorkflowExternFunctions( # type: ignore
109109
__temporal_get_metric_meter=lambda: metric_meter,
110110
__temporal_assert_local_activity_valid=assert_local_activity_valid,
111111
)

temporalio/workflow.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5172,14 +5172,30 @@ async def start_operation(
51725172
headers: Optional[Mapping[str, str]] = None,
51735173
) -> NexusOperationHandle[OutputT]: ...
51745174

5175-
# Overload for sync_operation methods
5175+
# Overload for sync_operation methods (async def)
51765176
@overload
51775177
@abstractmethod
51785178
async def start_operation(
51795179
self,
51805180
operation: Callable[
51815181
[ServiceHandlerT, nexusrpc.handler.StartOperationContext, InputT],
5182-
Union[Awaitable[OutputT], OutputT],
5182+
Awaitable[OutputT],
5183+
],
5184+
input: InputT,
5185+
*,
5186+
output_type: Optional[Type[OutputT]] = None,
5187+
schedule_to_close_timeout: Optional[timedelta] = None,
5188+
headers: Optional[Mapping[str, str]] = None,
5189+
) -> NexusOperationHandle[OutputT]: ...
5190+
5191+
# Overload for sync_operation methods (def)
5192+
@overload
5193+
@abstractmethod
5194+
async def start_operation(
5195+
self,
5196+
operation: Callable[
5197+
[ServiceHandlerT, nexusrpc.handler.StartOperationContext, InputT],
5198+
OutputT,
51835199
],
51845200
input: InputT,
51855201
*,
@@ -5257,14 +5273,30 @@ async def execute_operation(
52575273
headers: Optional[Mapping[str, str]] = None,
52585274
) -> OutputT: ...
52595275

5260-
# Overload for sync_operation methods
5276+
# Overload for sync_operation methods (async def)
5277+
@overload
5278+
@abstractmethod
5279+
async def execute_operation(
5280+
self,
5281+
operation: Callable[
5282+
[ServiceHandlerT, nexusrpc.handler.StartOperationContext, InputT],
5283+
Awaitable[OutputT],
5284+
],
5285+
input: InputT,
5286+
*,
5287+
output_type: Optional[Type[OutputT]] = None,
5288+
schedule_to_close_timeout: Optional[timedelta] = None,
5289+
headers: Optional[Mapping[str, str]] = None,
5290+
) -> OutputT: ...
5291+
5292+
# Overload for sync_operation methods (def)
52615293
@overload
52625294
@abstractmethod
52635295
async def execute_operation(
52645296
self,
52655297
operation: Callable[
52665298
[ServiceHandlerT, nexusrpc.handler.StartOperationContext, InputT],
5267-
Union[Awaitable[OutputT], OutputT],
5299+
OutputT,
52685300
],
52695301
input: InputT,
52705302
*,
@@ -5352,6 +5384,22 @@ async def execute_operation(
53525384
return await handle
53535385

53545386

5387+
@overload
5388+
def create_nexus_client(
5389+
*,
5390+
service: Type[ServiceT],
5391+
endpoint: str,
5392+
) -> NexusClient[ServiceT]: ...
5393+
5394+
5395+
@overload
5396+
def create_nexus_client(
5397+
*,
5398+
service: str,
5399+
endpoint: str,
5400+
) -> NexusClient[Any]: ...
5401+
5402+
53555403
def create_nexus_client(
53565404
*,
53575405
service: Union[Type[ServiceT], str],

tests/helpers/nexus.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from dataclasses import dataclass
33
from typing import Any, Mapping, Optional
44

5-
import temporalio.api
5+
import temporalio.api.failure.v1
66
import temporalio.api.nexus.v1
77
import temporalio.api.operatorservice.v1
88
import temporalio.workflow

tests/nexus/test_handler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ async def workflow_run_op_link_test(
251251
)
252252

253253
class OperationHandlerReturningUnwrappedResult(OperationHandler[Input, Output]):
254-
async def start(
254+
async def start( # type: ignore[override] # intentional test error
255255
self,
256256
ctx: StartOperationContext,
257257
input: Input,
@@ -814,7 +814,7 @@ async def test_logger_uses_operation_context(env: WorkflowEnvironment, caplog: A
814814

815815
class _InstantiationCase:
816816
executor: bool
817-
handler: Callable[[], Any]
817+
handler: Callable[..., Any]
818818
exception: Optional[Type[Exception]]
819819
match: Optional[str]
820820

tests/nexus/test_handler_interface_implementation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ class Impl:
3838
@workflow_run_operation
3939
async def op(
4040
self, ctx: WorkflowRunOperationContext, input: str
41-
) -> nexus.WorkflowHandle[int]: ...
41+
) -> nexus.WorkflowHandle[int]:
42+
raise NotImplementedError
4243

4344
error_message = None
4445

tests/nexus/test_handler_operation_definitions.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ class Service:
3636
@workflow_run_operation
3737
async def my_workflow_run_operation_handler(
3838
self, ctx: WorkflowRunOperationContext, input: Input
39-
) -> nexus.WorkflowHandle[Output]: ...
39+
) -> nexus.WorkflowHandle[Output]:
40+
raise NotImplementedError
4041

4142
expected_operations = {
4243
"my_workflow_run_operation_handler": nexusrpc.Operation(
@@ -54,7 +55,8 @@ class Service:
5455
@workflow_run_operation
5556
async def my_workflow_run_operation_handler(
5657
self, ctx: WorkflowRunOperationContext, input: Input
57-
) -> nexus.WorkflowHandle[Output]: ...
58+
) -> nexus.WorkflowHandle[Output]:
59+
raise NotImplementedError
5860

5961
expected_operations = NotCalled.expected_operations
6062

@@ -65,7 +67,8 @@ class Service:
6567
@workflow_run_operation(name="operation-name")
6668
async def workflow_run_operation_with_name_override(
6769
self, ctx: WorkflowRunOperationContext, input: Input
68-
) -> nexus.WorkflowHandle[Output]: ...
70+
) -> nexus.WorkflowHandle[Output]:
71+
raise NotImplementedError
6972

7073
expected_operations = {
7174
"workflow_run_operation_with_name_override": nexusrpc.Operation(

tests/nexus/test_workflow_caller.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ async def run(
276276
) -> CallerWfOutput:
277277
op_input = input.op_input
278278
op_handle = await self.nexus_client.start_operation(
279-
self._get_operation(op_input),
279+
self._get_operation(op_input), # type: ignore[arg-type] # test uses non-public operation types
280280
op_input,
281281
headers=op_input.headers,
282282
)

tests/nexus/test_workflow_caller_error_chains.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828

2929
class ErrorConversionTestCase:
30-
action_in_nexus_operation: Callable[[], None]
30+
action_in_nexus_operation: Callable[..., Any]
3131
expected_exception_chain_in_workflow: list[tuple[type[Exception], dict[str, Any]]]
3232

3333
def __init_subclass__(cls, **kwargs):
@@ -369,7 +369,11 @@ def __init__(self, input: ErrorTestInput):
369369
@workflow.run
370370
async def invoke_nexus_op_and_assert_error(self, input: ErrorTestInput) -> None:
371371
try:
372-
await self.nexus_client.execute_operation(ErrorTestService.op, input)
372+
await self.nexus_client.execute_operation(
373+
ErrorTestService.op, # type: ignore[arg-type] # mypy can't infer OutputT=None in Union type
374+
input,
375+
output_type=None,
376+
)
373377
except BaseException as err:
374378
errs = [err]
375379
while err.__cause__:

0 commit comments

Comments
 (0)