Skip to content

Commit 54f6121

Browse files
committed
Separate Temporal context for each operation verb
1 parent 1f8e56e commit 54f6121

File tree

8 files changed

+162
-133
lines changed

8 files changed

+162
-133
lines changed

temporalio/nexus/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
WorkflowRunOperationContext as WorkflowRunOperationContext,
55
)
66
from ._operation_context import (
7-
_temporal_operation_context as _temporal_operation_context,
7+
_TemporalCancelOperationContext as _TemporalCancelOperationContext,
88
)
99
from ._operation_context import (
10-
_TemporalNexusOperationContext as _TemporalNexusOperationContext,
10+
_TemporalStartOperationContext as _TemporalStartOperationContext,
1111
)
1212
from ._operation_context import client as client
1313
from ._operation_context import info as info

temporalio/nexus/_decorators.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@
1616
StartOperationContext,
1717
)
1818

19-
from temporalio.nexus._operation_context import WorkflowRunOperationContext
19+
from temporalio.nexus._operation_context import (
20+
WorkflowRunOperationContext,
21+
_TemporalStartOperationContext,
22+
)
2023
from temporalio.nexus._operation_handlers import (
2124
WorkflowRunOperationHandler,
2225
)
@@ -112,7 +115,8 @@ def operation_handler_factory(
112115
async def _start(
113116
ctx: StartOperationContext, input: InputT
114117
) -> WorkflowHandle[OutputT]:
115-
return await start(self, WorkflowRunOperationContext(ctx), input)
118+
tctx = _TemporalStartOperationContext.get()
119+
return await start(self, WorkflowRunOperationContext(tctx), input)
116120

117121
_start.__doc__ = start.__doc__
118122
return WorkflowRunOperationHandler(_start, input_type, output_type)

temporalio/nexus/_operation_context.py

Lines changed: 128 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
Any,
1111
Callable,
1212
Mapping,
13+
MutableMapping,
1314
Optional,
1415
Sequence,
1516
Union,
@@ -30,9 +31,16 @@
3031
SelfType,
3132
)
3233

34+
# The Temporal Nexus worker always builds a nexusrpc StartOperationContext or
35+
# CancelOperationContext and passes it as the first parameter to the nexusrpc operation
36+
# handler. In addition, it sets one of the following context vars.
3337

34-
_temporal_operation_context: ContextVar[_TemporalNexusOperationContext] = ContextVar(
35-
"temporal-operation-context"
38+
_temporal_start_operation_context: ContextVar[_TemporalStartOperationContext] = (
39+
ContextVar("temporal-start-operation-context")
40+
)
41+
42+
_temporal_cancel_operation_context: ContextVar[_TemporalCancelOperationContext] = (
43+
ContextVar("temporal-cancel-operation-context")
3644
)
3745

3846

@@ -51,59 +59,126 @@ def info() -> Info:
5159
"""
5260
Get the current Nexus operation information.
5361
"""
54-
return _TemporalNexusOperationContext.get().info()
62+
return _temporal_context().info()
5563

5664

5765
def client() -> temporalio.client.Client:
5866
"""
5967
Get the Temporal client used by the worker handling the current Nexus operation.
6068
"""
61-
return _TemporalNexusOperationContext.get().client
69+
return _temporal_context().client
70+
71+
72+
def _temporal_context() -> (
73+
Union[_TemporalStartOperationContext, _TemporalCancelOperationContext]
74+
):
75+
ctx = _try_temporal_context()
76+
if ctx is None:
77+
raise RuntimeError("Not in Nexus operation context.")
78+
return ctx
79+
80+
81+
def _try_temporal_context() -> (
82+
Optional[Union[_TemporalStartOperationContext, _TemporalCancelOperationContext]]
83+
):
84+
start_ctx = _temporal_start_operation_context.get(None)
85+
cancel_ctx = _temporal_cancel_operation_context.get(None)
86+
if start_ctx and cancel_ctx:
87+
raise RuntimeError("Cannot be in both start and cancel operation contexts.")
88+
return start_ctx or cancel_ctx
6289

6390

6491
@dataclass
65-
class _TemporalNexusOperationContext:
92+
class _TemporalStartOperationContext:
6693
"""
67-
Context for a Nexus operation being handled by a Temporal Nexus Worker.
94+
Context for a Nexus start operation being handled by a Temporal Nexus Worker.
6895
"""
6996

70-
info: Callable[[], Info]
71-
"""Information about the running Nexus operation."""
97+
nexus_context: StartOperationContext
98+
"""Nexus-specific start operation context."""
7299

73-
nexus_operation_context: Union[StartOperationContext, CancelOperationContext]
100+
info: Callable[[], Info]
101+
"""Temporal information about the running Nexus operation."""
74102

75103
client: temporalio.client.Client
76104
"""The Temporal client in use by the worker handling this Nexus operation."""
77105

78106
@classmethod
79-
def get(cls) -> _TemporalNexusOperationContext:
80-
ctx = _temporal_operation_context.get(None)
107+
def get(cls) -> _TemporalStartOperationContext:
108+
ctx = _temporal_start_operation_context.get(None)
81109
if ctx is None:
82110
raise RuntimeError("Not in Nexus operation context.")
83111
return ctx
84112

85-
@property
86-
def _temporal_start_operation_context(
113+
def set(self) -> None:
114+
_temporal_start_operation_context.set(self)
115+
116+
def get_completion_callbacks(
87117
self,
88-
) -> Optional[_TemporalStartOperationContext]:
89-
ctx = self.nexus_operation_context
90-
if not isinstance(ctx, StartOperationContext):
91-
return None
92-
return _TemporalStartOperationContext(ctx)
118+
) -> list[temporalio.client.NexusCompletionCallback]:
119+
ctx = self.nexus_context
120+
return (
121+
[
122+
# TODO(nexus-prerelease): For WorkflowRunOperation, when it handles the Nexus
123+
# request, it needs to copy the links to the callback in
124+
# StartWorkflowRequest.CompletionCallbacks and to StartWorkflowRequest.Links
125+
# (for backwards compatibility). PR reference in Go SDK:
126+
# https://github.com/temporalio/sdk-go/pull/1945
127+
temporalio.client.NexusCompletionCallback(
128+
url=ctx.callback_url,
129+
header=ctx.callback_headers,
130+
)
131+
]
132+
if ctx.callback_url
133+
else []
134+
)
93135

94-
@property
95-
def _temporal_cancel_operation_context(
136+
def get_workflow_event_links(
96137
self,
97-
) -> Optional[_TemporalCancelOperationContext]:
98-
ctx = self.nexus_operation_context
99-
if not isinstance(ctx, CancelOperationContext):
100-
return None
101-
return _TemporalCancelOperationContext(ctx)
138+
) -> list[temporalio.api.common.v1.Link.WorkflowEvent]:
139+
event_links = []
140+
for inbound_link in self.nexus_context.inbound_links:
141+
if link := _nexus_link_to_workflow_event(inbound_link):
142+
event_links.append(link)
143+
return event_links
144+
145+
def add_outbound_links(
146+
self, workflow_handle: temporalio.client.WorkflowHandle[Any, Any]
147+
):
148+
try:
149+
link = _workflow_event_to_nexus_link(
150+
_workflow_handle_to_workflow_execution_started_event_link(
151+
workflow_handle
152+
)
153+
)
154+
except Exception as e:
155+
logger.warning(
156+
f"Failed to create WorkflowExecutionStarted event link for workflow {id}: {e}"
157+
)
158+
else:
159+
self.nexus_context.outbound_links.append(
160+
# TODO(nexus-prerelease): Before, WorkflowRunOperation was generating an EventReference
161+
# link to send back to the caller. Now, it checks if the server returned
162+
# the link in the StartWorkflowExecutionResponse, and if so, send the link
163+
# from the response to the caller. Fallback to generating the link for
164+
# backwards compatibility. PR reference in Go SDK:
165+
# https://github.com/temporalio/sdk-go/pull/1934
166+
link
167+
)
168+
return workflow_handle
102169

103170

104171
@dataclass
105172
class WorkflowRunOperationContext:
106-
start_operation_context: StartOperationContext
173+
temporal_context: _TemporalStartOperationContext
174+
175+
@property
176+
def nexus_context(self) -> StartOperationContext:
177+
return self.temporal_context.nexus_context
178+
179+
@classmethod
180+
def get(cls) -> WorkflowRunOperationContext:
181+
return cls(_TemporalStartOperationContext.get())
107182

108183
# Overload for single-param workflow
109184
# TODO(nexus-prerelease): bring over other overloads
@@ -164,14 +239,6 @@ async def start_workflow(
164239
Nexus caller is itself a workflow, this means that the workflow in the caller
165240
namespace web UI will contain links to the started workflow, and vice versa.
166241
"""
167-
tctx = _TemporalNexusOperationContext.get()
168-
start_operation_context = tctx._temporal_start_operation_context
169-
if not start_operation_context:
170-
raise RuntimeError(
171-
"WorkflowRunOperationContext.start_workflow() must be called from "
172-
"within a Nexus start operation context"
173-
)
174-
175242
# TODO(nexus-preview): When sdk-python supports on_conflict_options, Typescript does this:
176243
# if (workflowOptions.workflowIdConflictPolicy === 'USE_EXISTING') {
177244
# internalOptions.onConflictOptions = {
@@ -184,11 +251,11 @@ async def start_workflow(
184251
# We must pass nexus_completion_callbacks, workflow_event_links, and request_id,
185252
# but these are deliberately not exposed in overloads, hence the type-check
186253
# violation.
187-
wf_handle = await tctx.client.start_workflow( # type: ignore
254+
wf_handle = await self.temporal_context.client.start_workflow( # type: ignore
188255
workflow=workflow,
189256
arg=arg,
190257
id=id,
191-
task_queue=task_queue or tctx.info().task_queue,
258+
task_queue=task_queue or self.temporal_context.info().task_queue,
192259
execution_timeout=execution_timeout,
193260
run_timeout=run_timeout,
194261
task_timeout=task_timeout,
@@ -208,78 +275,40 @@ async def start_workflow(
208275
request_eager_start=request_eager_start,
209276
priority=priority,
210277
versioning_override=versioning_override,
211-
nexus_completion_callbacks=start_operation_context.get_completion_callbacks(),
212-
workflow_event_links=start_operation_context.get_workflow_event_links(),
213-
request_id=start_operation_context.nexus_operation_context.request_id,
278+
nexus_completion_callbacks=self.temporal_context.get_completion_callbacks(),
279+
workflow_event_links=self.temporal_context.get_workflow_event_links(),
280+
request_id=self.temporal_context.nexus_context.request_id,
214281
)
215282

216-
start_operation_context.add_outbound_links(wf_handle)
283+
self.temporal_context.add_outbound_links(wf_handle)
217284

218285
return WorkflowHandle[ReturnType]._unsafe_from_client_workflow_handle(wf_handle)
219286

220287

221288
@dataclass
222-
class _TemporalStartOperationContext:
223-
nexus_operation_context: StartOperationContext
289+
class _TemporalCancelOperationContext:
290+
"""
291+
Context for a Nexus cancel operation being handled by a Temporal Nexus Worker.
292+
"""
224293

225-
def get_completion_callbacks(
226-
self,
227-
) -> list[temporalio.client.NexusCompletionCallback]:
228-
ctx = self.nexus_operation_context
229-
return (
230-
[
231-
# TODO(nexus-prerelease): For WorkflowRunOperation, when it handles the Nexus
232-
# request, it needs to copy the links to the callback in
233-
# StartWorkflowRequest.CompletionCallbacks and to StartWorkflowRequest.Links
234-
# (for backwards compatibility). PR reference in Go SDK:
235-
# https://github.com/temporalio/sdk-go/pull/1945
236-
temporalio.client.NexusCompletionCallback(
237-
url=ctx.callback_url,
238-
header=ctx.callback_headers,
239-
)
240-
]
241-
if ctx.callback_url
242-
else []
243-
)
294+
nexus_context: CancelOperationContext
295+
"""Nexus-specific cancel operation context."""
244296

245-
def get_workflow_event_links(
246-
self,
247-
) -> list[temporalio.api.common.v1.Link.WorkflowEvent]:
248-
event_links = []
249-
for inbound_link in self.nexus_operation_context.inbound_links:
250-
if link := _nexus_link_to_workflow_event(inbound_link):
251-
event_links.append(link)
252-
return event_links
297+
info: Callable[[], Info]
298+
"""Temporal information about the running Nexus cancel operation."""
253299

254-
def add_outbound_links(
255-
self, workflow_handle: temporalio.client.WorkflowHandle[Any, Any]
256-
):
257-
try:
258-
link = _workflow_event_to_nexus_link(
259-
_workflow_handle_to_workflow_execution_started_event_link(
260-
workflow_handle
261-
)
262-
)
263-
except Exception as e:
264-
logger.warning(
265-
f"Failed to create WorkflowExecutionStarted event link for workflow {id}: {e}"
266-
)
267-
else:
268-
self.nexus_operation_context.outbound_links.append(
269-
# TODO(nexus-prerelease): Before, WorkflowRunOperation was generating an EventReference
270-
# link to send back to the caller. Now, it checks if the server returned
271-
# the link in the StartWorkflowExecutionResponse, and if so, send the link
272-
# from the response to the caller. Fallback to generating the link for
273-
# backwards compatibility. PR reference in Go SDK:
274-
# https://github.com/temporalio/sdk-go/pull/1934
275-
link
276-
)
277-
return workflow_handle
300+
client: temporalio.client.Client
301+
"""The Temporal client in use by the worker handling the current Nexus operation."""
278302

303+
@classmethod
304+
def get(cls) -> _TemporalCancelOperationContext:
305+
ctx = _temporal_cancel_operation_context.get(None)
306+
if ctx is None:
307+
raise RuntimeError("Not in Nexus cancel operation context.")
308+
return ctx
279309

280-
@dataclass
281-
class _TemporalCancelOperationContext:
282-
nexus_operation_context: CancelOperationContext
310+
def set(self) -> None:
311+
_temporal_cancel_operation_context.set(self)
283312

284313

285314
def _workflow_handle_to_workflow_execution_started_event_link(
@@ -376,9 +405,9 @@ def process(
376405
self, msg: Any, kwargs: MutableMapping[str, Any]
377406
) -> tuple[Any, MutableMapping[str, Any]]:
378407
extra = dict(self.extra or {})
379-
if tctx := _temporal_operation_context.get(None):
380-
extra["service"] = tctx.nexus_operation_context.service
381-
extra["operation"] = tctx.nexus_operation_context.operation
408+
if tctx := _try_temporal_context():
409+
extra["service"] = tctx.nexus_context.service
410+
extra["operation"] = tctx.nexus_context.operation
382411
extra["task_queue"] = tctx.info().task_queue
383412
kwargs["extra"] = extra | kwargs.get("extra", {})
384413
return msg, kwargs

temporalio/nexus/_operation_handlers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
from temporalio import client
2828
from temporalio.nexus._operation_context import (
29-
_temporal_operation_context,
29+
_temporal_start_operation_context,
3030
)
3131
from temporalio.nexus._token import WorkflowHandle
3232

@@ -114,7 +114,7 @@ async def fetch_result(
114114
type=HandlerErrorType.NOT_FOUND,
115115
cause=err,
116116
)
117-
ctx = _temporal_operation_context.get()
117+
ctx = _temporal_start_operation_context.get()
118118
try:
119119
client_handle = nexus_handle.to_workflow_handle(
120120
ctx.client, result_type=self._output_type
@@ -148,7 +148,7 @@ async def cancel_operation(
148148
cause=err,
149149
)
150150

151-
ctx = _temporal_operation_context.get()
151+
ctx = _temporal_start_operation_context.get()
152152
try:
153153
client_workflow_handle = nexus_workflow_handle._to_client_workflow_handle(
154154
ctx.client

0 commit comments

Comments
 (0)