@@ -370,8 +370,9 @@ async def _run_activities(self) -> None:
370
370
371
371
if task .HasField ("start" ):
372
372
# Cancelled event and sync field will be updated inside
373
- # _run_activity when the activity function is obtained
374
- activity = _RunningActivity ()
373
+ # _run_activity when the activity function is obtained. Max
374
+ # size of 1000 should be plenty for the heartbeat queue.
375
+ activity = _RunningActivity (pending_heartbeats = asyncio .Queue (1000 ))
375
376
activity .task = asyncio .create_task (
376
377
self ._run_activity (task .task_token , task .start , activity )
377
378
)
@@ -409,22 +410,27 @@ def _heartbeat_activity(self, task_token: bytes, *details: Any) -> None:
409
410
logger = temporalio .activity .logger
410
411
activity = self ._running_activities .get (task_token )
411
412
if activity and not activity .done :
412
- # Just set as next pending if one is already running
413
- coro = self ._heartbeat_activity_async (
414
- logger , activity , task_token , * details
413
+ # Put on queue and schedule a task. We will let the queue-full error
414
+ # be thrown here
415
+ activity .pending_heartbeats .put_nowait (details )
416
+ activity .last_heartbeat_task = asyncio .create_task (
417
+ self ._heartbeat_activity_async (logger , activity , task_token )
415
418
)
416
- if activity .current_heartbeat_task :
417
- activity .pending_heartbeat = coro
418
- else :
419
- activity .current_heartbeat_task = asyncio .create_task (coro )
420
419
421
420
async def _heartbeat_activity_async (
422
421
self ,
423
422
logger : logging .LoggerAdapter ,
424
423
activity : _RunningActivity ,
425
424
task_token : bytes ,
426
- * details : Any ,
427
425
) -> None :
426
+ # Drain the queue, only taking the last value to actually heartbeat
427
+ details : Optional [Iterable [Any ]] = None
428
+ while not activity .pending_heartbeats .empty ():
429
+ details = activity .pending_heartbeats .get_nowait ()
430
+ if details is None :
431
+ return
432
+
433
+ # Perform the heartbeat
428
434
try :
429
435
heartbeat = temporalio .bridge .proto .ActivityHeartbeat (task_token = task_token )
430
436
if details :
@@ -437,16 +443,7 @@ async def _heartbeat_activity_async(
437
443
)
438
444
logger .debug ("Recording heartbeat with details %s" , details )
439
445
self ._bridge_worker .record_activity_heartbeat (heartbeat )
440
- # If there is one pending, schedule it
441
- if activity .pending_heartbeat :
442
- activity .current_heartbeat_task = asyncio .create_task (
443
- activity .pending_heartbeat
444
- )
445
- activity .pending_heartbeat = None
446
- else :
447
- activity .current_heartbeat_task = None
448
446
except Exception as err :
449
- activity .current_heartbeat_task = None
450
447
# If the activity is done, nothing we can do but log
451
448
if activity .done :
452
449
logger .exception (
@@ -696,12 +693,12 @@ async def _run_activity(
696
693
697
694
# Do final completion
698
695
try :
699
- # We mark the activity as done and let the currently running (and next
700
- # pending) heartbeat task finish
696
+ # We mark the activity as done and let the currently running
697
+ # heartbeat task finish
701
698
running_activity .done = True
702
- while running_activity .current_heartbeat_task :
699
+ if running_activity .last_heartbeat_task :
703
700
try :
704
- await running_activity .current_heartbeat_task
701
+ await running_activity .last_heartbeat_task
705
702
except :
706
703
# Should never happen because it's trapped in-task
707
704
temporalio .activity .logger .exception (
@@ -749,12 +746,12 @@ class _ActivityDefinition:
749
746
750
747
@dataclass
751
748
class _RunningActivity :
749
+ pending_heartbeats : asyncio .Queue [Iterable [Any ]]
752
750
# Most of these optional values are set before use
753
751
info : Optional [temporalio .activity .Info ] = None
754
752
task : Optional [asyncio .Task ] = None
755
753
cancelled_event : Optional [temporalio .activity ._CompositeEvent ] = None
756
- pending_heartbeat : Optional [Coroutine ] = None
757
- current_heartbeat_task : Optional [asyncio .Task ] = None
754
+ last_heartbeat_task : Optional [asyncio .Task ] = None
758
755
sync : bool = False
759
756
done : bool = False
760
757
cancelled_by_request : bool = False
@@ -895,19 +892,16 @@ async def execute_activity(self, input: ExecuteActivityInput) -> Any:
895
892
# loop (even though it's sync). So we need a call that puts the
896
893
# context back on the activity and calls heartbeat, then another
897
894
# call schedules it.
898
- def heartbeat_with_context (* details : Any ) -> None :
895
+ async def heartbeat_with_context (* details : Any ) -> None :
899
896
temporalio .activity ._Context .set (ctx )
900
897
assert orig_heartbeat
901
898
orig_heartbeat (* details )
902
899
903
- def thread_safe_heartbeat (* details : Any ) -> None :
904
- # TODO(cretz): Final heartbeat can be flaky if we don't wait on
905
- # result here, but waiting on result of
906
- # asyncio.run_coroutine_threadsafe times out in rare cases.
907
- # Need more investigation: https://github.com/temporalio/sdk-python/issues/12
908
- loop .call_soon_threadsafe (heartbeat_with_context , * details )
909
-
910
- ctx .heartbeat = thread_safe_heartbeat
900
+ # Invoke the async heartbeat waiting a max of 10 seconds for
901
+ # accepting
902
+ ctx .heartbeat = lambda * details : asyncio .run_coroutine_threadsafe (
903
+ heartbeat_with_context (* details ), loop
904
+ ).result (10 )
911
905
912
906
# For heartbeats, we use the existing heartbeat callable for thread
913
907
# pool executors or a multiprocessing queue for others
@@ -917,7 +911,7 @@ def thread_safe_heartbeat(*details: Any) -> None:
917
911
# Should always be present in worker, pre-checked on init
918
912
shared_manager = input ._worker ._config ["shared_state_manager" ]
919
913
assert shared_manager
920
- heartbeat = shared_manager .register_heartbeater (
914
+ heartbeat = await shared_manager .register_heartbeater (
921
915
info .task_token , ctx .heartbeat
922
916
)
923
917
@@ -935,7 +929,7 @@ def thread_safe_heartbeat(*details: Any) -> None:
935
929
)
936
930
finally :
937
931
if shared_manager :
938
- shared_manager .unregister_heartbeater (info .task_token )
932
+ await shared_manager .unregister_heartbeater (info .task_token )
939
933
940
934
# Otherwise for async activity, just run
941
935
return await input .fn (* input .args )
@@ -1032,7 +1026,7 @@ def new_event(self) -> threading.Event:
1032
1026
raise NotImplementedError
1033
1027
1034
1028
@abstractmethod
1035
- def register_heartbeater (
1029
+ async def register_heartbeater (
1036
1030
self , task_token : bytes , heartbeat : Callable [..., None ]
1037
1031
) -> SharedHeartbeatSender :
1038
1032
"""Register a heartbeat function.
@@ -1048,7 +1042,7 @@ def register_heartbeater(
1048
1042
raise NotImplementedError
1049
1043
1050
1044
@abstractmethod
1051
- def unregister_heartbeater (self , task_token : bytes ) -> None :
1045
+ async def unregister_heartbeater (self , task_token : bytes ) -> None :
1052
1046
"""Unregisters a previously registered heartbeater for the task
1053
1047
token. This should also flush any pending heartbeats.
1054
1048
"""
@@ -1084,12 +1078,12 @@ def __init__(
1084
1078
1000
1085
1079
)
1086
1080
self ._heartbeats : Dict [bytes , Callable [..., None ]] = {}
1087
- self ._heartbeat_completions : Dict [bytes , Callable [[], None ] ] = {}
1081
+ self ._heartbeat_completions : Dict [bytes , Callable ] = {}
1088
1082
1089
1083
def new_event (self ) -> threading .Event :
1090
1084
return self ._mgr .Event ()
1091
1085
1092
- def register_heartbeater (
1086
+ async def register_heartbeater (
1093
1087
self , task_token : bytes , heartbeat : Callable [..., None ]
1094
1088
) -> SharedHeartbeatSender :
1095
1089
self ._heartbeats [task_token ] = heartbeat
@@ -1098,17 +1092,19 @@ def register_heartbeater(
1098
1092
self ._queue_poller_executor .submit (self ._heartbeat_processor )
1099
1093
return _MultiprocessingSharedHeartbeatSender (self ._heartbeat_queue )
1100
1094
1101
- def unregister_heartbeater (self , task_token : bytes ) -> None :
1102
- # Put a completion on the queue and wait for it to happen
1103
- flush_complete = threading .Event ()
1104
- self ._heartbeat_completions [task_token ] = flush_complete .set
1095
+ async def unregister_heartbeater (self , task_token : bytes ) -> None :
1096
+ # Put a callback on the queue and wait for it to happen
1097
+ loop = asyncio .get_running_loop ()
1098
+ finish_event = asyncio .Event ()
1099
+ self ._heartbeat_completions [task_token ] = lambda : loop .call_soon_threadsafe (
1100
+ finish_event .set
1101
+ )
1105
1102
try :
1106
- # 30 seconds to put complete, 30 to get notified should be plenty
1103
+ # We only give the queue a few seconds to have enough room
1107
1104
self ._heartbeat_queue .put (
1108
- (task_token , _multiprocess_heartbeat_complete ), True , 30
1105
+ (task_token , _multiprocess_heartbeat_complete ), True , 5
1109
1106
)
1110
- if not flush_complete .wait (30 ):
1111
- raise RuntimeError ("Timeout waiting for heartbeat flush" )
1107
+ await finish_event .wait ()
1112
1108
finally :
1113
1109
del self ._heartbeat_completions [task_token ]
1114
1110
0 commit comments