Skip to content

Commit a6bf7cb

Browse files
committed
Revert "Use TemporalStartOperationContext instead of WorkflowRunOperationContext"
This reverts commit 75d16b0.
1 parent e5c774c commit a6bf7cb

10 files changed

+59
-41
lines changed

temporalio/nexus/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
from ._operation_context import (
77
_TemporalCancelOperationContext as _TemporalCancelOperationContext,
88
)
9+
from ._operation_context import (
10+
WorkflowRunOperationContext as WorkflowRunOperationContext,
11+
)
912
from ._operation_context import client as client
1013
from ._operation_context import info as info
1114
from ._operation_context import logger as logger

temporalio/nexus/_decorators.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from temporalio.nexus._operation_context import (
2020
TemporalStartOperationContext,
21+
WorkflowRunOperationContext,
2122
)
2223
from temporalio.nexus._operation_handlers import (
2324
WorkflowRunOperationHandler,
@@ -36,11 +37,11 @@
3637
@overload
3738
def workflow_run_operation(
3839
start: Callable[
39-
[ServiceHandlerT, TemporalStartOperationContext, InputT],
40+
[ServiceHandlerT, WorkflowRunOperationContext, InputT],
4041
Awaitable[WorkflowHandle[OutputT]],
4142
],
4243
) -> Callable[
43-
[ServiceHandlerT, TemporalStartOperationContext, InputT],
44+
[ServiceHandlerT, WorkflowRunOperationContext, InputT],
4445
Awaitable[WorkflowHandle[OutputT]],
4546
]: ...
4647

@@ -52,12 +53,12 @@ def workflow_run_operation(
5253
) -> Callable[
5354
[
5455
Callable[
55-
[ServiceHandlerT, TemporalStartOperationContext, InputT],
56+
[ServiceHandlerT, WorkflowRunOperationContext, InputT],
5657
Awaitable[WorkflowHandle[OutputT]],
5758
]
5859
],
5960
Callable[
60-
[ServiceHandlerT, TemporalStartOperationContext, InputT],
61+
[ServiceHandlerT, WorkflowRunOperationContext, InputT],
6162
Awaitable[WorkflowHandle[OutputT]],
6263
],
6364
]: ...
@@ -66,26 +67,26 @@ def workflow_run_operation(
6667
def workflow_run_operation(
6768
start: Optional[
6869
Callable[
69-
[ServiceHandlerT, TemporalStartOperationContext, InputT],
70+
[ServiceHandlerT, WorkflowRunOperationContext, InputT],
7071
Awaitable[WorkflowHandle[OutputT]],
7172
]
7273
] = None,
7374
*,
7475
name: Optional[str] = None,
7576
) -> Union[
7677
Callable[
77-
[ServiceHandlerT, TemporalStartOperationContext, InputT],
78+
[ServiceHandlerT, WorkflowRunOperationContext, InputT],
7879
Awaitable[WorkflowHandle[OutputT]],
7980
],
8081
Callable[
8182
[
8283
Callable[
83-
[ServiceHandlerT, TemporalStartOperationContext, InputT],
84+
[ServiceHandlerT, WorkflowRunOperationContext, InputT],
8485
Awaitable[WorkflowHandle[OutputT]],
8586
]
8687
],
8788
Callable[
88-
[ServiceHandlerT, TemporalStartOperationContext, InputT],
89+
[ServiceHandlerT, WorkflowRunOperationContext, InputT],
8990
Awaitable[WorkflowHandle[OutputT]],
9091
],
9192
],
@@ -96,11 +97,11 @@ def workflow_run_operation(
9697

9798
def decorator(
9899
start: Callable[
99-
[ServiceHandlerT, TemporalStartOperationContext, InputT],
100+
[ServiceHandlerT, WorkflowRunOperationContext, InputT],
100101
Awaitable[WorkflowHandle[OutputT]],
101102
],
102103
) -> Callable[
103-
[ServiceHandlerT, TemporalStartOperationContext, InputT],
104+
[ServiceHandlerT, WorkflowRunOperationContext, InputT],
104105
Awaitable[WorkflowHandle[OutputT]],
105106
]:
106107
(
@@ -114,7 +115,8 @@ def operation_handler_factory(
114115
async def _start(
115116
ctx: StartOperationContext, input: InputT
116117
) -> WorkflowHandle[OutputT]:
117-
return await start(self, TemporalStartOperationContext.get(), input)
118+
tctx = TemporalStartOperationContext.get()
119+
return await start(self, WorkflowRunOperationContext(tctx), input)
118120

119121
_start.__doc__ = start.__doc__
120122
return WorkflowRunOperationHandler(_start, input_type, output_type)

temporalio/nexus/_operation_context.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,19 @@ def add_outbound_links(
167167
)
168168
return workflow_handle
169169

170+
171+
@dataclass
172+
class WorkflowRunOperationContext:
173+
temporal_context: TemporalStartOperationContext
174+
175+
@property
176+
def nexus_context(self) -> StartOperationContext:
177+
return self.temporal_context.nexus_context
178+
179+
@classmethod
180+
def get(cls) -> WorkflowRunOperationContext:
181+
return cls(TemporalStartOperationContext.get())
182+
170183
# Overload for single-param workflow
171184
# TODO(nexus-prerelease): bring over other overloads
172185
async def start_workflow(
@@ -238,11 +251,11 @@ async def start_workflow(
238251
# We must pass nexus_completion_callbacks, workflow_event_links, and request_id,
239252
# but these are deliberately not exposed in overloads, hence the type-check
240253
# violation.
241-
wf_handle = await self.client.start_workflow( # type: ignore
254+
wf_handle = await self.temporal_context.client.start_workflow( # type: ignore
242255
workflow=workflow,
243256
arg=arg,
244257
id=id,
245-
task_queue=task_queue or self.info().task_queue,
258+
task_queue=task_queue or self.temporal_context.info().task_queue,
246259
execution_timeout=execution_timeout,
247260
run_timeout=run_timeout,
248261
task_timeout=task_timeout,
@@ -262,12 +275,12 @@ async def start_workflow(
262275
request_eager_start=request_eager_start,
263276
priority=priority,
264277
versioning_override=versioning_override,
265-
nexus_completion_callbacks=self.get_completion_callbacks(),
266-
workflow_event_links=self.get_workflow_event_links(),
267-
request_id=self.nexus_context.request_id,
278+
nexus_completion_callbacks=self.temporal_context.get_completion_callbacks(),
279+
workflow_event_links=self.temporal_context.get_workflow_event_links(),
280+
request_id=self.temporal_context.nexus_context.request_id,
268281
)
269282

270-
self.add_outbound_links(wf_handle)
283+
self.temporal_context.add_outbound_links(wf_handle)
271284

272285
return WorkflowHandle[ReturnType]._unsafe_from_client_workflow_handle(wf_handle)
273286

temporalio/nexus/_operation_handlers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class WorkflowRunOperationHandler(OperationHandler[InputT, OutputT]):
4141
4242
Use this class to create an operation handler that starts a workflow by passing your
4343
``start`` method to the constructor. Your ``start`` method must use
44-
:py:func:`temporalio.nexus.TemporalStartOperationContext.start_workflow` to start the
44+
:py:func:`temporalio.nexus.WorkflowRunOperationContext.start_workflow` to start the
4545
workflow.
4646
"""
4747

@@ -77,7 +77,7 @@ async def start(
7777
if isinstance(handle, client.WorkflowHandle):
7878
raise RuntimeError(
7979
f"Expected {handle} to be a nexus.WorkflowHandle, but got a client.WorkflowHandle. "
80-
f"You must use TemporalStartOperationContext.start_workflow "
80+
f"You must use WorkflowRunOperationContext.start_workflow "
8181
"to start a workflow that will deliver the result of the Nexus operation, "
8282
"not client.Client.start_workflow."
8383
)

temporalio/nexus/_util.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
OutputT,
2020
)
2121

22-
from temporalio.nexus._operation_context import TemporalStartOperationContext
22+
from temporalio.nexus._operation_context import WorkflowRunOperationContext
2323

2424
from ._token import (
2525
WorkflowHandle as WorkflowHandle,
@@ -30,7 +30,7 @@
3030

3131
def get_workflow_run_start_method_input_and_output_type_annotations(
3232
start: Callable[
33-
[ServiceHandlerT, TemporalStartOperationContext, InputT],
33+
[ServiceHandlerT, WorkflowRunOperationContext, InputT],
3434
Awaitable[WorkflowHandle[OutputT]],
3535
],
3636
) -> tuple[
@@ -70,7 +70,7 @@ def get_workflow_run_start_method_input_and_output_type_annotations(
7070

7171
def _get_start_method_input_and_output_type_annotations(
7272
start: Callable[
73-
[ServiceHandlerT, TemporalStartOperationContext, InputT],
73+
[ServiceHandlerT, WorkflowRunOperationContext, InputT],
7474
Union[OutputT, Awaitable[OutputT]],
7575
],
7676
) -> tuple[
@@ -102,11 +102,11 @@ def _get_start_method_input_and_output_type_annotations(
102102
input_type = None
103103
else:
104104
ctx_type, input_type = type_annotations.values()
105-
if not issubclass(ctx_type, TemporalStartOperationContext):
105+
if not issubclass(ctx_type, WorkflowRunOperationContext):
106106
# TODO(preview): stacklevel
107107
warnings.warn(
108108
f"Expected first parameter of {start} to be an instance of "
109-
f"TemporalStartOperationContext, but is {ctx_type}."
109+
f"WorkflowRunOperationContext, but is {ctx_type}."
110110
)
111111
input_type = None
112112

tests/nexus/test_handler.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
from temporalio.client import Client
4646
from temporalio.common import WorkflowIDReusePolicy
4747
from temporalio.exceptions import ApplicationError
48-
from temporalio.nexus import TemporalStartOperationContext, workflow_run_operation
48+
from temporalio.nexus import WorkflowRunOperationContext, workflow_run_operation
4949
from temporalio.testing import WorkflowEnvironment
5050
from temporalio.worker import Worker
5151
from tests.helpers.nexus import (
@@ -208,7 +208,7 @@ async def log(self, ctx: StartOperationContext, input: Input) -> Output:
208208

209209
@workflow_run_operation
210210
async def workflow_run_operation_happy_path(
211-
self, ctx: TemporalStartOperationContext, input: Input
211+
self, ctx: WorkflowRunOperationContext, input: Input
212212
) -> nexus.WorkflowHandle[Output]:
213213
return await ctx.start_workflow(
214214
MyWorkflow.run,
@@ -266,7 +266,7 @@ async def workflow_run_operation_without_type_annotations(self, ctx, input):
266266

267267
@workflow_run_operation
268268
async def workflow_run_op_link_test(
269-
self, ctx: TemporalStartOperationContext, input: Input
269+
self, ctx: WorkflowRunOperationContext, input: Input
270270
) -> nexus.WorkflowHandle[Output]:
271271
assert any(
272272
link.url == "http://inbound-link/"
@@ -1022,7 +1022,7 @@ async def run(self, input: Input) -> Output:
10221022
class ServiceHandlerForRequestIdTest:
10231023
@workflow_run_operation
10241024
async def operation_backed_by_a_workflow(
1025-
self, ctx: TemporalStartOperationContext, input: Input
1025+
self, ctx: WorkflowRunOperationContext, input: Input
10261026
) -> nexus.WorkflowHandle[Output]:
10271027
return await ctx.start_workflow(
10281028
EchoWorkflow.run,
@@ -1033,7 +1033,7 @@ async def operation_backed_by_a_workflow(
10331033

10341034
@workflow_run_operation
10351035
async def operation_that_executes_a_workflow_before_starting_the_backing_workflow(
1036-
self, ctx: TemporalStartOperationContext, input: Input
1036+
self, ctx: WorkflowRunOperationContext, input: Input
10371037
) -> nexus.WorkflowHandle[Output]:
10381038
await nexus.client().start_workflow(
10391039
EchoWorkflow.run,

tests/nexus/test_handler_interface_implementation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from nexusrpc.handler import StartOperationContext, sync_operation
77

88
from temporalio import nexus
9-
from temporalio.nexus import TemporalStartOperationContext, workflow_run_operation
9+
from temporalio.nexus import WorkflowRunOperationContext, workflow_run_operation
1010

1111
HTTP_PORT = 7243
1212

@@ -37,7 +37,7 @@ class Interface:
3737
class Impl:
3838
@workflow_run_operation
3939
async def op(
40-
self, ctx: TemporalStartOperationContext, input: str
40+
self, ctx: WorkflowRunOperationContext, input: str
4141
) -> nexus.WorkflowHandle[int]: ...
4242

4343
error_message = None

tests/nexus/test_handler_operation_definitions.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import pytest
1111

1212
from temporalio import nexus
13-
from temporalio.nexus import TemporalStartOperationContext, workflow_run_operation
13+
from temporalio.nexus import WorkflowRunOperationContext, workflow_run_operation
1414

1515

1616
@dataclass
@@ -34,7 +34,7 @@ class NotCalled(_TestCase):
3434
class Service:
3535
@workflow_run_operation
3636
async def my_workflow_run_operation_handler(
37-
self, ctx: TemporalStartOperationContext, input: Input
37+
self, ctx: WorkflowRunOperationContext, input: Input
3838
) -> nexus.WorkflowHandle[Output]: ...
3939

4040
expected_operations = {
@@ -52,7 +52,7 @@ class CalledWithoutArgs(_TestCase):
5252
class Service:
5353
@workflow_run_operation
5454
async def my_workflow_run_operation_handler(
55-
self, ctx: TemporalStartOperationContext, input: Input
55+
self, ctx: WorkflowRunOperationContext, input: Input
5656
) -> nexus.WorkflowHandle[Output]: ...
5757

5858
expected_operations = NotCalled.expected_operations
@@ -63,7 +63,7 @@ class CalledWithNameOverride(_TestCase):
6363
class Service:
6464
@workflow_run_operation(name="operation-name")
6565
async def workflow_run_operation_with_name_override(
66-
self, ctx: TemporalStartOperationContext, input: Input
66+
self, ctx: WorkflowRunOperationContext, input: Input
6767
) -> nexus.WorkflowHandle[Output]: ...
6868

6969
expected_operations = {

tests/nexus/test_workflow_caller.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
)
3939
from temporalio.common import WorkflowIDConflictPolicy
4040
from temporalio.exceptions import CancelledError, NexusHandlerError, NexusOperationError
41-
from temporalio.nexus import TemporalStartOperationContext, workflow_run_operation
41+
from temporalio.nexus import WorkflowRunOperationContext, workflow_run_operation
4242
from temporalio.service import RPCError, RPCStatusCode
4343
from temporalio.worker import Worker
4444
from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name
@@ -160,7 +160,7 @@ async def start(
160160
# TODO(nexus-preview): what do we want the DX to be for a user who is
161161
# starting a Nexus backing workflow from a custom start method? (They may
162162
# need to do this in order to customize the cancel method).
163-
handle = await TemporalStartOperationContext.get().start_workflow(
163+
handle = await WorkflowRunOperationContext.get().start_workflow(
164164
HandlerWorkflow.run,
165165
HandlerWfInput(op_input=input),
166166
id=input.response_type.operation_workflow_id,
@@ -206,7 +206,7 @@ async def sync_operation(
206206

207207
@workflow_run_operation
208208
async def async_operation(
209-
self, ctx: TemporalStartOperationContext, input: OpInput
209+
self, ctx: WorkflowRunOperationContext, input: OpInput
210210
) -> nexus.WorkflowHandle[HandlerWfOutput]:
211211
assert isinstance(input.response_type, AsyncResponse)
212212
if input.response_type.exception_in_operation_start:
@@ -912,7 +912,7 @@ async def run(self, input: str) -> str:
912912
class ServiceImplWithOperationsThatExecuteWorkflowBeforeStartingBackingWorkflow:
913913
@workflow_run_operation
914914
async def my_workflow_run_operation(
915-
self, ctx: TemporalStartOperationContext, input: None
915+
self, ctx: WorkflowRunOperationContext, input: None
916916
) -> nexus.WorkflowHandle[str]:
917917
result_1 = await nexus.client().execute_workflow(
918918
EchoWorkflow.run,

tests/nexus/test_workflow_run_operation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from nexusrpc.handler._decorators import operation_handler
1414

1515
from temporalio import workflow
16-
from temporalio.nexus import TemporalStartOperationContext
16+
from temporalio.nexus import WorkflowRunOperationContext
1717
from temporalio.nexus._operation_handlers import WorkflowRunOperationHandler
1818
from temporalio.testing import WorkflowEnvironment
1919
from temporalio.worker import Worker
@@ -49,7 +49,7 @@ def __init__(self):
4949
async def start(
5050
self, ctx: StartOperationContext, input: Input
5151
) -> StartOperationResultAsync:
52-
handle = await TemporalStartOperationContext.get().start_workflow(
52+
handle = await WorkflowRunOperationContext.get().start_workflow(
5353
EchoWorkflow.run,
5454
input.value,
5555
id=str(uuid.uuid4()),

0 commit comments

Comments
 (0)