diff --git a/docs/labelbox/exceptions.rst b/docs/labelbox/exceptions.rst index 3082bc081..96ea0f6d5 100644 --- a/docs/labelbox/exceptions.rst +++ b/docs/labelbox/exceptions.rst @@ -1,6 +1,6 @@ Exceptions =============================================================================================== -.. automodule:: labelbox.exceptions +.. automodule:: lbox.exceptions :members: :show-inheritance: \ No newline at end of file diff --git a/docs/labelbox/index.rst b/docs/labelbox/index.rst index 35118f56f..fcfa2f6b5 100644 --- a/docs/labelbox/index.rst +++ b/docs/labelbox/index.rst @@ -41,6 +41,7 @@ Labelbox Python SDK Documentation project project-model-config quality-mode + request-client resource-tag review search-filters diff --git a/docs/labelbox/request-client.rst b/docs/labelbox/request-client.rst new file mode 100644 index 000000000..fcfea7f97 --- /dev/null +++ b/docs/labelbox/request-client.rst @@ -0,0 +1,6 @@ +Request Client +=============================================================================================== + +.. automodule:: lbox.request_client + :members: + :show-inheritance: \ No newline at end of file diff --git a/libs/labelbox/mypy.ini b/libs/labelbox/mypy.ini index 48135600e..5129fbe0f 100644 --- a/libs/labelbox/mypy.ini +++ b/libs/labelbox/mypy.ini @@ -7,4 +7,7 @@ ignore_missing_imports = True ignore_errors = True [mypy-labelbox] -ignore_errors = True \ No newline at end of file +ignore_errors = True + +[mypy-lbox.exceptions] +ignore_missing_imports = True diff --git a/libs/labelbox/pyproject.toml b/libs/labelbox/pyproject.toml index f4c24af59..0663082af 100644 --- a/libs/labelbox/pyproject.toml +++ b/libs/labelbox/pyproject.toml @@ -12,6 +12,7 @@ dependencies = [ "tqdm>=4.66.2", "geojson>=3.1.0", "mypy==1.10.1", + "lbox-clients==1.0.0", ] readme = "README.md" requires-python = ">=3.8" @@ -90,6 +91,7 @@ unit = "pytest tests/unit" integration = { cmd = "pytest tests/integration" } data = { cmd = "pytest tests/data" } rye-fmt-check = "rye fmt --check" +MYPYPATH="../lbox-clients/src/" mypy-lint = "mypy src --pretty --show-error-codes --non-interactive --install-types" lint = { chain = ["mypy-lint", "rye-fmt-check"] } test = { chain = ["lint", "unit", "integration"] } diff --git a/libs/labelbox/src/labelbox/adv_client.py b/libs/labelbox/src/labelbox/adv_client.py index 626ac0279..20766a9a4 100644 --- a/libs/labelbox/src/labelbox/adv_client.py +++ b/libs/labelbox/src/labelbox/adv_client.py @@ -1,12 +1,12 @@ import io import json import logging -from typing import Dict, Any, Optional, List, Callable +from typing import Any, Callable, Dict, List, Optional from urllib.parse import urlparse -from labelbox.exceptions import LabelboxError import requests -from requests import Session, Response +from lbox.exceptions import LabelboxError +from requests import Response, Session logger = logging.getLogger(__name__) diff --git a/libs/labelbox/src/labelbox/client.py b/libs/labelbox/src/labelbox/client.py index b0b5a1407..f0d58849a 100644 --- a/libs/labelbox/src/labelbox/client.py +++ b/libs/labelbox/src/labelbox/client.py @@ -4,20 +4,18 @@ import mimetypes import os import random -import sys import time import urllib.parse from collections import defaultdict -from datetime import datetime, timezone -from typing import Any, List, Dict, Union, Optional, overload, Callable from types import MappingProxyType +from typing import Any, Dict, List, Optional, Union, overload -from labelbox.schema.search_filters import SearchFilter +import lbox.exceptions import requests import requests.exceptions from google.api_core import retry +from lbox.request_client import RequestClient -import labelbox.exceptions from labelbox import __version__ as SDK_VERSION from labelbox import utils from labelbox.adv_client import AdvClient @@ -26,20 +24,18 @@ from labelbox.orm.model import Entity, Field from labelbox.pagination import PaginatedCollection from labelbox.schema import role -from labelbox.schema.conflict_resolution_strategy import ( - ConflictResolutionStrategy, -) -from labelbox.schema.data_row import DataRow from labelbox.schema.catalog import Catalog +from labelbox.schema.data_row import DataRow from labelbox.schema.data_row_metadata import DataRowMetadataOntology from labelbox.schema.dataset import Dataset from labelbox.schema.embedding import Embedding from labelbox.schema.enums import CollectionJobStatus from labelbox.schema.foundry.foundry_client import FoundryClient from labelbox.schema.iam_integration import IAMIntegration -from labelbox.schema.identifiables import DataRowIds -from labelbox.schema.identifiables import GlobalKeys +from labelbox.schema.identifiables import DataRowIds, GlobalKeys +from labelbox.schema.label_score import LabelScore from labelbox.schema.labeling_frontend import LabelingFrontend +from labelbox.schema.labeling_service_dashboard import LabelingServiceDashboard from labelbox.schema.media_type import ( MediaType, get_media_type_validation_error, @@ -47,57 +43,48 @@ from labelbox.schema.model import Model from labelbox.schema.model_config import ModelConfig from labelbox.schema.model_run import ModelRun -from labelbox.schema.ontology import Ontology, DeleteFeatureFromOntologyResult from labelbox.schema.ontology import ( - Tool, Classification, + DeleteFeatureFromOntologyResult, FeatureSchema, + Ontology, PromptResponseClassification, + Tool, +) +from labelbox.schema.ontology_kind import ( + EditorTaskType, + EditorTaskTypeMapper, + OntologyKind, ) from labelbox.schema.organization import Organization from labelbox.schema.project import Project from labelbox.schema.quality_mode import ( - QualityMode, BENCHMARK_AUTO_AUDIT_NUMBER_OF_LABELS, BENCHMARK_AUTO_AUDIT_PERCENTAGE, CONSENSUS_AUTO_AUDIT_NUMBER_OF_LABELS, CONSENSUS_AUTO_AUDIT_PERCENTAGE, + QualityMode, ) from labelbox.schema.queue_mode import QueueMode from labelbox.schema.role import Role +from labelbox.schema.search_filters import SearchFilter from labelbox.schema.send_to_annotate_params import ( SendToAnnotateFromCatalogParams, + build_annotations_input, build_destination_task_queue_input, build_predictions_input, - build_annotations_input, ) from labelbox.schema.slice import CatalogSlice, ModelSlice -from labelbox.schema.task import Task, DataUpsertTask +from labelbox.schema.task import DataUpsertTask, Task from labelbox.schema.user import User -from labelbox.schema.label_score import LabelScore -from labelbox.schema.ontology_kind import ( - OntologyKind, - EditorTaskTypeMapper, - EditorTaskType, -) -from labelbox.schema.labeling_service_dashboard import LabelingServiceDashboard logger = logging.getLogger(__name__) -_LABELBOX_API_KEY = "LABELBOX_API_KEY" - - -def python_version_info(): - version_info = sys.version_info - - return f"{version_info.major}.{version_info.minor}.{version_info.micro}-{version_info.releaselevel}" - class Client: """A Labelbox client. - Contains info necessary for connecting to a Labelbox server (URL, - authentication key). Provides functions for querying and creating + Provides functions for querying and creating top-level data objects (Projects, Datasets). """ @@ -123,57 +110,45 @@ def __init__( enable_experimental (bool): Indicates whether or not to use experimental features app_url (str) : host url for all links to the web app Raises: - labelbox.exceptions.AuthenticationError: If no `api_key` + lbox.exceptions.AuthenticationError: If no `api_key` is provided as an argument or via the environment variable. """ - if api_key is None: - if _LABELBOX_API_KEY not in os.environ: - raise labelbox.exceptions.AuthenticationError( - "Labelbox API key not provided" - ) - api_key = os.environ[_LABELBOX_API_KEY] - self.api_key = api_key - - self.enable_experimental = enable_experimental - if enable_experimental: - logger.info("Experimental features have been enabled") - - logger.info("Initializing Labelbox client at '%s'", endpoint) - self.app_url = app_url - self.endpoint = endpoint - self.rest_endpoint = rest_endpoint self._data_row_metadata_ontology = None + self._request_client = RequestClient( + sdk_version=SDK_VERSION, + api_key=api_key, + endpoint=endpoint, + enable_experimental=enable_experimental, + app_url=app_url, + rest_endpoint=rest_endpoint, + ) self._adv_client = AdvClient.factory(rest_endpoint, api_key) - self._connection: requests.Session = self._init_connection() - def _init_connection(self) -> requests.Session: - connection = ( - requests.Session() - ) # using default connection pool size of 10 - connection.headers.update(self._default_headers()) + @property + def headers(self) -> MappingProxyType: + return self._request_client.headers - return connection + @property + def connection(self) -> requests.Session: + return self._request_client._connection @property - def headers(self) -> MappingProxyType: - return self._connection.headers - - def _default_headers(self): - return { - "Authorization": "Bearer %s" % self.api_key, - "Accept": "application/json", - "Content-Type": "application/json", - "X-User-Agent": f"python-sdk {SDK_VERSION}", - "X-Python-Version": f"{python_version_info()}", - } + def endpoint(self) -> str: + return self._request_client.endpoint + + @property + def rest_endpoint(self) -> str: + return self._request_client.rest_endpoint + + @property + def enable_experimental(self) -> bool: + return self._request_client.enable_experimental + + @property + def app_url(self) -> str: + return self._request_client.app_url - @retry.Retry( - predicate=retry.if_exception_type( - labelbox.exceptions.InternalServerError, - labelbox.exceptions.TimeoutError, - ) - ) def execute( self, query=None, @@ -184,258 +159,28 @@ def execute( experimental=False, error_log_key="message", raise_return_resource_not_found=False, - ): - """Sends a request to the server for the execution of the - given query. - - Checks the response for errors and wraps errors - in appropriate `labelbox.exceptions.LabelboxError` subtypes. + ) -> Dict[str, Any]: + """Executes a GraphQL query. Args: query (str): The query to execute. - params (dict): Query parameters referenced within the query. - data (str): json string containing the query to execute - files (dict): file arguments for request - timeout (float): Max allowed time for query execution, - in seconds. - Returns: - dict, parsed JSON response. - Raises: - labelbox.exceptions.AuthenticationError: If authentication - failed. - labelbox.exceptions.InvalidQueryError: If `query` is not - syntactically or semantically valid (checked server-side). - labelbox.exceptions.ApiLimitError: If the server API limit was - exceeded. See "How to import data" in the online documentation - to see API limits. - labelbox.exceptions.TimeoutError: If response was not received - in `timeout` seconds. - labelbox.exceptions.NetworkError: If an unknown error occurred - most likely due to connection issues. - labelbox.exceptions.LabelboxError: If an unknown error of any - kind occurred. - ValueError: If query and data are both None. - """ - logger.debug("Query: %s, params: %r, data %r", query, params, data) - - # Convert datetimes to UTC strings. - def convert_value(value): - if isinstance(value, datetime): - value = value.astimezone(timezone.utc) - value = value.strftime("%Y-%m-%dT%H:%M:%SZ") - return value - - if query is not None: - if params is not None: - params = { - key: convert_value(value) for key, value in params.items() - } - data = json.dumps({"query": query, "variables": params}).encode( - "utf-8" - ) - elif data is None: - raise ValueError("query and data cannot both be none") - - endpoint = ( - self.endpoint - if not experimental - else self.endpoint.replace("/graphql", "/_gql") - ) - - try: - headers = self._connection.headers.copy() - if files: - del headers["Content-Type"] - del headers["Accept"] - request = requests.Request( - "POST", - endpoint, - headers=headers, - data=data, - files=files if files else None, - ) - - prepped: requests.PreparedRequest = request.prepare() + variables (dict): Variables to pass to the query. + raise_return_resource_not_found (bool): If True, raise a + ResourceNotFoundError if the query returns None. - response = self._connection.send(prepped, timeout=timeout) - logger.debug("Response: %s", response.text) - except requests.exceptions.Timeout as e: - raise labelbox.exceptions.TimeoutError(str(e)) - except requests.exceptions.RequestException as e: - logger.error("Unknown error: %s", str(e)) - raise labelbox.exceptions.NetworkError(e) - except Exception as e: - raise labelbox.exceptions.LabelboxError( - "Unknown error during Client.query(): " + str(e), e - ) - - if ( - 200 <= response.status_code < 300 - or response.status_code < 500 - or response.status_code >= 600 - ): - try: - r_json = response.json() - except Exception: - raise labelbox.exceptions.LabelboxError( - "Failed to parse response as JSON: %s" % response.text - ) - else: - if ( - "upstream connect error or disconnect/reset before headers" - in response.text - ): - raise labelbox.exceptions.InternalServerError( - "Connection reset" - ) - elif response.status_code == 502: - error_502 = "502 Bad Gateway" - raise labelbox.exceptions.InternalServerError(error_502) - elif 500 <= response.status_code < 600: - error_500 = f"Internal server http error {response.status_code}" - raise labelbox.exceptions.InternalServerError(error_500) - - errors = r_json.get("errors", []) - - def check_errors(keywords, *path): - """Helper that looks for any of the given `keywords` in any of - current errors on paths (like error[path][component][to][keyword]). - """ - for error in errors: - obj = error - for path_elem in path: - obj = obj.get(path_elem, {}) - if obj in keywords: - return error - return None - - def get_error_status_code(error: dict) -> int: - try: - return int(error["extensions"].get("exception").get("status")) - except: - return 500 - - if ( - check_errors(["AUTHENTICATION_ERROR"], "extensions", "code") - is not None - ): - raise labelbox.exceptions.AuthenticationError("Invalid API key") - - authorization_error = check_errors( - ["AUTHORIZATION_ERROR"], "extensions", "code" - ) - if authorization_error is not None: - raise labelbox.exceptions.AuthorizationError( - authorization_error["message"] - ) - - validation_error = check_errors( - ["GRAPHQL_VALIDATION_FAILED"], "extensions", "code" - ) - - if validation_error is not None: - message = validation_error["message"] - if message == "Query complexity limit exceeded": - raise labelbox.exceptions.ValidationFailedError(message) - else: - raise labelbox.exceptions.InvalidQueryError(message) - - graphql_error = check_errors( - ["GRAPHQL_PARSE_FAILED"], "extensions", "code" - ) - if graphql_error is not None: - raise labelbox.exceptions.InvalidQueryError( - graphql_error["message"] - ) - - # Check if API limit was exceeded - response_msg = r_json.get("message", "") - - if response_msg.startswith("You have exceeded"): - raise labelbox.exceptions.ApiLimitError(response_msg) - - resource_not_found_error = check_errors( - ["RESOURCE_NOT_FOUND"], "extensions", "code" - ) - if resource_not_found_error is not None: - if raise_return_resource_not_found: - raise labelbox.exceptions.ResourceNotFoundError( - message=resource_not_found_error["message"] - ) - else: - # Return None and let the caller methods raise an exception - # as they already know which resource type and ID was requested - return None - - resource_conflict_error = check_errors( - ["RESOURCE_CONFLICT"], "extensions", "code" - ) - if resource_conflict_error is not None: - raise labelbox.exceptions.ResourceConflict( - resource_conflict_error["message"] - ) - - malformed_request_error = check_errors( - ["MALFORMED_REQUEST"], "extensions", "code" - ) - if malformed_request_error is not None: - raise labelbox.exceptions.MalformedQueryException( - malformed_request_error[error_log_key] - ) - - # A lot of different error situations are now labeled serverside - # as INTERNAL_SERVER_ERROR, when they are actually client errors. - # TODO: fix this in the server API - internal_server_error = check_errors( - ["INTERNAL_SERVER_ERROR"], "extensions", "code" - ) - if internal_server_error is not None: - message = internal_server_error.get("message") - error_status_code = get_error_status_code(internal_server_error) - if error_status_code == 400: - raise labelbox.exceptions.InvalidQueryError(message) - elif error_status_code == 422: - raise labelbox.exceptions.UnprocessableEntityError(message) - elif error_status_code == 426: - raise labelbox.exceptions.OperationNotAllowedException(message) - elif error_status_code == 500: - raise labelbox.exceptions.LabelboxError(message) - else: - raise labelbox.exceptions.InternalServerError(message) - - not_allowed_error = check_errors( - ["OPERATION_NOT_ALLOWED"], "extensions", "code" + Returns: + dict: The response from the server. + """ + return self._request_client.execute( + query, + params, + data=data, + files=files, + timeout=timeout, + experimental=experimental, + error_log_key=error_log_key, + raise_return_resource_not_found=raise_return_resource_not_found, ) - if not_allowed_error is not None: - message = not_allowed_error.get("message") - raise labelbox.exceptions.OperationNotAllowedException(message) - - if len(errors) > 0: - logger.warning("Unparsed errors on query execution: %r", errors) - messages = list( - map( - lambda x: { - "message": x["message"], - "code": x["extensions"]["code"], - }, - errors, - ) - ) - raise labelbox.exceptions.LabelboxError( - "Unknown error: %s" % str(messages) - ) - - # if we do return a proper error code, and didn't catch this above - # reraise - # this mainly catches a 401 for API access disabled for free tier - # TODO: need to unify API errors to handle things more uniformly - # in the SDK - if response.status_code != requests.codes.ok: - message = f"{response.status_code} {response.reason}" - cause = r_json.get("message") - raise labelbox.exceptions.LabelboxError(message, cause) - - return r_json["data"] def upload_file(self, path: str) -> str: """Uploads given path to local file. @@ -447,7 +192,7 @@ def upload_file(self, path: str) -> str: Returns: str, the URL of uploaded data. Raises: - labelbox.exceptions.LabelboxError: If upload failed. + lbox.exceptions.LabelboxError: If upload failed. """ content_type, _ = mimetypes.guess_type(path) filename = os.path.basename(path) @@ -457,9 +202,7 @@ def upload_file(self, path: str) -> str: ) @retry.Retry( - predicate=retry.if_exception_type( - labelbox.exceptions.InternalServerError - ) + predicate=retry.if_exception_type(lbox.exceptions.InternalServerError) ) def upload_data( self, @@ -480,7 +223,7 @@ def upload_data( str, the URL of uploaded data. Raises: - labelbox.exceptions.LabelboxError: If upload failed. + lbox.exceptions.LabelboxError: If upload failed. """ request_data = { @@ -505,7 +248,7 @@ def upload_data( if (filename and content_type) else content } - headers = self._connection.headers.copy() + headers = self.connection.headers.copy() headers.pop("Content-Type", None) request = requests.Request( "POST", @@ -517,20 +260,20 @@ def upload_data( prepped: requests.PreparedRequest = request.prepare() - response = self._connection.send(prepped) + response = self.connection.send(prepped) if response.status_code == 502: error_502 = "502 Bad Gateway" - raise labelbox.exceptions.InternalServerError(error_502) + raise lbox.exceptions.InternalServerError(error_502) elif response.status_code == 503: - raise labelbox.exceptions.InternalServerError(response.text) + raise lbox.exceptions.InternalServerError(response.text) elif response.status_code == 520: - raise labelbox.exceptions.InternalServerError(response.text) + raise lbox.exceptions.InternalServerError(response.text) try: file_data = response.json().get("data", None) except ValueError as e: # response is not valid JSON - raise labelbox.exceptions.LabelboxError( + raise lbox.exceptions.LabelboxError( "Failed to upload, unknown cause", e ) @@ -540,9 +283,9 @@ def upload_data( error_msg = next(iter(errors), {}).get( "message", "Unknown error" ) - except Exception as e: + except Exception: error_msg = "Unknown error" - raise labelbox.exceptions.LabelboxError( + raise lbox.exceptions.LabelboxError( "Failed to upload, message: %s" % error_msg ) @@ -557,7 +300,7 @@ def _get_single(self, db_object_type, uid): Returns: Object of `db_object_type`. Raises: - labelbox.exceptions.ResourceNotFoundError: If there is no object + lbox.exceptions.ResourceNotFoundError: If there is no object of the given type for the given ID. """ query_str, params = query.get_single(db_object_type, uid) @@ -565,9 +308,7 @@ def _get_single(self, db_object_type, uid): res = self.execute(query_str, params) res = res and res.get(utils.camel_case(db_object_type.type_name())) if res is None: - raise labelbox.exceptions.ResourceNotFoundError( - db_object_type, params - ) + raise lbox.exceptions.ResourceNotFoundError(db_object_type, params) else: return db_object_type(self, res) @@ -581,7 +322,7 @@ def get_project(self, project_id) -> Project: Returns: The sought Project. Raises: - labelbox.exceptions.ResourceNotFoundError: If there is no + lbox.exceptions.ResourceNotFoundError: If there is no Project with the given ID. """ return self._get_single(Entity.Project, project_id) @@ -596,7 +337,7 @@ def get_dataset(self, dataset_id) -> Dataset: Returns: The sought Dataset. Raises: - labelbox.exceptions.ResourceNotFoundError: If there is no + lbox.exceptions.ResourceNotFoundError: If there is no Dataset with the given ID. """ return self._get_single(Entity.Dataset, dataset_id) @@ -722,7 +463,7 @@ def _create(self, db_object_type, data, extra_params={}): ) if not res: - raise labelbox.exceptions.LabelboxError( + raise lbox.exceptions.LabelboxError( "Failed to create %s" % db_object_type.type_name() ) res = res["create%s" % db_object_type.type_name()] @@ -780,7 +521,7 @@ def delete_model_config(self, id: str) -> bool: params = {"id": id} result = self.execute(query, params) if not result: - raise labelbox.exceptions.ResourceNotFoundError( + raise lbox.exceptions.ResourceNotFoundError( Entity.ModelConfig, params ) return result["deleteModelConfig"]["success"] @@ -841,8 +582,8 @@ def create_dataset( ) if not validation_result["validateDataset"]["valid"]: - raise labelbox.exceptions.LabelboxError( - f"IAMIntegration was not successfully added to the dataset." + raise lbox.exceptions.LabelboxError( + "IAMIntegration was not successfully added to the dataset." ) except Exception as e: dataset.delete() @@ -1207,7 +948,7 @@ def get_data_row_by_global_key(self, global_key: str) -> DataRow: """ res = self.get_data_row_ids_for_global_keys([global_key]) if res["status"] != "SUCCESS": - raise labelbox.exceptions.ResourceNotFoundError( + raise lbox.exceptions.ResourceNotFoundError( Entity.DataRow, {global_key: global_key} ) data_row_id = res["results"][0] @@ -1235,7 +976,7 @@ def get_model(self, model_id) -> Model: Returns: The sought Model. Raises: - labelbox.exceptions.ResourceNotFoundError: If there is no + lbox.exceptions.ResourceNotFoundError: If there is no Model with the given ID. """ return self._get_single(Entity.Model, model_id) @@ -1478,10 +1219,10 @@ def delete_unused_feature_schema(self, feature_schema_id: str) -> None: + "/feature-schemas/" + urllib.parse.quote(feature_schema_id) ) - response = self._connection.delete(endpoint) + response = self.connection.delete(endpoint) if response.status_code != requests.codes.no_content: - raise labelbox.exceptions.LabelboxError( + raise lbox.exceptions.LabelboxError( "Failed to delete the feature schema, message: " + str(response.json()["message"]) ) @@ -1499,10 +1240,10 @@ def delete_unused_ontology(self, ontology_id: str) -> None: + "/ontologies/" + urllib.parse.quote(ontology_id) ) - response = self._connection.delete(endpoint) + response = self.connection.delete(endpoint) if response.status_code != requests.codes.no_content: - raise labelbox.exceptions.LabelboxError( + raise lbox.exceptions.LabelboxError( "Failed to delete the ontology, message: " + str(response.json()["message"]) ) @@ -1527,12 +1268,12 @@ def update_feature_schema_title( + urllib.parse.quote(feature_schema_id) + "/definition" ) - response = self._connection.patch(endpoint, json={"title": title}) + response = self.connection.patch(endpoint, json={"title": title}) if response.status_code == requests.codes.ok: return self.get_feature_schema(feature_schema_id) else: - raise labelbox.exceptions.LabelboxError( + raise lbox.exceptions.LabelboxError( "Failed to update the feature schema, message: " + str(response.json()["message"]) ) @@ -1561,14 +1302,14 @@ def upsert_feature_schema(self, feature_schema: Dict) -> FeatureSchema: + "/feature-schemas/" + urllib.parse.quote(feature_schema_id) ) - response = self._connection.put( + response = self.connection.put( endpoint, json={"normalized": json.dumps(feature_schema)} ) if response.status_code == requests.codes.ok: return self.get_feature_schema(response.json()["schemaId"]) else: - raise labelbox.exceptions.LabelboxError( + raise lbox.exceptions.LabelboxError( "Failed to upsert the feature schema, message: " + str(response.json()["message"]) ) @@ -1594,9 +1335,9 @@ def insert_feature_schema_into_ontology( + "/feature-schemas/" + urllib.parse.quote(feature_schema_id) ) - response = self._connection.post(endpoint, json={"position": position}) + response = self.connection.post(endpoint, json={"position": position}) if response.status_code != requests.codes.created: - raise labelbox.exceptions.LabelboxError( + raise lbox.exceptions.LabelboxError( "Failed to insert the feature schema into the ontology, message: " + str(response.json()["message"]) ) @@ -1616,12 +1357,12 @@ def get_unused_ontologies(self, after: str = None) -> List[str]: """ endpoint = self.rest_endpoint + "/ontologies/unused" - response = self._connection.get(endpoint, json={"after": after}) + response = self.connection.get(endpoint, json={"after": after}) if response.status_code == requests.codes.ok: return response.json() else: - raise labelbox.exceptions.LabelboxError( + raise lbox.exceptions.LabelboxError( "Failed to get unused ontologies, message: " + str(response.json()["message"]) ) @@ -1641,12 +1382,12 @@ def get_unused_feature_schemas(self, after: str = None) -> List[str]: """ endpoint = self.rest_endpoint + "/feature-schemas/unused" - response = self._connection.get(endpoint, json={"after": after}) + response = self.connection.get(endpoint, json={"after": after}) if response.status_code == requests.codes.ok: return response.json() else: - raise labelbox.exceptions.LabelboxError( + raise lbox.exceptions.LabelboxError( "Failed to get unused feature schemas, message: " + str(response.json()["message"]) ) @@ -1942,12 +1683,12 @@ def _format_failed_rows( elif ( res["assignGlobalKeysToDataRowsResult"]["jobStatus"] == "FAILED" ): - raise labelbox.exceptions.LabelboxError( + raise lbox.exceptions.LabelboxError( "Job assign_global_keys_to_data_rows failed." ) current_time = time.time() if current_time - start_time > timeout_seconds: - raise labelbox.exceptions.TimeoutError( + raise lbox.exceptions.TimeoutError( "Timed out waiting for assign_global_keys_to_data_rows job to complete." ) time.sleep(sleep_time) @@ -2051,12 +1792,12 @@ def _format_failed_rows( return {"status": status, "results": results, "errors": errors} elif res["dataRowsForGlobalKeysResult"]["jobStatus"] == "FAILED": - raise labelbox.exceptions.LabelboxError( + raise lbox.exceptions.LabelboxError( "Job dataRowsForGlobalKeys failed." ) current_time = time.time() if current_time - start_time > timeout_seconds: - raise labelbox.exceptions.TimeoutError( + raise lbox.exceptions.TimeoutError( "Timed out waiting for get_data_rows_for_global_keys job to complete." ) time.sleep(sleep_time) @@ -2155,12 +1896,12 @@ def _format_failed_rows( return {"status": status, "results": results, "errors": errors} elif res["clearGlobalKeysResult"]["jobStatus"] == "FAILED": - raise labelbox.exceptions.LabelboxError( + raise lbox.exceptions.LabelboxError( "Job clearGlobalKeys failed." ) current_time = time.time() if current_time - start_time > timeout_seconds: - raise labelbox.exceptions.TimeoutError( + raise lbox.exceptions.TimeoutError( "Timed out waiting for clear_global_keys job to complete." ) time.sleep(sleep_time) @@ -2209,7 +1950,7 @@ def is_feature_schema_archived( + "/ontologies/" + urllib.parse.quote(ontology_id) ) - response = self._connection.get(ontology_endpoint) + response = self.connection.get(ontology_endpoint) if response.status_code == requests.codes.ok: feature_schema_nodes = response.json()["featureSchemaNodes"] @@ -2225,16 +1966,14 @@ def is_feature_schema_archived( if filtered_feature_schema_nodes: return bool(filtered_feature_schema_nodes[0]["archived"]) else: - raise labelbox.exceptions.LabelboxError( + raise lbox.exceptions.LabelboxError( "The specified feature schema was not in the ontology." ) elif response.status_code == 404: - raise labelbox.exceptions.ResourceNotFoundError( - Ontology, ontology_id - ) + raise lbox.exceptions.ResourceNotFoundError(Ontology, ontology_id) else: - raise labelbox.exceptions.LabelboxError( + raise lbox.exceptions.LabelboxError( "Failed to get the feature schema archived status." ) @@ -2261,9 +2000,7 @@ def get_model_slice(self, slice_id) -> ModelSlice: """ res = self.execute(query_str, {"id": slice_id}) if res is None or res["getSavedQuery"] is None: - raise labelbox.exceptions.ResourceNotFoundError( - ModelSlice, slice_id - ) + raise lbox.exceptions.ResourceNotFoundError(ModelSlice, slice_id) return Entity.ModelSlice(self, res["getSavedQuery"]) @@ -2293,7 +2030,7 @@ def delete_feature_schema_from_ontology( + "/feature-schemas/" + urllib.parse.quote(feature_schema_id) ) - response = self._connection.delete(ontology_endpoint) + response = self.connection.delete(ontology_endpoint) if response.status_code == requests.codes.ok: response_json = response.json() @@ -2310,7 +2047,7 @@ def delete_feature_schema_from_ontology( result.deleted = bool(response_json["deleted"]) return result else: - raise labelbox.exceptions.LabelboxError( + raise lbox.exceptions.LabelboxError( "Failed to remove feature schema from ontology, message: " + str(response.json()["message"]) ) @@ -2335,14 +2072,14 @@ def unarchive_feature_schema_node( + urllib.parse.quote(root_feature_schema_id) + "/unarchive" ) - response = self._connection.patch(ontology_endpoint) + response = self.connection.patch(ontology_endpoint) if response.status_code == requests.codes.ok: if not bool(response.json()["unarchived"]): - raise labelbox.exceptions.LabelboxError( + raise lbox.exceptions.LabelboxError( "Failed unarchive the feature schema." ) else: - raise labelbox.exceptions.LabelboxError( + raise lbox.exceptions.LabelboxError( "Failed unarchive the feature schema node, message: ", response.text, ) @@ -2571,9 +2308,7 @@ def get_embedding_by_name(self, name: str) -> Embedding: for e in embeddings: if e.name == name: return e - raise labelbox.exceptions.ResourceNotFoundError( - Embedding, dict(name=name) - ) + raise lbox.exceptions.ResourceNotFoundError(Embedding, dict(name=name)) def upsert_label_feedback( self, label_id: str, feedback: str, scores: Dict[str, float] @@ -2620,8 +2355,7 @@ def upsert_label_feedback( scores_raw = res["upsertAutoQaLabelFeedback"]["scores"] return [ - labelbox.LabelScore(name=x["name"], score=x["score"]) - for x in scores_raw + LabelScore(name=x["name"], score=x["score"]) for x in scores_raw ] def get_labeling_service_dashboards( @@ -2697,7 +2431,7 @@ def get_task_by_id(self, task_id: str) -> Union[Task, DataUpsertTask]: result = self.execute(query, {"userId": user.uid, "taskId": task_id}) data = result.get("user", {}).get("createdTasks", []) if not data: - raise labelbox.exceptions.ResourceNotFoundError( + raise lbox.exceptions.ResourceNotFoundError( message=f"The task {task_id} does not exist." ) task_data = data[0] diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/raster.py b/libs/labelbox/src/labelbox/data/annotation_types/data/raster.py index ba4c6485f..2debf50e5 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/data/raster.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/data/raster.py @@ -1,18 +1,18 @@ from abc import ABC from io import BytesIO from typing import Callable, Optional, Union -from typing_extensions import Literal -from PIL import Image +import numpy as np +import requests from google.api_core import retry +from lbox.exceptions import InternalServerError +from PIL import Image +from pydantic import BaseModel, ConfigDict, model_validator from requests.exceptions import ConnectTimeout -import requests -import numpy as np +from typing_extensions import Literal -from pydantic import BaseModel, model_validator, ConfigDict -from labelbox.exceptions import InternalServerError -from .base_data import BaseData from ..types import TypedArray +from .base_data import BaseData class RasterData(BaseModel, ABC): diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/text.py b/libs/labelbox/src/labelbox/data/annotation_types/data/text.py index fe4c222d3..6f7a4ff6f 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/data/text.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/data/text.py @@ -1,13 +1,14 @@ from typing import Callable, Optional import requests -from requests.exceptions import ConnectTimeout from google.api_core import retry - +from lbox.exceptions import InternalServerError from pydantic import ConfigDict, model_validator -from labelbox.exceptions import InternalServerError +from requests.exceptions import ConnectTimeout + from labelbox.typing_imports import Literal from labelbox.utils import _NoCoercionMixin + from .base_data import BaseData diff --git a/libs/labelbox/src/labelbox/data/mixins.py b/libs/labelbox/src/labelbox/data/mixins.py index 4440c8a72..fd5ecaa42 100644 --- a/libs/labelbox/src/labelbox/data/mixins.py +++ b/libs/labelbox/src/labelbox/data/mixins.py @@ -1,13 +1,10 @@ -from typing import Optional, List +from typing import List, Optional -from pydantic import BaseModel, field_validator, model_serializer - -from labelbox.exceptions import ( +from lbox.exceptions import ( ConfidenceNotSupportedException, CustomMetricsNotSupportedException, ) - -from warnings import warn +from pydantic import BaseModel, field_validator class ConfidenceMixin(BaseModel): diff --git a/libs/labelbox/src/labelbox/orm/db_object.py b/libs/labelbox/src/labelbox/orm/db_object.py index b210a8a5b..d08aaa719 100644 --- a/libs/labelbox/src/labelbox/orm/db_object.py +++ b/libs/labelbox/src/labelbox/orm/db_object.py @@ -1,17 +1,17 @@ -from dataclasses import dataclass +import json +import logging from datetime import datetime, timezone from functools import wraps -import logging -import json -from labelbox import utils -from labelbox.exceptions import ( - InvalidQueryError, +from lbox.exceptions import ( InvalidAttributeError, + InvalidQueryError, OperationNotSupportedException, ) + +from labelbox import utils from labelbox.orm import query -from labelbox.orm.model import Field, Relationship, Entity +from labelbox.orm.model import Entity, Field, Relationship from labelbox.pagination import PaginatedCollection logger = logging.getLogger(__name__) diff --git a/libs/labelbox/src/labelbox/orm/model.py b/libs/labelbox/src/labelbox/orm/model.py index 1f3ee1d86..535ab0f7d 100644 --- a/libs/labelbox/src/labelbox/orm/model.py +++ b/libs/labelbox/src/labelbox/orm/model.py @@ -1,10 +1,11 @@ from dataclasses import dataclass from enum import Enum, auto -from typing import Dict, List, Union, Any, Type, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, List, Type, Union + +from lbox.exceptions import InvalidAttributeError import labelbox from labelbox import utils -from labelbox.exceptions import InvalidAttributeError from labelbox.orm.comparison import Comparison """ Defines Field, Relationship and Entity. These classes are building diff --git a/libs/labelbox/src/labelbox/orm/query.py b/libs/labelbox/src/labelbox/orm/query.py index 8fa9fea00..8c019fb5e 100644 --- a/libs/labelbox/src/labelbox/orm/query.py +++ b/libs/labelbox/src/labelbox/orm/query.py @@ -1,14 +1,15 @@ from itertools import chain from typing import Any, Dict -from labelbox import utils -from labelbox.exceptions import ( - InvalidQueryError, +from lbox.exceptions import ( InvalidAttributeError, + InvalidQueryError, MalformedQueryException, ) -from labelbox.orm.comparison import LogicalExpression, Comparison -from labelbox.orm.model import Field, Relationship, Entity + +from labelbox import utils +from labelbox.orm.comparison import Comparison, LogicalExpression +from labelbox.orm.model import Entity, Field, Relationship """ Common query creation functionality. """ diff --git a/libs/labelbox/src/labelbox/schema/annotation_import.py b/libs/labelbox/src/labelbox/schema/annotation_import.py index df7f272a3..497ac899d 100644 --- a/libs/labelbox/src/labelbox/schema/annotation_import.py +++ b/libs/labelbox/src/labelbox/schema/annotation_import.py @@ -3,33 +3,34 @@ import logging import os import time +from collections import defaultdict from typing import ( + TYPE_CHECKING, Any, BinaryIO, Dict, List, Optional, Union, - TYPE_CHECKING, cast, ) -from collections import defaultdict -from google.api_core import retry -from labelbox import parser import requests +from google.api_core import retry +from lbox.exceptions import ApiLimitError, NetworkError, ResourceNotFoundError from tqdm import tqdm # type: ignore import labelbox +from labelbox import parser from labelbox.orm import query from labelbox.orm.db_object import DbObject from labelbox.orm.model import Field, Relationship -from labelbox.utils import is_exactly_one_set from labelbox.schema.confidence_presence_checker import ( LabelsConfidencePresenceChecker, ) from labelbox.schema.enums import AnnotationImportState from labelbox.schema.serialization import serialize_labels +from labelbox.utils import is_exactly_one_set if TYPE_CHECKING: from labelbox.types import Label @@ -140,9 +141,9 @@ def wait_until_done( @retry.Retry( predicate=retry.if_exception_type( - labelbox.exceptions.ApiLimitError, - labelbox.exceptions.TimeoutError, - labelbox.exceptions.NetworkError, + ApiLimitError, + TimeoutError, + NetworkError, ) ) def __backoff_refresh(self) -> None: @@ -435,9 +436,7 @@ def from_name( } response = client.execute(query_str, params) if response is None: - raise labelbox.exceptions.ResourceNotFoundError( - MEAPredictionImport, params - ) + raise ResourceNotFoundError(MEAPredictionImport, params) response = response["modelErrorAnalysisPredictionImport"] if as_json: return response @@ -560,9 +559,7 @@ def from_name( } response = client.execute(query_str, params) if response is None: - raise labelbox.exceptions.ResourceNotFoundError( - MALPredictionImport, params - ) + raise ResourceNotFoundError(MALPredictionImport, params) response = response["meaToMalPredictionImport"] if as_json: return response @@ -709,9 +706,7 @@ def from_name( } response = client.execute(query_str, params) if response is None: - raise labelbox.exceptions.ResourceNotFoundError( - MALPredictionImport, params - ) + raise ResourceNotFoundError(MALPredictionImport, params) response = response["modelAssistedLabelingPredictionImport"] if as_json: return response @@ -885,7 +880,7 @@ def from_name( } response = client.execute(query_str, params) if response is None: - raise labelbox.exceptions.ResourceNotFoundError(LabelImport, params) + raise ResourceNotFoundError(LabelImport, params) response = response["labelImport"] if as_json: return response diff --git a/libs/labelbox/src/labelbox/schema/batch.py b/libs/labelbox/src/labelbox/schema/batch.py index 316732a67..99e4c4908 100644 --- a/libs/labelbox/src/labelbox/schema/batch.py +++ b/libs/labelbox/src/labelbox/schema/batch.py @@ -1,15 +1,11 @@ -from typing import Generator, TYPE_CHECKING +import logging +from typing import TYPE_CHECKING + +from lbox.exceptions import ResourceNotFoundError -from labelbox.orm.db_object import DbObject, experimental from labelbox.orm import query +from labelbox.orm.db_object import DbObject from labelbox.orm.model import Entity, Field, Relationship -from labelbox.exceptions import LabelboxError, ResourceNotFoundError -from io import StringIO -from labelbox import parser -import requests -import logging -import time -import warnings if TYPE_CHECKING: from labelbox import Project diff --git a/libs/labelbox/src/labelbox/schema/dataset.py b/libs/labelbox/src/labelbox/schema/dataset.py index 04877c885..6d879e767 100644 --- a/libs/labelbox/src/labelbox/schema/dataset.py +++ b/libs/labelbox/src/labelbox/schema/dataset.py @@ -1,57 +1,44 @@ -from datetime import datetime -from typing import Dict, Generator, List, Optional, Any, Final, Tuple, Union -import os import json import logging -from collections.abc import Iterable -from string import Template -import time +import os import warnings - -from labelbox import parser -from itertools import islice - from concurrent.futures import ThreadPoolExecutor, as_completed -from io import StringIO -import requests +from itertools import islice +from string import Template +from typing import Any, Dict, List, Optional, Tuple, Union -from labelbox.exceptions import ( +from lbox.exceptions import ( InvalidQueryError, LabelboxError, - ResourceNotFoundError, ResourceCreationError, -) + ResourceNotFoundError, +) # type: ignore + +import labelbox.schema.internal.data_row_uploader as data_row_uploader +from labelbox.orm import query from labelbox.orm.comparison import Comparison -from labelbox.orm.db_object import DbObject, Updateable, Deletable, experimental +from labelbox.orm.db_object import DbObject, Deletable, Updateable from labelbox.orm.model import Entity, Field, Relationship -from labelbox.orm import query -from labelbox.exceptions import MalformedQueryException from labelbox.pagination import PaginatedCollection from labelbox.schema.data_row import DataRow -from labelbox.schema.embedding import EmbeddingVector from labelbox.schema.export_filters import DatasetExportFilters, build_filters from labelbox.schema.export_params import ( CatalogExportParams, validate_catalog_export_params, ) from labelbox.schema.export_task import ExportTask -from labelbox.schema.identifiable import UniqueId, GlobalKey -from labelbox.schema.task import Task, DataUpsertTask -from labelbox.schema.user import User from labelbox.schema.iam_integration import IAMIntegration +from labelbox.schema.identifiable import GlobalKey, UniqueId from labelbox.schema.internal.data_row_upsert_item import ( + DataRowCreateItem, DataRowItemBase, DataRowUpsertItem, - DataRowCreateItem, -) -import labelbox.schema.internal.data_row_uploader as data_row_uploader -from labelbox.schema.internal.descriptor_file_creator import ( - DescriptorFileCreator, ) from labelbox.schema.internal.datarow_upload_constants import ( FILE_UPLOAD_THREAD_COUNT, UPSERT_CHUNK_SIZE_BYTES, ) +from labelbox.schema.task import DataUpsertTask, Task logger = logging.getLogger(__name__) @@ -324,7 +311,7 @@ def data_rows_for_external_id( A list of `DataRow` with the given ID. Raises: - labelbox.exceptions.ResourceNotFoundError: If there is no `DataRow` + lbox.exceptions.ResourceNotFoundError: If there is no `DataRow` in this `DataSet` with the given external ID, or if there are multiple `DataRows` for it. """ @@ -350,7 +337,7 @@ def data_row_for_external_id(self, external_id) -> "DataRow": A single `DataRow` with the given ID. Raises: - labelbox.exceptions.ResourceNotFoundError: If there is no `DataRow` + lbox.exceptions.ResourceNotFoundError: If there is no `DataRow` in this `DataSet` with the given external ID, or if there are multiple `DataRows` for it. """ @@ -359,7 +346,7 @@ def data_row_for_external_id(self, external_id) -> "DataRow": ) if len(data_rows) > 1: logger.warning( - f"More than one data_row has the provided external_id : `%s`. Use function data_rows_for_external_id to fetch all", + "More than one data_row has the provided external_id : `%s`. Use function data_rows_for_external_id to fetch all", external_id, ) return data_rows[0] diff --git a/libs/labelbox/src/labelbox/schema/foundry/foundry_client.py b/libs/labelbox/src/labelbox/schema/foundry/foundry_client.py index 914a363c7..315fc831c 100644 --- a/libs/labelbox/src/labelbox/schema/foundry/foundry_client.py +++ b/libs/labelbox/src/labelbox/schema/foundry/foundry_client.py @@ -1,6 +1,8 @@ from typing import Union -from labelbox import exceptions -from labelbox.schema.foundry.app import App, APP_FIELD_NAMES + +from lbox import exceptions # type: ignore + +from labelbox.schema.foundry.app import APP_FIELD_NAMES, App from labelbox.schema.identifiables import DataRowIds, GlobalKeys, IdType from labelbox.schema.task import Task @@ -51,7 +53,7 @@ def _get_app(self, id: str) -> App: try: response = self.client.execute(query_str, params) - except exceptions.InvalidQueryError as e: + except exceptions.InvalidQueryError: raise exceptions.ResourceNotFoundError(App, params) except Exception as e: raise exceptions.LabelboxError(f"Unable to get app with id {id}", e) diff --git a/libs/labelbox/src/labelbox/schema/internal/descriptor_file_creator.py b/libs/labelbox/src/labelbox/schema/internal/descriptor_file_creator.py index ce3ce4b35..a9c1c250c 100644 --- a/libs/labelbox/src/labelbox/schema/internal/descriptor_file_creator.py +++ b/libs/labelbox/src/labelbox/schema/internal/descriptor_file_creator.py @@ -1,25 +1,21 @@ import json import os -import sys from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import TYPE_CHECKING, Generator, Iterable, List -from typing import Iterable, List, Generator +from lbox.exceptions import ( + InvalidAttributeError, + InvalidQueryError, +) # type: ignore -from labelbox.exceptions import InvalidQueryError -from labelbox.exceptions import InvalidAttributeError -from labelbox.exceptions import MalformedQueryException -from labelbox.orm.model import Entity -from labelbox.orm.model import Field +from labelbox.orm.model import Entity, Field from labelbox.schema.embedding import EmbeddingVector -from labelbox.schema.internal.datarow_upload_constants import ( - FILE_UPLOAD_THREAD_COUNT, -) from labelbox.schema.internal.data_row_upsert_item import ( DataRowItemBase, - DataRowUpsertItem, ) - -from typing import TYPE_CHECKING +from labelbox.schema.internal.datarow_upload_constants import ( + FILE_UPLOAD_THREAD_COUNT, +) if TYPE_CHECKING: from labelbox import Client @@ -161,7 +157,7 @@ def check_message_keys(message): ] ) for key in message.keys(): - if not key in accepted_message_keys: + if key not in accepted_message_keys: raise KeyError( f"Invalid {key} key found! Accepted keys in messages list is {accepted_message_keys}" ) diff --git a/libs/labelbox/src/labelbox/schema/labeling_service.py b/libs/labelbox/src/labelbox/schema/labeling_service.py index 0fbd15bd1..571ef8099 100644 --- a/libs/labelbox/src/labelbox/schema/labeling_service.py +++ b/libs/labelbox/src/labelbox/schema/labeling_service.py @@ -1,13 +1,11 @@ from datetime import datetime from typing import Any -from typing_extensions import Annotated -from pydantic import BaseModel, Field +from lbox.exceptions import ResourceNotFoundError -from labelbox.exceptions import ResourceNotFoundError -from labelbox.utils import _CamelCaseMixin from labelbox.schema.labeling_service_dashboard import LabelingServiceDashboard from labelbox.schema.labeling_service_status import LabelingServiceStatus +from labelbox.utils import _CamelCaseMixin from ..annotated_types import Cuid diff --git a/libs/labelbox/src/labelbox/schema/labeling_service_dashboard.py b/libs/labelbox/src/labelbox/schema/labeling_service_dashboard.py index c5e1fa11e..e2c6fa26b 100644 --- a/libs/labelbox/src/labelbox/schema/labeling_service_dashboard.py +++ b/libs/labelbox/src/labelbox/schema/labeling_service_dashboard.py @@ -1,16 +1,17 @@ -from string import Template from datetime import datetime +from string import Template from typing import Any, Dict, List, Optional, Union -from labelbox.exceptions import ResourceNotFoundError +from lbox.exceptions import ResourceNotFoundError +from pydantic import BaseModel, Field, model_validator + from labelbox.pagination import PaginatedCollection -from pydantic import BaseModel, model_validator, Field +from labelbox.schema.labeling_service_status import LabelingServiceStatus +from labelbox.schema.media_type import MediaType from labelbox.schema.search_filters import SearchFilter, build_search_filter -from labelbox.utils import _CamelCaseMixin +from labelbox.utils import _CamelCaseMixin, sentence_case + from .ontology_kind import EditorTaskType -from labelbox.schema.media_type import MediaType -from labelbox.schema.labeling_service_status import LabelingServiceStatus -from labelbox.utils import sentence_case GRAPHQL_QUERY_SELECTIONS = """ id diff --git a/libs/labelbox/src/labelbox/schema/ontology.py b/libs/labelbox/src/labelbox/schema/ontology.py index efe32611b..1726b514c 100644 --- a/libs/labelbox/src/labelbox/schema/ontology.py +++ b/libs/labelbox/src/labelbox/schema/ontology.py @@ -1,17 +1,18 @@ # type: ignore import colorsys +import json +import warnings from dataclasses import dataclass, field from enum import Enum -from typing import Any, Dict, List, Optional, Union, Type +from typing import Any, Dict, List, Optional, Type, Union + +from lbox.exceptions import InconsistentOntologyException +from pydantic import StringConstraints from typing_extensions import Annotated -import warnings -from labelbox.exceptions import InconsistentOntologyException from labelbox.orm.db_object import DbObject from labelbox.orm.model import Field, Relationship -import json -from pydantic import StringConstraints FeatureSchemaId: Type[str] = Annotated[ str, StringConstraints(min_length=25, max_length=25) diff --git a/libs/labelbox/src/labelbox/schema/organization.py b/libs/labelbox/src/labelbox/schema/organization.py index 71e715f11..bd416e997 100644 --- a/libs/labelbox/src/labelbox/schema/organization.py +++ b/libs/labelbox/src/labelbox/schema/organization.py @@ -1,21 +1,21 @@ -import json -from typing import TYPE_CHECKING, List, Optional, Dict +from typing import TYPE_CHECKING, Dict, List, Optional + +from lbox.exceptions import LabelboxError -from labelbox.exceptions import LabelboxError from labelbox import utils -from labelbox.orm.db_object import DbObject, query, Entity +from labelbox.orm.db_object import DbObject, Entity, query from labelbox.orm.model import Field, Relationship from labelbox.schema.invite import InviteLimit from labelbox.schema.resource_tag import ResourceTag if TYPE_CHECKING: from labelbox import ( - Role, - User, - ProjectRole, + IAMIntegration, Invite, InviteLimit, - IAMIntegration, + ProjectRole, + Role, + User, ) diff --git a/libs/labelbox/src/labelbox/schema/project.py b/libs/labelbox/src/labelbox/schema/project.py index f2de4db5e..5746e8011 100644 --- a/libs/labelbox/src/labelbox/schema/project.py +++ b/libs/labelbox/src/labelbox/schema/project.py @@ -1,10 +1,10 @@ import json import logging -from string import Template import time import warnings from collections import namedtuple from datetime import datetime, timezone +from string import Template from typing import ( TYPE_CHECKING, Any, @@ -16,20 +16,15 @@ overload, ) -from labelbox.schema.labeling_service import ( - LabelingService, - LabelingServiceStatus, -) -from labelbox.schema.labeling_service_dashboard import LabelingServiceDashboard - -from labelbox import utils -from labelbox.exceptions import error_message_for_unparsed_graphql_error -from labelbox.exceptions import ( +from lbox.exceptions import ( InvalidQueryError, LabelboxError, ProcessingWaitTimeout, ResourceNotFoundError, -) + error_message_for_unparsed_graphql_error, +) # type: ignore + +from labelbox import utils from labelbox.orm import query from labelbox.orm.db_object import DbObject, Deletable, Updateable, experimental from labelbox.orm.model import Entity, Field, Relationship @@ -46,21 +41,26 @@ from labelbox.schema.id_type import IdType from labelbox.schema.identifiable import DataRowIdentifier, GlobalKey, UniqueId from labelbox.schema.identifiables import DataRowIdentifiers, UniqueIds +from labelbox.schema.labeling_service import ( + LabelingService, + LabelingServiceStatus, +) +from labelbox.schema.labeling_service_dashboard import LabelingServiceDashboard from labelbox.schema.media_type import MediaType from labelbox.schema.model_config import ModelConfig -from labelbox.schema.project_model_config import ProjectModelConfig -from labelbox.schema.queue_mode import QueueMode -from labelbox.schema.resource_tag import ResourceTag -from labelbox.schema.task import Task -from labelbox.schema.task_queue import TaskQueue from labelbox.schema.ontology_kind import ( EditorTaskType, UploadType, ) +from labelbox.schema.project_model_config import ProjectModelConfig from labelbox.schema.project_overview import ( ProjectOverview, ProjectOverviewDetailed, ) +from labelbox.schema.queue_mode import QueueMode +from labelbox.schema.resource_tag import ResourceTag +from labelbox.schema.task import Task +from labelbox.schema.task_queue import TaskQueue if TYPE_CHECKING: pass @@ -773,7 +773,7 @@ def create_batch( Returns: the created batch Raises: - labelbox.exceptions.ValueError if a project is not batch mode, if the project is auto data generation, if the batch exceeds 100k data rows + lbox.exceptions.ValueError if a project is not batch mode, if the project is auto data generation, if the batch exceeds 100k data rows """ # @TODO: make this automatic? if self.queue_mode != QueueMode.Batch: diff --git a/libs/labelbox/src/labelbox/schema/project_model_config.py b/libs/labelbox/src/labelbox/schema/project_model_config.py index 9b6d8a0bb..c8773abf9 100644 --- a/libs/labelbox/src/labelbox/schema/project_model_config.py +++ b/libs/labelbox/src/labelbox/schema/project_model_config.py @@ -1,10 +1,11 @@ -from labelbox.orm.db_object import DbObject -from labelbox.orm.model import Field, Relationship -from labelbox.exceptions import ( +from lbox.exceptions import ( LabelboxError, error_message_for_unparsed_graphql_error, ) +from labelbox.orm.db_object import DbObject +from labelbox.orm.model import Field, Relationship + class ProjectModelConfig(DbObject): """A ProjectModelConfig represents an association between a project and a single model config. diff --git a/libs/labelbox/src/labelbox/schema/task.py b/libs/labelbox/src/labelbox/schema/task.py index 9d7a26e1d..f996ae05d 100644 --- a/libs/labelbox/src/labelbox/schema/task.py +++ b/libs/labelbox/src/labelbox/schema/task.py @@ -1,14 +1,14 @@ import json import logging -import requests import time -from typing import TYPE_CHECKING, Callable, Optional, Dict, Any, List, Union -from labelbox import parser +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union -from labelbox.exceptions import ResourceNotFoundError -from labelbox.orm.db_object import DbObject -from labelbox.orm.model import Field, Relationship, Entity +import requests +from lbox.exceptions import ResourceNotFoundError +from labelbox import parser +from labelbox.orm.db_object import DbObject +from labelbox.orm.model import Entity, Field, Relationship from labelbox.pagination import PaginatedCollection from labelbox.schema.internal.datarow_upload_constants import ( DOWNLOAD_RESULT_PAGE_SIZE, diff --git a/libs/labelbox/src/labelbox/schema/user_group.py b/libs/labelbox/src/labelbox/schema/user_group.py index 9d506bf92..5f1fb6eed 100644 --- a/libs/labelbox/src/labelbox/schema/user_group.py +++ b/libs/labelbox/src/labelbox/schema/user_group.py @@ -1,21 +1,22 @@ -from enum import Enum -from typing import Set, Iterator from collections import defaultdict +from enum import Enum +from typing import Iterator, Set -from labelbox import Client -from labelbox.exceptions import ResourceCreationError -from labelbox.schema.user import User -from labelbox.schema.project import Project -from labelbox.exceptions import ( - UnprocessableEntityError, +from lbox.exceptions import ( MalformedQueryException, + ResourceCreationError, ResourceNotFoundError, + UnprocessableEntityError, ) -from labelbox.schema.queue_mode import QueueMode -from labelbox.schema.ontology_kind import EditorTaskType -from labelbox.schema.media_type import MediaType from pydantic import BaseModel, ConfigDict +from labelbox import Client +from labelbox.schema.media_type import MediaType +from labelbox.schema.ontology_kind import EditorTaskType +from labelbox.schema.project import Project +from labelbox.schema.queue_mode import QueueMode +from labelbox.schema.user import User + class UserGroupColor(Enum): """ diff --git a/libs/labelbox/tests/conftest.py b/libs/labelbox/tests/conftest.py index 6d13a8d83..d25544034 100644 --- a/libs/labelbox/tests/conftest.py +++ b/libs/labelbox/tests/conftest.py @@ -1,35 +1,38 @@ -from datetime import datetime -from random import randint -from string import ascii_letters - import json import os import re -import uuid import time -from labelbox.schema.project import Project -import requests -from labelbox.schema.ontology import Ontology -import pytest -from types import SimpleNamespace -from typing import Type +import uuid +from datetime import datetime from enum import Enum -from typing import Tuple +from random import randint +from string import ascii_letters +from types import SimpleNamespace +from typing import Tuple, Type + +import pytest +import requests -from labelbox import Dataset, DataRow -from labelbox import MediaType +from labelbox import ( + Classification, + Client, + DataRow, + Dataset, + LabelingFrontend, + MediaType, + OntologyBuilder, + Option, + Tool, +) from labelbox.orm import query from labelbox.pagination import PaginatedCollection +from labelbox.schema.annotation_import import LabelImport +from labelbox.schema.enums import AnnotationImportState from labelbox.schema.invite import Invite +from labelbox.schema.ontology import Ontology +from labelbox.schema.project import Project from labelbox.schema.quality_mode import QualityMode from labelbox.schema.queue_mode import QueueMode -from labelbox import Client - -from labelbox import LabelingFrontend -from labelbox import OntologyBuilder, Tool, Option, Classification -from labelbox.schema.annotation_import import LabelImport -from labelbox.schema.enums import AnnotationImportState -from labelbox.exceptions import LabelboxError IMG_URL = "https://picsum.photos/200/300.jpg" MASKABLE_IMG_URL = "https://storage.googleapis.com/labelbox-datasets/image_sample_data/2560px-Kitano_Street_Kobe01s5s4110.jpeg" diff --git a/libs/labelbox/tests/data/annotation_import/test_model.py b/libs/labelbox/tests/data/annotation_import/test_model.py index dcfe9ef2c..56b2e07c4 100644 --- a/libs/labelbox/tests/data/annotation_import/test_model.py +++ b/libs/labelbox/tests/data/annotation_import/test_model.py @@ -1,7 +1,7 @@ import pytest +from lbox.exceptions import ResourceNotFoundError from labelbox import Model -from labelbox.exceptions import ResourceNotFoundError def test_model(client, configured_project, rand_gen): diff --git a/libs/labelbox/tests/data/annotation_types/test_annotation.py b/libs/labelbox/tests/data/annotation_types/test_annotation.py index 8cdeac9ba..01547bd56 100644 --- a/libs/labelbox/tests/data/annotation_types/test_annotation.py +++ b/libs/labelbox/tests/data/annotation_types/test_annotation.py @@ -1,18 +1,20 @@ import pytest +from lbox.exceptions import ConfidenceNotSupportedException +from pydantic import ValidationError from labelbox.data.annotation_types import ( - Text, - Point, - Line, ClassificationAnnotation, + Line, ObjectAnnotation, + Point, + Text, TextEntity, ) -from labelbox.data.annotation_types.video import VideoObjectAnnotation from labelbox.data.annotation_types.geometry.rectangle import Rectangle -from labelbox.data.annotation_types.video import VideoClassificationAnnotation -from labelbox.exceptions import ConfidenceNotSupportedException -from pydantic import ValidationError +from labelbox.data.annotation_types.video import ( + VideoClassificationAnnotation, + VideoObjectAnnotation, +) def test_annotation(): diff --git a/libs/labelbox/tests/data/test_data_row_metadata.py b/libs/labelbox/tests/data/test_data_row_metadata.py index 891cab9be..b32077f0b 100644 --- a/libs/labelbox/tests/data/test_data_row_metadata.py +++ b/libs/labelbox/tests/data/test_data_row_metadata.py @@ -1,18 +1,18 @@ +import uuid from datetime import datetime import pytest -import uuid +from lbox.exceptions import MalformedQueryException from labelbox import Dataset -from labelbox.exceptions import MalformedQueryException -from labelbox.schema.identifiables import GlobalKeys, UniqueIds from labelbox.schema.data_row_metadata import ( - DataRowMetadataField, DataRowMetadata, + DataRowMetadataField, DataRowMetadataKind, DataRowMetadataOntology, _parse_metadata_schema, ) +from labelbox.schema.identifiables import GlobalKeys, UniqueIds INVALID_SCHEMA_ID = "1" * 25 FAKE_SCHEMA_ID = "0" * 25 diff --git a/libs/labelbox/tests/integration/schema/test_user_group.py b/libs/labelbox/tests/integration/schema/test_user_group.py index 6aebd4e89..60645452b 100644 --- a/libs/labelbox/tests/integration/schema/test_user_group.py +++ b/libs/labelbox/tests/integration/schema/test_user_group.py @@ -1,14 +1,14 @@ -import pytest -import faker from uuid import uuid4 -from labelbox import Client -from labelbox.schema.user_group import UserGroup, UserGroupColor -from labelbox.exceptions import ( - ResourceNotFoundError, + +import faker +import pytest +from lbox.exceptions import ( ResourceCreationError, - UnprocessableEntityError, + ResourceNotFoundError, ) +from labelbox.schema.user_group import UserGroup, UserGroupColor + data = faker.Faker() diff --git a/libs/labelbox/tests/integration/test_batch.py b/libs/labelbox/tests/integration/test_batch.py index 8a8707175..f63b4c3d9 100644 --- a/libs/labelbox/tests/integration/test_batch.py +++ b/libs/labelbox/tests/integration/test_batch.py @@ -1,16 +1,16 @@ -import time from typing import List from uuid import uuid4 -import pytest -from labelbox import Dataset, Project -from labelbox.exceptions import ( - ProcessingWaitTimeout, +import pytest +from lbox.exceptions import ( + LabelboxError, MalformedQueryException, + ProcessingWaitTimeout, ResourceConflict, - LabelboxError, ) +from labelbox import Dataset, Project + def get_data_row_ids(ds: Dataset): return [dr.uid for dr in list(ds.data_rows())] diff --git a/libs/labelbox/tests/integration/test_chat_evaluation_ontology_project.py b/libs/labelbox/tests/integration/test_chat_evaluation_ontology_project.py index 47e39e2cf..3e462d677 100644 --- a/libs/labelbox/tests/integration/test_chat_evaluation_ontology_project.py +++ b/libs/labelbox/tests/integration/test_chat_evaluation_ontology_project.py @@ -1,9 +1,9 @@ -import pytest from unittest.mock import patch +import pytest + from labelbox import MediaType from labelbox.schema.ontology_kind import OntologyKind -from labelbox.exceptions import MalformedQueryException def test_create_chat_evaluation_ontology_project( diff --git a/libs/labelbox/tests/integration/test_client_errors.py b/libs/labelbox/tests/integration/test_client_errors.py index 64b8fb626..c8721dfc4 100644 --- a/libs/labelbox/tests/integration/test_client_errors.py +++ b/libs/labelbox/tests/integration/test_client_errors.py @@ -1,44 +1,45 @@ -from multiprocessing.dummy import Pool import os import time +from multiprocessing.dummy import Pool + +import lbox.exceptions import pytest from google.api_core.exceptions import RetryError -from labelbox import Project, Dataset, User import labelbox.client -import labelbox.exceptions +from labelbox import Project, User def test_missing_api_key(): - key = os.environ.get(labelbox.client._LABELBOX_API_KEY, None) + key = os.environ.get(lbox.request_client._LABELBOX_API_KEY, None) if key is not None: - del os.environ[labelbox.client._LABELBOX_API_KEY] + del os.environ[lbox.request_client._LABELBOX_API_KEY] - with pytest.raises(labelbox.exceptions.AuthenticationError) as excinfo: + with pytest.raises(lbox.exceptions.AuthenticationError) as excinfo: labelbox.client.Client() assert excinfo.value.message == "Labelbox API key not provided" if key is not None: - os.environ[labelbox.client._LABELBOX_API_KEY] = key + os.environ[lbox.request_client._LABELBOX_API_KEY] = key def test_bad_key(rand_gen): bad_key = "BAD_KEY_" + rand_gen(str) client = labelbox.client.Client(api_key=bad_key) - with pytest.raises(labelbox.exceptions.AuthenticationError) as excinfo: + with pytest.raises(lbox.exceptions.AuthenticationError) as excinfo: client.create_project(name=rand_gen(str)) def test_syntax_error(client): - with pytest.raises(labelbox.exceptions.InvalidQueryError) as excinfo: + with pytest.raises(lbox.exceptions.InvalidQueryError) as excinfo: client.execute("asda", check_naming=False) assert excinfo.value.message.startswith("Syntax Error:") def test_semantic_error(client): - with pytest.raises(labelbox.exceptions.InvalidQueryError) as excinfo: + with pytest.raises(lbox.exceptions.InvalidQueryError) as excinfo: client.execute("query {bbb {id}}", check_naming=False) assert excinfo.value.message.startswith('Cannot query field "bbb"') @@ -58,7 +59,7 @@ def test_timeout_error(client, project): def test_query_complexity_error(client): - with pytest.raises(labelbox.exceptions.ValidationFailedError) as excinfo: + with pytest.raises(lbox.exceptions.ValidationFailedError) as excinfo: client.execute( "{projects {datasets {dataRows {labels {id}}}}}", check_naming=False ) @@ -66,16 +67,16 @@ def test_query_complexity_error(client): def test_resource_not_found_error(client): - with pytest.raises(labelbox.exceptions.ResourceNotFoundError): + with pytest.raises(lbox.exceptions.ResourceNotFoundError): client.get_project("invalid project ID") def test_network_error(client): client = labelbox.client.Client( - api_key=client.api_key, endpoint="not_a_valid_URL" + api_key=client._request_client.api_key, endpoint="not_a_valid_URL" ) - with pytest.raises(labelbox.exceptions.NetworkError) as excinfo: + with pytest.raises(lbox.exceptions.NetworkError) as excinfo: client.create_project(name="Project name") @@ -84,20 +85,20 @@ def test_invalid_attribute_error( rand_gen, ): # Creation - with pytest.raises(labelbox.exceptions.InvalidAttributeError) as excinfo: + with pytest.raises(lbox.exceptions.InvalidAttributeError) as excinfo: client.create_project(name="Name", invalid_field="Whatever") assert excinfo.value.db_object_type == Project assert excinfo.value.field == "invalid_field" # Update project = client.create_project(name=rand_gen(str)) - with pytest.raises(labelbox.exceptions.InvalidAttributeError) as excinfo: + with pytest.raises(lbox.exceptions.InvalidAttributeError) as excinfo: project.update(invalid_field="Whatever") assert excinfo.value.db_object_type == Project assert excinfo.value.field == "invalid_field" # Top-level-fetch - with pytest.raises(labelbox.exceptions.InvalidAttributeError) as excinfo: + with pytest.raises(lbox.exceptions.InvalidAttributeError) as excinfo: client.get_projects(where=User.email == "email") assert excinfo.value.db_object_type == Project assert excinfo.value.field == {User.email} @@ -108,7 +109,7 @@ def test_api_limit_error(client): def get(arg): try: return client.get_user() - except labelbox.exceptions.ApiLimitError as e: + except lbox.exceptions.ApiLimitError as e: return e # Rate limited at 1500 + buffer @@ -120,7 +121,7 @@ def get(arg): elapsed = time.time() - start assert elapsed < 60, "Didn't finish fast enough" - assert labelbox.exceptions.ApiLimitError in {type(r) for r in results} + assert lbox.exceptions.ApiLimitError in {type(r) for r in results} # Sleep at the end of this test to allow other tests to execute. time.sleep(60) diff --git a/libs/labelbox/tests/integration/test_data_row_delete_metadata.py b/libs/labelbox/tests/integration/test_data_row_delete_metadata.py index 2df860181..a2ffd31ba 100644 --- a/libs/labelbox/tests/integration/test_data_row_delete_metadata.py +++ b/libs/labelbox/tests/integration/test_data_row_delete_metadata.py @@ -1,13 +1,13 @@ -from datetime import datetime, timezone import uuid +from datetime import datetime, timezone import pytest +from lbox.exceptions import MalformedQueryException -from labelbox import DataRow, Dataset, Client, DataRowMetadataOntology -from labelbox.exceptions import MalformedQueryException +from labelbox import Client, DataRow, DataRowMetadataOntology, Dataset from labelbox.schema.data_row_metadata import ( - DataRowMetadataField, DataRowMetadata, + DataRowMetadataField, DataRowMetadataKind, DeleteDataRowMetadata, ) diff --git a/libs/labelbox/tests/integration/test_data_rows.py b/libs/labelbox/tests/integration/test_data_rows.py index 7f69c2995..9b92ae146 100644 --- a/libs/labelbox/tests/integration/test_data_rows.py +++ b/libs/labelbox/tests/integration/test_data_rows.py @@ -1,25 +1,25 @@ -from tempfile import NamedTemporaryFile -import uuid -from datetime import datetime import json -import requests import os - +import uuid +from datetime import datetime +from tempfile import NamedTemporaryFile from unittest.mock import patch -import pytest -from labelbox.schema.media_type import MediaType -from labelbox import DataRow, AssetAttachment -from labelbox.exceptions import ( +import pytest +import requests +from lbox.exceptions import ( + InvalidQueryError, MalformedQueryException, ResourceCreationError, - InvalidQueryError, ) -from labelbox.schema.task import Task, DataUpsertTask + +from labelbox import AssetAttachment, DataRow from labelbox.schema.data_row_metadata import ( DataRowMetadataField, DataRowMetadataKind, ) +from labelbox.schema.media_type import MediaType +from labelbox.schema.task import Task SPLIT_SCHEMA_ID = "cko8sbczn0002h2dkdaxb5kal" TEST_SPLIT_ID = "cko8scbz70005h2dkastwhgqt" diff --git a/libs/labelbox/tests/integration/test_dataset.py b/libs/labelbox/tests/integration/test_dataset.py index 89210d6c9..a32c5541d 100644 --- a/libs/labelbox/tests/integration/test_dataset.py +++ b/libs/labelbox/tests/integration/test_dataset.py @@ -1,9 +1,10 @@ +from unittest.mock import MagicMock + import pytest import requests -from unittest.mock import MagicMock -from labelbox import Dataset -from labelbox.exceptions import ResourceNotFoundError, ResourceCreationError +from lbox.exceptions import ResourceCreationError, ResourceNotFoundError +from labelbox import Dataset from labelbox.schema.internal.descriptor_file_creator import ( DescriptorFileCreator, ) diff --git a/libs/labelbox/tests/integration/test_embedding.py b/libs/labelbox/tests/integration/test_embedding.py index 1b54ab81c..41f7ed3de 100644 --- a/libs/labelbox/tests/integration/test_embedding.py +++ b/libs/labelbox/tests/integration/test_embedding.py @@ -2,12 +2,12 @@ import random import threading from tempfile import NamedTemporaryFile -from typing import List, Dict, Any +from typing import Any, Dict, List +import lbox.exceptions import pytest -import labelbox.exceptions -from labelbox import Client, Dataset, DataRow +from labelbox import Client, DataRow, Dataset from labelbox.schema.embedding import Embedding @@ -23,7 +23,7 @@ def test_get_embedding_by_id(client: Client, embedding: Embedding): def test_get_embedding_by_name_not_found(client: Client): - with pytest.raises(labelbox.exceptions.ResourceNotFoundError): + with pytest.raises(lbox.exceptions.ResourceNotFoundError): client.get_embedding_by_name("does-not-exist") diff --git a/libs/labelbox/tests/integration/test_filtering.py b/libs/labelbox/tests/integration/test_filtering.py index 1c37227b2..cb6f11baa 100644 --- a/libs/labelbox/tests/integration/test_filtering.py +++ b/libs/labelbox/tests/integration/test_filtering.py @@ -1,7 +1,7 @@ import pytest +from lbox.exceptions import InvalidQueryError from labelbox import Project -from labelbox.exceptions import InvalidQueryError from labelbox.schema.queue_mode import QueueMode diff --git a/libs/labelbox/tests/integration/test_foundry.py b/libs/labelbox/tests/integration/test_foundry.py index 83c4effc5..b9fd1b6f3 100644 --- a/libs/labelbox/tests/integration/test_foundry.py +++ b/libs/labelbox/tests/integration/test_foundry.py @@ -1,7 +1,8 @@ -import labelbox as lb import pytest -from labelbox.schema.foundry.app import App +from lbox.exceptions import LabelboxError, ResourceNotFoundError +import labelbox as lb +from labelbox.schema.foundry.app import App from labelbox.schema.foundry.foundry_client import FoundryClient # Yolo object detection model id @@ -97,7 +98,7 @@ def test_get_app(foundry_client, app): def test_get_app_with_invalid_id(foundry_client): - with pytest.raises(lb.exceptions.ResourceNotFoundError): + with pytest.raises(ResourceNotFoundError): foundry_client._get_app("invalid-id") @@ -144,7 +145,7 @@ def test_run_foundry_app_returns_model_run_id( def test_run_foundry_with_invalid_data_row_id(foundry_client, app, random_str): invalid_datarow_id = "invalid-global-key" data_rows = lb.GlobalKeys([invalid_datarow_id]) - with pytest.raises(lb.exceptions.LabelboxError) as exception: + with pytest.raises(LabelboxError) as exception: foundry_client.run_app( model_run_name=f"test-app-with-invalid-datarow-id-{random_str}", data_rows=data_rows, @@ -156,7 +157,7 @@ def test_run_foundry_with_invalid_data_row_id(foundry_client, app, random_str): def test_run_foundry_with_invalid_global_key(foundry_client, app, random_str): invalid_global_key = "invalid-global-key" data_rows = lb.GlobalKeys([invalid_global_key]) - with pytest.raises(lb.exceptions.LabelboxError) as exception: + with pytest.raises(LabelboxError) as exception: foundry_client.run_app( model_run_name=f"test-app-with-invalid-global-key-{random_str}", data_rows=data_rows, diff --git a/libs/labelbox/tests/integration/test_labeling_frontend.py b/libs/labelbox/tests/integration/test_labeling_frontend.py index d6ea1aac9..9a72fed47 100644 --- a/libs/labelbox/tests/integration/test_labeling_frontend.py +++ b/libs/labelbox/tests/integration/test_labeling_frontend.py @@ -1,7 +1,7 @@ import pytest +from lbox.exceptions import OperationNotSupportedException from labelbox import LabelingFrontend -from labelbox.exceptions import OperationNotSupportedException def test_get_labeling_frontends(client): diff --git a/libs/labelbox/tests/integration/test_labeling_service.py b/libs/labelbox/tests/integration/test_labeling_service.py index 09b5c24a1..bba8cef78 100644 --- a/libs/labelbox/tests/integration/test_labeling_service.py +++ b/libs/labelbox/tests/integration/test_labeling_service.py @@ -1,6 +1,6 @@ import pytest +from lbox.exceptions import LabelboxError, ResourceNotFoundError -from labelbox.exceptions import LabelboxError, ResourceNotFoundError from labelbox.schema.labeling_service import LabelingServiceStatus diff --git a/libs/labelbox/tests/integration/test_model_config.py b/libs/labelbox/tests/integration/test_model_config.py index 7a060b917..66912e8d9 100644 --- a/libs/labelbox/tests/integration/test_model_config.py +++ b/libs/labelbox/tests/integration/test_model_config.py @@ -1,5 +1,5 @@ import pytest -from labelbox.exceptions import ResourceNotFoundError +from lbox.exceptions import ResourceNotFoundError def test_create_model_config(client, valid_model_id): diff --git a/libs/labelbox/tests/integration/test_project.py b/libs/labelbox/tests/integration/test_project.py index a38fa2b5d..6f0f74e35 100644 --- a/libs/labelbox/tests/integration/test_project.py +++ b/libs/labelbox/tests/integration/test_project.py @@ -1,11 +1,12 @@ -import time import os +import time import uuid + import pytest import requests +from lbox.exceptions import InvalidQueryError -from labelbox import Project, LabelingFrontend, Dataset -from labelbox.exceptions import InvalidQueryError +from labelbox import Dataset, LabelingFrontend, Project from labelbox.schema.media_type import MediaType from labelbox.schema.quality_mode import QualityMode from labelbox.schema.queue_mode import QueueMode diff --git a/libs/labelbox/tests/integration/test_project_model_config.py b/libs/labelbox/tests/integration/test_project_model_config.py index 2d783f62b..f86bbb38e 100644 --- a/libs/labelbox/tests/integration/test_project_model_config.py +++ b/libs/labelbox/tests/integration/test_project_model_config.py @@ -1,5 +1,5 @@ import pytest -from labelbox.exceptions import ResourceNotFoundError +from lbox.exceptions import ResourceNotFoundError def test_add_single_model_config( diff --git a/libs/labelbox/tests/integration/test_project_set_model_setup_complete.py b/libs/labelbox/tests/integration/test_project_set_model_setup_complete.py index 1c3e68c9a..8872a27f4 100644 --- a/libs/labelbox/tests/integration/test_project_set_model_setup_complete.py +++ b/libs/labelbox/tests/integration/test_project_set_model_setup_complete.py @@ -1,6 +1,5 @@ import pytest - -from labelbox.exceptions import LabelboxError, OperationNotAllowedException +from lbox.exceptions import LabelboxError, OperationNotAllowedException def test_live_chat_evaluation_project( diff --git a/libs/labelbox/tests/integration/test_project_setup.py b/libs/labelbox/tests/integration/test_project_setup.py index faadea228..a09d0469d 100644 --- a/libs/labelbox/tests/integration/test_project_setup.py +++ b/libs/labelbox/tests/integration/test_project_setup.py @@ -1,11 +1,11 @@ -from datetime import datetime, timedelta, timezone import json import time +from datetime import datetime, timedelta, timezone import pytest +from lbox.exceptions import InvalidQueryError from labelbox import LabelingFrontend -from labelbox.exceptions import InvalidQueryError, ResourceConflict def simple_ontology(): diff --git a/libs/labelbox/tests/integration/test_prompt_response_generation_project.py b/libs/labelbox/tests/integration/test_prompt_response_generation_project.py index 1373ee470..f5003f061 100644 --- a/libs/labelbox/tests/integration/test_prompt_response_generation_project.py +++ b/libs/labelbox/tests/integration/test_prompt_response_generation_project.py @@ -1,9 +1,8 @@ -import pytest from unittest.mock import patch +import pytest + from labelbox import MediaType -from labelbox.schema.ontology_kind import OntologyKind -from labelbox.exceptions import MalformedQueryException @pytest.mark.parametrize( diff --git a/libs/labelbox/tests/unit/schema/test_user_group.py b/libs/labelbox/tests/unit/schema/test_user_group.py index 65584f8ef..4d78f096e 100644 --- a/libs/labelbox/tests/unit/schema/test_user_group.py +++ b/libs/labelbox/tests/unit/schema/test_user_group.py @@ -1,20 +1,22 @@ -import pytest from collections import defaultdict from unittest.mock import MagicMock -from labelbox import Client -from labelbox.exceptions import ( + +import pytest +from lbox.exceptions import ( + MalformedQueryException, ResourceConflict, ResourceCreationError, ResourceNotFoundError, - MalformedQueryException, UnprocessableEntityError, ) + +from labelbox import Client +from labelbox.schema.media_type import MediaType +from labelbox.schema.ontology_kind import EditorTaskType from labelbox.schema.project import Project +from labelbox.schema.queue_mode import QueueMode from labelbox.schema.user import User from labelbox.schema.user_group import UserGroup, UserGroupColor -from labelbox.schema.queue_mode import QueueMode -from labelbox.schema.ontology_kind import EditorTaskType -from labelbox.schema.media_type import MediaType @pytest.fixture diff --git a/libs/labelbox/tests/unit/test_exceptions.py b/libs/labelbox/tests/unit/test_exceptions.py index 4602fb984..074a735f2 100644 --- a/libs/labelbox/tests/unit/test_exceptions.py +++ b/libs/labelbox/tests/unit/test_exceptions.py @@ -1,6 +1,5 @@ import pytest - -from labelbox.exceptions import error_message_for_unparsed_graphql_error +from lbox.exceptions import error_message_for_unparsed_graphql_error @pytest.mark.parametrize( diff --git a/libs/labelbox/tests/unit/test_unit_ontology.py b/libs/labelbox/tests/unit/test_unit_ontology.py index 0566ad623..fc4c7797b 100644 --- a/libs/labelbox/tests/unit/test_unit_ontology.py +++ b/libs/labelbox/tests/unit/test_unit_ontology.py @@ -1,8 +1,9 @@ +from itertools import product + import pytest +from lbox.exceptions import InconsistentOntologyException -from labelbox.exceptions import InconsistentOntologyException -from labelbox import Tool, Classification, Option, OntologyBuilder -from itertools import product +from labelbox import Classification, OntologyBuilder, Option, Tool _SAMPLE_ONTOLOGY = { "tools": [ diff --git a/libs/lbox-clients/.gitignore b/libs/lbox-clients/.gitignore new file mode 100644 index 000000000..ae8554dec --- /dev/null +++ b/libs/lbox-clients/.gitignore @@ -0,0 +1,10 @@ +# python generated files +__pycache__/ +*.py[oc] +build/ +dist/ +wheels/ +*.egg-info + +# venv +.venv diff --git a/libs/lbox-clients/Dockerfile b/libs/lbox-clients/Dockerfile new file mode 100644 index 000000000..2ee61ab7e --- /dev/null +++ b/libs/lbox-clients/Dockerfile @@ -0,0 +1,44 @@ +# https://github.com/ucyo/python-package-template/blob/master/Dockerfile +FROM python:3.8-slim as rye + +ENV LANG="C.UTF-8" \ + LC_ALL="C.UTF-8" \ + PATH="/home/python/.local/bin:/home/python/.rye/shims:$PATH" \ + PIP_NO_CACHE_DIR="false" \ + RYE_VERSION="0.34.0" \ + RYE_INSTALL_OPTION="--yes" \ + LABELBOX_TEST_ENVIRON="prod" + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + ca-certificates \ + curl \ + inotify-tools \ + make \ + # cv2 + libsm6 \ + libxext6 \ + ffmpeg \ + libfontconfig1 \ + libxrender1 \ + libgl1-mesa-glx \ + libgeos-dev \ + gcc \ + && rm -rf /var/lib/apt/lists/* + +RUN groupadd --gid 1000 python && \ + useradd --uid 1000 --gid python --shell /bin/bash --create-home python + +USER 1000 +WORKDIR /home/python/ + +RUN curl -sSf https://rye.astral.sh/get | bash - + +COPY --chown=python:python . /home/python/labelbox-python/ +WORKDIR /home/python/labelbox-python + +RUN rye config --set-bool behavior.global-python=true && \ + rye config --set-bool behavior.use-uv=true && \ + rye pin 3.8 && \ + rye sync + +CMD rye run unit && rye integration \ No newline at end of file diff --git a/libs/lbox-clients/README.md b/libs/lbox-clients/README.md new file mode 100644 index 000000000..a7bf3e94f --- /dev/null +++ b/libs/lbox-clients/README.md @@ -0,0 +1,9 @@ +# lbox-example + +This is an example module which can be cloned and reused to develop modules under the `lbox` namespace. + +## Module Status: Experimental + +**TLDR: This module may be removed or altered at any given time and there is no offical support.** + +Please see [here](https://docs.labelbox.com/docs/product-release-phases) for the formal definition of `Experimental`. \ No newline at end of file diff --git a/libs/lbox-clients/pyproject.toml b/libs/lbox-clients/pyproject.toml new file mode 100644 index 000000000..1ad53b2c9 --- /dev/null +++ b/libs/lbox-clients/pyproject.toml @@ -0,0 +1,62 @@ +[project] +name = "lbox-clients" +version = "1.0.0" +description = "This module contains client sdk uses to conntect to the Labelbox API and backends" +authors = [ + { name = "Labelbox", email = "engineering@labelbox.com" } +] +dependencies = [ + "requests>=2.22.0", + "google-api-core>=1.22.1", +] +readme = "README.md" +requires-python = ">= 3.8" + +classifiers=[ + # How mature is this project? + "Development Status :: 5 - Production/Stable", + # Indicate who your project is intended for + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development :: Libraries", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "Intended Audience :: Education", + # Pick your license as you wish + "License :: OSI Approved :: Apache Software License", + # Specify the Python versions you support here. + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", +] +keywords = ["ml", "ai", "labelbox", "labeling", "llm", "machinelearning", "edu"] + +[project.urls] +Homepage = "https://labelbox.com/" +Documentation = "https://labelbox-python.readthedocs.io/en/latest/" +Repository = "https://github.com/Labelbox/labelbox-python" +Issues = "https://github.com/Labelbox/labelbox-python/issues" +Changelog = "https://github.com/Labelbox/labelbox-python/blob/develop/libs/labelbox/CHANGELOG.md" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.rye] +managed = true +dev-dependencies = [] + +[tool.rye.scripts] +unit = "pytest tests/unit" +integration = "python -c \"import sys; sys.exit(0)\"" + +[tool.hatch.metadata] +allow-direct-references = true + +[tool.hatch.build.targets.wheel] +packages = ["src/lbox"] + +[tool.pytest.ini_options] +addopts = "-rP -vvv --durations=20 --cov=lbox.example --import-mode=importlib" \ No newline at end of file diff --git a/libs/labelbox/src/labelbox/exceptions.py b/libs/lbox-clients/src/lbox/exceptions.py similarity index 96% rename from libs/labelbox/src/labelbox/exceptions.py rename to libs/lbox-clients/src/lbox/exceptions.py index 34cfeaf4d..493e9cb09 100644 --- a/libs/labelbox/src/labelbox/exceptions.py +++ b/libs/lbox-clients/src/lbox/exceptions.py @@ -16,7 +16,10 @@ def __init__(self, message, cause=None): self.cause = cause def __str__(self): - return self.message + str(self.args) + exception_message = self.message + if self.cause is not None: + exception_message += " (caused by: %s)" % self.cause + return exception_message class AuthenticationError(LabelboxError): diff --git a/libs/lbox-clients/src/lbox/request_client.py b/libs/lbox-clients/src/lbox/request_client.py new file mode 100644 index 000000000..35464f812 --- /dev/null +++ b/libs/lbox-clients/src/lbox/request_client.py @@ -0,0 +1,368 @@ +import json +import logging +import os +import sys +from datetime import datetime, timezone +from types import MappingProxyType +from typing import Callable, Dict, Optional + +import requests +import requests.exceptions +from google.api_core import retry +from lbox import exceptions # type: ignore + +logger = logging.getLogger(__name__) + +_LABELBOX_API_KEY = "LABELBOX_API_KEY" + + +def python_version_info(): + version_info = sys.version_info + + return f"{version_info.major}.{version_info.minor}.{version_info.micro}-{version_info.releaselevel}" + + +class RequestClient: + """A Labelbox request client. + + Contains info necessary for connecting to a Labelbox server (URL, + authentication key). + """ + + def __init__( + self, + sdk_version, + api_key=None, + endpoint="https://api.labelbox.com/graphql", + enable_experimental=False, + app_url="https://app.labelbox.com", + rest_endpoint="https://api.labelbox.com/api/v1", + ): + """Creates and initializes a RequestClient. + This class executes graphql and rest requests to the Labelbox server. + + Args: + api_key (str): API key. If None, the key is obtained from the "LABELBOX_API_KEY" environment variable. + endpoint (str): URL of the Labelbox server to connect to. + enable_experimental (bool): Indicates whether or not to use experimental features + app_url (str) : host url for all links to the web app + Raises: + exceptions.AuthenticationError: If no `api_key` + is provided as an argument or via the environment + variable. + """ + if api_key is None: + if _LABELBOX_API_KEY not in os.environ: + raise exceptions.AuthenticationError("Labelbox API key not provided") + api_key = os.environ[_LABELBOX_API_KEY] + self.api_key = api_key + + self.enable_experimental = enable_experimental + if enable_experimental: + logger.info("Experimental features have been enabled") + + logger.info("Initializing Labelbox client at '%s'", endpoint) + self.app_url = app_url + self.endpoint = endpoint + self.rest_endpoint = rest_endpoint + self.sdk_version = sdk_version + self._connection: requests.Session = self._init_connection() + + def _init_connection(self) -> requests.Session: + connection = requests.Session() # using default connection pool size of 10 + connection.headers.update(self._default_headers()) + + return connection + + @property + def headers(self) -> MappingProxyType: + return self._connection.headers + + def _default_headers(self): + return { + "Authorization": "Bearer %s" % self.api_key, + "Accept": "application/json", + "Content-Type": "application/json", + "X-User-Agent": f"python-sdk {self.sdk_version}", + "X-Python-Version": f"{python_version_info()}", + } + + @retry.Retry( + predicate=retry.if_exception_type( + exceptions.InternalServerError, + exceptions.TimeoutError, + ) + ) + def execute( + self, + query=None, + params=None, + data=None, + files=None, + timeout=60.0, + experimental=False, + error_log_key="message", + raise_return_resource_not_found=False, + error_handlers: Optional[ + Dict[str, Callable[[requests.models.Response], None]] + ] = None, + ): + """Sends a request to the server for the execution of the + given query. + + Checks the response for errors and wraps errors + in appropriate `exceptions.LabelboxError` subtypes. + + Args: + query (str): The query to execute. + params (dict): Query parameters referenced within the query. + data (str): json string containing the query to execute + files (dict): file arguments for request + timeout (float): Max allowed time for query execution, + in seconds. + raise_return_resource_not_found: By default the client relies on the caller to raise the correct exception when a resource is not found. + If this is set to True, the client will raise a ResourceNotFoundError exception automatically. + This simplifies processing. + We recommend to use it only of api returns a clear and well-formed error when a resource not found for a given query. + error_handlers (dict): A dictionary mapping graphql error code to handler functions. + Allows a caller to handle specific errors reporting in a custom way or produce more user-friendly readable messages. + + Example - custom error handler: + >>> def _raise_readable_errors(self, response): + >>> errors = response.json().get('errors', []) + >>> if errors: + >>> message = errors[0].get( + >>> 'message', json.dumps([{ + >>> "errorMessage": "Unknown error" + >>> }])) + >>> errors = json.loads(message) + >>> error_messages = [error['errorMessage'] for error in errors] + >>> else: + >>> error_messages = ["Uknown error"] + >>> raise LabelboxError(". ".join(error_messages)) + + Returns: + dict, parsed JSON response. + Raises: + exceptions.AuthenticationError: If authentication + failed. + exceptions.InvalidQueryError: If `query` is not + syntactically or semantically valid (checked server-side). + exceptions.ApiLimitError: If the server API limit was + exceeded. See "How to import data" in the online documentation + to see API limits. + exceptions.TimeoutError: If response was not received + in `timeout` seconds. + exceptions.NetworkError: If an unknown error occurred + most likely due to connection issues. + exceptions.LabelboxError: If an unknown error of any + kind occurred. + ValueError: If query and data are both None. + """ + logger.debug("Query: %s, params: %r, data %r", query, params, data) + + # Convert datetimes to UTC strings. + def convert_value(value): + if isinstance(value, datetime): + value = value.astimezone(timezone.utc) + value = value.strftime("%Y-%m-%dT%H:%M:%SZ") + return value + + if query is not None: + if params is not None: + params = {key: convert_value(value) for key, value in params.items()} + data = json.dumps({"query": query, "variables": params}).encode("utf-8") + elif data is None: + raise ValueError("query and data cannot both be none") + + endpoint = ( + self.endpoint + if not experimental + else self.endpoint.replace("/graphql", "/_gql") + ) + + try: + headers = self._connection.headers.copy() + if files: + del headers["Content-Type"] + del headers["Accept"] + request = requests.Request( + "POST", + endpoint, + headers=headers, + data=data, + files=files if files else None, + ) + + prepped: requests.PreparedRequest = request.prepare() + + response = self._connection.send(prepped, timeout=timeout) + logger.debug("Response: %s", response.text) + except requests.exceptions.Timeout as e: + raise exceptions.TimeoutError(str(e)) + except requests.exceptions.RequestException as e: + logger.error("Unknown error: %s", str(e)) + raise exceptions.NetworkError(e) + except Exception as e: + raise exceptions.LabelboxError( + "Unknown error during Client.query(): " + str(e), e + ) + + if ( + 200 <= response.status_code < 300 + or response.status_code < 500 + or response.status_code >= 600 + ): + try: + r_json = response.json() + except Exception: + raise exceptions.LabelboxError( + "Failed to parse response as JSON: %s" % response.text + ) + else: + if ( + "upstream connect error or disconnect/reset before headers" + in response.text + ): + raise exceptions.InternalServerError("Connection reset") + elif response.status_code == 502: + error_502 = "502 Bad Gateway" + raise exceptions.InternalServerError(error_502) + elif 500 <= response.status_code < 600: + error_500 = f"Internal server http error {response.status_code}" + raise exceptions.InternalServerError(error_500) + + errors = r_json.get("errors", []) + + def check_errors(keywords, *path): + """Helper that looks for any of the given `keywords` in any of + current errors on paths (like error[path][component][to][keyword]). + """ + for error in errors: + obj = error + for path_elem in path: + obj = obj.get(path_elem, {}) + if obj in keywords: + return error + return None + + def get_error_status_code(error: dict) -> int: + try: + return int(error["extensions"].get("exception").get("status")) + except Exception: + return 500 + + if check_errors(["AUTHENTICATION_ERROR"], "extensions", "code") is not None: + raise exceptions.AuthenticationError("Invalid API key") + + authorization_error = check_errors( + ["AUTHORIZATION_ERROR"], "extensions", "code" + ) + if authorization_error is not None: + raise exceptions.AuthorizationError(authorization_error["message"]) + + validation_error = check_errors( + ["GRAPHQL_VALIDATION_FAILED"], "extensions", "code" + ) + + if validation_error is not None: + message = validation_error["message"] + if message == "Query complexity limit exceeded": + raise exceptions.ValidationFailedError(message) + else: + raise exceptions.InvalidQueryError(message) + + graphql_error = check_errors(["GRAPHQL_PARSE_FAILED"], "extensions", "code") + if graphql_error is not None: + raise exceptions.InvalidQueryError(graphql_error["message"]) + + # Check if API limit was exceeded + response_msg = r_json.get("message", "") + + if response_msg.startswith("You have exceeded"): + raise exceptions.ApiLimitError(response_msg) + + resource_not_found_error = check_errors( + ["RESOURCE_NOT_FOUND"], "extensions", "code" + ) + if resource_not_found_error is not None: + if raise_return_resource_not_found: + raise exceptions.ResourceNotFoundError( + message=resource_not_found_error["message"] + ) + else: + # Return None and let the caller methods raise an exception + # as they already know which resource type and ID was requested + return None + + resource_conflict_error = check_errors( + ["RESOURCE_CONFLICT"], "extensions", "code" + ) + if resource_conflict_error is not None: + raise exceptions.ResourceConflict(resource_conflict_error["message"]) + + malformed_request_error = check_errors( + ["MALFORMED_REQUEST"], "extensions", "code" + ) + if malformed_request_error is not None: + raise exceptions.MalformedQueryException( + malformed_request_error[error_log_key] + ) + + # A lot of different error situations are now labeled serverside + # as INTERNAL_SERVER_ERROR, when they are actually client errors. + # TODO: fix this in the server API + internal_server_error = check_errors( + ["INTERNAL_SERVER_ERROR"], "extensions", "code" + ) + error_code = "INTERNAL_SERVER_ERROR" + + if internal_server_error is not None: + if error_handlers and error_code in error_handlers: + handler = error_handlers[error_code] + handler(response) + return None + message = internal_server_error.get("message") + error_status_code = get_error_status_code(internal_server_error) + if error_status_code == 400: + raise exceptions.InvalidQueryError(message) + elif error_status_code == 422: + raise exceptions.UnprocessableEntityError(message) + elif error_status_code == 426: + raise exceptions.OperationNotAllowedException(message) + elif error_status_code == 500: + raise exceptions.LabelboxError(message) + else: + raise exceptions.InternalServerError(message) + + not_allowed_error = check_errors( + ["OPERATION_NOT_ALLOWED"], "extensions", "code" + ) + if not_allowed_error is not None: + message = not_allowed_error.get("message") + raise exceptions.OperationNotAllowedException(message) + + if len(errors) > 0: + logger.warning("Unparsed errors on query execution: %r", errors) + messages = list( + map( + lambda x: { + "message": x["message"], + "code": x["extensions"]["code"], + }, + errors, + ) + ) + raise exceptions.LabelboxError("Unknown error: %s" % str(messages)) + + # if we do return a proper error code, and didn't catch this above + # reraise + # this mainly catches a 401 for API access disabled for free tier + # TODO: need to unify API errors to handle things more uniformly + # in the SDK + if response.status_code != requests.codes.ok: + message = f"{response.status_code} {response.reason}" + cause = r_json.get("message") + raise exceptions.LabelboxError(message, cause) + + return r_json["data"] diff --git a/libs/lbox-clients/tests/unit/lbox/test_client.py b/libs/lbox-clients/tests/unit/lbox/test_client.py new file mode 100644 index 000000000..42b141d33 --- /dev/null +++ b/libs/lbox-clients/tests/unit/lbox/test_client.py @@ -0,0 +1,46 @@ +from unittest.mock import MagicMock + +from lbox.request_client import RequestClient + + +# @patch.dict(os.environ, {'LABELBOX_API_KEY': 'bar'}) +def test_headers(): + client = RequestClient( + sdk_version="foo", api_key="api_key", endpoint="http://localhost:8080/_gql" + ) + assert client.headers + assert client.headers["Authorization"] == "Bearer api_key" + assert client.headers["Content-Type"] == "application/json" + assert client.headers["User-Agent"] + assert client.headers["X-Python-Version"] + + +def test_custom_error_handling(): + mock_raise_error = MagicMock() + + response_dict = { + "errors": [ + { + "message": "Internal server error", + "extensions": {"code": "INTERNAL_SERVER_ERROR"}, + } + ], + } + response = MagicMock() + response.json.return_value = response_dict + response.status_code = 200 + + client = RequestClient( + sdk_version="foo", api_key="api_key", endpoint="http://localhost:8080/_gql" + ) + connection_mock = MagicMock() + connection_mock.send.return_value = response + client._connection = connection_mock + + client.execute( + "query_str", + {"projectId": "project_id"}, + raise_return_resource_not_found=True, + error_handlers={"INTERNAL_SERVER_ERROR": mock_raise_error}, + ) + mock_raise_error.assert_called_once_with(response) diff --git a/requirements-dev.lock b/requirements-dev.lock index 05ca1683a..a51fa0dcf 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -6,10 +6,10 @@ # features: [] # all-features: true # with-sources: false -# generate-hashes: false -# universal: false -e file:libs/labelbox +-e file:libs/lbox-clients + # via labelbox -e file:libs/lbox-example alabaster==0.7.13 # via sphinx @@ -74,6 +74,7 @@ gitpython==3.1.43 # via databooks google-api-core==2.19.1 # via labelbox + # via lbox-clients google-auth==2.31.0 # via google-api-core googleapis-common-protos==1.63.2 @@ -125,6 +126,7 @@ matplotlib-inline==0.1.7 mistune==3.0.2 # via nbconvert mypy==1.10.1 + # via labelbox mypy-extensions==1.0.0 # via black # via mypy @@ -233,6 +235,7 @@ regex==2024.5.15 requests==2.32.3 # via google-api-core # via labelbox + # via lbox-clients # via sphinx rich==12.6.0 # via databooks diff --git a/requirements.lock b/requirements.lock index 07e026d59..16ed91c80 100644 --- a/requirements.lock +++ b/requirements.lock @@ -6,10 +6,10 @@ # features: [] # all-features: true # with-sources: false -# generate-hashes: false -# universal: false -e file:libs/labelbox +-e file:libs/lbox-clients + # via labelbox -e file:libs/lbox-example alabaster==0.7.13 # via sphinx @@ -33,6 +33,7 @@ geojson==3.1.0 # via labelbox google-api-core==2.19.1 # via labelbox + # via lbox-clients google-auth==2.31.0 # via google-api-core googleapis-common-protos==1.63.2 @@ -87,6 +88,7 @@ pytz==2024.1 requests==2.32.3 # via google-api-core # via labelbox + # via lbox-clients # via sphinx rsa==4.9 # via google-auth