Skip to content

Commit e9d3b01

Browse files
authored
Activity worker refactor (#860)
1 parent b53be98 commit e9d3b01

File tree

1 file changed

+177
-177
lines changed

1 file changed

+177
-177
lines changed

temporalio/worker/_activity.py

Lines changed: 177 additions & 177 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,8 @@ async def drain_poll_queue(self) -> None:
201201

202202
# Only call this after run()/drain_poll_queue() have returned. This will not
203203
# raise an exception.
204+
# TODO(dan): based on the comment above it looks like the intention may have been to use
205+
# return_exceptions=True
204206
async def wait_all_completed(self) -> None:
205207
running_tasks = [v.task for v in self._running_activities.values() if v.task]
206208
if running_tasks:
@@ -281,183 +283,7 @@ async def _run_activity(
281283
task_token=task_token
282284
)
283285
try:
284-
# Find activity or fail
285-
activity_def = self._activities.get(
286-
start.activity_type, self._dynamic_activity
287-
)
288-
if not activity_def:
289-
activity_names = ", ".join(sorted(self._activities.keys()))
290-
raise temporalio.exceptions.ApplicationError(
291-
f"Activity function {start.activity_type} for workflow {start.workflow_execution.workflow_id} "
292-
f"is not registered on this worker, available activities: {activity_names}",
293-
type="NotFoundError",
294-
)
295-
296-
# Create the worker shutdown event if not created
297-
if not self._worker_shutdown_event:
298-
self._worker_shutdown_event = temporalio.activity._CompositeEvent(
299-
thread_event=threading.Event(), async_event=asyncio.Event()
300-
)
301-
302-
# Setup events
303-
sync_non_threaded = False
304-
if not activity_def.is_async:
305-
running_activity.sync = True
306-
# If we're in a thread-pool executor we can use threading events
307-
# otherwise we must use manager events
308-
if isinstance(
309-
self._activity_executor, concurrent.futures.ThreadPoolExecutor
310-
):
311-
running_activity.cancelled_event = (
312-
temporalio.activity._CompositeEvent(
313-
thread_event=threading.Event(),
314-
# No async event
315-
async_event=None,
316-
)
317-
)
318-
if not activity_def.no_thread_cancel_exception:
319-
running_activity.cancel_thread_raiser = _ThreadExceptionRaiser()
320-
else:
321-
sync_non_threaded = True
322-
manager = self._shared_state_manager
323-
# Pre-checked on worker init
324-
assert manager
325-
running_activity.cancelled_event = (
326-
temporalio.activity._CompositeEvent(
327-
thread_event=manager.new_event(),
328-
# No async event
329-
async_event=None,
330-
)
331-
)
332-
# We also must set the worker shutdown thread event to a
333-
# manager event if this is the first sync event. We don't
334-
# want to create if there never is a sync event.
335-
if not self._seen_sync_activity:
336-
self._worker_shutdown_event.thread_event = manager.new_event()
337-
# Say we've seen a sync activity
338-
self._seen_sync_activity = True
339-
else:
340-
# We have to set the async form of events
341-
running_activity.cancelled_event = temporalio.activity._CompositeEvent(
342-
thread_event=threading.Event(),
343-
async_event=asyncio.Event(),
344-
)
345-
346-
# Convert arguments. We use raw value for dynamic. Otherwise, we
347-
# only use arg type hints if they match the input count.
348-
arg_types = activity_def.arg_types
349-
if not activity_def.name:
350-
# Dynamic is just the raw value for each input value
351-
arg_types = [temporalio.common.RawValue] * len(start.input)
352-
elif arg_types is not None and len(arg_types) != len(start.input):
353-
arg_types = None
354-
try:
355-
args = (
356-
[]
357-
if not start.input
358-
else await self._data_converter.decode(
359-
start.input, type_hints=arg_types
360-
)
361-
)
362-
except Exception as err:
363-
raise temporalio.exceptions.ApplicationError(
364-
"Failed decoding arguments"
365-
) from err
366-
# Put the args inside a list if dynamic
367-
if not activity_def.name:
368-
args = [args]
369-
370-
# Convert heartbeat details
371-
# TODO(cretz): Allow some way to configure heartbeat type hinting?
372-
try:
373-
heartbeat_details = (
374-
[]
375-
if not start.heartbeat_details
376-
else await self._data_converter.decode(start.heartbeat_details)
377-
)
378-
except Exception as err:
379-
raise temporalio.exceptions.ApplicationError(
380-
"Failed decoding heartbeat details", non_retryable=True
381-
) from err
382-
383-
# Build info
384-
info = temporalio.activity.Info(
385-
activity_id=start.activity_id,
386-
activity_type=start.activity_type,
387-
attempt=start.attempt,
388-
current_attempt_scheduled_time=_proto_to_datetime(
389-
start.current_attempt_scheduled_time
390-
),
391-
heartbeat_details=heartbeat_details,
392-
heartbeat_timeout=_proto_to_non_zero_timedelta(start.heartbeat_timeout)
393-
if start.HasField("heartbeat_timeout")
394-
else None,
395-
is_local=start.is_local,
396-
schedule_to_close_timeout=_proto_to_non_zero_timedelta(
397-
start.schedule_to_close_timeout
398-
)
399-
if start.HasField("schedule_to_close_timeout")
400-
else None,
401-
scheduled_time=_proto_to_datetime(start.scheduled_time),
402-
start_to_close_timeout=_proto_to_non_zero_timedelta(
403-
start.start_to_close_timeout
404-
)
405-
if start.HasField("start_to_close_timeout")
406-
else None,
407-
started_time=_proto_to_datetime(start.started_time),
408-
task_queue=self._task_queue,
409-
task_token=task_token,
410-
workflow_id=start.workflow_execution.workflow_id,
411-
workflow_namespace=start.workflow_namespace,
412-
workflow_run_id=start.workflow_execution.run_id,
413-
workflow_type=start.workflow_type,
414-
priority=temporalio.common.Priority._from_proto(start.priority),
415-
)
416-
running_activity.info = info
417-
input = ExecuteActivityInput(
418-
fn=activity_def.fn,
419-
args=args,
420-
executor=None if not running_activity.sync else self._activity_executor,
421-
headers=start.header_fields,
422-
)
423-
424-
# Set the context early so the logging adapter works and
425-
# interceptors have it
426-
temporalio.activity._Context.set(
427-
temporalio.activity._Context(
428-
info=lambda: info,
429-
heartbeat=None,
430-
cancelled_event=running_activity.cancelled_event,
431-
worker_shutdown_event=self._worker_shutdown_event,
432-
shield_thread_cancel_exception=None
433-
if not running_activity.cancel_thread_raiser
434-
else running_activity.cancel_thread_raiser.shielded,
435-
payload_converter_class_or_instance=self._data_converter.payload_converter,
436-
runtime_metric_meter=None
437-
if sync_non_threaded
438-
else self._metric_meter,
439-
)
440-
)
441-
temporalio.activity.logger.debug("Starting activity")
442-
443-
# Build the interceptors chaining in reverse. We build a context right
444-
# now even though the info() can't be intercepted and heartbeat() will
445-
# fail. The interceptors may want to use the info() during init.
446-
impl: ActivityInboundInterceptor = _ActivityInboundImpl(
447-
self, running_activity
448-
)
449-
for interceptor in reversed(list(self._interceptors)):
450-
impl = interceptor.intercept_activity(impl)
451-
# Init
452-
impl.init(_ActivityOutboundImpl(self, running_activity.info))
453-
# Exec
454-
result = await impl.execute_activity(input)
455-
# Convert result even if none. Since Python essentially only
456-
# supports single result types (even if they are tuples), we will do
457-
# the same.
458-
completion.result.completed.result.CopyFrom(
459-
(await self._data_converter.encode([result]))[0]
460-
)
286+
await self._execute_activity(start, running_activity, completion)
461287
except BaseException as err:
462288
try:
463289
if isinstance(err, temporalio.activity._CompleteAsyncError):
@@ -545,6 +371,180 @@ async def _run_activity(
545371
except Exception:
546372
temporalio.activity.logger.exception("Failed completing activity task")
547373

374+
async def _execute_activity(
375+
self,
376+
start: temporalio.bridge.proto.activity_task.Start,
377+
running_activity: _RunningActivity,
378+
completion: temporalio.bridge.proto.ActivityTaskCompletion,
379+
):
380+
# Find activity or fail
381+
activity_def = self._activities.get(start.activity_type, self._dynamic_activity)
382+
if not activity_def:
383+
activity_names = ", ".join(sorted(self._activities.keys()))
384+
raise temporalio.exceptions.ApplicationError(
385+
f"Activity function {start.activity_type} for workflow {start.workflow_execution.workflow_id} "
386+
f"is not registered on this worker, available activities: {activity_names}",
387+
type="NotFoundError",
388+
)
389+
390+
# Create the worker shutdown event if not created
391+
if not self._worker_shutdown_event:
392+
self._worker_shutdown_event = temporalio.activity._CompositeEvent(
393+
thread_event=threading.Event(), async_event=asyncio.Event()
394+
)
395+
396+
# Setup events
397+
sync_non_threaded = False
398+
if not activity_def.is_async:
399+
running_activity.sync = True
400+
# If we're in a thread-pool executor we can use threading events
401+
# otherwise we must use manager events
402+
if isinstance(
403+
self._activity_executor, concurrent.futures.ThreadPoolExecutor
404+
):
405+
running_activity.cancelled_event = temporalio.activity._CompositeEvent(
406+
thread_event=threading.Event(),
407+
# No async event
408+
async_event=None,
409+
)
410+
if not activity_def.no_thread_cancel_exception:
411+
running_activity.cancel_thread_raiser = _ThreadExceptionRaiser()
412+
else:
413+
sync_non_threaded = True
414+
manager = self._shared_state_manager
415+
# Pre-checked on worker init
416+
assert manager
417+
running_activity.cancelled_event = temporalio.activity._CompositeEvent(
418+
thread_event=manager.new_event(),
419+
# No async event
420+
async_event=None,
421+
)
422+
# We also must set the worker shutdown thread event to a
423+
# manager event if this is the first sync event. We don't
424+
# want to create if there never is a sync event.
425+
if not self._seen_sync_activity:
426+
self._worker_shutdown_event.thread_event = manager.new_event()
427+
# Say we've seen a sync activity
428+
self._seen_sync_activity = True
429+
else:
430+
# We have to set the async form of events
431+
running_activity.cancelled_event = temporalio.activity._CompositeEvent(
432+
thread_event=threading.Event(),
433+
async_event=asyncio.Event(),
434+
)
435+
436+
# Convert arguments. We use raw value for dynamic. Otherwise, we
437+
# only use arg type hints if they match the input count.
438+
arg_types = activity_def.arg_types
439+
if not activity_def.name:
440+
# Dynamic is just the raw value for each input value
441+
arg_types = [temporalio.common.RawValue] * len(start.input)
442+
elif arg_types is not None and len(arg_types) != len(start.input):
443+
arg_types = None
444+
try:
445+
args = (
446+
[]
447+
if not start.input
448+
else await self._data_converter.decode(
449+
start.input, type_hints=arg_types
450+
)
451+
)
452+
except Exception as err:
453+
raise temporalio.exceptions.ApplicationError(
454+
"Failed decoding arguments"
455+
) from err
456+
# Put the args inside a list if dynamic
457+
if not activity_def.name:
458+
args = [args]
459+
460+
# Convert heartbeat details
461+
# TODO(cretz): Allow some way to configure heartbeat type hinting?
462+
try:
463+
heartbeat_details = (
464+
[]
465+
if not start.heartbeat_details
466+
else await self._data_converter.decode(start.heartbeat_details)
467+
)
468+
except Exception as err:
469+
raise temporalio.exceptions.ApplicationError(
470+
"Failed decoding heartbeat details", non_retryable=True
471+
) from err
472+
473+
# Build info
474+
info = temporalio.activity.Info(
475+
activity_id=start.activity_id,
476+
activity_type=start.activity_type,
477+
attempt=start.attempt,
478+
current_attempt_scheduled_time=_proto_to_datetime(
479+
start.current_attempt_scheduled_time
480+
),
481+
heartbeat_details=heartbeat_details,
482+
heartbeat_timeout=_proto_to_non_zero_timedelta(start.heartbeat_timeout)
483+
if start.HasField("heartbeat_timeout")
484+
else None,
485+
is_local=start.is_local,
486+
schedule_to_close_timeout=_proto_to_non_zero_timedelta(
487+
start.schedule_to_close_timeout
488+
)
489+
if start.HasField("schedule_to_close_timeout")
490+
else None,
491+
scheduled_time=_proto_to_datetime(start.scheduled_time),
492+
start_to_close_timeout=_proto_to_non_zero_timedelta(
493+
start.start_to_close_timeout
494+
)
495+
if start.HasField("start_to_close_timeout")
496+
else None,
497+
started_time=_proto_to_datetime(start.started_time),
498+
task_queue=self._task_queue,
499+
task_token=completion.task_token,
500+
workflow_id=start.workflow_execution.workflow_id,
501+
workflow_namespace=start.workflow_namespace,
502+
workflow_run_id=start.workflow_execution.run_id,
503+
workflow_type=start.workflow_type,
504+
priority=temporalio.common.Priority._from_proto(start.priority),
505+
)
506+
running_activity.info = info
507+
input = ExecuteActivityInput(
508+
fn=activity_def.fn,
509+
args=args,
510+
executor=None if not running_activity.sync else self._activity_executor,
511+
headers=start.header_fields,
512+
)
513+
514+
# Set the context early so the logging adapter works and
515+
# interceptors have it
516+
temporalio.activity._Context.set(
517+
temporalio.activity._Context(
518+
info=lambda: info,
519+
heartbeat=None,
520+
cancelled_event=running_activity.cancelled_event,
521+
worker_shutdown_event=self._worker_shutdown_event,
522+
shield_thread_cancel_exception=None
523+
if not running_activity.cancel_thread_raiser
524+
else running_activity.cancel_thread_raiser.shielded,
525+
payload_converter_class_or_instance=self._data_converter.payload_converter,
526+
runtime_metric_meter=None if sync_non_threaded else self._metric_meter,
527+
)
528+
)
529+
temporalio.activity.logger.debug("Starting activity")
530+
531+
# Build the interceptors chaining in reverse. We build a context right
532+
# now even though the info() can't be intercepted and heartbeat() will
533+
# fail. The interceptors may want to use the info() during init.
534+
impl: ActivityInboundInterceptor = _ActivityInboundImpl(self, running_activity)
535+
for interceptor in reversed(list(self._interceptors)):
536+
impl = interceptor.intercept_activity(impl)
537+
# Init
538+
impl.init(_ActivityOutboundImpl(self, running_activity.info))
539+
# Exec
540+
result = await impl.execute_activity(input)
541+
# Convert result even if none. Since Python essentially only
542+
# supports single result types (even if they are tuples), we will do
543+
# the same.
544+
completion.result.completed.result.CopyFrom(
545+
(await self._data_converter.encode([result]))[0]
546+
)
547+
548548
def assert_activity_valid(self, activity) -> None:
549549
if self._dynamic_activity:
550550
return

0 commit comments

Comments
 (0)