diff --git a/label_studio/core/utils/common.py b/label_studio/core/utils/common.py index 068e5606b32e..d1da18195d76 100644 --- a/label_studio/core/utils/common.py +++ b/label_studio/core/utils/common.py @@ -630,6 +630,24 @@ def batch(iterable, n=1): yield iterable[ndx : min(ndx + n, l)] +def batched_iterator(iterable, n): + """ + TODO: replace with itertools.batched when we drop support for Python < 3.12 + """ + + iterator = iter(iterable) + while True: + batch = [] + for _ in range(n): + try: + batch.append(next(iterator)) + except StopIteration: + if batch: + yield batch + return + yield batch + + def round_floats(o): if isinstance(o, float): return round(o, 2) diff --git a/label_studio/projects/functions/utils.py b/label_studio/projects/functions/utils.py index a98ca9e72733..75eda5ca9e2d 100644 --- a/label_studio/projects/functions/utils.py +++ b/label_studio/projects/functions/utils.py @@ -1,6 +1,7 @@ from logging import getLogger from typing import TYPE_CHECKING +from django.db.models import QuerySet from tasks.models import AnnotationDraft, Task logger = getLogger(__name__) @@ -10,6 +11,48 @@ from projects.models import Project, ProjectSummary +def get_unique_ids_list(tasks_queryset): + """ + Convert various input types to a list of unique IDs. + + :param tasks_queryset: Can be: + - list of IDs (integers) + - list of objects with 'id' attribute + - Django QuerySet + - set of IDs or objects + :return: list of unique IDs + """ + if isinstance(tasks_queryset, (list, tuple)): + if not tasks_queryset: + return [] + + # Check if it's a list of IDs (integers) + if isinstance(tasks_queryset[0], int): + return list(set(tasks_queryset)) # Remove duplicates + + # It's a list of objects with 'id' attribute + return list(set(obj.id for obj in tasks_queryset)) + + elif isinstance(tasks_queryset, set): + if not tasks_queryset: + return [] + + # Check if it's a set of IDs (integers) + first_item = next(iter(tasks_queryset)) + if isinstance(first_item, int): + return list(tasks_queryset) + + # It's a set of objects with 'id' attribute + return list(obj.id for obj in tasks_queryset) + + elif isinstance(tasks_queryset, QuerySet): + # It's a Django QuerySet + return list(tasks_queryset.values_list('id', flat=True)) + + else: + raise ValueError(f'Unsupported type for tasks_queryset: {type(tasks_queryset)}') + + def make_queryset_from_iterable(tasks_list): """ Make queryset from list/set of int/Tasks diff --git a/label_studio/projects/mixins.py b/label_studio/projects/mixins.py index 0fc49457a9bf..c09d59a411be 100644 --- a/label_studio/projects/mixins.py +++ b/label_studio/projects/mixins.py @@ -3,6 +3,7 @@ from core.redis import start_job_async_or_sync from django.db.models import QuerySet from django.utils.functional import cached_property +from projects.functions.utils import get_unique_ids_list if TYPE_CHECKING: from users.models import User @@ -22,11 +23,8 @@ def update_tasks_counters_and_is_labeled(self, tasks_queryset, from_scratch=True :param from_scratch: Skip calculated tasks """ # get only id from queryset to decrease data size in job - if not (isinstance(tasks_queryset, set) or isinstance(tasks_queryset, list)): - tasks_queryset = set(tasks_queryset.values_list('id', flat=True)) - start_job_async_or_sync( - self._update_tasks_counters_and_is_labeled, list(tasks_queryset), from_scratch=from_scratch - ) + task_ids = get_unique_ids_list(tasks_queryset) + start_job_async_or_sync(self._update_tasks_counters_and_is_labeled, task_ids, from_scratch=from_scratch) def update_tasks_counters_and_task_states( self, @@ -46,11 +44,10 @@ def update_tasks_counters_and_task_states( :param from_scratch: Skip calculated tasks """ # get only id from queryset to decrease data size in job - if not (isinstance(tasks_queryset, set) or isinstance(tasks_queryset, list)): - tasks_queryset = set(tasks_queryset.values_list('id', flat=True)) + task_ids = get_unique_ids_list(tasks_queryset) start_job_async_or_sync( self._update_tasks_counters_and_task_states, - tasks_queryset, + task_ids, maximum_annotations_changed, overlap_cohort_percentage_changed, tasks_number_changed, diff --git a/label_studio/tasks/functions.py b/label_studio/tasks/functions.py index 9f9312d404a9..2dbc5daba641 100644 --- a/label_studio/tasks/functions.py +++ b/label_studio/tasks/functions.py @@ -4,16 +4,14 @@ import shutil import sys -from core.bulk_update_utils import bulk_update from core.models import AsyncMigrationStatus from core.redis import start_job_async_or_sync -from core.utils.common import batch +from core.utils.common import batch, batched_iterator from data_export.mixins import ExportMixin from data_export.models import DataExport from data_export.serializers import ExportDataSerializer from data_manager.managers import TaskQuerySet from django.conf import settings -from django.db import transaction from django.db.models import Count, Q from organizations.models import Organization from projects.models import Project @@ -181,8 +179,6 @@ def update_tasks_counters(queryset, from_scratch=True): :param from_scratch: Skip calculated tasks :return: Count of updated tasks """ - objs = [] - total_annotations = Count('annotations', distinct=True, filter=Q(annotations__was_cancelled=False)) cancelled_annotations = Count('annotations', distinct=True, filter=Q(annotations__was_cancelled=True)) total_predictions = Count('predictions', distinct=True) @@ -211,15 +207,26 @@ def update_tasks_counters(queryset, from_scratch=True): new_total_predictions=total_predictions, ) - for task in queryset.only('id', 'total_annotations', 'cancelled_annotations', 'total_predictions'): - task.total_annotations = task.new_total_annotations - task.cancelled_annotations = task.new_cancelled_annotations - task.total_predictions = task.new_total_predictions - objs.append(task) - with transaction.atomic(): - bulk_update( - objs, - update_fields=['total_annotations', 'cancelled_annotations', 'total_predictions'], - batch_size=settings.BATCH_SIZE, - ) - return len(objs) + updated_count = 0 + + tasks_iterator = queryset.only('id', 'total_annotations', 'cancelled_annotations', 'total_predictions').iterator( + chunk_size=settings.BATCH_SIZE + ) + + for _batch in batched_iterator(tasks_iterator, settings.BATCH_SIZE): + batch_list = [] + for task in _batch: + task.total_annotations = task.new_total_annotations + task.cancelled_annotations = task.new_cancelled_annotations + task.total_predictions = task.new_total_predictions + batch_list.append(task) + + if batch_list: + Task.objects.bulk_update( + batch_list, + ['total_annotations', 'cancelled_annotations', 'total_predictions'], + batch_size=settings.BATCH_SIZE, + ) + updated_count += len(batch_list) + + return updated_count