10
10
Any ,
11
11
Callable ,
12
12
Mapping ,
13
+ MutableMapping ,
13
14
Optional ,
14
15
Sequence ,
15
16
Union ,
30
31
SelfType ,
31
32
)
32
33
34
+ # The Temporal Nexus worker always builds a nexusrpc StartOperationContext or
35
+ # CancelOperationContext and passes it as the first parameter to the nexusrpc operation
36
+ # handler. In addition, it sets one of the following context vars.
33
37
34
- _temporal_operation_context : ContextVar [_TemporalNexusOperationContext ] = ContextVar (
35
- "temporal-operation-context"
38
+ _temporal_start_operation_context : ContextVar [_TemporalStartOperationContext ] = (
39
+ ContextVar ("temporal-start-operation-context" )
40
+ )
41
+
42
+ _temporal_cancel_operation_context : ContextVar [_TemporalCancelOperationContext ] = (
43
+ ContextVar ("temporal-cancel-operation-context" )
36
44
)
37
45
38
46
@@ -51,59 +59,126 @@ def info() -> Info:
51
59
"""
52
60
Get the current Nexus operation information.
53
61
"""
54
- return _TemporalNexusOperationContext . get ().info ()
62
+ return _temporal_context ().info ()
55
63
56
64
57
65
def client () -> temporalio .client .Client :
58
66
"""
59
67
Get the Temporal client used by the worker handling the current Nexus operation.
60
68
"""
61
- return _TemporalNexusOperationContext .get ().client
69
+ return _temporal_context ().client
70
+
71
+
72
+ def _temporal_context () -> (
73
+ Union [_TemporalStartOperationContext , _TemporalCancelOperationContext ]
74
+ ):
75
+ ctx = _try_temporal_context ()
76
+ if ctx is None :
77
+ raise RuntimeError ("Not in Nexus operation context." )
78
+ return ctx
79
+
80
+
81
+ def _try_temporal_context () -> (
82
+ Optional [Union [_TemporalStartOperationContext , _TemporalCancelOperationContext ]]
83
+ ):
84
+ start_ctx = _temporal_start_operation_context .get (None )
85
+ cancel_ctx = _temporal_cancel_operation_context .get (None )
86
+ if start_ctx and cancel_ctx :
87
+ raise RuntimeError ("Cannot be in both start and cancel operation contexts." )
88
+ return start_ctx or cancel_ctx
62
89
63
90
64
91
@dataclass
65
- class _TemporalNexusOperationContext :
92
+ class _TemporalStartOperationContext :
66
93
"""
67
- Context for a Nexus operation being handled by a Temporal Nexus Worker.
94
+ Context for a Nexus start operation being handled by a Temporal Nexus Worker.
68
95
"""
69
96
70
- info : Callable [[], Info ]
71
- """Information about the running Nexus operation."""
97
+ nexus_context : StartOperationContext
98
+ """Nexus-specific start operation context ."""
72
99
73
- nexus_operation_context : Union [StartOperationContext , CancelOperationContext ]
100
+ info : Callable [[], Info ]
101
+ """Temporal information about the running Nexus operation."""
74
102
75
103
client : temporalio .client .Client
76
104
"""The Temporal client in use by the worker handling this Nexus operation."""
77
105
78
106
@classmethod
79
- def get (cls ) -> _TemporalNexusOperationContext :
80
- ctx = _temporal_operation_context .get (None )
107
+ def get (cls ) -> _TemporalStartOperationContext :
108
+ ctx = _temporal_start_operation_context .get (None )
81
109
if ctx is None :
82
110
raise RuntimeError ("Not in Nexus operation context." )
83
111
return ctx
84
112
85
- @property
86
- def _temporal_start_operation_context (
113
+ def set (self ) -> None :
114
+ _temporal_start_operation_context .set (self )
115
+
116
+ def get_completion_callbacks (
87
117
self ,
88
- ) -> Optional [_TemporalStartOperationContext ]:
89
- ctx = self .nexus_operation_context
90
- if not isinstance (ctx , StartOperationContext ):
91
- return None
92
- return _TemporalStartOperationContext (ctx )
118
+ ) -> list [temporalio .client .NexusCompletionCallback ]:
119
+ ctx = self .nexus_context
120
+ return (
121
+ [
122
+ # TODO(nexus-prerelease): For WorkflowRunOperation, when it handles the Nexus
123
+ # request, it needs to copy the links to the callback in
124
+ # StartWorkflowRequest.CompletionCallbacks and to StartWorkflowRequest.Links
125
+ # (for backwards compatibility). PR reference in Go SDK:
126
+ # https://github.com/temporalio/sdk-go/pull/1945
127
+ temporalio .client .NexusCompletionCallback (
128
+ url = ctx .callback_url ,
129
+ header = ctx .callback_headers ,
130
+ )
131
+ ]
132
+ if ctx .callback_url
133
+ else []
134
+ )
93
135
94
- @property
95
- def _temporal_cancel_operation_context (
136
+ def get_workflow_event_links (
96
137
self ,
97
- ) -> Optional [_TemporalCancelOperationContext ]:
98
- ctx = self .nexus_operation_context
99
- if not isinstance (ctx , CancelOperationContext ):
100
- return None
101
- return _TemporalCancelOperationContext (ctx )
138
+ ) -> list [temporalio .api .common .v1 .Link .WorkflowEvent ]:
139
+ event_links = []
140
+ for inbound_link in self .nexus_context .inbound_links :
141
+ if link := _nexus_link_to_workflow_event (inbound_link ):
142
+ event_links .append (link )
143
+ return event_links
144
+
145
+ def add_outbound_links (
146
+ self , workflow_handle : temporalio .client .WorkflowHandle [Any , Any ]
147
+ ):
148
+ try :
149
+ link = _workflow_event_to_nexus_link (
150
+ _workflow_handle_to_workflow_execution_started_event_link (
151
+ workflow_handle
152
+ )
153
+ )
154
+ except Exception as e :
155
+ logger .warning (
156
+ f"Failed to create WorkflowExecutionStarted event link for workflow { id } : { e } "
157
+ )
158
+ else :
159
+ self .nexus_context .outbound_links .append (
160
+ # TODO(nexus-prerelease): Before, WorkflowRunOperation was generating an EventReference
161
+ # link to send back to the caller. Now, it checks if the server returned
162
+ # the link in the StartWorkflowExecutionResponse, and if so, send the link
163
+ # from the response to the caller. Fallback to generating the link for
164
+ # backwards compatibility. PR reference in Go SDK:
165
+ # https://github.com/temporalio/sdk-go/pull/1934
166
+ link
167
+ )
168
+ return workflow_handle
102
169
103
170
104
171
@dataclass
105
172
class WorkflowRunOperationContext :
106
- start_operation_context : StartOperationContext
173
+ temporal_context : _TemporalStartOperationContext
174
+
175
+ @property
176
+ def nexus_context (self ) -> StartOperationContext :
177
+ return self .temporal_context .nexus_context
178
+
179
+ @classmethod
180
+ def get (cls ) -> WorkflowRunOperationContext :
181
+ return cls (_TemporalStartOperationContext .get ())
107
182
108
183
# Overload for single-param workflow
109
184
# TODO(nexus-prerelease): bring over other overloads
@@ -164,14 +239,6 @@ async def start_workflow(
164
239
Nexus caller is itself a workflow, this means that the workflow in the caller
165
240
namespace web UI will contain links to the started workflow, and vice versa.
166
241
"""
167
- tctx = _TemporalNexusOperationContext .get ()
168
- start_operation_context = tctx ._temporal_start_operation_context
169
- if not start_operation_context :
170
- raise RuntimeError (
171
- "WorkflowRunOperationContext.start_workflow() must be called from "
172
- "within a Nexus start operation context"
173
- )
174
-
175
242
# TODO(nexus-preview): When sdk-python supports on_conflict_options, Typescript does this:
176
243
# if (workflowOptions.workflowIdConflictPolicy === 'USE_EXISTING') {
177
244
# internalOptions.onConflictOptions = {
@@ -184,11 +251,11 @@ async def start_workflow(
184
251
# We must pass nexus_completion_callbacks, workflow_event_links, and request_id,
185
252
# but these are deliberately not exposed in overloads, hence the type-check
186
253
# violation.
187
- wf_handle = await tctx .client .start_workflow ( # type: ignore
254
+ wf_handle = await self . temporal_context .client .start_workflow ( # type: ignore
188
255
workflow = workflow ,
189
256
arg = arg ,
190
257
id = id ,
191
- task_queue = task_queue or tctx .info ().task_queue ,
258
+ task_queue = task_queue or self . temporal_context .info ().task_queue ,
192
259
execution_timeout = execution_timeout ,
193
260
run_timeout = run_timeout ,
194
261
task_timeout = task_timeout ,
@@ -208,78 +275,40 @@ async def start_workflow(
208
275
request_eager_start = request_eager_start ,
209
276
priority = priority ,
210
277
versioning_override = versioning_override ,
211
- nexus_completion_callbacks = start_operation_context .get_completion_callbacks (),
212
- workflow_event_links = start_operation_context .get_workflow_event_links (),
213
- request_id = start_operation_context . nexus_operation_context .request_id ,
278
+ nexus_completion_callbacks = self . temporal_context .get_completion_callbacks (),
279
+ workflow_event_links = self . temporal_context .get_workflow_event_links (),
280
+ request_id = self . temporal_context . nexus_context .request_id ,
214
281
)
215
282
216
- start_operation_context .add_outbound_links (wf_handle )
283
+ self . temporal_context .add_outbound_links (wf_handle )
217
284
218
285
return WorkflowHandle [ReturnType ]._unsafe_from_client_workflow_handle (wf_handle )
219
286
220
287
221
288
@dataclass
222
- class _TemporalStartOperationContext :
223
- nexus_operation_context : StartOperationContext
289
+ class _TemporalCancelOperationContext :
290
+ """
291
+ Context for a Nexus cancel operation being handled by a Temporal Nexus Worker.
292
+ """
224
293
225
- def get_completion_callbacks (
226
- self ,
227
- ) -> list [temporalio .client .NexusCompletionCallback ]:
228
- ctx = self .nexus_operation_context
229
- return (
230
- [
231
- # TODO(nexus-prerelease): For WorkflowRunOperation, when it handles the Nexus
232
- # request, it needs to copy the links to the callback in
233
- # StartWorkflowRequest.CompletionCallbacks and to StartWorkflowRequest.Links
234
- # (for backwards compatibility). PR reference in Go SDK:
235
- # https://github.com/temporalio/sdk-go/pull/1945
236
- temporalio .client .NexusCompletionCallback (
237
- url = ctx .callback_url ,
238
- header = ctx .callback_headers ,
239
- )
240
- ]
241
- if ctx .callback_url
242
- else []
243
- )
294
+ nexus_context : CancelOperationContext
295
+ """Nexus-specific cancel operation context."""
244
296
245
- def get_workflow_event_links (
246
- self ,
247
- ) -> list [temporalio .api .common .v1 .Link .WorkflowEvent ]:
248
- event_links = []
249
- for inbound_link in self .nexus_operation_context .inbound_links :
250
- if link := _nexus_link_to_workflow_event (inbound_link ):
251
- event_links .append (link )
252
- return event_links
297
+ info : Callable [[], Info ]
298
+ """Temporal information about the running Nexus cancel operation."""
253
299
254
- def add_outbound_links (
255
- self , workflow_handle : temporalio .client .WorkflowHandle [Any , Any ]
256
- ):
257
- try :
258
- link = _workflow_event_to_nexus_link (
259
- _workflow_handle_to_workflow_execution_started_event_link (
260
- workflow_handle
261
- )
262
- )
263
- except Exception as e :
264
- logger .warning (
265
- f"Failed to create WorkflowExecutionStarted event link for workflow { id } : { e } "
266
- )
267
- else :
268
- self .nexus_operation_context .outbound_links .append (
269
- # TODO(nexus-prerelease): Before, WorkflowRunOperation was generating an EventReference
270
- # link to send back to the caller. Now, it checks if the server returned
271
- # the link in the StartWorkflowExecutionResponse, and if so, send the link
272
- # from the response to the caller. Fallback to generating the link for
273
- # backwards compatibility. PR reference in Go SDK:
274
- # https://github.com/temporalio/sdk-go/pull/1934
275
- link
276
- )
277
- return workflow_handle
300
+ client : temporalio .client .Client
301
+ """The Temporal client in use by the worker handling the current Nexus operation."""
278
302
303
+ @classmethod
304
+ def get (cls ) -> _TemporalCancelOperationContext :
305
+ ctx = _temporal_cancel_operation_context .get (None )
306
+ if ctx is None :
307
+ raise RuntimeError ("Not in Nexus cancel operation context." )
308
+ return ctx
279
309
280
- @dataclass
281
- class _TemporalCancelOperationContext :
282
- nexus_operation_context : CancelOperationContext
310
+ def set (self ) -> None :
311
+ _temporal_cancel_operation_context .set (self )
283
312
284
313
285
314
def _workflow_handle_to_workflow_execution_started_event_link (
@@ -376,9 +405,9 @@ def process(
376
405
self , msg : Any , kwargs : MutableMapping [str , Any ]
377
406
) -> tuple [Any , MutableMapping [str , Any ]]:
378
407
extra = dict (self .extra or {})
379
- if tctx := _temporal_operation_context . get ( None ):
380
- extra ["service" ] = tctx .nexus_operation_context .service
381
- extra ["operation" ] = tctx .nexus_operation_context .operation
408
+ if tctx := _try_temporal_context ( ):
409
+ extra ["service" ] = tctx .nexus_context .service
410
+ extra ["operation" ] = tctx .nexus_context .operation
382
411
extra ["task_queue" ] = tctx .info ().task_queue
383
412
kwargs ["extra" ] = extra | kwargs .get ("extra" , {})
384
413
return msg , kwargs
0 commit comments