diff --git a/libs/labelbox/src/labelbox/client.py b/libs/labelbox/src/labelbox/client.py index 019285350..3154c29f9 100644 --- a/libs/labelbox/src/labelbox/client.py +++ b/libs/labelbox/src/labelbox/client.py @@ -840,7 +840,7 @@ def create_offline_model_evaluation_project(self, **kwargs) -> Project: kwargs.pop("append_to_existing_dataset", None) kwargs.pop("data_row_count", None) - return self.create_project(**kwargs) + return self._create_project(**kwargs) def _create_project(self, **kwargs) -> Project: auto_audit_percentage = kwargs.get("auto_audit_percentage") diff --git a/libs/labelbox/src/labelbox/exceptions.py b/libs/labelbox/src/labelbox/exceptions.py index 678196f5d..612e7ef58 100644 --- a/libs/labelbox/src/labelbox/exceptions.py +++ b/libs/labelbox/src/labelbox/exceptions.py @@ -1,3 +1,6 @@ +import re + + class LabelboxError(Exception): """Base class for exceptions.""" @@ -147,3 +150,21 @@ class CustomMetricsNotSupportedException(Exception): class ProcessingWaitTimeout(Exception): """Raised when waiting for the data rows to be processed takes longer than allowed""" + + +def error_message_for_unparsed_graphql_error(error_string: str) -> str: + """ + Since our client only parses certain graphql errors, this function is used to + extract the error message from the error string when the error is not + parsed by the client. + """ + # Regex to find the message content + pattern = r"'message': '([^']+)'" + # Search for the pattern in the error string + match = re.search(pattern, error_string) + if match: + error_content = match.group(1) + else: + error_content = "Unknown error" + + return error_content diff --git a/libs/labelbox/src/labelbox/schema/project.py b/libs/labelbox/src/labelbox/schema/project.py index a2142ebc5..d96485ccf 100644 --- a/libs/labelbox/src/labelbox/schema/project.py +++ b/libs/labelbox/src/labelbox/schema/project.py @@ -13,13 +13,10 @@ from labelbox import parser from labelbox import utils -from labelbox.exceptions import ( - InvalidQueryError, - LabelboxError, - ProcessingWaitTimeout, - ResourceConflict, - ResourceNotFoundError -) +from labelbox.exceptions import error_message_for_unparsed_graphql_error +from labelbox.exceptions import (InvalidQueryError, LabelboxError, + ProcessingWaitTimeout, ResourceConflict, + ResourceNotFoundError) from labelbox.orm import query from labelbox.orm.db_object import DbObject, Deletable, Updateable, experimental from labelbox.orm.model import Entity, Field, Relationship @@ -122,6 +119,7 @@ class Project(DbObject, Updateable, Deletable): media_type = Field.Enum(MediaType, "media_type", "allowedMediaType") editor_task_type = Field.Enum(EditorTaskType, "editor_task_type") data_row_count = Field.Int("data_row_count") + model_setup_complete: Field = Field.Boolean("model_setup_complete") # Relationships created_by = Relationship.ToOne("User", False, "created_by") @@ -1271,7 +1269,18 @@ def add_model_config(self, model_config_id: str) -> str: "projectId": self.uid, "modelConfigId": model_config_id, } - result = self.client.execute(query, params) + try: + result = self.client.execute(query, params) + except LabelboxError as e: + if e.message.startswith( + "Unknown error: " + ): # unfortunate hack to handle unparsed graphql errors + error_content = error_message_for_unparsed_graphql_error( + e.message) + else: + error_content = e.message + raise LabelboxError(message=error_content) from e + if not result: raise ResourceNotFoundError(ModelConfig, params) return result["createProjectModelConfig"]["projectModelConfigId"] @@ -1299,6 +1308,29 @@ def delete_project_model_config(self, project_model_config_id: str) -> bool: raise ResourceNotFoundError(ProjectModelConfig, params) return result["deleteProjectModelConfig"]["success"] + def set_project_model_setup_complete(self) -> bool: + """ + Sets the model setup is complete for this project. + Once the project is marked as "setup complete", a user can not add / modify delete existing project model configs. + + Returns: + bool, indicates if the model setup is complete. + + NOTE: This method should only be used for live model evaluation projects. + It will throw exception for all other types of projects. + User Project is_chat_evaluation() method to check if the project is a live model evaluation project. + """ + query = """mutation SetProjectModelSetupCompletePyApi($projectId: ID!) { + setProjectModelSetupComplete(where: {id: $projectId}, data: {modelSetupComplete: true}) { + modelSetupComplete + } + }""" + + result = self.client.execute(query, {"projectId": self.uid}) + self.model_setup_complete = result["setProjectModelSetupComplete"][ + "modelSetupComplete"] + return result["setProjectModelSetupComplete"]["modelSetupComplete"] + def set_labeling_parameter_overrides( self, data: List[LabelingParameterOverrideInput]) -> bool: """ Adds labeling parameter overrides to this project. @@ -1752,7 +1784,9 @@ def __check_data_rows_have_been_processed( return response["queryAllDataRowsHaveBeenProcessed"][ "allDataRowsHaveBeenProcessed"] - def get_overview(self, details=False) -> Union[ProjectOverview, ProjectOverviewDetailed]: + def get_overview( + self, + details=False) -> Union[ProjectOverview, ProjectOverviewDetailed]: """Return the overview of a project. This method returns the number of data rows per task queue and issues of a project, @@ -1792,7 +1826,7 @@ def get_overview(self, details=False) -> Union[ProjectOverview, ProjectOverviewD # Must use experimental to access "issues" result = self.client.execute(query, {"projectId": self.uid}, - experimental=True)["project"] + experimental=True)["project"] # Reformat category names overview = { @@ -1805,7 +1839,7 @@ def get_overview(self, details=False) -> Union[ProjectOverview, ProjectOverviewD # Rename categories overview["to_label"] = overview.pop("unlabeled") - overview["total_data_rows"] = overview.pop("all") + overview["total_data_rows"] = overview.pop("all") if not details: return ProjectOverview(**overview) @@ -1813,18 +1847,20 @@ def get_overview(self, details=False) -> Union[ProjectOverview, ProjectOverviewD # Build dictionary for queue details for review and rework queues for category in ["rework", "review"]: queues = [ - {tq["name"]: tq.get("dataRowCount")} + { + tq["name"]: tq.get("dataRowCount") + } for tq in result.get("taskQueues") if tq.get("queueType") == f"MANUAL_{category.upper()}_QUEUE" ] - overview[f"in_{category}"] = { + overview[f"in_{category}"] = { "data": queues, "total": overview[f"in_{category}"] } - + return ProjectOverviewDetailed(**overview) - + def clone(self) -> "Project": """ Clones the current project. diff --git a/libs/labelbox/src/labelbox/schema/project_model_config.py b/libs/labelbox/src/labelbox/schema/project_model_config.py index 8248a6a7a..9cf6dcbfa 100644 --- a/libs/labelbox/src/labelbox/schema/project_model_config.py +++ b/libs/labelbox/src/labelbox/schema/project_model_config.py @@ -1,5 +1,6 @@ -from labelbox.orm.db_object import DbObject, Deletable +from labelbox.orm.db_object import DbObject from labelbox.orm.model import Field, Relationship +from labelbox.exceptions import LabelboxError, error_message_for_unparsed_graphql_error class ProjectModelConfig(DbObject): @@ -30,5 +31,17 @@ def delete(self) -> bool: params = { "id": self.uid, } - result = self.client.execute(query, params) + + try: + result = self.client.execute(query, params) + except LabelboxError as e: + if e.message.startswith( + "Unknown error: " + ): # unfortunate hack to handle unparsed graphql errors + error_content = error_message_for_unparsed_graphql_error( + e.message) + else: + error_content = e.message + raise LabelboxError(message=error_content) from e + return result["deleteProjectModelConfig"]["success"] diff --git a/libs/labelbox/tests/integration/conftest.py b/libs/labelbox/tests/integration/conftest.py index 612d98122..b639b20df 100644 --- a/libs/labelbox/tests/integration/conftest.py +++ b/libs/labelbox/tests/integration/conftest.py @@ -407,7 +407,7 @@ def chat_evaluation_ontology(client, rand_gen): @pytest.fixture -def chat_evaluation_project_create_dataset(client, rand_gen): +def live_chat_evaluation_project_with_new_dataset(client, rand_gen): project_name = f"test-model-evaluation-project-{rand_gen(str)}" dataset_name = f"test-model-evaluation-dataset-{rand_gen(str)}" project = client.create_model_evaluation_project(name=project_name, diff --git a/libs/labelbox/tests/integration/test_chat_evaluation_ontology_project.py b/libs/labelbox/tests/integration/test_chat_evaluation_ontology_project.py index 6a41a7a09..f0b58f1d3 100644 --- a/libs/labelbox/tests/integration/test_chat_evaluation_ontology_project.py +++ b/libs/labelbox/tests/integration/test_chat_evaluation_ontology_project.py @@ -6,7 +6,7 @@ def test_create_chat_evaluation_ontology_project( client, chat_evaluation_ontology, - chat_evaluation_project_create_dataset, conversation_data_row, + live_chat_evaluation_project_with_new_dataset, conversation_data_row, rand_gen): ontology = chat_evaluation_ontology @@ -23,7 +23,9 @@ def test_create_chat_evaluation_ontology_project( assert classification.schema_id assert classification.feature_schema_id - project = chat_evaluation_project_create_dataset + project = live_chat_evaluation_project_with_new_dataset + assert project.model_setup_complete is None + project.setup_editor(ontology) assert project.labeling_frontend().name == "Editor" diff --git a/libs/labelbox/tests/integration/test_project_model_config.py b/libs/labelbox/tests/integration/test_project_model_config.py index 1ff5041ef..220995e27 100644 --- a/libs/labelbox/tests/integration/test_project_model_config.py +++ b/libs/labelbox/tests/integration/test_project_model_config.py @@ -3,9 +3,9 @@ from labelbox.exceptions import ResourceNotFoundError -def test_add_single_model_config(chat_evaluation_project_create_dataset, +def test_add_single_model_config(live_chat_evaluation_project_with_new_dataset, model_config): - configured_project = chat_evaluation_project_create_dataset + configured_project = live_chat_evaluation_project_with_new_dataset project_model_config_id = configured_project.add_model_config( model_config.uid) @@ -18,9 +18,9 @@ def test_add_single_model_config(chat_evaluation_project_create_dataset, def test_add_multiple_model_config(client, rand_gen, - chat_evaluation_project_create_dataset, + live_chat_evaluation_project_with_new_dataset, model_config, valid_model_id): - configured_project = chat_evaluation_project_create_dataset + configured_project = live_chat_evaluation_project_with_new_dataset second_model_config = client.create_model_config(rand_gen(str), valid_model_id, {"param": "value"}) @@ -40,9 +40,9 @@ def test_add_multiple_model_config(client, rand_gen, project_model_config_id) -def test_delete_project_model_config(chat_evaluation_project_create_dataset, +def test_delete_project_model_config(live_chat_evaluation_project_with_new_dataset, model_config): - configured_project = chat_evaluation_project_create_dataset + configured_project = live_chat_evaluation_project_with_new_dataset assert configured_project.delete_project_model_config( configured_project.add_model_config(model_config.uid)) assert not len(configured_project.project_model_configs()) diff --git a/libs/labelbox/tests/integration/test_project_set_model_setup_complete.py b/libs/labelbox/tests/integration/test_project_set_model_setup_complete.py new file mode 100644 index 000000000..d48514024 --- /dev/null +++ b/libs/labelbox/tests/integration/test_project_set_model_setup_complete.py @@ -0,0 +1,66 @@ +import pytest + +from labelbox.exceptions import LabelboxError, OperationNotAllowedException + + +def test_live_chat_evaluation_project( + live_chat_evaluation_project_with_new_dataset, model_config): + + project = live_chat_evaluation_project_with_new_dataset + + project.set_project_model_setup_complete() + assert bool(project.model_setup_complete) is True + + with pytest.raises( + expected_exception=LabelboxError, + match= + "Cannot create model config for project because model setup is complete" + ): + project.add_model_config(model_config.uid) + + +def test_live_chat_evaluation_project_delete_cofig( + live_chat_evaluation_project_with_new_dataset, model_config): + + project = live_chat_evaluation_project_with_new_dataset + project_model_config_id = project.add_model_config(model_config.uid) + assert project_model_config_id + + project_model_config = None + for pmg in project.project_model_configs(): + if pmg.uid == project_model_config_id: + project_model_config = pmg + break + assert project_model_config + + project.set_project_model_setup_complete() + assert bool(project.model_setup_complete) is True + + with pytest.raises( + expected_exception=LabelboxError, + match= + "Cannot create model config for project because model setup is complete" + ): + project_model_config.delete() + + +def test_offline_chat_evaluation_project(offline_chat_evaluation_project, + model_config): + + project = offline_chat_evaluation_project + + with pytest.raises( + expected_exception=OperationNotAllowedException, + match= + "Only live model chat evaluation projects can complete model setup" + ): + project.set_project_model_setup_complete() + + +def test_any_other_project(project, model_config): + with pytest.raises( + expected_exception=OperationNotAllowedException, + match= + "Only live model chat evaluation projects can complete model setup" + ): + project.set_project_model_setup_complete() diff --git a/libs/labelbox/tests/unit/test_exceptions.py b/libs/labelbox/tests/unit/test_exceptions.py new file mode 100644 index 000000000..69bcfbd77 --- /dev/null +++ b/libs/labelbox/tests/unit/test_exceptions.py @@ -0,0 +1,13 @@ +import pytest + +from labelbox.exceptions import error_message_for_unparsed_graphql_error + + +@pytest.mark.parametrize('exception_message, expected_result', [ + ("Unparsed errors on query execution: [{'message': 'Cannot create model config for project because model setup is complete'}]", + "Cannot create model config for project because model setup is complete"), + ("blah blah blah", "Unknown error"), +]) +def test_client_unparsed_exception_messages(exception_message, expected_result): + assert error_message_for_unparsed_graphql_error( + exception_message) == expected_result diff --git a/libs/labelbox/tests/unit/test_project.py b/libs/labelbox/tests/unit/test_project.py index 9d522e7c0..c2c818291 100644 --- a/libs/labelbox/tests/unit/test_project.py +++ b/libs/labelbox/tests/unit/test_project.py @@ -31,6 +31,7 @@ def test_project_editor_task_type(api_editor_task_type, "allowedMediaType": "IMAGE", "queueMode": "BATCH", "setupComplete": "2021-06-01T00:00:00.000Z", + "modelSetupComplete": None, }) assert project.editor_task_type == expected_editor_task_type