@@ -138,24 +138,21 @@ def __init__(
138
138
self ._dynamic_activity = defn
139
139
140
140
async def run (self ) -> None :
141
- # Create a task that fails when we get a failure on the queue
142
- async def raise_from_queue () -> NoReturn :
141
+ """Continually poll for activity tasks and dispatch to handlers."""
142
+
143
+ async def raise_from_exception_queue () -> NoReturn :
143
144
raise await self ._fail_worker_exception_queue .get ()
144
145
145
- exception_task = asyncio .create_task (raise_from_queue ())
146
+ exception_task = asyncio .create_task (raise_from_exception_queue ())
146
147
147
- # Continually poll for activity work
148
148
while True :
149
149
try :
150
- # Poll for a task
151
150
poll_task = asyncio .create_task (
152
151
self ._bridge_worker ().poll_activity_task ()
153
152
)
154
153
await asyncio .wait (
155
154
[poll_task , exception_task ], return_when = asyncio .FIRST_COMPLETED
156
- ) # type: ignore
157
- # If exception for failing the worker happened, raise it.
158
- # Otherwise, the poll succeeded.
155
+ )
159
156
if exception_task .done ():
160
157
poll_task .cancel ()
161
158
await exception_task
@@ -167,11 +164,14 @@ async def raise_from_queue() -> NoReturn:
167
164
# size of 1000 should be plenty for the heartbeat queue.
168
165
activity = _RunningActivity (pending_heartbeats = asyncio .Queue (1000 ))
169
166
activity .task = asyncio .create_task (
170
- self ._run_activity (task .task_token , task .start , activity )
167
+ self ._handle_start_activity_task (
168
+ task .task_token , task .start , activity
169
+ )
171
170
)
172
171
self ._running_activities [task .task_token ] = activity
173
172
elif task .HasField ("cancel" ):
174
- self ._cancel (task .task_token , task .cancel )
173
+ # TODO(nexus-prerelease): does the task get removed from running_activities?
174
+ self ._handle_cancel_activity_task (task .task_token , task .cancel )
175
175
else :
176
176
raise RuntimeError (f"Unrecognized activity task: { task } " )
177
177
except temporalio .bridge .worker .PollShutdownError :
@@ -208,9 +208,10 @@ async def wait_all_completed(self) -> None:
208
208
if running_tasks :
209
209
await asyncio .gather (* running_tasks , return_exceptions = False )
210
210
211
- def _cancel (
211
+ def _handle_cancel_activity_task (
212
212
self , task_token : bytes , cancel : temporalio .bridge .proto .activity_task .Cancel
213
213
) -> None :
214
+ """Request cancellation of a running activity task."""
214
215
activity = self ._running_activities .get (task_token )
215
216
if not activity :
216
217
warnings .warn (f"Cannot find activity to cancel for token { task_token !r} " )
@@ -275,12 +276,17 @@ async def _heartbeat_async(
275
276
)
276
277
activity .cancel (cancelled_due_to_heartbeat_error = err )
277
278
278
- async def _run_activity (
279
+ async def _handle_start_activity_task (
279
280
self ,
280
281
task_token : bytes ,
281
282
start : temporalio .bridge .proto .activity_task .Start ,
282
283
running_activity : _RunningActivity ,
283
284
) -> None :
285
+ """Handle a start activity task.
286
+
287
+ Attempt to execute the user activity function and invoke the data converter on
288
+ the result. Handle errors and send the task completion.
289
+ """
284
290
logger .debug ("Running activity %s (token %s)" , start .activity_type , task_token )
285
291
# We choose to surround interceptor creation and activity invocation in
286
292
# a try block so we can mark the workflow as failed on any error instead
@@ -289,7 +295,9 @@ async def _run_activity(
289
295
task_token = task_token
290
296
)
291
297
try :
292
- await self ._execute_activity (start , running_activity , completion )
298
+ result = await self ._execute_activity (start , running_activity , task_token )
299
+ [payload ] = await self ._data_converter .encode ([result ])
300
+ completion .result .completed .result .CopyFrom (payload )
293
301
except BaseException as err :
294
302
try :
295
303
if isinstance (err , temporalio .activity ._CompleteAsyncError ):
@@ -318,7 +326,7 @@ async def _run_activity(
318
326
and running_activity .cancellation_details .details .paused
319
327
):
320
328
temporalio .activity .logger .warning (
321
- f "Completing as failure due to unhandled cancel error produced by activity pause" ,
329
+ "Completing as failure due to unhandled cancel error produced by activity pause" ,
322
330
)
323
331
await self ._data_converter .encode_failure (
324
332
temporalio .exceptions .ApplicationError (
@@ -402,8 +410,12 @@ async def _execute_activity(
402
410
self ,
403
411
start : temporalio .bridge .proto .activity_task .Start ,
404
412
running_activity : _RunningActivity ,
405
- completion : temporalio .bridge .proto .ActivityTaskCompletion ,
406
- ):
413
+ task_token : bytes ,
414
+ ) -> Any :
415
+ """Invoke the user's activity function.
416
+
417
+ Exceptions are handled by a caller of this function.
418
+ """
407
419
# Find activity or fail
408
420
activity_def = self ._activities .get (start .activity_type , self ._dynamic_activity )
409
421
if not activity_def :
@@ -523,7 +535,7 @@ async def _execute_activity(
523
535
else None ,
524
536
started_time = _proto_to_datetime (start .started_time ),
525
537
task_queue = self ._task_queue ,
526
- task_token = completion . task_token ,
538
+ task_token = task_token ,
527
539
workflow_id = start .workflow_execution .workflow_id ,
528
540
workflow_namespace = start .workflow_namespace ,
529
541
workflow_run_id = start .workflow_execution .run_id ,
@@ -562,16 +574,9 @@ async def _execute_activity(
562
574
impl : ActivityInboundInterceptor = _ActivityInboundImpl (self , running_activity )
563
575
for interceptor in reversed (list (self ._interceptors )):
564
576
impl = interceptor .intercept_activity (impl )
565
- # Init
577
+
566
578
impl .init (_ActivityOutboundImpl (self , running_activity .info ))
567
- # Exec
568
- result = await impl .execute_activity (input )
569
- # Convert result even if none. Since Python essentially only
570
- # supports single result types (even if they are tuples), we will do
571
- # the same.
572
- completion .result .completed .result .CopyFrom (
573
- (await self ._data_converter .encode ([result ]))[0 ]
574
- )
579
+ return await impl .execute_activity (input )
575
580
576
581
def assert_activity_valid (self , activity ) -> None :
577
582
if self ._dynamic_activity :
0 commit comments