From adfb74434c6ccbc91ddff82b408c747117198f9e Mon Sep 17 00:00:00 2001 From: Val Brodsky Date: Mon, 1 Jul 2024 15:59:30 -0400 Subject: [PATCH] Add upload_type to Project --- libs/labelbox/src/labelbox/client.py | 8 +++--- .../src/labelbox/schema/ontology_kind.py | 27 ++++++++++++++++--- libs/labelbox/src/labelbox/schema/project.py | 7 ++--- libs/labelbox/tests/unit/test_project.py | 1 + 4 files changed, 33 insertions(+), 10 deletions(-) diff --git a/libs/labelbox/src/labelbox/client.py b/libs/labelbox/src/labelbox/client.py index 7522ba7d4..6dbe9d60e 100644 --- a/libs/labelbox/src/labelbox/client.py +++ b/libs/labelbox/src/labelbox/client.py @@ -555,7 +555,7 @@ def get_users(self, where=None) -> PaginatedCollection: An iterable of Users (typically a PaginatedCollection). """ return self._get_all(Entity.User, where, filter_deleted=False) - + def get_datasets(self, where=None) -> PaginatedCollection: """ Fetches one or more datasets. @@ -609,6 +609,7 @@ def _create(self, db_object_type, data, extra_params={}): data = {**data, **extra_params} query_string, params = query.create(db_object_type, data) res = self.execute(query_string, params) + if not res: raise labelbox.exceptions.LabelboxError("Failed to create %s" % db_object_type.type_name()) @@ -2231,9 +2232,8 @@ def get_embedding_by_name(self, name: str) -> Embedding: raise labelbox.exceptions.ResourceNotFoundError(Embedding, dict(name=name)) - def upsert_label_feedback( - self, label_id: str, feedback: str, - scores: Dict[str, float]) -> List[LabelScore]: + def upsert_label_feedback(self, label_id: str, feedback: str, + scores: Dict[str, float]) -> List[LabelScore]: """ Submits the label feedback which is a free-form text and numeric label scores. diff --git a/libs/labelbox/src/labelbox/schema/ontology_kind.py b/libs/labelbox/src/labelbox/schema/ontology_kind.py index e33e7cef3..d31feda12 100644 --- a/libs/labelbox/src/labelbox/schema/ontology_kind.py +++ b/libs/labelbox/src/labelbox/schema/ontology_kind.py @@ -34,16 +34,16 @@ def is_supported(cls, value): return isinstance(value, cls) @classmethod - def _missing_(cls, name) -> 'EditorTaskType': + def _missing_(cls, value) -> 'EditorTaskType': """Handle missing null new task types Handle upper case names for compatibility with the GraphQL""" - if name is None: + if value is None: return cls.Missing for name, member in cls.__members__.items(): - if name == name.upper(): + if value == name.upper(): return member return cls.Missing @@ -71,3 +71,24 @@ def map_to_editor_task_type(onotology_kind: OntologyKind, return EditorTaskType.ModelChatEvaluation else: return EditorTaskType.Missing + + +class UploadType(Enum): + Auto = 'AUTO', + Manual = 'MANUAL', + Missing = None + + @classmethod + def is_supported(cls, value): + return isinstance(value, cls) + + @classmethod + def _missing_(cls, value: object) -> 'UploadType': + if value is None: + return cls.Missing + + for name, member in cls.__members__.items(): + if value == name.upper(): + return member + + return cls.Missing diff --git a/libs/labelbox/src/labelbox/schema/project.py b/libs/labelbox/src/labelbox/schema/project.py index c5f9ef0b0..d141c89cb 100644 --- a/libs/labelbox/src/labelbox/schema/project.py +++ b/libs/labelbox/src/labelbox/schema/project.py @@ -37,7 +37,8 @@ from labelbox.schema.resource_tag import ResourceTag from labelbox.schema.task import Task from labelbox.schema.task_queue import TaskQueue -from labelbox.schema.ontology_kind import (EditorTaskType, OntologyKind) +from labelbox.schema.ontology_kind import (EditorTaskType, OntologyKind, + UploadType) from labelbox.schema.project_overview import ProjectOverview, ProjectOverviewDetailed if TYPE_CHECKING: @@ -121,6 +122,7 @@ class Project(DbObject, Updateable, Deletable): editor_task_type = Field.Enum(EditorTaskType, "editor_task_type") data_row_count = Field.Int("data_row_count") model_setup_complete: Field = Field.Boolean("model_setup_complete") + upload_type: Field = Field.Enum(UploadType, "upload_type") # Relationships created_by = Relationship.ToOne("User", False, "created_by") @@ -145,8 +147,7 @@ def is_chat_evaluation(self) -> bool: return self.media_type == MediaType.Conversational and self.editor_task_type == EditorTaskType.ModelChatEvaluation def is_auto_data_generation(self) -> bool: - return self.media_type == MediaType.LLMPromptCreation or self.media_type == MediaType.LLMPromptResponseCreation or self.is_chat_evaluation( - ) + return (self.upload_type == UploadType.Auto) # type: ignore def project_model_configs(self): query_str = """query ProjectModelConfigsPyApi($id: ID!) { diff --git a/libs/labelbox/tests/unit/test_project.py b/libs/labelbox/tests/unit/test_project.py index c2c818291..5a6754aa3 100644 --- a/libs/labelbox/tests/unit/test_project.py +++ b/libs/labelbox/tests/unit/test_project.py @@ -32,6 +32,7 @@ def test_project_editor_task_type(api_editor_task_type, "queueMode": "BATCH", "setupComplete": "2021-06-01T00:00:00.000Z", "modelSetupComplete": None, + "uploadType": "Auto", }) assert project.editor_task_type == expected_editor_task_type