Skip to content

Commit 4a18f1f

Browse files
authored
Activity worker: refactoring part 2 (#899)
1 parent 6bd7256 commit 4a18f1f

File tree

1 file changed

+31
-26
lines changed

1 file changed

+31
-26
lines changed

temporalio/worker/_activity.py

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -138,24 +138,21 @@ def __init__(
138138
self._dynamic_activity = defn
139139

140140
async def run(self) -> None:
141-
# Create a task that fails when we get a failure on the queue
142-
async def raise_from_queue() -> NoReturn:
141+
"""Continually poll for activity tasks and dispatch to handlers."""
142+
143+
async def raise_from_exception_queue() -> NoReturn:
143144
raise await self._fail_worker_exception_queue.get()
144145

145-
exception_task = asyncio.create_task(raise_from_queue())
146+
exception_task = asyncio.create_task(raise_from_exception_queue())
146147

147-
# Continually poll for activity work
148148
while True:
149149
try:
150-
# Poll for a task
151150
poll_task = asyncio.create_task(
152151
self._bridge_worker().poll_activity_task()
153152
)
154153
await asyncio.wait(
155154
[poll_task, exception_task], return_when=asyncio.FIRST_COMPLETED
156-
) # type: ignore
157-
# If exception for failing the worker happened, raise it.
158-
# Otherwise, the poll succeeded.
155+
)
159156
if exception_task.done():
160157
poll_task.cancel()
161158
await exception_task
@@ -167,11 +164,14 @@ async def raise_from_queue() -> NoReturn:
167164
# size of 1000 should be plenty for the heartbeat queue.
168165
activity = _RunningActivity(pending_heartbeats=asyncio.Queue(1000))
169166
activity.task = asyncio.create_task(
170-
self._run_activity(task.task_token, task.start, activity)
167+
self._handle_start_activity_task(
168+
task.task_token, task.start, activity
169+
)
171170
)
172171
self._running_activities[task.task_token] = activity
173172
elif task.HasField("cancel"):
174-
self._cancel(task.task_token, task.cancel)
173+
# TODO(nexus-prerelease): does the task get removed from running_activities?
174+
self._handle_cancel_activity_task(task.task_token, task.cancel)
175175
else:
176176
raise RuntimeError(f"Unrecognized activity task: {task}")
177177
except temporalio.bridge.worker.PollShutdownError:
@@ -208,9 +208,10 @@ async def wait_all_completed(self) -> None:
208208
if running_tasks:
209209
await asyncio.gather(*running_tasks, return_exceptions=False)
210210

211-
def _cancel(
211+
def _handle_cancel_activity_task(
212212
self, task_token: bytes, cancel: temporalio.bridge.proto.activity_task.Cancel
213213
) -> None:
214+
"""Request cancellation of a running activity task."""
214215
activity = self._running_activities.get(task_token)
215216
if not activity:
216217
warnings.warn(f"Cannot find activity to cancel for token {task_token!r}")
@@ -275,12 +276,17 @@ async def _heartbeat_async(
275276
)
276277
activity.cancel(cancelled_due_to_heartbeat_error=err)
277278

278-
async def _run_activity(
279+
async def _handle_start_activity_task(
279280
self,
280281
task_token: bytes,
281282
start: temporalio.bridge.proto.activity_task.Start,
282283
running_activity: _RunningActivity,
283284
) -> None:
285+
"""Handle a start activity task.
286+
287+
Attempt to execute the user activity function and invoke the data converter on
288+
the result. Handle errors and send the task completion.
289+
"""
284290
logger.debug("Running activity %s (token %s)", start.activity_type, task_token)
285291
# We choose to surround interceptor creation and activity invocation in
286292
# a try block so we can mark the workflow as failed on any error instead
@@ -289,7 +295,9 @@ async def _run_activity(
289295
task_token=task_token
290296
)
291297
try:
292-
await self._execute_activity(start, running_activity, completion)
298+
result = await self._execute_activity(start, running_activity, task_token)
299+
[payload] = await self._data_converter.encode([result])
300+
completion.result.completed.result.CopyFrom(payload)
293301
except BaseException as err:
294302
try:
295303
if isinstance(err, temporalio.activity._CompleteAsyncError):
@@ -318,7 +326,7 @@ async def _run_activity(
318326
and running_activity.cancellation_details.details.paused
319327
):
320328
temporalio.activity.logger.warning(
321-
f"Completing as failure due to unhandled cancel error produced by activity pause",
329+
"Completing as failure due to unhandled cancel error produced by activity pause",
322330
)
323331
await self._data_converter.encode_failure(
324332
temporalio.exceptions.ApplicationError(
@@ -402,8 +410,12 @@ async def _execute_activity(
402410
self,
403411
start: temporalio.bridge.proto.activity_task.Start,
404412
running_activity: _RunningActivity,
405-
completion: temporalio.bridge.proto.ActivityTaskCompletion,
406-
):
413+
task_token: bytes,
414+
) -> Any:
415+
"""Invoke the user's activity function.
416+
417+
Exceptions are handled by a caller of this function.
418+
"""
407419
# Find activity or fail
408420
activity_def = self._activities.get(start.activity_type, self._dynamic_activity)
409421
if not activity_def:
@@ -523,7 +535,7 @@ async def _execute_activity(
523535
else None,
524536
started_time=_proto_to_datetime(start.started_time),
525537
task_queue=self._task_queue,
526-
task_token=completion.task_token,
538+
task_token=task_token,
527539
workflow_id=start.workflow_execution.workflow_id,
528540
workflow_namespace=start.workflow_namespace,
529541
workflow_run_id=start.workflow_execution.run_id,
@@ -562,16 +574,9 @@ async def _execute_activity(
562574
impl: ActivityInboundInterceptor = _ActivityInboundImpl(self, running_activity)
563575
for interceptor in reversed(list(self._interceptors)):
564576
impl = interceptor.intercept_activity(impl)
565-
# Init
577+
566578
impl.init(_ActivityOutboundImpl(self, running_activity.info))
567-
# Exec
568-
result = await impl.execute_activity(input)
569-
# Convert result even if none. Since Python essentially only
570-
# supports single result types (even if they are tuples), we will do
571-
# the same.
572-
completion.result.completed.result.CopyFrom(
573-
(await self._data_converter.encode([result]))[0]
574-
)
579+
return await impl.execute_activity(input)
575580

576581
def assert_activity_valid(self, activity) -> None:
577582
if self._dynamic_activity:

0 commit comments

Comments
 (0)