From 86d598d6d26c1ba4a3e1073dc68f2d18eccc7a12 Mon Sep 17 00:00:00 2001 From: Val Brodsky Date: Fri, 13 Sep 2024 13:46:18 -0700 Subject: [PATCH] Set tasks_completed_count to None if labeling has not started --- .../schema/labeling_service_dashboard.py | 52 ++++++------- .../unit/test_labeling_service_dashboard.py | 75 +++++++++++++++++++ 2 files changed, 98 insertions(+), 29 deletions(-) create mode 100644 libs/labelbox/tests/unit/test_labeling_service_dashboard.py diff --git a/libs/labelbox/src/labelbox/schema/labeling_service_dashboard.py b/libs/labelbox/src/labelbox/schema/labeling_service_dashboard.py index 10a956a66..2052897f6 100644 --- a/libs/labelbox/src/labelbox/schema/labeling_service_dashboard.py +++ b/libs/labelbox/src/labelbox/schema/labeling_service_dashboard.py @@ -4,13 +4,13 @@ from labelbox.exceptions import ResourceNotFoundError from labelbox.pagination import PaginatedCollection -from pydantic import BaseModel, root_validator, Field +from pydantic import BaseModel, model_validator, Field from labelbox.schema.search_filters import SearchFilter, build_search_filter from labelbox.utils import _CamelCaseMixin from .ontology_kind import EditorTaskType from labelbox.schema.media_type import MediaType from labelbox.schema.labeling_service_status import LabelingServiceStatus -from labelbox.utils import _CamelCaseMixin, sentence_case +from labelbox.utils import sentence_case GRAPHQL_QUERY_SELECTIONS = """ id @@ -58,7 +58,7 @@ class LabelingServiceDashboard(_CamelCaseMixin): status (LabelingServiceStatus): status of the labeling service data_rows_count (int): total number of data rows batched in the project tasks_completed_count (int): number of tasks completed (in the Done queue) - tasks_remaining_count (int): number of tasks remaining (in a queue other then Done) + tasks_remaining_count (int): number of tasks remaining (i.e. tasks in progress), None if labeling has not started tags (List[LabelingServiceDashboardTags]): tags associated with the project media_type (MediaType): media type of the project editor_task_type (EditorTaskType): editor task type of the project @@ -73,7 +73,7 @@ class LabelingServiceDashboard(_CamelCaseMixin): status: LabelingServiceStatus = Field(frozen=True, default=None) data_rows_count: int = Field(frozen=True) tasks_completed_count: int = Field(frozen=True) - tasks_remaining_count: int = Field(frozen=True) + tasks_remaining_count: Optional[int] = Field(frozen=True, default=None) media_type: Optional[MediaType] = Field(frozen=True, default=None) editor_task_type: EditorTaskType = Field(frozen=True, default=None) tags: List[LabelingServiceDashboardTags] = Field(frozen=True, default=None) @@ -84,8 +84,7 @@ def __init__(self, **kwargs): super().__init__(**kwargs) if not self.client.enable_experimental: raise RuntimeError( - "Please enable experimental in client to use LabelingService" - ) + "Please enable experimental in client to use LabelingService") @property def service_type(self): @@ -98,28 +97,20 @@ def service_type(self): if self.editor_task_type is None: return sentence_case(self.media_type.value) - if ( - self.editor_task_type == EditorTaskType.OfflineModelChatEvaluation - and self.media_type == MediaType.Conversational - ): + if (self.editor_task_type == EditorTaskType.OfflineModelChatEvaluation + and self.media_type == MediaType.Conversational): return "Offline chat evaluation" - if ( - self.editor_task_type == EditorTaskType.ModelChatEvaluation - and self.media_type == MediaType.Conversational - ): + if (self.editor_task_type == EditorTaskType.ModelChatEvaluation and + self.media_type == MediaType.Conversational): return "Live chat evaluation" - if ( - self.editor_task_type == EditorTaskType.ResponseCreation - and self.media_type == MediaType.Text - ): + if (self.editor_task_type == EditorTaskType.ResponseCreation and + self.media_type == MediaType.Text): return "Response creation" - if ( - self.media_type == MediaType.LLMPromptCreation - or self.media_type == MediaType.LLMPromptResponseCreation - ): + if (self.media_type == MediaType.LLMPromptCreation or + self.media_type == MediaType.LLMPromptResponseCreation): return "Prompt response creation" return sentence_case(self.media_type.value) @@ -163,8 +154,7 @@ def get_all( pageInfo { endCursor } } } - """ - ) + """) else: template = Template( """query SearchProjectsPyApi($$first: Int, $$from: String) { @@ -174,13 +164,11 @@ def get_all( pageInfo { endCursor } } } - """ - ) + """) query_str = template.substitute( labeling_dashboard_selections=GRAPHQL_QUERY_SELECTIONS, search_query=build_search_filter(search_query) - if search_query - else None, + if search_query else None, ) params: Dict[str, Union[str, int]] = {} @@ -198,7 +186,7 @@ def convert_to_labeling_service_dashboard(client, data): experimental=True, ) - @root_validator(pre=True) + @model_validator(mode='before') def convert_boost_data(cls, data): if "boostStatus" in data: data["status"] = LabelingServiceStatus(data.pop("boostStatus")) @@ -212,6 +200,12 @@ def convert_boost_data(cls, data): if "boostRequestedBy" in data: data["created_by_id"] = data.pop("boostRequestedBy") + tasks_remaining_count = data.get("tasksRemainingCount", 0) + tasks_total_count = data.get("tasksTotalCount", 0) + # to avoid confusion, setting tasks_completed_count to None if none of tasks has even completed an none are in flight + if tasks_total_count == 0 and tasks_remaining_count == 0: + data.pop("tasksRemainingCount") + return data def dict(self, *args, **kwargs): diff --git a/libs/labelbox/tests/unit/test_labeling_service_dashboard.py b/libs/labelbox/tests/unit/test_labeling_service_dashboard.py new file mode 100644 index 000000000..8ecdef2f1 --- /dev/null +++ b/libs/labelbox/tests/unit/test_labeling_service_dashboard.py @@ -0,0 +1,75 @@ +from unittest.mock import MagicMock + +from labelbox.schema.labeling_service_dashboard import LabelingServiceDashboard + + +def test_no_tasks_remaining_count(): + labeling_service_dashboard_data = { + 'id': 'cm0eeo4c301lg07061phfhva0', + 'name': 'TestStatus', + 'boostRequestedAt': '2024-08-28T22:08:07.446Z', + 'boostUpdatedAt': '2024-08-28T22:08:07.446Z', + 'boostRequestedBy': None, + 'boostStatus': 'SET_UP', + 'dataRowsCount': 0, + 'dataRowsDoneCount': 0, + 'dataRowsInReviewCount': 0, + 'dataRowsInReworkCount': 0, + 'tasksTotalCount': 0, + 'tasksCompletedCount': 0, + 'tasksRemainingCount': 0, + 'mediaType': 'image', + 'editorTaskType': None, + 'tags': [], + 'client': MagicMock() + } + lsd = LabelingServiceDashboard(**labeling_service_dashboard_data) + assert lsd.tasks_remaining_count is None + + +def test_tasks_remaining_count_exists(): + labeling_service_dashboard_data = { + 'id': 'cm0eeo4c301lg07061phfhva0', + 'name': 'TestStatus', + 'boostRequestedAt': '2024-08-28T22:08:07.446Z', + 'boostUpdatedAt': '2024-08-28T22:08:07.446Z', + 'boostRequestedBy': None, + 'boostStatus': 'SET_UP', + 'dataRowsCount': 0, + 'dataRowsDoneCount': 0, + 'dataRowsInReviewCount': 0, + 'dataRowsInReworkCount': 0, + 'tasksTotalCount': 0, + 'tasksCompletedCount': 0, + 'tasksRemainingCount': 1, + 'mediaType': 'image', + 'editorTaskType': None, + 'tags': [], + 'client': MagicMock() + } + lsd = LabelingServiceDashboard(**labeling_service_dashboard_data) + assert lsd.tasks_remaining_count == 1 + + +def test_tasks_total_no_tasks_remaining_count(): + labeling_service_dashboard_data = { + 'id': 'cm0eeo4c301lg07061phfhva0', + 'name': 'TestStatus', + 'boostRequestedAt': '2024-08-28T22:08:07.446Z', + 'boostUpdatedAt': '2024-08-28T22:08:07.446Z', + 'boostRequestedBy': None, + 'boostStatus': 'SET_UP', + 'dataRowsCount': 0, + 'dataRowsDoneCount': 0, + 'dataRowsInReviewCount': 1, + 'dataRowsInReworkCount': 0, + 'tasksTotalCount': 1, + 'tasksCompletedCount': 0, + 'tasksRemainingCount': 0, + 'mediaType': 'image', + 'editorTaskType': None, + 'tags': [], + 'client': MagicMock() + } + lsd = LabelingServiceDashboard(**labeling_service_dashboard_data) + assert lsd.tasks_remaining_count == 0