Skip to content

Commit fc528ee

Browse files
authored
Move to proper async heartbeat queuing (#14)
Fixes #12
1 parent c1efe42 commit fc528ee

File tree

2 files changed

+48
-54
lines changed

2 files changed

+48
-54
lines changed

temporalio/worker.py

Lines changed: 44 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -370,8 +370,9 @@ async def _run_activities(self) -> None:
370370

371371
if task.HasField("start"):
372372
# Cancelled event and sync field will be updated inside
373-
# _run_activity when the activity function is obtained
374-
activity = _RunningActivity()
373+
# _run_activity when the activity function is obtained. Max
374+
# size of 1000 should be plenty for the heartbeat queue.
375+
activity = _RunningActivity(pending_heartbeats=asyncio.Queue(1000))
375376
activity.task = asyncio.create_task(
376377
self._run_activity(task.task_token, task.start, activity)
377378
)
@@ -409,22 +410,27 @@ def _heartbeat_activity(self, task_token: bytes, *details: Any) -> None:
409410
logger = temporalio.activity.logger
410411
activity = self._running_activities.get(task_token)
411412
if activity and not activity.done:
412-
# Just set as next pending if one is already running
413-
coro = self._heartbeat_activity_async(
414-
logger, activity, task_token, *details
413+
# Put on queue and schedule a task. We will let the queue-full error
414+
# be thrown here
415+
activity.pending_heartbeats.put_nowait(details)
416+
activity.last_heartbeat_task = asyncio.create_task(
417+
self._heartbeat_activity_async(logger, activity, task_token)
415418
)
416-
if activity.current_heartbeat_task:
417-
activity.pending_heartbeat = coro
418-
else:
419-
activity.current_heartbeat_task = asyncio.create_task(coro)
420419

421420
async def _heartbeat_activity_async(
422421
self,
423422
logger: logging.LoggerAdapter,
424423
activity: _RunningActivity,
425424
task_token: bytes,
426-
*details: Any,
427425
) -> None:
426+
# Drain the queue, only taking the last value to actually heartbeat
427+
details: Optional[Iterable[Any]] = None
428+
while not activity.pending_heartbeats.empty():
429+
details = activity.pending_heartbeats.get_nowait()
430+
if details is None:
431+
return
432+
433+
# Perform the heartbeat
428434
try:
429435
heartbeat = temporalio.bridge.proto.ActivityHeartbeat(task_token=task_token)
430436
if details:
@@ -437,16 +443,7 @@ async def _heartbeat_activity_async(
437443
)
438444
logger.debug("Recording heartbeat with details %s", details)
439445
self._bridge_worker.record_activity_heartbeat(heartbeat)
440-
# If there is one pending, schedule it
441-
if activity.pending_heartbeat:
442-
activity.current_heartbeat_task = asyncio.create_task(
443-
activity.pending_heartbeat
444-
)
445-
activity.pending_heartbeat = None
446-
else:
447-
activity.current_heartbeat_task = None
448446
except Exception as err:
449-
activity.current_heartbeat_task = None
450447
# If the activity is done, nothing we can do but log
451448
if activity.done:
452449
logger.exception(
@@ -696,12 +693,12 @@ async def _run_activity(
696693

697694
# Do final completion
698695
try:
699-
# We mark the activity as done and let the currently running (and next
700-
# pending) heartbeat task finish
696+
# We mark the activity as done and let the currently running
697+
# heartbeat task finish
701698
running_activity.done = True
702-
while running_activity.current_heartbeat_task:
699+
if running_activity.last_heartbeat_task:
703700
try:
704-
await running_activity.current_heartbeat_task
701+
await running_activity.last_heartbeat_task
705702
except:
706703
# Should never happen because it's trapped in-task
707704
temporalio.activity.logger.exception(
@@ -749,12 +746,12 @@ class _ActivityDefinition:
749746

750747
@dataclass
751748
class _RunningActivity:
749+
pending_heartbeats: asyncio.Queue[Iterable[Any]]
752750
# Most of these optional values are set before use
753751
info: Optional[temporalio.activity.Info] = None
754752
task: Optional[asyncio.Task] = None
755753
cancelled_event: Optional[temporalio.activity._CompositeEvent] = None
756-
pending_heartbeat: Optional[Coroutine] = None
757-
current_heartbeat_task: Optional[asyncio.Task] = None
754+
last_heartbeat_task: Optional[asyncio.Task] = None
758755
sync: bool = False
759756
done: bool = False
760757
cancelled_by_request: bool = False
@@ -895,19 +892,16 @@ async def execute_activity(self, input: ExecuteActivityInput) -> Any:
895892
# loop (even though it's sync). So we need a call that puts the
896893
# context back on the activity and calls heartbeat, then another
897894
# call schedules it.
898-
def heartbeat_with_context(*details: Any) -> None:
895+
async def heartbeat_with_context(*details: Any) -> None:
899896
temporalio.activity._Context.set(ctx)
900897
assert orig_heartbeat
901898
orig_heartbeat(*details)
902899

903-
def thread_safe_heartbeat(*details: Any) -> None:
904-
# TODO(cretz): Final heartbeat can be flaky if we don't wait on
905-
# result here, but waiting on result of
906-
# asyncio.run_coroutine_threadsafe times out in rare cases.
907-
# Need more investigation: https://github.com/temporalio/sdk-python/issues/12
908-
loop.call_soon_threadsafe(heartbeat_with_context, *details)
909-
910-
ctx.heartbeat = thread_safe_heartbeat
900+
# Invoke the async heartbeat waiting a max of 10 seconds for
901+
# accepting
902+
ctx.heartbeat = lambda *details: asyncio.run_coroutine_threadsafe(
903+
heartbeat_with_context(*details), loop
904+
).result(10)
911905

912906
# For heartbeats, we use the existing heartbeat callable for thread
913907
# pool executors or a multiprocessing queue for others
@@ -917,7 +911,7 @@ def thread_safe_heartbeat(*details: Any) -> None:
917911
# Should always be present in worker, pre-checked on init
918912
shared_manager = input._worker._config["shared_state_manager"]
919913
assert shared_manager
920-
heartbeat = shared_manager.register_heartbeater(
914+
heartbeat = await shared_manager.register_heartbeater(
921915
info.task_token, ctx.heartbeat
922916
)
923917

@@ -935,7 +929,7 @@ def thread_safe_heartbeat(*details: Any) -> None:
935929
)
936930
finally:
937931
if shared_manager:
938-
shared_manager.unregister_heartbeater(info.task_token)
932+
await shared_manager.unregister_heartbeater(info.task_token)
939933

940934
# Otherwise for async activity, just run
941935
return await input.fn(*input.args)
@@ -1032,7 +1026,7 @@ def new_event(self) -> threading.Event:
10321026
raise NotImplementedError
10331027

10341028
@abstractmethod
1035-
def register_heartbeater(
1029+
async def register_heartbeater(
10361030
self, task_token: bytes, heartbeat: Callable[..., None]
10371031
) -> SharedHeartbeatSender:
10381032
"""Register a heartbeat function.
@@ -1048,7 +1042,7 @@ def register_heartbeater(
10481042
raise NotImplementedError
10491043

10501044
@abstractmethod
1051-
def unregister_heartbeater(self, task_token: bytes) -> None:
1045+
async def unregister_heartbeater(self, task_token: bytes) -> None:
10521046
"""Unregisters a previously registered heartbeater for the task
10531047
token. This should also flush any pending heartbeats.
10541048
"""
@@ -1084,12 +1078,12 @@ def __init__(
10841078
1000
10851079
)
10861080
self._heartbeats: Dict[bytes, Callable[..., None]] = {}
1087-
self._heartbeat_completions: Dict[bytes, Callable[[], None]] = {}
1081+
self._heartbeat_completions: Dict[bytes, Callable] = {}
10881082

10891083
def new_event(self) -> threading.Event:
10901084
return self._mgr.Event()
10911085

1092-
def register_heartbeater(
1086+
async def register_heartbeater(
10931087
self, task_token: bytes, heartbeat: Callable[..., None]
10941088
) -> SharedHeartbeatSender:
10951089
self._heartbeats[task_token] = heartbeat
@@ -1098,17 +1092,19 @@ def register_heartbeater(
10981092
self._queue_poller_executor.submit(self._heartbeat_processor)
10991093
return _MultiprocessingSharedHeartbeatSender(self._heartbeat_queue)
11001094

1101-
def unregister_heartbeater(self, task_token: bytes) -> None:
1102-
# Put a completion on the queue and wait for it to happen
1103-
flush_complete = threading.Event()
1104-
self._heartbeat_completions[task_token] = flush_complete.set
1095+
async def unregister_heartbeater(self, task_token: bytes) -> None:
1096+
# Put a callback on the queue and wait for it to happen
1097+
loop = asyncio.get_running_loop()
1098+
finish_event = asyncio.Event()
1099+
self._heartbeat_completions[task_token] = lambda: loop.call_soon_threadsafe(
1100+
finish_event.set
1101+
)
11051102
try:
1106-
# 30 seconds to put complete, 30 to get notified should be plenty
1103+
# We only give the queue a few seconds to have enough room
11071104
self._heartbeat_queue.put(
1108-
(task_token, _multiprocess_heartbeat_complete), True, 30
1105+
(task_token, _multiprocess_heartbeat_complete), True, 5
11091106
)
1110-
if not flush_complete.wait(30):
1111-
raise RuntimeError("Timeout waiting for heartbeat flush")
1107+
await finish_event.wait()
11121108
finally:
11131109
del self._heartbeat_completions[task_token]
11141110

tests/test_worker.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,8 @@ async def test_sync_activity_thread_cancel(
217217
):
218218
def wait_cancel() -> str:
219219
while not temporalio.activity.is_cancelled():
220-
temporalio.activity.heartbeat()
221220
time.sleep(1)
221+
temporalio.activity.heartbeat()
222222
return "Cancelled"
223223

224224
with concurrent.futures.ThreadPoolExecutor() as executor:
@@ -228,16 +228,16 @@ def wait_cancel() -> str:
228228
wait_cancel,
229229
cancel_after_ms=100,
230230
wait_for_cancellation=True,
231-
heartbeat_timeout_ms=30000,
231+
heartbeat_timeout_ms=3000,
232232
worker_config={"activity_executor": executor},
233233
)
234234
assert result.result == "Cancelled"
235235

236236

237237
def picklable_activity_wait_cancel() -> str:
238238
while not temporalio.activity.is_cancelled():
239-
temporalio.activity.heartbeat()
240239
time.sleep(1)
240+
temporalio.activity.heartbeat()
241241
return "Cancelled"
242242

243243

@@ -251,7 +251,7 @@ async def test_sync_activity_process_cancel(
251251
picklable_activity_wait_cancel,
252252
cancel_after_ms=100,
253253
wait_for_cancellation=True,
254-
heartbeat_timeout_ms=30000,
254+
heartbeat_timeout_ms=3000,
255255
worker_config={"activity_executor": executor},
256256
)
257257
assert result.result == "Cancelled"
@@ -430,8 +430,6 @@ def picklable_heartbeat_details_activity() -> str:
430430
some_list.append(f"attempt: {info.attempt}")
431431
temporalio.activity.logger.debug("Heartbeating with value: %s", some_list)
432432
temporalio.activity.heartbeat(some_list)
433-
# TODO(cretz): Remove when we fix multiprocess heartbeats
434-
time.sleep(1)
435433
if len(some_list) < 2:
436434
raise RuntimeError(f"Try again, list contains: {some_list}")
437435
return ", ".join(some_list)

0 commit comments

Comments
 (0)