Skip to content

Commit bda9b1d

Browse files
authored
Add distributed lock for scheduled task (#732)
* Add distributed lock for scheduled task * Add the task to extend lock * Fix the close
1 parent e0a106e commit bda9b1d

File tree

3 files changed

+122
-21
lines changed

3 files changed

+122
-21
lines changed

backend/app/task/celery.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ def init_celery() -> celery.Celery:
4343
'group': OVERWRITE_CELERY_RESULT_GROUP_TABLE_NAME,
4444
},
4545
result_extended=True,
46-
# result_expires=0, # 任务结果自动清理,0 或 None 表示不清理
46+
# result_expires=0, # 清理任务结果,默认每天凌晨 4 点,0 或 None 表示不清理
47+
# beat_sync_every=1, # 保存任务状态周期,默认 3 * 60 秒
4748
beat_schedule=LOCAL_BEAT_SCHEDULE,
4849
beat_scheduler='backend.app.task.utils.schedulers:DatabaseScheduler',
4950
task_cls='backend.app.task.tasks.base:TaskBase',

backend/app/task/utils/schedulers.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99

1010
from celery import current_app, schedules
1111
from celery.beat import ScheduleEntry, Scheduler
12+
from celery.signals import beat_init
1213
from celery.utils.log import get_logger
14+
from redis.asyncio.lock import Lock
1315
from sqlalchemy import select
1416
from sqlalchemy.exc import DatabaseError, InterfaceError
1517

@@ -28,6 +30,38 @@
2830
# 此计划程序必须比常规的 5 分钟更频繁地唤醒,因为它需要考虑对计划的外部更改
2931
DEFAULT_MAX_INTERVAL = 5 # seconds
3032

33+
# 计划锁时长,避免重复创建
34+
DEFAULT_MAX_LOCK_TIMEOUT = 300 # seconds
35+
36+
# 锁检测周期,应小于计划锁时长
37+
DEFAULT_LOCK_INTERVAL = 60 # seconds
38+
39+
# Copied from:
40+
# https://github.com/andymccurdy/redis-py/blob/master/redis/lock.py#L33
41+
# Changes:
42+
# The second line from the bottom: The original Lua script intends
43+
# to extend time to (lock remaining time + additional time); while
44+
# the script here extend time to an expected expiration time.
45+
# KEYS[1] - lock name
46+
# ARGS[1] - token
47+
# ARGS[2] - additional milliseconds
48+
# return 1 if the locks time was extended, otherwise 0
49+
LUA_EXTEND_TO_SCRIPT = """
50+
local token = redis.call('get', KEYS[1])
51+
if not token or token ~= ARGV[1] then
52+
return 0
53+
end
54+
local expiration = redis.call('pttl', KEYS[1])
55+
if not expiration then
56+
expiration = 0
57+
end
58+
if expiration < 0 then
59+
return 0
60+
end
61+
redis.call('pexpire', KEYS[1], ARGV[2])
62+
return 1
63+
"""
64+
3165
logger = get_logger('fba.schedulers')
3266

3367

@@ -260,13 +294,18 @@ def _unpack_options(
260294

261295

262296
class DatabaseScheduler(Scheduler):
297+
"""数据库调度程序"""
298+
263299
Entry = ModelEntry
264300

265301
_schedule = None
266302
_last_update = None
267303
_initial_read = True
268304
_heap_invalidated = False
269305

306+
lock: Lock | None = None
307+
lock_key = f'{settings.CELERY_REDIS_PREFIX}:beat_lock'
308+
270309
def __init__(self, *args, **kwargs):
271310
self.app = kwargs['app']
272311
self._dirty = set()
@@ -315,6 +354,16 @@ def reserve(self, entry):
315354
self._dirty.add(new_entry.name)
316355
return new_entry
317356

357+
def close(self):
358+
"""重写父函数"""
359+
if self.lock:
360+
logger.info('beat: Releasing lock')
361+
if run_await(self.lock.owned)():
362+
run_await(self.lock.release)()
363+
self.lock = None
364+
365+
super().close()
366+
318367
def sync(self):
319368
"""重写父函数"""
320369
_tried = set()
@@ -401,3 +450,48 @@ def schedule(self) -> dict[str, ModelEntry]:
401450

402451
# logger.debug(self._schedule)
403452
return self._schedule
453+
454+
455+
async def extend_scheduler_lock(lock):
456+
"""
457+
延长调度程序锁
458+
459+
:param lock: 计划程序锁
460+
:return:
461+
"""
462+
while True:
463+
await asyncio.sleep(DEFAULT_LOCK_INTERVAL)
464+
if lock:
465+
try:
466+
await lock.extend(DEFAULT_MAX_LOCK_TIMEOUT)
467+
except Exception as e:
468+
logger.error(f'Failed to extend lock: {e}')
469+
470+
471+
@beat_init.connect
472+
def acquire_distributed_beat_lock(sender=None, *args, **kwargs):
473+
"""
474+
尝试在启动时获取锁
475+
476+
:param sender: 接收方应响应的发送方
477+
:return:
478+
"""
479+
scheduler = sender.scheduler
480+
if not scheduler.lock_key:
481+
return
482+
483+
logger.debug('beat: Acquiring lock...')
484+
lock = redis_client.lock(
485+
scheduler.lock_key,
486+
timeout=DEFAULT_MAX_LOCK_TIMEOUT,
487+
sleep=scheduler.max_interval,
488+
)
489+
# overwrite redis-py's extend script
490+
# which will add additional timeout instead of extend to a new timeout
491+
lock.lua_extend = redis_client.register_script(LUA_EXTEND_TO_SCRIPT)
492+
run_await(lock.acquire)()
493+
logger.info('beat: Acquired lock')
494+
scheduler.lock = lock
495+
496+
loop = asyncio.get_event_loop()
497+
loop.create_task(extend_scheduler_lock(scheduler.lock))

backend/utils/_await.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
1-
#!/usr/bin/env python3
2-
# -*- coding: utf-8 -*-
31
import asyncio
42
import atexit
53
import threading
64
import weakref
75

8-
from typing import Awaitable, Callable, TypeVar
6+
from functools import wraps
7+
from typing import Any, Awaitable, Callable, Coroutine, TypeVar
98

109
T = TypeVar('T')
1110

1211

1312
class _TaskRunner:
14-
"""A task runner that runs an asyncio event loop on a background thread."""
13+
"""在后台线程上运行 asyncio 事件循环的任务运行器"""
1514

1615
def __init__(self):
1716
self.__loop: asyncio.AbstractEventLoop | None = None
@@ -20,48 +19,55 @@ def __init__(self):
2019
atexit.register(self.close)
2120

2221
def close(self):
23-
"""关闭事件循环"""
22+
"""关闭事件循环并清理"""
2423
if self.__loop:
2524
self.__loop.stop()
25+
self.__loop = None
26+
if self.__thread:
27+
self.__thread.join()
28+
self.__thread = None
29+
name = f'TaskRunner-{threading.get_ident()}'
30+
_runner_map.pop(name, None)
2631

2732
def _target(self):
28-
"""后台线程目标"""
29-
loop = self.__loop
33+
"""后台线程的目标函数"""
3034
try:
31-
loop.run_forever()
35+
self.__loop.run_forever()
3236
finally:
33-
loop.close()
37+
self.__loop.close()
3438

35-
def run(self, coro):
36-
"""在后台线程上同步运行协程"""
39+
def run(self, coro: Awaitable[T]) -> T:
40+
"""在后台事件循环上运行协程并返回其结果"""
3741
with self.__lock:
38-
name = f'{threading.current_thread().name} - runner'
42+
name = f'TaskRunner-{threading.get_ident()}'
3943
if self.__loop is None:
4044
self.__loop = asyncio.new_event_loop()
4145
self.__thread = threading.Thread(target=self._target, daemon=True, name=name)
4246
self.__thread.start()
43-
fut = asyncio.run_coroutine_threadsafe(coro, self.__loop)
44-
return fut.result(None)
47+
future = asyncio.run_coroutine_threadsafe(coro, self.__loop)
48+
return future.result()
4549

4650

4751
_runner_map = weakref.WeakValueDictionary()
4852

4953

50-
def run_await(coro: Callable[..., Awaitable[T]]) -> Callable[..., T]:
51-
"""将协程包装在一个函数中,该函数会阻塞,直到它执行完为止"""
54+
def run_await(coro: Callable[..., Awaitable[T]] | Callable[..., Coroutine[Any, Any, T]]) -> Callable[..., T]:
55+
"""将协程包装在函数中,该函数将在后台事件循环上运行,直到它执行完为止"""
5256

57+
@wraps(coro)
5358
def wrapped(*args, **kwargs):
54-
name = threading.current_thread().name
5559
inner = coro(*args, **kwargs)
60+
if not asyncio.iscoroutine(inner) and not asyncio.isfuture(inner):
61+
raise TypeError(f'Expected coroutine, got {type(inner)}')
5662
try:
57-
# 如果当前此线程中正在运行循环
58-
# 使用任务运行程序
63+
# 如果事件循环正在运行,则使用任务调用
5964
asyncio.get_running_loop()
65+
name = f'TaskRunner-{threading.get_ident()}'
6066
if name not in _runner_map:
6167
_runner_map[name] = _TaskRunner()
6268
return _runner_map[name].run(inner)
6369
except RuntimeError:
64-
# 如果没有,请创建一个新的事件循环
70+
# 如果没有,则创建一个新的事件循环
6571
loop = asyncio.get_event_loop()
6672
return loop.run_until_complete(inner)
6773

0 commit comments

Comments
 (0)