From 669c3fa2cb73fc48d7fc1b7ab2fa481dcee3a0e8 Mon Sep 17 00:00:00 2001 From: Val Brodsky Date: Mon, 19 Aug 2024 15:02:01 -0700 Subject: [PATCH 1/5] Add support for service_type --- .../schema/labeling_service_dashboard.py | 28 ++++++++++++++++++- .../src/labelbox/schema/ontology_kind.py | 22 +++++++++------ libs/labelbox/src/labelbox/utils.py | 9 ++++++ libs/labelbox/tests/unit/test_utils.py | 10 ++++++- 4 files changed, 59 insertions(+), 10 deletions(-) diff --git a/libs/labelbox/src/labelbox/schema/labeling_service_dashboard.py b/libs/labelbox/src/labelbox/schema/labeling_service_dashboard.py index 9621008e3..d97bf72bf 100644 --- a/libs/labelbox/src/labelbox/schema/labeling_service_dashboard.py +++ b/libs/labelbox/src/labelbox/schema/labeling_service_dashboard.py @@ -7,7 +7,10 @@ from labelbox.pydantic_compat import BaseModel, root_validator, Field from labelbox.schema.search_filters import SearchFilter, build_search_filter from labelbox.utils import _CamelCaseMixin +from .ontology_kind import EditorTaskType +from labelbox.schema.media_type import MediaType from labelbox.schema.labeling_service_status import LabelingServiceStatus +from labelbox.utils import _CamelCaseMixin, sentence_case GRAPHQL_QUERY_SELECTIONS = """ id @@ -21,6 +24,8 @@ dataRowsInReviewCount dataRowsInReworkCount dataRowsDoneCount + mediaType + editorTaskType """ @@ -38,7 +43,6 @@ class LabelingServiceDashboard(BaseModel): """ id: str = Field(frozen=True) name: str = Field(frozen=True) - service_type: Optional[str] = Field(frozen=True, default=None) created_at: Optional[datetime] = Field(frozen=True, default=None) updated_at: Optional[datetime] = Field(frozen=True, default=None) created_by_id: Optional[str] = Field(frozen=True, default=None) @@ -48,6 +52,9 @@ class LabelingServiceDashboard(BaseModel): data_rows_in_review_count: int = Field(frozen=True) data_rows_in_rework_count: int = Field(frozen=True) data_rows_done_count: int = Field(frozen=True) + media_type: MediaType = Field(frozen=True, default=MediaType.Unknown) + editor_task_type: EditorTaskType = Field(frozen=True, + default=EditorTaskType.Missing) client: Any # type Any to avoid circular import from client @@ -65,6 +72,25 @@ def tasks_completed(self): def tasks_remaining(self): return self.data_rows_count - self.data_rows_done_count + @property + def service_type(self): + if self.editor_task_type is None: + return sentence_case(self.media_type.value) + + if self.editor_task_type == EditorTaskType.OfflineModelChatEvaluation and self.media_type == MediaType.Conversational: + return "Offline chat evaluation" + + if self.editor_task_type == EditorTaskType.ModelChatEvaluation and self.media_type == MediaType.Conversational: + return "Live chat evaluation" + + if self.editor_task_type == EditorTaskType.ResponseCreation and self.media_type == MediaType.Text: + return "Response creation" + + if media_type == MediaType.LLMPromptCreation or media_type == MediaType.LLMPromptResponseCreation: + return "Prompt response creation" + + return sentence_case(self.media_type.value) + class Config(_CamelCaseMixin.Config): ... diff --git a/libs/labelbox/src/labelbox/schema/ontology_kind.py b/libs/labelbox/src/labelbox/schema/ontology_kind.py index e8b4475ae..7dd3311cb 100644 --- a/libs/labelbox/src/labelbox/schema/ontology_kind.py +++ b/libs/labelbox/src/labelbox/schema/ontology_kind.py @@ -21,24 +21,30 @@ def get_ontology_kind_validation_error(cls, ontology_kind): return TypeError(f"{ontology_kind}: is not a valid ontology kind. Use" f" any of {OntologyKind.__members__.items()}" " from OntologyKind.") - + @staticmethod - def evaluate_ontology_kind_with_media_type(ontology_kind, - media_type: Optional[MediaType]) -> Union[MediaType, None]: - + def evaluate_ontology_kind_with_media_type( + ontology_kind, + media_type: Optional[MediaType]) -> Union[MediaType, None]: + ontology_to_media = { - OntologyKind.ModelEvaluation: (MediaType.Conversational, "For chat evaluation, media_type must be Conversational."), - OntologyKind.ResponseCreation: (MediaType.Text, "For response creation, media_type must be Text.") + OntologyKind.ModelEvaluation: + (MediaType.Conversational, + "For chat evaluation, media_type must be Conversational."), + OntologyKind.ResponseCreation: + (MediaType.Text, + "For response creation, media_type must be Text.") } if ontology_kind in ontology_to_media: - expected_media_type, error_message = ontology_to_media[ontology_kind] + expected_media_type, error_message = ontology_to_media[ + ontology_kind] if media_type is None or media_type == expected_media_type: media_type = expected_media_type else: raise ValueError(error_message) - + return media_type diff --git a/libs/labelbox/src/labelbox/utils.py b/libs/labelbox/src/labelbox/utils.py index f606932c7..dc285e60b 100644 --- a/libs/labelbox/src/labelbox/utils.py +++ b/libs/labelbox/src/labelbox/utils.py @@ -39,6 +39,15 @@ def snake_case(s): return _convert(s, "_", lambda i: False) +def sentence_case(s: str) -> str: + """ Converts a string in [snake|camel|title]case to Sentence case. """ + # Replace underscores with spaces and convert to lower case + sentence_str = s.replace("_", " ").lower() + # Capitalize the first letter of each word + sentence_str = sentence_str.capitalize() + return sentence_str + + def is_exactly_one_set(*args): return sum([bool(arg) for arg in args]) == 1 diff --git a/libs/labelbox/tests/unit/test_utils.py b/libs/labelbox/tests/unit/test_utils.py index 129edcd72..dfd72c335 100644 --- a/libs/labelbox/tests/unit/test_utils.py +++ b/libs/labelbox/tests/unit/test_utils.py @@ -1,5 +1,5 @@ import pytest -from labelbox.utils import format_iso_datetime, format_iso_from_string +from labelbox.utils import format_iso_datetime, format_iso_from_string, sentence_case @pytest.mark.parametrize('datetime_str, expected_datetime_str', @@ -11,3 +11,11 @@ def test_datetime_parsing(datetime_str, expected_datetime_str): # NOTE I would normally not take 'expected' using another function from sdk code, but in this case this is exactly the usage in _validate_parse_datetime assert format_iso_datetime( format_iso_from_string(datetime_str)) == expected_datetime_str + + +@pytest.mark.parametrize( + 'str, expected_str', + [('AUDIO', 'Audio'), + ('LLM_PROMPT_RESPONSE_CREATION', 'Llm prompt response creation')]) +def test_sentence_case(str, expected_str): + assert sentence_case(str) == expected_str From 5093d42e94d452211605bbd11699640ceeb9ed91 Mon Sep 17 00:00:00 2001 From: Val Brodsky Date: Mon, 19 Aug 2024 15:14:40 -0700 Subject: [PATCH 2/5] Add all other attributes --- .../schema/labeling_service_dashboard.py | 31 ++++++++++++------- .../src/labelbox/schema/media_type.py | 8 ++--- 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/libs/labelbox/src/labelbox/schema/labeling_service_dashboard.py b/libs/labelbox/src/labelbox/schema/labeling_service_dashboard.py index d97bf72bf..5d9c322ab 100644 --- a/libs/labelbox/src/labelbox/schema/labeling_service_dashboard.py +++ b/libs/labelbox/src/labelbox/schema/labeling_service_dashboard.py @@ -15,10 +15,9 @@ GRAPHQL_QUERY_SELECTIONS = """ id name - # serviceType - # createdAt - # updatedAt - # createdById + boostRequestedAt + boostUpdatedAt + boostRequestedBy boostStatus dataRowsCount dataRowsInReviewCount @@ -46,15 +45,13 @@ class LabelingServiceDashboard(BaseModel): created_at: Optional[datetime] = Field(frozen=True, default=None) updated_at: Optional[datetime] = Field(frozen=True, default=None) created_by_id: Optional[str] = Field(frozen=True, default=None) - status: LabelingServiceStatus = Field(frozen=True, - default=LabelingServiceStatus.Missing) + status: LabelingServiceStatus = Field(frozen=True, default=None) data_rows_count: int = Field(frozen=True) data_rows_in_review_count: int = Field(frozen=True) data_rows_in_rework_count: int = Field(frozen=True) data_rows_done_count: int = Field(frozen=True) - media_type: MediaType = Field(frozen=True, default=MediaType.Unknown) - editor_task_type: EditorTaskType = Field(frozen=True, - default=EditorTaskType.Missing) + media_type: Optional[MediaType] = Field(frozen=True, default=None) + editor_task_type: EditorTaskType = Field(frozen=True, default=None) client: Any # type Any to avoid circular import from client @@ -74,6 +71,9 @@ def tasks_remaining(self): @property def service_type(self): + if self.media_type is None: + return None + if self.editor_task_type is None: return sentence_case(self.media_type.value) @@ -86,7 +86,7 @@ def service_type(self): if self.editor_task_type == EditorTaskType.ResponseCreation and self.media_type == MediaType.Text: return "Response creation" - if media_type == MediaType.LLMPromptCreation or media_type == MediaType.LLMPromptResponseCreation: + if self.media_type == MediaType.LLMPromptCreation or self.media_type == MediaType.LLMPromptResponseCreation: return "Prompt response creation" return sentence_case(self.media_type.value) @@ -167,8 +167,17 @@ def convert_to_labeling_service_dashboard(client, data): ) @root_validator(pre=True) - def convert_boost_status_to_enum(cls, data): + def convert_boost_data(cls, data): if 'boostStatus' in data: data['status'] = LabelingServiceStatus(data.pop('boostStatus')) + if 'boostRequestedAt' in data: + data['created_at'] = data.pop('boostRequestedAt') + + if 'boostUpdatedAt' in data: + data['updated_at'] = data.pop('boostUpdatedAt') + + if 'boostRequestedBy' in data: + data['created_by_id'] = data.pop('boostRequestedBy') + return data diff --git a/libs/labelbox/src/labelbox/schema/media_type.py b/libs/labelbox/src/labelbox/schema/media_type.py index 22a66ff5e..d3231e8e8 100644 --- a/libs/labelbox/src/labelbox/schema/media_type.py +++ b/libs/labelbox/src/labelbox/schema/media_type.py @@ -23,17 +23,17 @@ class MediaType(Enum): LLM = "LLM" @classmethod - def _missing_(cls, name): + def _missing_(cls, value: str): """Handle missing null data types for projects created without setting allowedMediaType Handle upper case names for compatibility with the GraphQL""" - if name is None: + if value is None: return cls.Unknown - for member in cls.__members__: - if member.name == name.upper(): + for name, member in cls.__members__.items(): + if name.upper() == value.upper(): return member @classmethod From 8441789303c6f28c92ce522237dcdfc3da40072a Mon Sep 17 00:00:00 2001 From: Val Brodsky Date: Mon, 19 Aug 2024 21:07:50 -0700 Subject: [PATCH 3/5] Fix media_type conversions --- libs/labelbox/src/labelbox/schema/media_type.py | 16 ++++++++++++++-- .../tests/integration/test_labeling_dashboard.py | 3 +++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/libs/labelbox/src/labelbox/schema/media_type.py b/libs/labelbox/src/labelbox/schema/media_type.py index d3231e8e8..cfb00d372 100644 --- a/libs/labelbox/src/labelbox/schema/media_type.py +++ b/libs/labelbox/src/labelbox/schema/media_type.py @@ -1,5 +1,7 @@ from enum import Enum +from labelbox.utils import camel_case + class MediaType(Enum): Audio = "AUDIO" @@ -23,7 +25,7 @@ class MediaType(Enum): LLM = "LLM" @classmethod - def _missing_(cls, value: str): + def _missing_(cls, value): """Handle missing null data types for projects created without setting allowedMediaType Handle upper case names for compatibility with @@ -32,8 +34,18 @@ def _missing_(cls, value: str): if value is None: return cls.Unknown + def matches(value, name): + value_upper = value.upper() + name_upper = name.upper() + value_underscore = value.replace("-", "_") + camel_case_value = camel_case(value_underscore) + + return (value_upper == name_upper or + value_underscore.upper() == name_upper or + camel_case_value.upper() == name_upper) + for name, member in cls.__members__.items(): - if name.upper() == value.upper(): + if matches(value, name): return member @classmethod diff --git a/libs/labelbox/tests/integration/test_labeling_dashboard.py b/libs/labelbox/tests/integration/test_labeling_dashboard.py index 8289be90d..9dc59ae50 100644 --- a/libs/labelbox/tests/integration/test_labeling_dashboard.py +++ b/libs/labelbox/tests/integration/test_labeling_dashboard.py @@ -82,3 +82,6 @@ def test_request_labeling_service_dashboard_filters(requested_labeling_service): ] assert len(labeling_service_dashboard) == 0 assert labeling_service_dashboard == [] + labeling_service_dashboard = project.client.get_labeling_service_dashboards( + ).get_one() + assert labeling_service_dashboard From 66e8cb92f5831d016083f5775ea6cc421cd2744c Mon Sep 17 00:00:00 2001 From: Val Brodsky Date: Tue, 20 Aug 2024 12:05:00 -0700 Subject: [PATCH 4/5] Update tests --- libs/labelbox/tests/integration/test_labeling_dashboard.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/libs/labelbox/tests/integration/test_labeling_dashboard.py b/libs/labelbox/tests/integration/test_labeling_dashboard.py index 9dc59ae50..1c2e2b417 100644 --- a/libs/labelbox/tests/integration/test_labeling_dashboard.py +++ b/libs/labelbox/tests/integration/test_labeling_dashboard.py @@ -1,6 +1,8 @@ from datetime import datetime, timedelta from labelbox.schema.labeling_service import LabelingServiceStatus from labelbox.schema.search_filters import DateOperator, DateRange, DateRangeOperator, DateRangeValue, DateValue, IdOperator, OperationType, OrganizationFilter, WorkforceRequestedDateFilter, WorkforceRequestedDateRangeFilter, WorkspaceFilter +from labelbox.schema.ontology_kind import EditorTaskType +from labelbox.schema.media_type import MediaType def test_request_labeling_service_dashboard(rand_gen, @@ -18,6 +20,9 @@ def test_request_labeling_service_dashboard(rand_gen, assert labeling_service_dashboard.status == LabelingServiceStatus.Missing assert labeling_service_dashboard.tasks_completed == 0 assert labeling_service_dashboard.tasks_remaining == 0 + assert labeling_service_dashboard.media_type == MediaType.Conversational + assert labeling_service_dashboard.editor_task_type == EditorTaskType.OfflineModelChatEvaluation + assert labeling_service_dashboard.service_type == "Offline chat evaluation" labeling_service_dashboard = [ ld for ld in project.client.get_labeling_service_dashboards() From 028f449ea71b1a20466936d2278270a5466a6041 Mon Sep 17 00:00:00 2001 From: Val Brodsky Date: Tue, 20 Aug 2024 12:18:30 -0700 Subject: [PATCH 5/5] Add docstrings --- .../src/labelbox/schema/labeling_service_dashboard.py | 9 +++++++++ libs/labelbox/src/labelbox/schema/media_type.py | 6 ++++++ 2 files changed, 15 insertions(+) diff --git a/libs/labelbox/src/labelbox/schema/labeling_service_dashboard.py b/libs/labelbox/src/labelbox/schema/labeling_service_dashboard.py index 5d9c322ab..b7147e98a 100644 --- a/libs/labelbox/src/labelbox/schema/labeling_service_dashboard.py +++ b/libs/labelbox/src/labelbox/schema/labeling_service_dashboard.py @@ -63,14 +63,23 @@ def __init__(self, **kwargs): @property def tasks_completed(self): + """ + Count how many data rows have been completed (i.e. in the Done queue) + """ return self.data_rows_done_count @property def tasks_remaining(self): + """ + Count how many data rows have not been completed + """ return self.data_rows_count - self.data_rows_done_count @property def service_type(self): + """ + Descriptive labeling service definition by media type and editor task type + """ if self.media_type is None: return None diff --git a/libs/labelbox/src/labelbox/schema/media_type.py b/libs/labelbox/src/labelbox/schema/media_type.py index cfb00d372..266e2a0e3 100644 --- a/libs/labelbox/src/labelbox/schema/media_type.py +++ b/libs/labelbox/src/labelbox/schema/media_type.py @@ -35,6 +35,12 @@ def _missing_(cls, value): return cls.Unknown def matches(value, name): + """ + This will convert string values (from api) to match enum values + Some string values come as snake case (i.e. llm-prompt-creation) + Some string values come as camel case (i.e. llmPromptCreation) + etc depending on which api returns the value + """ value_upper = value.upper() name_upper = name.upper() value_underscore = value.replace("-", "_")