|
11 | 11 | from celery.utils.serialization import b64decode
|
12 | 12 | from celery.worker.request import Request
|
13 | 13 | from celery.worker.strategy import hybrid_to_proto2
|
| 14 | +from django.test import TransactionTestCase |
14 | 15 |
|
15 | 16 | from django_celery_results.backends.database import DatabaseBackend
|
16 | 17 | from django_celery_results.models import ChordCounter, TaskResult
|
@@ -919,3 +920,82 @@ def test_backend_result_extended_is_false(self):
|
919 | 920 | tr = TaskResult.objects.get(task_id=tid2)
|
920 | 921 | assert tr.task_args is None
|
921 | 922 | assert tr.task_kwargs is None
|
| 923 | + |
| 924 | + |
| 925 | +class DjangoCeleryResultRouter: |
| 926 | + route_app_labels = {"django_celery_results"} |
| 927 | + |
| 928 | + def db_for_read(self, model, **hints): |
| 929 | + """Route read access to the read-only database""" |
| 930 | + if model._meta.app_label in self.route_app_labels: |
| 931 | + return "read-only" |
| 932 | + return None |
| 933 | + |
| 934 | + def db_for_write(self, model, **hints): |
| 935 | + """Route write access to the default database""" |
| 936 | + if model._meta.app_label in self.route_app_labels: |
| 937 | + return "default" |
| 938 | + return None |
| 939 | + |
| 940 | + |
| 941 | +class ChordPartReturnTestCase(TransactionTestCase): |
| 942 | + databases = {"default", "read-only"} |
| 943 | + |
| 944 | + def setUp(self): |
| 945 | + super().setUp() |
| 946 | + self.app.conf.result_serializer = 'json' |
| 947 | + self.app.conf.result_backend = ( |
| 948 | + 'django_celery_results.backends:DatabaseBackend') |
| 949 | + self.app.conf.result_extended = True |
| 950 | + self.b = DatabaseBackend(app=self.app) |
| 951 | + |
| 952 | + def test_on_chord_part_return_multiple_databases(self): |
| 953 | + """ |
| 954 | + Test if the ChordCounter is properly decremented and the callback is |
| 955 | + triggered after all chord parts have returned with multiple databases |
| 956 | + """ |
| 957 | + with self.settings(DATABASE_ROUTERS=[DjangoCeleryResultRouter()]): |
| 958 | + gid = uuid() |
| 959 | + tid1 = uuid() |
| 960 | + tid2 = uuid() |
| 961 | + subtasks = [AsyncResult(tid1), AsyncResult(tid2)] |
| 962 | + group = GroupResult(id=gid, results=subtasks) |
| 963 | + |
| 964 | + assert ChordCounter.objects.count() == 0 |
| 965 | + assert ChordCounter.objects.using("read-only").count() == 0 |
| 966 | + assert ChordCounter.objects.using("default").count() == 0 |
| 967 | + |
| 968 | + self.b.apply_chord(group, self.add.s()) |
| 969 | + |
| 970 | + # Check if the ChordCounter was created in the correct database |
| 971 | + assert ChordCounter.objects.count() == 1 |
| 972 | + assert ChordCounter.objects.using("read-only").count() == 1 |
| 973 | + assert ChordCounter.objects.using("default").count() == 1 |
| 974 | + |
| 975 | + chord_counter = ChordCounter.objects.get(group_id=gid) |
| 976 | + assert chord_counter.count == 2 |
| 977 | + |
| 978 | + request = mock.MagicMock() |
| 979 | + request.id = subtasks[0].id |
| 980 | + request.group = gid |
| 981 | + request.task = "my_task" |
| 982 | + request.args = ["a", 1, "password"] |
| 983 | + request.kwargs = {"c": 3, "d": "e", "password": "password"} |
| 984 | + request.argsrepr = "argsrepr" |
| 985 | + request.kwargsrepr = "kwargsrepr" |
| 986 | + request.hostname = "celery@ip-0-0-0-0" |
| 987 | + request.properties = {"periodic_task_name": "my_periodic_task"} |
| 988 | + request.ignore_result = False |
| 989 | + result = {"foo": "baz"} |
| 990 | + |
| 991 | + self.b.mark_as_done(tid1, result, request=request) |
| 992 | + |
| 993 | + chord_counter.refresh_from_db() |
| 994 | + assert chord_counter.count == 1 |
| 995 | + |
| 996 | + self.b.mark_as_done(tid2, result, request=request) |
| 997 | + |
| 998 | + with pytest.raises(ChordCounter.DoesNotExist): |
| 999 | + ChordCounter.objects.get(group_id=gid) |
| 1000 | + |
| 1001 | + request.chord.delay.assert_called_once() |
0 commit comments