Skip to content

Commit d72cad3

Browse files
gianbotdvdria
andauthored
Fix atomic transaction not routing to the the correct DB in DatabaseBackend.on_chord_part_return transaction.atomic (#427)
* using ChordCounter.objects.db in DatabaseBackend.on_chord_part_return transaction.atomic * WIP testing on chord part return with multiple databases * WIP testing on chord part return with multiple databases pre-committed * Completed testing on chord part return with multiple databases * Changed tests and transaction atomic using according to pull request message * Removed ports from settings --------- Co-authored-by: Davide Ria <d.ria@frankhood.it>
1 parent 23265e6 commit d72cad3

File tree

3 files changed

+103
-2
lines changed

3 files changed

+103
-2
lines changed

django_celery_results/backends/database.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from celery.result import GroupResult, allow_join_result, result_from_tuple
88
from celery.utils.log import get_logger
99
from celery.utils.serialization import b64decode, b64encode
10-
from django.db import connection, transaction
10+
from django.db import connection, router, transaction
1111
from django.db.utils import InterfaceError
1212
from kombu.exceptions import DecodeError
1313

@@ -246,7 +246,7 @@ def on_chord_part_return(self, request, state, result, **kwargs):
246246
if not gid or not tid:
247247
return
248248
call_callback = False
249-
with transaction.atomic():
249+
with transaction.atomic(using=router.db_for_write(ChordCounter)):
250250
# We need to know if `count` hits 0.
251251
# wrap the update in a transaction
252252
# with a `select_for_update` lock to prevent race conditions.

t/proj/settings.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,20 @@
4949
'MIRROR': 'default',
5050
},
5151
},
52+
'read-only': {
53+
'ENGINE': 'django.db.backends.postgresql',
54+
'HOST': 'localhost',
55+
'NAME': 'read-only-database',
56+
'USER': 'postgres',
57+
'PASSWORD': 'postgres',
58+
'OPTIONS': {
59+
'connect_timeout': 1000,
60+
'options': '-c default_transaction_read_only=on',
61+
},
62+
'TEST': {
63+
'MIRROR': 'default',
64+
},
65+
}
5266
}
5367
except ImportError:
5468
DATABASES = {
@@ -66,6 +80,13 @@
6680
'timeout': 1000,
6781
}
6882
},
83+
'read-only': {
84+
'ENGINE': 'django.db.backends.sqlite3',
85+
'NAME': os.path.join(BASE_DIR, 'db.sqlite3'),
86+
'OPTIONS': {
87+
'timeout': 1000,
88+
}
89+
}
6990
}
7091

7192
# Quick-start development settings - unsuitable for production

t/unit/backends/test_database.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from celery.utils.serialization import b64decode
1212
from celery.worker.request import Request
1313
from celery.worker.strategy import hybrid_to_proto2
14+
from django.test import TransactionTestCase
1415

1516
from django_celery_results.backends.database import DatabaseBackend
1617
from django_celery_results.models import ChordCounter, TaskResult
@@ -919,3 +920,82 @@ def test_backend_result_extended_is_false(self):
919920
tr = TaskResult.objects.get(task_id=tid2)
920921
assert tr.task_args is None
921922
assert tr.task_kwargs is None
923+
924+
925+
class DjangoCeleryResultRouter:
926+
route_app_labels = {"django_celery_results"}
927+
928+
def db_for_read(self, model, **hints):
929+
"""Route read access to the read-only database"""
930+
if model._meta.app_label in self.route_app_labels:
931+
return "read-only"
932+
return None
933+
934+
def db_for_write(self, model, **hints):
935+
"""Route write access to the default database"""
936+
if model._meta.app_label in self.route_app_labels:
937+
return "default"
938+
return None
939+
940+
941+
class ChordPartReturnTestCase(TransactionTestCase):
942+
databases = {"default", "read-only"}
943+
944+
def setUp(self):
945+
super().setUp()
946+
self.app.conf.result_serializer = 'json'
947+
self.app.conf.result_backend = (
948+
'django_celery_results.backends:DatabaseBackend')
949+
self.app.conf.result_extended = True
950+
self.b = DatabaseBackend(app=self.app)
951+
952+
def test_on_chord_part_return_multiple_databases(self):
953+
"""
954+
Test if the ChordCounter is properly decremented and the callback is
955+
triggered after all chord parts have returned with multiple databases
956+
"""
957+
with self.settings(DATABASE_ROUTERS=[DjangoCeleryResultRouter()]):
958+
gid = uuid()
959+
tid1 = uuid()
960+
tid2 = uuid()
961+
subtasks = [AsyncResult(tid1), AsyncResult(tid2)]
962+
group = GroupResult(id=gid, results=subtasks)
963+
964+
assert ChordCounter.objects.count() == 0
965+
assert ChordCounter.objects.using("read-only").count() == 0
966+
assert ChordCounter.objects.using("default").count() == 0
967+
968+
self.b.apply_chord(group, self.add.s())
969+
970+
# Check if the ChordCounter was created in the correct database
971+
assert ChordCounter.objects.count() == 1
972+
assert ChordCounter.objects.using("read-only").count() == 1
973+
assert ChordCounter.objects.using("default").count() == 1
974+
975+
chord_counter = ChordCounter.objects.get(group_id=gid)
976+
assert chord_counter.count == 2
977+
978+
request = mock.MagicMock()
979+
request.id = subtasks[0].id
980+
request.group = gid
981+
request.task = "my_task"
982+
request.args = ["a", 1, "password"]
983+
request.kwargs = {"c": 3, "d": "e", "password": "password"}
984+
request.argsrepr = "argsrepr"
985+
request.kwargsrepr = "kwargsrepr"
986+
request.hostname = "celery@ip-0-0-0-0"
987+
request.properties = {"periodic_task_name": "my_periodic_task"}
988+
request.ignore_result = False
989+
result = {"foo": "baz"}
990+
991+
self.b.mark_as_done(tid1, result, request=request)
992+
993+
chord_counter.refresh_from_db()
994+
assert chord_counter.count == 1
995+
996+
self.b.mark_as_done(tid2, result, request=request)
997+
998+
with pytest.raises(ChordCounter.DoesNotExist):
999+
ChordCounter.objects.get(group_id=gid)
1000+
1001+
request.chord.delay.assert_called_once()

0 commit comments

Comments
 (0)