diff --git a/libs/labelbox/src/labelbox/client.py b/libs/labelbox/src/labelbox/client.py index 86c2f86e2..1b0ea866f 100644 --- a/libs/labelbox/src/labelbox/client.py +++ b/libs/labelbox/src/labelbox/client.py @@ -9,7 +9,7 @@ import urllib.parse from collections import defaultdict from datetime import datetime, timezone -from typing import Any, List, Dict, Union, Optional, overload +from typing import Any, List, Dict, Union, Optional, overload, Callable import requests import requests.exceptions @@ -138,15 +138,19 @@ def _default_headers(self): @retry.Retry(predicate=retry.if_exception_type( labelbox.exceptions.InternalServerError, labelbox.exceptions.TimeoutError)) - def execute(self, - query=None, - params=None, - data=None, - files=None, - timeout=60.0, - experimental=False, - error_log_key="message", - raise_return_resource_not_found=False): + def execute( + self, + query=None, + params=None, + data=None, + files=None, + timeout=60.0, + experimental=False, + error_log_key="message", + raise_return_resource_not_found=False, + error_handlers: Optional[Dict[str, Callable[[requests.models.Response], + None]]] = None + ) -> Dict[str, Any]: """ Sends a request to the server for the execution of the given query. @@ -160,6 +164,13 @@ def execute(self, files (dict): file arguments for request timeout (float): Max allowed time for query execution, in seconds. + raise_return_resource_not_found: By default the client relies on the caller to raise the correct exception when a resource is not found. + If this is set to True, the client will raise a ResourceNotFoundError exception automatically. + This simplifies processing. + We recommend to use it only of api returns a clear and well-formed error when a resource not found for a given query. + error_handlers (dict): A dictionary mapping graphql error code to handler functions. + Allows a caller to handle specific errors reporting in a custom way or produce more user-friendly readable messages. + Returns: dict, parsed JSON response. Raises: @@ -323,7 +334,12 @@ def get_error_status_code(error: dict) -> int: # TODO: fix this in the server API internal_server_error = check_errors(["INTERNAL_SERVER_ERROR"], "extensions", "code") + error_code = "INTERNAL_SERVER_ERROR" + if internal_server_error is not None: + if error_handlers and error_code in error_handlers: + handler = error_handlers[error_code] + handler(response) message = internal_server_error.get("message") error_status_code = get_error_status_code(internal_server_error) if error_status_code == 400: diff --git a/libs/labelbox/src/labelbox/exceptions.py b/libs/labelbox/src/labelbox/exceptions.py index 048ca0757..4abde6526 100644 --- a/libs/labelbox/src/labelbox/exceptions.py +++ b/libs/labelbox/src/labelbox/exceptions.py @@ -16,7 +16,10 @@ def __init__(self, message, cause=None): self.cause = cause def __str__(self): - return self.message + str(self.args) + exception_message = self.message + if self.cause is not None: + exception_message += " (caused by: %s)" % self.cause + return exception_message class AuthenticationError(LabelboxError): diff --git a/libs/labelbox/src/labelbox/schema/labeling_service.py b/libs/labelbox/src/labelbox/schema/labeling_service.py index eed31202c..cbc65232c 100644 --- a/libs/labelbox/src/labelbox/schema/labeling_service.py +++ b/libs/labelbox/src/labelbox/schema/labeling_service.py @@ -1,9 +1,10 @@ from datetime import datetime from enum import Enum +import json from typing import Any from typing_extensions import Annotated -from labelbox.exceptions import ResourceNotFoundError +from labelbox.exceptions import LabelboxError, ResourceNotFoundError from labelbox.pydantic_compat import BaseModel, Field from labelbox.utils import _CamelCaseMixin @@ -87,12 +88,29 @@ def request(self) -> 'LabelingService': } """ result = self.client.execute(query_str, {"projectId": self.project_id}, - raise_return_resource_not_found=True) + raise_return_resource_not_found=True, + error_handlers={ + "INTERNAL_SERVER_ERROR": + self._raise_readable_errors + }) success = result["validateAndRequestProjectBoostWorkforce"]["success"] if not success: raise Exception("Failed to start labeling service") return LabelingService.get(self.client, self.project_id) + def _raise_readable_errors(self, response): + errors = response.json().get('errors', []) + if errors: + message = errors[0].get( + 'message', json.dumps([{ + "errorMessage": "Unknown error" + }])) + errors = json.loads(message) + error_messages = [error['errorMessage'] for error in errors] + else: + error_messages = ["Uknown error"] + raise LabelboxError(". ".join(error_messages)) + @classmethod def getOrCreate(cls, client, project_id: Cuid) -> 'LabelingService': """ diff --git a/libs/labelbox/tests/integration/test_labeling_service.py b/libs/labelbox/tests/integration/test_labeling_service.py index be0b8a6ee..611a20f98 100644 --- a/libs/labelbox/tests/integration/test_labeling_service.py +++ b/libs/labelbox/tests/integration/test_labeling_service.py @@ -42,11 +42,8 @@ def test_request_labeling_service_moe_project( project.upsert_instructions('tests/integration/media/sample_pdf.pdf') labeling_service = project.get_labeling_service() - with pytest.raises( - LabelboxError, - match= - '[{"errorType":"PROJECT_MODEL_CONFIG","errorMessage":"Project model config is not completed"}]' - ): + with pytest.raises(LabelboxError, + match='Project model config is not completed'): labeling_service.request() project.add_model_config(model_config.uid) project.set_project_model_setup_complete() @@ -64,5 +61,8 @@ def test_request_labeling_service_incomplete_requirements(ontology, project): ): # No labeling service by default labeling_service.request() project.connect_ontology(ontology) - with pytest.raises(LabelboxError): + with pytest.raises( + LabelboxError, + match= + "['Data is missing', 'Ontology instructions are not completed']"): labeling_service.request()