diff --git a/.github/workflows/python-package-develop.yml b/.github/workflows/python-package-develop.yml index 05eff5dc4..769d04c74 100644 --- a/.github/workflows/python-package-develop.yml +++ b/.github/workflows/python-package-develop.yml @@ -2,9 +2,9 @@ name: Labelbox Python SDK Staging (Develop) on: push: - branches: [develop] + branches: [develop, v6] pull_request: - branches: [develop] + branches: [develop, v6] concurrency: group: ${{ github.workflow }}-${{ github.ref }} diff --git a/libs/labelbox/src/labelbox/__init__.py b/libs/labelbox/src/labelbox/__init__.py index 1bd0ba967..b22a37a84 100644 --- a/libs/labelbox/src/labelbox/__init__.py +++ b/libs/labelbox/src/labelbox/__init__.py @@ -3,76 +3,78 @@ __version__ = "4.0.0" from labelbox.client import Client -from labelbox.schema.project import Project -from labelbox.schema.model import Model -from labelbox.schema.model_config import ModelConfig +from labelbox.data_uploader import DataUploader +from labelbox.request_client import RequestClient from labelbox.schema.annotation_import import ( + LabelImport, MALPredictionImport, MEAPredictionImport, - LabelImport, MEAToMALPredictionImport, ) -from labelbox.schema.dataset import Dataset -from labelbox.schema.data_row import DataRow +from labelbox.schema.asset_attachment import AssetAttachment +from labelbox.schema.batch import Batch +from labelbox.schema.benchmark import Benchmark from labelbox.schema.catalog import Catalog +from labelbox.schema.data_row import DataRow +from labelbox.schema.data_row_metadata import ( + DataRowMetadata, + DataRowMetadataField, + DataRowMetadataOntology, + DeleteDataRowMetadata, +) +from labelbox.schema.dataset import Dataset from labelbox.schema.enums import AnnotationImportState -from labelbox.schema.label import Label -from labelbox.schema.batch import Batch -from labelbox.schema.review import Review -from labelbox.schema.user import User -from labelbox.schema.organization import Organization -from labelbox.schema.task import Task from labelbox.schema.export_task import ( - StreamType, + BufferedJsonConverterOutput, ExportTask, - JsonConverter, - JsonConverterOutput, FileConverter, FileConverterOutput, - BufferedJsonConverterOutput, + JsonConverter, + JsonConverterOutput, + StreamType, ) +from labelbox.schema.iam_integration import IAMIntegration +from labelbox.schema.identifiable import GlobalKey, UniqueId +from labelbox.schema.identifiables import DataRowIds, GlobalKeys, UniqueIds +from labelbox.schema.invite import Invite, InviteLimit +from labelbox.schema.label import Label +from labelbox.schema.label_score import LabelScore from labelbox.schema.labeling_frontend import ( LabelingFrontend, LabelingFrontendOptions, ) -from labelbox.schema.asset_attachment import AssetAttachment -from labelbox.schema.webhook import Webhook +from labelbox.schema.labeling_service import LabelingService +from labelbox.schema.labeling_service_dashboard import LabelingServiceDashboard +from labelbox.schema.labeling_service_status import LabelingServiceStatus +from labelbox.schema.media_type import MediaType +from labelbox.schema.model import Model +from labelbox.schema.model_config import ModelConfig +from labelbox.schema.model_run import DataSplit, ModelRun from labelbox.schema.ontology import ( + Classification, + FeatureSchema, Ontology, OntologyBuilder, - Classification, Option, + PromptResponseClassification, + ResponseOption, Tool, - FeatureSchema, -) -from labelbox.schema.ontology import PromptResponseClassification -from labelbox.schema.ontology import ResponseOption -from labelbox.schema.role import Role, ProjectRole -from labelbox.schema.invite import Invite, InviteLimit -from labelbox.schema.data_row_metadata import ( - DataRowMetadataOntology, - DataRowMetadataField, - DataRowMetadata, - DeleteDataRowMetadata, ) -from labelbox.schema.model_run import ModelRun, DataSplit -from labelbox.schema.benchmark import Benchmark -from labelbox.schema.iam_integration import IAMIntegration -from labelbox.schema.resource_tag import ResourceTag -from labelbox.schema.project_model_config import ProjectModelConfig -from labelbox.schema.project_resource_tag import ProjectResourceTag -from labelbox.schema.media_type import MediaType -from labelbox.schema.slice import Slice, CatalogSlice, ModelSlice -from labelbox.schema.queue_mode import QueueMode -from labelbox.schema.task_queue import TaskQueue -from labelbox.schema.label_score import LabelScore -from labelbox.schema.identifiables import UniqueIds, GlobalKeys, DataRowIds -from labelbox.schema.identifiable import UniqueId, GlobalKey from labelbox.schema.ontology_kind import OntologyKind +from labelbox.schema.organization import Organization +from labelbox.schema.project import Project +from labelbox.schema.project_model_config import ProjectModelConfig from labelbox.schema.project_overview import ( ProjectOverview, ProjectOverviewDetailed, ) -from labelbox.schema.labeling_service import LabelingService -from labelbox.schema.labeling_service_dashboard import LabelingServiceDashboard -from labelbox.schema.labeling_service_status import LabelingServiceStatus +from labelbox.schema.project_resource_tag import ProjectResourceTag +from labelbox.schema.queue_mode import QueueMode +from labelbox.schema.resource_tag import ResourceTag +from labelbox.schema.review import Review +from labelbox.schema.role import ProjectRole, Role +from labelbox.schema.slice import CatalogSlice, ModelSlice, Slice +from labelbox.schema.task import Task +from labelbox.schema.task_queue import TaskQueue +from labelbox.schema.user import User +from labelbox.schema.webhook import Webhook diff --git a/libs/labelbox/src/labelbox/client.py b/libs/labelbox/src/labelbox/client.py index b0b5a1407..440695b6b 100644 --- a/libs/labelbox/src/labelbox/client.py +++ b/libs/labelbox/src/labelbox/client.py @@ -1,45 +1,39 @@ # type: ignore import json import logging -import mimetypes -import os import random -import sys import time import urllib.parse from collections import defaultdict -from datetime import datetime, timezone -from typing import Any, List, Dict, Union, Optional, overload, Callable +from functools import cache from types import MappingProxyType +from typing import Any, Dict, List, Optional, Union, overload -from labelbox.schema.search_filters import SearchFilter import requests import requests.exceptions from google.api_core import retry import labelbox.exceptions -from labelbox import __version__ as SDK_VERSION from labelbox import utils from labelbox.adv_client import AdvClient +from labelbox.client import SDK_VERSION from labelbox.orm import query from labelbox.orm.db_object import DbObject from labelbox.orm.model import Entity, Field from labelbox.pagination import PaginatedCollection from labelbox.schema import role -from labelbox.schema.conflict_resolution_strategy import ( - ConflictResolutionStrategy, -) -from labelbox.schema.data_row import DataRow from labelbox.schema.catalog import Catalog +from labelbox.schema.data_row import DataRow from labelbox.schema.data_row_metadata import DataRowMetadataOntology from labelbox.schema.dataset import Dataset from labelbox.schema.embedding import Embedding from labelbox.schema.enums import CollectionJobStatus from labelbox.schema.foundry.foundry_client import FoundryClient from labelbox.schema.iam_integration import IAMIntegration -from labelbox.schema.identifiables import DataRowIds -from labelbox.schema.identifiables import GlobalKeys +from labelbox.schema.identifiables import DataRowIds, GlobalKeys +from labelbox.schema.label_score import LabelScore from labelbox.schema.labeling_frontend import LabelingFrontend +from labelbox.schema.labeling_service_dashboard import LabelingServiceDashboard from labelbox.schema.media_type import ( MediaType, get_media_type_validation_error, @@ -47,57 +41,238 @@ from labelbox.schema.model import Model from labelbox.schema.model_config import ModelConfig from labelbox.schema.model_run import ModelRun -from labelbox.schema.ontology import Ontology, DeleteFeatureFromOntologyResult from labelbox.schema.ontology import ( - Tool, Classification, + DeleteFeatureFromOntologyResult, FeatureSchema, + Ontology, PromptResponseClassification, + Tool, +) +from labelbox.schema.ontology_kind import ( + EditorTaskType, + EditorTaskTypeMapper, + OntologyKind, ) from labelbox.schema.organization import Organization from labelbox.schema.project import Project from labelbox.schema.quality_mode import ( - QualityMode, BENCHMARK_AUTO_AUDIT_NUMBER_OF_LABELS, BENCHMARK_AUTO_AUDIT_PERCENTAGE, CONSENSUS_AUTO_AUDIT_NUMBER_OF_LABELS, CONSENSUS_AUTO_AUDIT_PERCENTAGE, + QualityMode, ) from labelbox.schema.queue_mode import QueueMode from labelbox.schema.role import Role +from labelbox.schema.search_filters import SearchFilter from labelbox.schema.send_to_annotate_params import ( SendToAnnotateFromCatalogParams, + build_annotations_input, build_destination_task_queue_input, build_predictions_input, - build_annotations_input, ) from labelbox.schema.slice import CatalogSlice, ModelSlice -from labelbox.schema.task import Task, DataUpsertTask +from labelbox.schema.task import DataUpsertTask, Task from labelbox.schema.user import User -from labelbox.schema.label_score import LabelScore -from labelbox.schema.ontology_kind import ( - OntologyKind, - EditorTaskTypeMapper, - EditorTaskType, -) -from labelbox.schema.labeling_service_dashboard import LabelingServiceDashboard + +from .data_uploader import DataUploader +from .request_client import RequestClient logger = logging.getLogger(__name__) -_LABELBOX_API_KEY = "LABELBOX_API_KEY" +@cache +def get_data_row_metadata_ontology( + client: RequestClient, +) -> DataRowMetadataOntology: + """ + + Returns: + DataRowMetadataOntology: The ontology for Data Row Metadata for an organization + + """ + return DataRowMetadataOntology(client) + + +def get_single(request_client: RequestClient, db_object_type, uid): + """Fetches a single object of the given type, for the given ID. + + Args: + db_object_type (type): DbObject subclass. + uid (str): Unique ID of the row. + Returns: + Object of `db_object_type`. + Raises: + labelbox.exceptions.ResourceNotFoundError: If there is no object + of the given type for the given ID. + """ + query_str, params = query.get_single(db_object_type, uid) + + res = request_client.execute(query_str, params) + res = res and res.get(utils.camel_case(db_object_type.type_name())) + if res is None: + raise labelbox.exceptions.ResourceNotFoundError(db_object_type, params) + else: + return db_object_type(request_client, res) + + +def get_all(request_client, db_object_type, where, filter_deleted=True): + """Fetches all the objects of the given type the user has access to. + + Args: + db_object_type (type): DbObject subclass. + where (Comparison, LogicalOperation or None): The `where` clause + for filtering. + Returns: + An iterable of `db_object_type` instances. + """ + if filter_deleted: + not_deleted = db_object_type.deleted == False + where = not_deleted if where is None else where & not_deleted + query_str, params = query.get_all(db_object_type, where) + + return PaginatedCollection( + request_client, + query_str, + params, + [utils.camel_case(db_object_type.type_name()) + "s"], + db_object_type, + ) + + +def get_organization(request_client: RequestClient) -> Organization: + """Gets the Organization DB object of the current user. + + >>> organization = client.get_organization() + """ + return get_single(request_client, Entity.Organization, None) + + +def get_data_row(request_client: RequestClient, data_row_id): + """ + + Returns: + DataRow: returns a single data row given the data row id + """ + return get_single(request_client, Entity.DataRow, data_row_id) + + +def get_user(request_client: RequestClient) -> User: + """ + Gets the current User database object. + """ + return get_single(request_client, Entity.User, None) + + +def get_project(request_client: RequestClient, project_id) -> Project: + """Gets a single Project with the given ID. + + Args: + request_client (RequestClient): The request client to use for the query + project_id (str): Unique ID of the Project. + Returns: + The sought Project. + Raises: + labelbox.exceptions.ResourceNotFoundError: If there is no + Project with the given ID. + """ + return get_single(request_client, Entity.Project, project_id) + + +def get_roles(request_client: RequestClient) -> List[Role]: + """ + Returns: + Roles: Provides information on available roles within an organization. + Roles are used for user management. + """ + return role.get_roles(request_client) + + +def create_entity(request_client, db_object_type, data, extra_params={}): + """Creates an object on the server. Attribute values are + passed as keyword arguments: + + Args: + db_object_type (type): A DbObjectType subtype. + data (dict): Keys are attributes or their names (in Python, + snake-case convention) and values are desired attribute values. + extra_params (dict): Additional parameters to pass to GraphQL. + These have to be Field(...): value pairs. + Returns: + A new object of the given DB object type. + Raises: + InvalidAttributeError: If the DB object type does not contain + any of the attribute names given in `data`. + """ + # Convert string attribute names to Field or Relationship objects. + # Also convert Labelbox object values to their UIDs. + data = { + db_object_type.attribute(attr) + if isinstance(attr, str) + else attr: value.uid if isinstance(value, DbObject) else value + for attr, value in data.items() + } + + data = {**data, **extra_params} + query_string, params = query.create(db_object_type, data) + res = request_client.execute( + query_string, params, raise_return_resource_not_found=True + ) + + if not res: + raise labelbox.exceptions.LabelboxError( + "Failed to create %s" % db_object_type.type_name() + ) + res = res["create%s" % db_object_type.type_name()] + + return db_object_type(request_client, res) + + +def get_labeling_frontends(request_client, where=None) -> List[LabelingFrontend]: + """Fetches all the labeling frontends. + + >>> frontend = client.get_labeling_frontends(where=LabelingFrontend.name == "Editor") + + Args: + where (Comparison, LogicalOperation or None): The `where` clause + for filtering. + Returns: + An iterable of LabelingFrontends (typically a PaginatedCollection). + """ + return get_all(request_client, Entity.LabelingFrontend, where) + + +def get_batch(request_client: RequestClient, project_id: str, batch_id: str) -> Entity.Batch: + # obtain batch entity to return + get_batch_str = """query %s($projectId: ID!, $batchId: ID!) { + project(where: {id: $projectId}) { + batches(where: {id: $batchId}) { + nodes { + %s + } + } + } + } + """ % ( + "getProjectBatchPyApi", + query.results_query_part(Entity.Batch), + ) -def python_version_info(): - version_info = sys.version_info + batch = request_client.execute( + get_batch_str, + {"projectId": project_id, "batchId": batch_id}, + timeout=180.0, + experimental=True, + )["project"]["batches"]["nodes"][0] - return f"{version_info.major}.{version_info.minor}.{version_info.micro}-{version_info.releaselevel}" + return Entity.Batch(request_client, project_id, batch) class Client: """A Labelbox client. - Contains info necessary for connecting to a Labelbox server (URL, - authentication key). Provides functions for querying and creating + Provides functions for querying and creating top-level data objects (Projects, Datasets). """ @@ -127,315 +302,29 @@ def __init__( is provided as an argument or via the environment variable. """ - if api_key is None: - if _LABELBOX_API_KEY not in os.environ: - raise labelbox.exceptions.AuthenticationError( - "Labelbox API key not provided" - ) - api_key = os.environ[_LABELBOX_API_KEY] - self.api_key = api_key - - self.enable_experimental = enable_experimental - if enable_experimental: - logger.info("Experimental features have been enabled") - - logger.info("Initializing Labelbox client at '%s'", endpoint) - self.app_url = app_url - self.endpoint = endpoint - self.rest_endpoint = rest_endpoint self._data_row_metadata_ontology = None + self._request_client = RequestClient( + api_key, + sdk_version=SDK_VERSION, + endpoint=endpoint, + enable_experimental=enable_experimental, + app_url=app_url, + rest_endpoint=rest_endpoint, + ) self._adv_client = AdvClient.factory(rest_endpoint, api_key) - self._connection: requests.Session = self._init_connection() - - def _init_connection(self) -> requests.Session: - connection = ( - requests.Session() - ) # using default connection pool size of 10 - connection.headers.update(self._default_headers()) - return connection + """ + Ideally we should not access connection and headers directly as it breaks encapsulation for the RequestClient. + TODO: Remove this properties and provide either client-facing methods or ways to encapsulate this logic in the request client + """ @property def headers(self) -> MappingProxyType: - return self._connection.headers - - def _default_headers(self): - return { - "Authorization": "Bearer %s" % self.api_key, - "Accept": "application/json", - "Content-Type": "application/json", - "X-User-Agent": f"python-sdk {SDK_VERSION}", - "X-Python-Version": f"{python_version_info()}", - } - - @retry.Retry( - predicate=retry.if_exception_type( - labelbox.exceptions.InternalServerError, - labelbox.exceptions.TimeoutError, - ) - ) - def execute( - self, - query=None, - params=None, - data=None, - files=None, - timeout=60.0, - experimental=False, - error_log_key="message", - raise_return_resource_not_found=False, - ): - """Sends a request to the server for the execution of the - given query. - - Checks the response for errors and wraps errors - in appropriate `labelbox.exceptions.LabelboxError` subtypes. - - Args: - query (str): The query to execute. - params (dict): Query parameters referenced within the query. - data (str): json string containing the query to execute - files (dict): file arguments for request - timeout (float): Max allowed time for query execution, - in seconds. - Returns: - dict, parsed JSON response. - Raises: - labelbox.exceptions.AuthenticationError: If authentication - failed. - labelbox.exceptions.InvalidQueryError: If `query` is not - syntactically or semantically valid (checked server-side). - labelbox.exceptions.ApiLimitError: If the server API limit was - exceeded. See "How to import data" in the online documentation - to see API limits. - labelbox.exceptions.TimeoutError: If response was not received - in `timeout` seconds. - labelbox.exceptions.NetworkError: If an unknown error occurred - most likely due to connection issues. - labelbox.exceptions.LabelboxError: If an unknown error of any - kind occurred. - ValueError: If query and data are both None. - """ - logger.debug("Query: %s, params: %r, data %r", query, params, data) - - # Convert datetimes to UTC strings. - def convert_value(value): - if isinstance(value, datetime): - value = value.astimezone(timezone.utc) - value = value.strftime("%Y-%m-%dT%H:%M:%SZ") - return value - - if query is not None: - if params is not None: - params = { - key: convert_value(value) for key, value in params.items() - } - data = json.dumps({"query": query, "variables": params}).encode( - "utf-8" - ) - elif data is None: - raise ValueError("query and data cannot both be none") - - endpoint = ( - self.endpoint - if not experimental - else self.endpoint.replace("/graphql", "/_gql") - ) - - try: - headers = self._connection.headers.copy() - if files: - del headers["Content-Type"] - del headers["Accept"] - request = requests.Request( - "POST", - endpoint, - headers=headers, - data=data, - files=files if files else None, - ) - - prepped: requests.PreparedRequest = request.prepare() - - response = self._connection.send(prepped, timeout=timeout) - logger.debug("Response: %s", response.text) - except requests.exceptions.Timeout as e: - raise labelbox.exceptions.TimeoutError(str(e)) - except requests.exceptions.RequestException as e: - logger.error("Unknown error: %s", str(e)) - raise labelbox.exceptions.NetworkError(e) - except Exception as e: - raise labelbox.exceptions.LabelboxError( - "Unknown error during Client.query(): " + str(e), e - ) - - if ( - 200 <= response.status_code < 300 - or response.status_code < 500 - or response.status_code >= 600 - ): - try: - r_json = response.json() - except Exception: - raise labelbox.exceptions.LabelboxError( - "Failed to parse response as JSON: %s" % response.text - ) - else: - if ( - "upstream connect error or disconnect/reset before headers" - in response.text - ): - raise labelbox.exceptions.InternalServerError( - "Connection reset" - ) - elif response.status_code == 502: - error_502 = "502 Bad Gateway" - raise labelbox.exceptions.InternalServerError(error_502) - elif 500 <= response.status_code < 600: - error_500 = f"Internal server http error {response.status_code}" - raise labelbox.exceptions.InternalServerError(error_500) - - errors = r_json.get("errors", []) - - def check_errors(keywords, *path): - """Helper that looks for any of the given `keywords` in any of - current errors on paths (like error[path][component][to][keyword]). - """ - for error in errors: - obj = error - for path_elem in path: - obj = obj.get(path_elem, {}) - if obj in keywords: - return error - return None - - def get_error_status_code(error: dict) -> int: - try: - return int(error["extensions"].get("exception").get("status")) - except: - return 500 - - if ( - check_errors(["AUTHENTICATION_ERROR"], "extensions", "code") - is not None - ): - raise labelbox.exceptions.AuthenticationError("Invalid API key") - - authorization_error = check_errors( - ["AUTHORIZATION_ERROR"], "extensions", "code" - ) - if authorization_error is not None: - raise labelbox.exceptions.AuthorizationError( - authorization_error["message"] - ) - - validation_error = check_errors( - ["GRAPHQL_VALIDATION_FAILED"], "extensions", "code" - ) - - if validation_error is not None: - message = validation_error["message"] - if message == "Query complexity limit exceeded": - raise labelbox.exceptions.ValidationFailedError(message) - else: - raise labelbox.exceptions.InvalidQueryError(message) - - graphql_error = check_errors( - ["GRAPHQL_PARSE_FAILED"], "extensions", "code" - ) - if graphql_error is not None: - raise labelbox.exceptions.InvalidQueryError( - graphql_error["message"] - ) - - # Check if API limit was exceeded - response_msg = r_json.get("message", "") - - if response_msg.startswith("You have exceeded"): - raise labelbox.exceptions.ApiLimitError(response_msg) - - resource_not_found_error = check_errors( - ["RESOURCE_NOT_FOUND"], "extensions", "code" - ) - if resource_not_found_error is not None: - if raise_return_resource_not_found: - raise labelbox.exceptions.ResourceNotFoundError( - message=resource_not_found_error["message"] - ) - else: - # Return None and let the caller methods raise an exception - # as they already know which resource type and ID was requested - return None - - resource_conflict_error = check_errors( - ["RESOURCE_CONFLICT"], "extensions", "code" - ) - if resource_conflict_error is not None: - raise labelbox.exceptions.ResourceConflict( - resource_conflict_error["message"] - ) - - malformed_request_error = check_errors( - ["MALFORMED_REQUEST"], "extensions", "code" - ) - if malformed_request_error is not None: - raise labelbox.exceptions.MalformedQueryException( - malformed_request_error[error_log_key] - ) - - # A lot of different error situations are now labeled serverside - # as INTERNAL_SERVER_ERROR, when they are actually client errors. - # TODO: fix this in the server API - internal_server_error = check_errors( - ["INTERNAL_SERVER_ERROR"], "extensions", "code" - ) - if internal_server_error is not None: - message = internal_server_error.get("message") - error_status_code = get_error_status_code(internal_server_error) - if error_status_code == 400: - raise labelbox.exceptions.InvalidQueryError(message) - elif error_status_code == 422: - raise labelbox.exceptions.UnprocessableEntityError(message) - elif error_status_code == 426: - raise labelbox.exceptions.OperationNotAllowedException(message) - elif error_status_code == 500: - raise labelbox.exceptions.LabelboxError(message) - else: - raise labelbox.exceptions.InternalServerError(message) - - not_allowed_error = check_errors( - ["OPERATION_NOT_ALLOWED"], "extensions", "code" - ) - if not_allowed_error is not None: - message = not_allowed_error.get("message") - raise labelbox.exceptions.OperationNotAllowedException(message) - - if len(errors) > 0: - logger.warning("Unparsed errors on query execution: %r", errors) - messages = list( - map( - lambda x: { - "message": x["message"], - "code": x["extensions"]["code"], - }, - errors, - ) - ) - raise labelbox.exceptions.LabelboxError( - "Unknown error: %s" % str(messages) - ) + return self._request_client.headers - # if we do return a proper error code, and didn't catch this above - # reraise - # this mainly catches a 401 for API access disabled for free tier - # TODO: need to unify API errors to handle things more uniformly - # in the SDK - if response.status_code != requests.codes.ok: - message = f"{response.status_code} {response.reason}" - cause = r_json.get("message") - raise labelbox.exceptions.LabelboxError(message, cause) - - return r_json["data"] + @property + def connection(self) -> requests.Session: + return self._request_client._connection def upload_file(self, path: str) -> str: """Uploads given path to local file. @@ -449,12 +338,7 @@ def upload_file(self, path: str) -> str: Raises: labelbox.exceptions.LabelboxError: If upload failed. """ - content_type, _ = mimetypes.guess_type(path) - filename = os.path.basename(path) - with open(path, "rb") as f: - return self.upload_data( - content=f.read(), filename=filename, content_type=content_type - ) + return DataUploader(self._request_client).upload_file(path) @retry.Retry( predicate=retry.if_exception_type( @@ -482,72 +366,10 @@ def upload_data( Raises: labelbox.exceptions.LabelboxError: If upload failed. """ - - request_data = { - "operations": json.dumps( - { - "variables": { - "file": None, - "contentLength": len(content), - "sign": sign, - }, - "query": """mutation UploadFile($file: Upload!, $contentLength: Int!, - $sign: Boolean) { - uploadFile(file: $file, contentLength: $contentLength, - sign: $sign) {url filename} } """, - } - ), - "map": (None, json.dumps({"1": ["variables.file"]})), - } - - files = { - "1": (filename, content, content_type) - if (filename and content_type) - else content - } - headers = self._connection.headers.copy() - headers.pop("Content-Type", None) - request = requests.Request( - "POST", - self.endpoint, - headers=headers, - data=request_data, - files=files, + return DataUploader(self._request_client).upload_data( + content, filename, content_type, sign ) - prepped: requests.PreparedRequest = request.prepare() - - response = self._connection.send(prepped) - - if response.status_code == 502: - error_502 = "502 Bad Gateway" - raise labelbox.exceptions.InternalServerError(error_502) - elif response.status_code == 503: - raise labelbox.exceptions.InternalServerError(response.text) - elif response.status_code == 520: - raise labelbox.exceptions.InternalServerError(response.text) - - try: - file_data = response.json().get("data", None) - except ValueError as e: # response is not valid JSON - raise labelbox.exceptions.LabelboxError( - "Failed to upload, unknown cause", e - ) - - if not file_data or not file_data.get("uploadFile", None): - try: - errors = response.json().get("errors", []) - error_msg = next(iter(errors), {}).get( - "message", "Unknown error" - ) - except Exception as e: - error_msg = "Unknown error" - raise labelbox.exceptions.LabelboxError( - "Failed to upload, message: %s" % error_msg - ) - - return file_data["uploadFile"]["url"] - def _get_single(self, db_object_type, uid): """Fetches a single object of the given type, for the given ID. @@ -560,16 +382,7 @@ def _get_single(self, db_object_type, uid): labelbox.exceptions.ResourceNotFoundError: If there is no object of the given type for the given ID. """ - query_str, params = query.get_single(db_object_type, uid) - - res = self.execute(query_str, params) - res = res and res.get(utils.camel_case(db_object_type.type_name())) - if res is None: - raise labelbox.exceptions.ResourceNotFoundError( - db_object_type, params - ) - else: - return db_object_type(self, res) + return get_single(self._request_client, db_object_type, uid) def get_project(self, project_id) -> Project: """Gets a single Project with the given ID. @@ -613,7 +426,7 @@ def get_organization(self) -> Organization: >>> organization = client.get_organization() """ - return self._get_single(Entity.Organization, None) + return get_organization(self._request_client) def _get_all(self, db_object_type, where, filter_deleted=True): """Fetches all the objects of the given type the user has access to. @@ -631,7 +444,7 @@ def _get_all(self, db_object_type, where, filter_deleted=True): query_str, params = query.get_all(db_object_type, where) return PaginatedCollection( - self, + self._request_client, query_str, params, [utils.camel_case(db_object_type.type_name()) + "s"], @@ -717,7 +530,7 @@ def _create(self, db_object_type, data, extra_params={}): data = {**data, **extra_params} query_string, params = query.create(db_object_type, data) - res = self.execute( + res = self._request_client.execute( query_string, params, raise_return_resource_not_found=True ) @@ -727,7 +540,7 @@ def _create(self, db_object_type, data, extra_params={}): ) res = res["create%s" % db_object_type.type_name()] - return db_object_type(self, res) + return db_object_type(self._request_client, res) def create_model_config( self, name: str, model_id: str, inference_params: dict @@ -759,8 +572,8 @@ def create_model_config( "inferenceParams": inference_params, "name": name, } - result = self.execute(query, params) - return ModelConfig(self, result["createModelConfig"]) + result = self._request_client.execute(query, params) + return ModelConfig(self._request_client, result["createModelConfig"]) def delete_model_config(self, id: str) -> bool: """Deletes an existing model config with the given id @@ -778,7 +591,7 @@ def delete_model_config(self, id: str) -> bool: } }""" params = {"id": id} - result = self.execute(query, params) + result = self._request_client.execute(query, params) if not result: raise labelbox.exceptions.ResourceNotFoundError( Entity.ModelConfig, params @@ -827,13 +640,13 @@ def create_dataset( "Integration is not valid. Please select another." ) - self.execute( + self._request_client.execute( """mutation setSignerForDatasetPyApi($signerId: ID!, $datasetId: ID!) { setSignerForDataset(data: { signerId: $signerId}, where: {id: $datasetId}){id}} """, {"signerId": iam_integration.uid, "datasetId": dataset.uid}, ) - validation_result = self.execute( + validation_result = self._request_client.execute( """mutation validateDatasetPyApi($id: ID!){validateDataset(where: {id : $id}){ valid checks{name, success}}} """, @@ -842,7 +655,7 @@ def create_dataset( if not validation_result["validateDataset"]["valid"]: raise labelbox.exceptions.LabelboxError( - f"IAMIntegration was not successfully added to the dataset." + "IAMIntegration was not successfully added to the dataset." ) except Exception as e: dataset.delete() @@ -1190,7 +1003,7 @@ def get_roles(self) -> List[Role]: Roles: Provides information on available roles within an organization. Roles are used for user management. """ - return role.get_roles(self) + return role.get_roles(self._request_client) def get_data_row(self, data_row_id): """ @@ -1221,9 +1034,7 @@ def get_data_row_metadata_ontology(self) -> DataRowMetadataOntology: DataRowMetadataOntology: The ontology for Data Row Metadata for an organization """ - if self._data_row_metadata_ontology is None: - self._data_row_metadata_ontology = DataRowMetadataOntology(self) - return self._data_row_metadata_ontology + return get_data_row_metadata_ontology(self._request_client) def get_model(self, model_id) -> Model: """Gets a single Model with the given ID. @@ -1273,10 +1084,10 @@ def create_model(self, name, ontology_id) -> Model: } }""" % query.results_query_part(Entity.Model) - result = self.execute( + result = self._request_client.execute( query_str, {"name": name, "ontologyId": ontology_id} ) - return Entity.Model(self, result["createModel"]) + return Entity.Model(self._request_client, result["createModel"]) def get_data_row_ids_for_external_ids( self, external_ids: List[str] @@ -1297,7 +1108,7 @@ def get_data_row_ids_for_external_ids( max_ids_per_request = 100 result = defaultdict(list) for i in range(0, len(external_ids), max_ids_per_request): - for row in self.execute( + for row in self._request_client.execute( query_str, {"externalId_in": external_ids[i : i + max_ids_per_request]}, )["externalIdsToDataRowIds"]: @@ -1333,7 +1144,7 @@ def get_ontologies(self, name_contains) -> PaginatedCollection: """ % query.results_query_part(Entity.Ontology) params = {"search": name_contains, "filter": {"status": "ALL"}} return PaginatedCollection( - self, + self._request_client, query_str, params, ["ontologies", "nodes"], @@ -1355,7 +1166,7 @@ def get_feature_schema(self, feature_schema_id): rootSchemaNode(where: $rootSchemaNodeWhere){%s} }""" % query.results_query_part(Entity.FeatureSchema) - res = self.execute( + res = self._request_client.execute( query_str, {"rootSchemaNodeWhere": {"featureSchemaId": feature_schema_id}}, )["rootSchemaNode"] @@ -1389,7 +1200,7 @@ def rootSchemaPayloadToFeatureSchema(client, payload): return Entity.FeatureSchema(client, payload) return PaginatedCollection( - self, + self._request_client, query_str, params, ["rootSchemaNodes", "nodes"], @@ -1474,11 +1285,11 @@ def delete_unused_feature_schema(self, feature_schema_id: str) -> None: """ endpoint = ( - self.rest_endpoint + self._request_client.rest_endpoint + "/feature-schemas/" + urllib.parse.quote(feature_schema_id) ) - response = self._connection.delete(endpoint) + response = self.connection.delete(endpoint) if response.status_code != requests.codes.no_content: raise labelbox.exceptions.LabelboxError( @@ -1495,11 +1306,11 @@ def delete_unused_ontology(self, ontology_id: str) -> None: >>> client.delete_unused_ontology("cleabc1my012ioqvu5anyaabc") """ endpoint = ( - self.rest_endpoint + self._request_client.rest_endpoint + "/ontologies/" + urllib.parse.quote(ontology_id) ) - response = self._connection.delete(endpoint) + response = self.connection.delete(endpoint) if response.status_code != requests.codes.no_content: raise labelbox.exceptions.LabelboxError( @@ -1522,12 +1333,12 @@ def update_feature_schema_title( """ endpoint = ( - self.rest_endpoint + self._request_client.rest_endpoint + "/feature-schemas/" + urllib.parse.quote(feature_schema_id) + "/definition" ) - response = self._connection.patch(endpoint, json={"title": title}) + response = self.connection.patch(endpoint, json={"title": title}) if response.status_code == requests.codes.ok: return self.get_feature_schema(feature_schema_id) @@ -1557,11 +1368,11 @@ def upsert_feature_schema(self, feature_schema: Dict) -> FeatureSchema: feature_schema.get("featureSchemaId") or "new_feature_schema_id" ) endpoint = ( - self.rest_endpoint + self._request_client.rest_endpoint + "/feature-schemas/" + urllib.parse.quote(feature_schema_id) ) - response = self._connection.put( + response = self.connection.put( endpoint, json={"normalized": json.dumps(feature_schema)} ) @@ -1588,13 +1399,13 @@ def insert_feature_schema_into_ontology( """ endpoint = ( - self.rest_endpoint + self._request_client.rest_endpoint + "/ontologies/" + urllib.parse.quote(ontology_id) + "/feature-schemas/" + urllib.parse.quote(feature_schema_id) ) - response = self._connection.post(endpoint, json={"position": position}) + response = self.connection.post(endpoint, json={"position": position}) if response.status_code != requests.codes.created: raise labelbox.exceptions.LabelboxError( "Failed to insert the feature schema into the ontology, message: " @@ -1615,8 +1426,8 @@ def get_unused_ontologies(self, after: str = None) -> List[str]: >>> client.get_unused_ontologies("cleabc1my012ioqvu5anyaabc") """ - endpoint = self.rest_endpoint + "/ontologies/unused" - response = self._connection.get(endpoint, json={"after": after}) + endpoint = self._request_client.rest_endpoint + "/ontologies/unused" + response = self.connection.get(endpoint, json={"after": after}) if response.status_code == requests.codes.ok: return response.json() @@ -1640,8 +1451,10 @@ def get_unused_feature_schemas(self, after: str = None) -> List[str]: >>> client.get_unused_feature_schemas("cleabc1my012ioqvu5anyaabc") """ - endpoint = self.rest_endpoint + "/feature-schemas/unused" - response = self._connection.get(endpoint, json={"after": after}) + endpoint = ( + self._request_client.rest_endpoint + "/feature-schemas/unused" + ) + response = self.connection.get(endpoint, json={"after": after}) if response.status_code == requests.codes.ok: return response.json() @@ -1716,8 +1529,8 @@ def create_ontology( if editor_task_type_value: params["data"]["editorTaskType"] = editor_task_type_value - res = self.execute(query_str, params) - return Entity.Ontology(self, res["upsertOntology"]) + res = self._request_client.execute(query_str, params) + return Entity.Ontology(self._request_client, res["upsertOntology"]) def create_feature_schema(self, normalized): """ @@ -1755,11 +1568,13 @@ def create_feature_schema(self, normalized): } """ % query.results_query_part(Entity.FeatureSchema) normalized = {k: v for k, v in normalized.items() if v} params = {"data": {"normalized": json.dumps(normalized)}} - res = self.execute(query_str, params)["upsertRootSchemaNode"] + res = self._request_client.execute(query_str, params)[ + "upsertRootSchemaNode" + ] # Technically we are querying for a Schema Node. # But the features are the same so we just grab the feature schema id res["id"] = res["normalized"]["featureSchemaId"] - return Entity.FeatureSchema(self, res) + return Entity.FeatureSchema(self._request_client, res) def get_model_run(self, model_run_id: str) -> ModelRun: """Gets a single ModelRun with the given ID. @@ -1855,7 +1670,9 @@ def _format_failed_rows( for input in global_key_to_data_row_inputs ] } - assign_global_keys_to_data_rows_job = self.execute(query_str, params) + assign_global_keys_to_data_rows_job = self._request_client.execute( + query_str, params + ) # Query string for retrieving job status and result, if job is done result_query_str = """query assignGlobalKeysToDataRowsResultPyApi($jobId: ID!) { @@ -1890,7 +1707,7 @@ def _format_failed_rows( sleep_time = 2 start_time = time.time() while True: - res = self.execute(result_query_str, result_params) + res = self._request_client.execute(result_query_str, result_params) if ( res["assignGlobalKeysToDataRowsResult"]["jobStatus"] == "COMPLETE" @@ -1997,7 +1814,9 @@ def _format_failed_rows( dataRowsForGlobalKeys(where: {ids: $globalKeys}) { jobId}} """ params = {"globalKeys": global_keys} - data_rows_for_global_keys_job = self.execute(query_str, params) + data_rows_for_global_keys_job = self._request_client.execute( + query_str, params + ) # Query string for retrieving job status and result, if job is done result_query_str = """query getDataRowsForGlobalKeysResultPyApi($jobId: ID!) { @@ -2017,7 +1836,7 @@ def _format_failed_rows( sleep_time = 2 start_time = time.time() while True: - res = self.execute(result_query_str, result_params) + res = self._request_client.execute(result_query_str, result_params) if res["dataRowsForGlobalKeysResult"]["jobStatus"] == "COMPLETE": data = res["dataRowsForGlobalKeysResult"]["data"] results, errors = [], [] @@ -2099,7 +1918,7 @@ def _format_failed_rows( clearGlobalKeys(where: {ids: $globalKeys}) { jobId}} """ params = {"globalKeys": global_keys} - clear_global_keys_job = self.execute(query_str, params) + clear_global_keys_job = self._request_client.execute(query_str, params) # Query string for retrieving job status and result, if job is done result_query_str = """query clearGlobalKeysResultPyApi($jobId: ID!) { @@ -2117,7 +1936,7 @@ def _format_failed_rows( sleep_time = 2 start_time = time.time() while True: - res = self.execute(result_query_str, result_params) + res = self._request_client.execute(result_query_str, result_params) if res["clearGlobalKeysResult"]["jobStatus"] == "COMPLETE": data = res["clearGlobalKeysResult"]["data"] results, errors = [], [] @@ -2166,7 +1985,7 @@ def _format_failed_rows( time.sleep(sleep_time) def get_catalog(self) -> Catalog: - return Catalog(client=self) + return Catalog(client=self._request_client) def get_catalog_slice(self, slice_id) -> CatalogSlice: """ @@ -2188,8 +2007,8 @@ def get_catalog_slice(self, slice_id) -> CatalogSlice: } } """ - res = self.execute(query_str, {"id": slice_id}) - return Entity.CatalogSlice(self, res["getSavedQuery"]) + res = self._request_client.execute(query_str, {"id": slice_id}) + return Entity.CatalogSlice(self._request_client, res["getSavedQuery"]) def is_feature_schema_archived( self, ontology_id: str, feature_schema_id: str @@ -2205,11 +2024,11 @@ def is_feature_schema_archived( """ ontology_endpoint = ( - self.rest_endpoint + self._request_client.rest_endpoint + "/ontologies/" + urllib.parse.quote(ontology_id) ) - response = self._connection.get(ontology_endpoint) + response = self.connection.get(ontology_endpoint) if response.status_code == requests.codes.ok: feature_schema_nodes = response.json()["featureSchemaNodes"] @@ -2259,13 +2078,13 @@ def get_model_slice(self, slice_id) -> ModelSlice: } } """ - res = self.execute(query_str, {"id": slice_id}) + res = self._request_client.execute(query_str, {"id": slice_id}) if res is None or res["getSavedQuery"] is None: raise labelbox.exceptions.ResourceNotFoundError( ModelSlice, slice_id ) - return Entity.ModelSlice(self, res["getSavedQuery"]) + return Entity.ModelSlice(self._request_client, res["getSavedQuery"]) def delete_feature_schema_from_ontology( self, ontology_id: str, feature_schema_id: str @@ -2287,13 +2106,13 @@ def delete_feature_schema_from_ontology( >>> client.delete_feature_schema_from_ontology(, ) """ ontology_endpoint = ( - self.rest_endpoint + self._request_client.rest_endpoint + "/ontologies/" + urllib.parse.quote(ontology_id) + "/feature-schemas/" + urllib.parse.quote(feature_schema_id) ) - response = self._connection.delete(ontology_endpoint) + response = self.connection.delete(ontology_endpoint) if response.status_code == requests.codes.ok: response_json = response.json() @@ -2328,14 +2147,14 @@ def unarchive_feature_schema_node( None """ ontology_endpoint = ( - self.rest_endpoint + self._request_client.rest_endpoint + "/ontologies/" + urllib.parse.quote(ontology_id) + "/feature-schemas/" + urllib.parse.quote(root_feature_schema_id) + "/unarchive" ) - response = self._connection.patch(ontology_endpoint) + response = self.connection.patch(ontology_endpoint) if response.status_code == requests.codes.ok: if not bool(response.json()["unarchived"]): raise labelbox.exceptions.LabelboxError( @@ -2363,14 +2182,14 @@ def get_batch(self, project_id: str, batch_id: str) -> Entity.Batch: query.results_query_part(Entity.Batch), ) - batch = self.execute( + batch = self._request_client.execute( get_batch_str, {"projectId": project_id, "batchId": batch_id}, timeout=180.0, experimental=True, )["project"]["batches"]["nodes"][0] - return Entity.Batch(self, project_id, batch) + return Entity.Batch(self._request_client, project_id, batch) def send_to_annotate_from_catalog( self, @@ -2442,7 +2261,7 @@ def send_to_annotate_from_catalog( else None ) - res = self.execute( + res = self._request_client.execute( mutation_str, { "input": { @@ -2468,7 +2287,7 @@ def send_to_annotate_from_catalog( }, )["sendToAnnotateFromCatalog"] - return Entity.Task.get_task(self, res["taskId"]) + return Entity.Task.get_task(self._request_client, res["taskId"]) @staticmethod def build_catalog_query(data_rows: Union[DataRowIds, GlobalKeys]): @@ -2513,7 +2332,7 @@ def run_foundry_app( data_rows (DataRowIds or GlobalKeys): Data row identifiers to run predictions on app_id (str): Foundry app to run predictions with """ - foundry_client = FoundryClient(self) + foundry_client = FoundryClient(self._request_client) return foundry_client.run_app(model_run_name, data_rows, app_id) def create_embedding(self, name: str, dims: int) -> Embedding: @@ -2613,7 +2432,7 @@ def upsert_label_feedback( } } """ - res = self.execute( + res = self._request_client.execute( mutation_str, {"labelId": label_id, "feedback": feedback, "scores": scores}, ) @@ -2658,7 +2477,9 @@ def get_labeling_service_dashboards( See libs/labelbox/src/labelbox/schema/search_filters.py and libs/labelbox/tests/unit/test_unit_search_filters.py for more examples. """ - return LabelingServiceDashboard.get_all(self, search_query=search_query) + return LabelingServiceDashboard.get_all( + self._request_client, search_query=search_query + ) def get_task_by_id(self, task_id: str) -> Union[Task, DataUpsertTask]: """ @@ -2694,7 +2515,9 @@ def get_task_by_id(self, task_id: str) -> Union[Task, DataUpsertTask]: } } """ - result = self.execute(query, {"userId": user.uid, "taskId": task_id}) + result = self._request_client.execute( + query, {"userId": user.uid, "taskId": task_id} + ) data = result.get("user", {}).get("createdTasks", []) if not data: raise labelbox.exceptions.ResourceNotFoundError( @@ -2702,9 +2525,10 @@ def get_task_by_id(self, task_id: str) -> Union[Task, DataUpsertTask]: ) task_data = data[0] if task_data["type"].lower() == "adv-upsert-data-rows": - task = DataUpsertTask(self, task_data) + task = DataUpsertTask(self._request_client, task_data) else: - task = Task(self, task_data) + task = Task(self._request_client, task_data) task._user = user return task + diff --git a/libs/labelbox/src/labelbox/data_uploader.py b/libs/labelbox/src/labelbox/data_uploader.py new file mode 100644 index 000000000..002add61b --- /dev/null +++ b/libs/labelbox/src/labelbox/data_uploader.py @@ -0,0 +1,139 @@ +import json +import logging +import mimetypes +import os +from typing import Optional + +import requests +import requests.exceptions +from google.api_core import retry + +from . import exceptions +from .request_client import RequestClient + +logger = logging.getLogger(__name__) + + +class DataUploader: + """A class to upload data. + + Contains info necessary for connecting to a Labelbox server (URL, + authentication key). + """ + + def __init__( + self, + client: RequestClient, + ) -> None: + self._client = client + + @property + def connection(self) -> requests.Session: + return self._client._connection + + @retry.Retry( + predicate=retry.if_exception_type(exceptions.InternalServerError) + ) + def upload_data( + self, + content: bytes, + filename: Optional[str] = None, + content_type: Optional[str] = None, + sign: bool = False, + ) -> str: + """Uploads the given data (bytes) to Labelbox. + + Args: + content: bytestring to upload + filename: name of the upload + content_type: content type of data uploaded + sign: whether or not to sign the url + + Returns: + str, the URL of uploaded data. + + Raises: + exceptions.LabelboxError: If upload failed. + """ + + request_data = { + "operations": json.dumps( + { + "variables": { + "file": None, + "contentLength": len(content), + "sign": sign, + }, + "query": """mutation UploadFile($file: Upload!, $contentLength: Int!, + $sign: Boolean) { + uploadFile(file: $file, contentLength: $contentLength, + sign: $sign) {url filename} } """, + } + ), + "map": (None, json.dumps({"1": ["variables.file"]})), + } + + files = { + "1": (filename, content, content_type) + if (filename and content_type) + else content + } + headers = self.connection.headers.copy() + headers.pop("Content-Type", None) + request = requests.Request( + "POST", + self._client.endpoint, + headers=headers, + data=request_data, + files=files, + ) + + prepped: requests.PreparedRequest = request.prepare() + + response = self.connection.send(prepped) + + if response.status_code == 502: + error_502 = "502 Bad Gateway" + raise exceptions.InternalServerError(error_502) + elif response.status_code == 503: + raise exceptions.InternalServerError(response.text) + elif response.status_code == 520: + raise exceptions.InternalServerError(response.text) + + try: + file_data = response.json().get("data", None) + except ValueError as e: # response is not valid JSON + raise exceptions.LabelboxError("Failed to upload, unknown cause", e) + + if not file_data or not file_data.get("uploadFile", None): + try: + errors = response.json().get("errors", []) + error_msg = next(iter(errors), {}).get( + "message", "Unknown error" + ) + except Exception: + error_msg = "Unknown error" + raise exceptions.LabelboxError( + "Failed to upload, message: %s" % error_msg + ) + + return file_data["uploadFile"]["url"] + + def upload_file(self, path: str) -> str: + """Uploads given path to local file. + + Also includes best guess at the content type of the file. + + Args: + path (str): path to local file to be uploaded. + Returns: + str, the URL of uploaded data. + Raises: + labelbox.exceptions.LabelboxError: If upload failed. + """ + content_type, _ = mimetypes.guess_type(path) + filename = os.path.basename(path) + with open(path, "rb") as f: + return self.upload_data( + content=f.read(), filename=filename, content_type=content_type + ) diff --git a/libs/labelbox/src/labelbox/orm/db_object.py b/libs/labelbox/src/labelbox/orm/db_object.py index b210a8a5b..3f1c5ca79 100644 --- a/libs/labelbox/src/labelbox/orm/db_object.py +++ b/libs/labelbox/src/labelbox/orm/db_object.py @@ -1,17 +1,17 @@ -from dataclasses import dataclass +import json +import logging from datetime import datetime, timezone from functools import wraps -import logging -import json from labelbox import utils +from labelbox.client import get_data_row_metadata_ontology from labelbox.exceptions import ( - InvalidQueryError, InvalidAttributeError, + InvalidQueryError, OperationNotSupportedException, ) from labelbox.orm import query -from labelbox.orm.model import Field, Relationship, Entity +from labelbox.orm.model import Entity, Field, Relationship from labelbox.pagination import PaginatedCollection logger = logging.getLogger(__name__) @@ -43,7 +43,7 @@ def __init__(self, client, field_values): by library internals and not by the end user. Args: - client (labelbox.Client): the client used for fetching data from DB. + client (labelbox.RequestCliebt): the client used for fetching data from DB. field_values (dict): Data obtained from the DB. Maps database object fields (their graphql_name version) to values. """ @@ -85,7 +85,7 @@ def _set_field_values(self, field_values): value = field.field_type.enum_cls(value) elif isinstance(field.field_type, Field.ListType): if field.field_type.list_cls.__name__ == "DataRowMetadataField": - mdo = self.client.get_data_row_metadata_ontology() + mdo = get_data_row_metadata_ontology() try: value = mdo.parse_metadata_fields(value) except ValueError: diff --git a/libs/labelbox/src/labelbox/pagination.py b/libs/labelbox/src/labelbox/pagination.py index a3b170ec7..e49f114a6 100644 --- a/libs/labelbox/src/labelbox/pagination.py +++ b/libs/labelbox/src/labelbox/pagination.py @@ -1,12 +1,20 @@ # Size of a single page in a paginated query. from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union - -from typing import TYPE_CHECKING +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Tuple, + Type, + Union, +) if TYPE_CHECKING: - from labelbox import Client from labelbox.orm.db_object import DbObject + from labelbox.request_client import RequestClient _PAGE_SIZE = 100 @@ -22,7 +30,7 @@ class PaginatedCollection: def __init__( self, - client: "Client", + client: "RequestClient", query: str, params: Dict[str, Union[str, int]], dereferencing: Union[List[str], Dict[str, Any]], @@ -33,7 +41,7 @@ def __init__( """Creates a PaginatedCollection. Args: - client (labelbox.Client): the client used for fetching data from DB. + client (labelbox.RequestClient): the client used for fetching data from DB. query (str): Base query used for pagination. It must contain two '%d' placeholders, the first for pagination 'skip' clause and the second for the 'first' clause. @@ -113,7 +121,7 @@ def get_many(self, n: int): class _Pagination(ABC): def __init__( self, - client: "Client", + client: "RequestClient", obj_class: Type["DbObject"], dereferencing: Dict[str, Any], query: str, diff --git a/libs/labelbox/src/labelbox/request_client.py b/libs/labelbox/src/labelbox/request_client.py new file mode 100644 index 000000000..c3c8cf836 --- /dev/null +++ b/libs/labelbox/src/labelbox/request_client.py @@ -0,0 +1,353 @@ +import json +import logging +import os +import sys +from datetime import datetime, timezone +from types import MappingProxyType + +import requests +import requests.exceptions +from google.api_core import retry + +from . import exceptions + +logger = logging.getLogger(__name__) + +_LABELBOX_API_KEY = "LABELBOX_API_KEY" + + +def python_version_info(): + version_info = sys.version_info + + return f"{version_info.major}.{version_info.minor}.{version_info.micro}-{version_info.releaselevel}" + + +class RequestClient: + """A Labelbox request client. + + Contains info necessary for connecting to a Labelbox server (URL, + authentication key). + """ + + def __init__( + self, + sdk_version, + api_key=None, + endpoint="https://api.labelbox.com/graphql", + enable_experimental=False, + app_url="https://app.labelbox.com", + rest_endpoint="https://api.labelbox.com/api/v1", + ): + """Creates and initializes a RequestClient. + This class executes graphql and rest requests to the Labelbox server. + + Args: + api_key (str): API key. If None, the key is obtained from the "LABELBOX_API_KEY" environment variable. + endpoint (str): URL of the Labelbox server to connect to. + enable_experimental (bool): Indicates whether or not to use experimental features + app_url (str) : host url for all links to the web app + Raises: + exceptions.AuthenticationError: If no `api_key` + is provided as an argument or via the environment + variable. + """ + if api_key is None: + if _LABELBOX_API_KEY not in os.environ: + raise exceptions.AuthenticationError( + "Labelbox API key not provided" + ) + api_key = os.environ[_LABELBOX_API_KEY] + self.api_key = api_key + + self.enable_experimental = enable_experimental + if enable_experimental: + logger.info("Experimental features have been enabled") + + logger.info("Initializing Labelbox client at '%s'", endpoint) + self.app_url = app_url + self.endpoint = endpoint + self.rest_endpoint = rest_endpoint + self.sdk_version = sdk_version + self._connection: requests.Session = self._init_connection() + + def _init_connection(self) -> requests.Session: + connection = ( + requests.Session() + ) # using default connection pool size of 10 + connection.headers.update(self._default_headers()) + + return connection + + @property + def headers(self) -> MappingProxyType: + return self._connection.headers + + def _default_headers(self): + return { + "Authorization": "Bearer %s" % self.api_key, + "Accept": "application/json", + "Content-Type": "application/json", + "X-User-Agent": f"python-sdk {self.sdk_version}", + "X-Python-Version": f"{python_version_info()}", + } + + @retry.Retry( + predicate=retry.if_exception_type( + exceptions.InternalServerError, + exceptions.TimeoutError, + ) + ) + def execute( + self, + query=None, + params=None, + data=None, + files=None, + timeout=60.0, + experimental=False, + error_log_key="message", + raise_return_resource_not_found=False, + ): + """Sends a request to the server for the execution of the + given query. + + Checks the response for errors and wraps errors + in appropriate `exceptions.LabelboxError` subtypes. + + Args: + query (str): The query to execute. + params (dict): Query parameters referenced within the query. + data (str): json string containing the query to execute + files (dict): file arguments for request + timeout (float): Max allowed time for query execution, + in seconds. + Returns: + dict, parsed JSON response. + Raises: + exceptions.AuthenticationError: If authentication + failed. + exceptions.InvalidQueryError: If `query` is not + syntactically or semantically valid (checked server-side). + exceptions.ApiLimitError: If the server API limit was + exceeded. See "How to import data" in the online documentation + to see API limits. + exceptions.TimeoutError: If response was not received + in `timeout` seconds. + exceptions.NetworkError: If an unknown error occurred + most likely due to connection issues. + exceptions.LabelboxError: If an unknown error of any + kind occurred. + ValueError: If query and data are both None. + """ + logger.debug("Query: %s, params: %r, data %r", query, params, data) + + # Convert datetimes to UTC strings. + def convert_value(value): + if isinstance(value, datetime): + value = value.astimezone(timezone.utc) + value = value.strftime("%Y-%m-%dT%H:%M:%SZ") + return value + + if query is not None: + if params is not None: + params = { + key: convert_value(value) for key, value in params.items() + } + data = json.dumps({"query": query, "variables": params}).encode( + "utf-8" + ) + elif data is None: + raise ValueError("query and data cannot both be none") + + endpoint = ( + self.endpoint + if not experimental + else self.endpoint.replace("/graphql", "/_gql") + ) + + try: + headers = self._connection.headers.copy() + if files: + del headers["Content-Type"] + del headers["Accept"] + request = requests.Request( + "POST", + endpoint, + headers=headers, + data=data, + files=files if files else None, + ) + + prepped: requests.PreparedRequest = request.prepare() + + response = self._connection.send(prepped, timeout=timeout) + logger.debug("Response: %s", response.text) + except requests.exceptions.Timeout as e: + raise exceptions.TimeoutError(str(e)) + except requests.exceptions.RequestException as e: + logger.error("Unknown error: %s", str(e)) + raise exceptions.NetworkError(e) + except Exception as e: + raise exceptions.LabelboxError( + "Unknown error during RequestClient.query(): " + str(e), e + ) + + if ( + 200 <= response.status_code < 300 + or response.status_code < 500 + or response.status_code >= 600 + ): + try: + r_json = response.json() + except Exception: + raise exceptions.LabelboxError( + "Failed to parse response as JSON: %s" % response.text + ) + else: + if ( + "upstream connect error or disconnect/reset before headers" + in response.text + ): + raise exceptions.InternalServerError("Connection reset") + elif response.status_code == 502: + error_502 = "502 Bad Gateway" + raise exceptions.InternalServerError(error_502) + elif 500 <= response.status_code < 600: + error_500 = f"Internal server http error {response.status_code}" + raise exceptions.InternalServerError(error_500) + + errors = r_json.get("errors", []) + + def check_errors(keywords, *path): + """Helper that looks for any of the given `keywords` in any of + current errors on paths (like error[path][component][to][keyword]). + """ + for error in errors: + obj = error + for path_elem in path: + obj = obj.get(path_elem, {}) + if obj in keywords: + return error + return None + + def get_error_status_code(error: dict) -> int: + try: + return int(error["extensions"].get("exception").get("status")) + except: + return 500 + + if ( + check_errors(["AUTHENTICATION_ERROR"], "extensions", "code") + is not None + ): + raise exceptions.AuthenticationError("Invalid API key") + + authorization_error = check_errors( + ["AUTHORIZATION_ERROR"], "extensions", "code" + ) + if authorization_error is not None: + raise exceptions.AuthorizationError(authorization_error["message"]) + + validation_error = check_errors( + ["GRAPHQL_VALIDATION_FAILED"], "extensions", "code" + ) + + if validation_error is not None: + message = validation_error["message"] + if message == "Query complexity limit exceeded": + raise exceptions.ValidationFailedError(message) + else: + raise exceptions.InvalidQueryError(message) + + graphql_error = check_errors( + ["GRAPHQL_PARSE_FAILED"], "extensions", "code" + ) + if graphql_error is not None: + raise exceptions.InvalidQueryError(graphql_error["message"]) + + # Check if API limit was exceeded + response_msg = r_json.get("message", "") + + if response_msg.startswith("You have exceeded"): + raise exceptions.ApiLimitError(response_msg) + + resource_not_found_error = check_errors( + ["RESOURCE_NOT_FOUND"], "extensions", "code" + ) + if resource_not_found_error is not None: + if raise_return_resource_not_found: + raise exceptions.ResourceNotFoundError( + message=resource_not_found_error["message"] + ) + else: + # Return None and let the caller methods raise an exception + # as they already know which resource type and ID was requested + return None + + resource_conflict_error = check_errors( + ["RESOURCE_CONFLICT"], "extensions", "code" + ) + if resource_conflict_error is not None: + raise exceptions.ResourceConflict( + resource_conflict_error["message"] + ) + + malformed_request_error = check_errors( + ["MALFORMED_REQUEST"], "extensions", "code" + ) + if malformed_request_error is not None: + raise exceptions.MalformedQueryException( + malformed_request_error[error_log_key] + ) + + # A lot of different error situations are now labeled serverside + # as INTERNAL_SERVER_ERROR, when they are actually client errors. + # TODO: fix this in the server API + internal_server_error = check_errors( + ["INTERNAL_SERVER_ERROR"], "extensions", "code" + ) + if internal_server_error is not None: + message = internal_server_error.get("message") + error_status_code = get_error_status_code(internal_server_error) + if error_status_code == 400: + raise exceptions.InvalidQueryError(message) + elif error_status_code == 422: + raise exceptions.UnprocessableEntityError(message) + elif error_status_code == 426: + raise exceptions.OperationNotAllowedException(message) + elif error_status_code == 500: + raise exceptions.LabelboxError(message) + else: + raise exceptions.InternalServerError(message) + + not_allowed_error = check_errors( + ["OPERATION_NOT_ALLOWED"], "extensions", "code" + ) + if not_allowed_error is not None: + message = not_allowed_error.get("message") + raise exceptions.OperationNotAllowedException(message) + + if len(errors) > 0: + logger.warning("Unparsed errors on query execution: %r", errors) + messages = list( + map( + lambda x: { + "message": x["message"], + "code": x["extensions"]["code"], + }, + errors, + ) + ) + raise exceptions.LabelboxError("Unknown error: %s" % str(messages)) + + # if we do return a proper error code, and didn't catch this above + # reraise + # this mainly catches a 401 for API access disabled for free tier + # TODO: need to unify API errors to handle things more uniformly + # in the SDK + if response.status_code != requests.codes.ok: + message = f"{response.status_code} {response.reason}" + cause = r_json.get("message") + raise exceptions.LabelboxError(message, cause) + + return r_json["data"] diff --git a/libs/labelbox/src/labelbox/schema/__init__.py b/libs/labelbox/src/labelbox/schema/__init__.py index 03327e0d1..e57c04a29 100644 --- a/libs/labelbox/src/labelbox/schema/__init__.py +++ b/libs/labelbox/src/labelbox/schema/__init__.py @@ -1,29 +1,28 @@ -import labelbox.schema.asset_attachment -import labelbox.schema.bulk_import_request import labelbox.schema.annotation_import +import labelbox.schema.asset_attachment +import labelbox.schema.batch import labelbox.schema.benchmark +import labelbox.schema.catalog import labelbox.schema.data_row +import labelbox.schema.data_row_metadata import labelbox.schema.dataset +import labelbox.schema.iam_integration +import labelbox.schema.identifiable +import labelbox.schema.identifiables import labelbox.schema.invite import labelbox.schema.label import labelbox.schema.labeling_frontend import labelbox.schema.labeling_service +import labelbox.schema.media_type import labelbox.schema.model import labelbox.schema.model_run import labelbox.schema.ontology +import labelbox.schema.ontology_kind import labelbox.schema.organization import labelbox.schema.project +import labelbox.schema.project_overview import labelbox.schema.review import labelbox.schema.role import labelbox.schema.task import labelbox.schema.user import labelbox.schema.webhook -import labelbox.schema.data_row_metadata -import labelbox.schema.batch -import labelbox.schema.iam_integration -import labelbox.schema.media_type -import labelbox.schema.identifiables -import labelbox.schema.identifiable -import labelbox.schema.catalog -import labelbox.schema.ontology_kind -import labelbox.schema.project_overview diff --git a/libs/labelbox/src/labelbox/schema/annotation_import.py b/libs/labelbox/src/labelbox/schema/annotation_import.py index df7f272a3..f56f294ef 100644 --- a/libs/labelbox/src/labelbox/schema/annotation_import.py +++ b/libs/labelbox/src/labelbox/schema/annotation_import.py @@ -3,33 +3,35 @@ import logging import os import time +from collections import defaultdict from typing import ( + TYPE_CHECKING, Any, BinaryIO, Dict, List, Optional, Union, - TYPE_CHECKING, cast, ) -from collections import defaultdict -from google.api_core import retry -from labelbox import parser import requests +from google.api_core import retry from tqdm import tqdm # type: ignore import labelbox +from labelbox import parser from labelbox.orm import query from labelbox.orm.db_object import DbObject from labelbox.orm.model import Field, Relationship -from labelbox.utils import is_exactly_one_set from labelbox.schema.confidence_presence_checker import ( LabelsConfidencePresenceChecker, ) from labelbox.schema.enums import AnnotationImportState from labelbox.schema.serialization import serialize_labels +from labelbox.utils import is_exactly_one_set + +from ..request_client import RequestClient if TYPE_CHECKING: from labelbox.types import Label @@ -260,7 +262,7 @@ def _validate_data_rows(cls, objects: List[Dict[str, Any]]): @classmethod def from_name( cls, - client: "labelbox.Client", + client: "RequestClient", parent_id: str, name: str, as_json: bool = False, @@ -276,7 +278,7 @@ class CreatableAnnotationImport(AnnotationImport): @classmethod def create( cls, - client: "labelbox.Client", + client: "RequestClient", id: str, name: str, path: Optional[str] = None, @@ -295,20 +297,20 @@ def create( @classmethod def create_from_url( - cls, client: "labelbox.Client", id: str, name: str, url: str + cls, client: "RequestClient", id: str, name: str, url: str ) -> "AnnotationImport": raise NotImplementedError("Inheriting class must override") @classmethod def create_from_file( - cls, client: "labelbox.Client", id: str, name: str, path: str + cls, client: "RequestClient", id: str, name: str, path: str ) -> "AnnotationImport": raise NotImplementedError("Inheriting class must override") @classmethod def create_from_objects( cls, - client: "labelbox.Client", + client: "RequestClient", id: str, name: str, labels: Union[List[Dict[str, Any]], List["Label"]], @@ -328,13 +330,13 @@ def parent_id(self) -> str: @classmethod def create_from_file( - cls, client: "labelbox.Client", model_run_id: str, name: str, path: str + cls, client: "RequestClient", model_run_id: str, name: str, path: str ) -> "MEAPredictionImport": """ Create an MEA prediction import job from a file of annotations Args: - client: Labelbox Client for executing queries + client: Labelbox RequestClient for executing queries model_run_id: Model run to import labels into name: Name of the import job. Can be used to reference the task later path: Path to ndjson file containing annotations @@ -352,7 +354,7 @@ def create_from_file( @classmethod def create_from_objects( cls, - client: "labelbox.Client", + client: "RequestClient", model_run_id: str, name, predictions: Union[List[Dict[str, Any]], List["Label"]], @@ -361,7 +363,7 @@ def create_from_objects( Create an MEA prediction import job from an in memory dictionary Args: - client: Labelbox Client for executing queries + client: Labelbox RequestClient for executing queries model_run_id: Model run to import labels into name: Name of the import job. Can be used to reference the task later predictions: List of prediction annotations @@ -376,14 +378,14 @@ def create_from_objects( @classmethod def create_from_url( - cls, client: "labelbox.Client", model_run_id: str, name: str, url: str + cls, client: "RequestClient", model_run_id: str, name: str, url: str ) -> "MEAPredictionImport": """ Create an MEA prediction import job from a url The url must point to a file containing prediction annotations. Args: - client: Labelbox Client for executing queries + client: Labelbox RequestClient for executing queries model_run_id: Model run to import labels into name: Name of the import job. Can be used to reference the task later url: Url pointing to file to upload @@ -409,7 +411,7 @@ def create_from_url( @classmethod def from_name( cls, - client: "labelbox.Client", + client: "RequestClient", model_run_id: str, name: str, as_json: bool = False, @@ -418,7 +420,7 @@ def from_name( Retrieves an MEA import job. Args: - client: Labelbox Client for executing queries + client: Labelbox RequestClient for executing queries model_run_id: ID used for querying import jobs name: Name of the import job. Returns: @@ -464,7 +466,7 @@ def _get_file_mutation(cls) -> str: @classmethod def _create_mea_import_from_bytes( cls, - client: "labelbox.Client", + client: "RequestClient", model_run_id: str, name: str, bytes_data: BinaryIO, @@ -501,7 +503,7 @@ def parent_id(self) -> str: @classmethod def create_for_model_run_data_rows( cls, - client: "labelbox.Client", + client: "RequestClient", model_run_id: str, data_row_ids: List[str], project_id: str, @@ -511,7 +513,7 @@ def create_for_model_run_data_rows( Create an MEA to MAL prediction import job from a list of data row ids of a specific model run Args: - client: Labelbox Client for executing queries + client: Labelbox RequestClient for executing queries data_row_ids: A list of data row ids model_run_id: model run id Returns: @@ -534,7 +536,7 @@ def create_for_model_run_data_rows( @classmethod def from_name( cls, - client: "labelbox.Client", + client: "RequestClient", project_id: str, name: str, as_json: bool = False, @@ -543,7 +545,7 @@ def from_name( Retrieves an MEA to MAL import job. Args: - client: Labelbox Client for executing queries + client: Labelbox RequestClient for executing queries project_id: ID used for querying import jobs name: Name of the import job. Returns: @@ -592,13 +594,13 @@ def parent_id(self) -> str: @classmethod def create_from_file( - cls, client: "labelbox.Client", project_id: str, name: str, path: str + cls, client: "RequestClient", project_id: str, name: str, path: str ) -> "MALPredictionImport": """ Create an MAL prediction import job from a file of annotations Args: - client: Labelbox Client for executing queries + client: Labelbox RequestClient for executing queries project_id: Project to import labels into name: Name of the import job. Can be used to reference the task later path: Path to ndjson file containing annotations @@ -616,7 +618,7 @@ def create_from_file( @classmethod def create_from_objects( cls, - client: "labelbox.Client", + client: "RequestClient", project_id: str, name: str, predictions: Union[List[Dict[str, Any]], List["Label"]], @@ -625,7 +627,7 @@ def create_from_objects( Create an MAL prediction import job from an in memory dictionary Args: - client: Labelbox Client for executing queries + client: Labelbox RequestClient for executing queries project_id: Project to import labels into name: Name of the import job. Can be used to reference the task later predictions: List of prediction annotations @@ -650,14 +652,14 @@ def create_from_objects( @classmethod def create_from_url( - cls, client: "labelbox.Client", project_id: str, name: str, url: str + cls, client: "RequestClient", project_id: str, name: str, url: str ) -> "MALPredictionImport": """ Create an MAL prediction import job from a url The url must point to a file containing prediction annotations. Args: - client: Labelbox Client for executing queries + client: Labelbox RequestClient for executing queries project_id: Project to import labels into name: Name of the import job. Can be used to reference the task later url: Url pointing to file to upload @@ -683,7 +685,7 @@ def create_from_url( @classmethod def from_name( cls, - client: "labelbox.Client", + client: "RequestClient", project_id: str, name: str, as_json: bool = False, @@ -692,7 +694,7 @@ def from_name( Retrieves an MAL import job. Args: - client: Labelbox Client for executing queries + client: Labelbox RequestClient for executing queries project_id: ID used for querying import jobs name: Name of the import job. Returns: @@ -738,7 +740,7 @@ def _get_file_mutation(cls) -> str: @classmethod def _create_mal_import_from_bytes( cls, - client: "labelbox.Client", + client: "RequestClient", project_id: str, name: str, bytes_data: BinaryIO, @@ -770,13 +772,13 @@ def parent_id(self) -> str: @classmethod def create_from_file( - cls, client: "labelbox.Client", project_id: str, name: str, path: str + cls, client: "RequestClient", project_id: str, name: str, path: str ) -> "LabelImport": """ Create a label import job from a file of annotations Args: - client: Labelbox Client for executing queries + client: Labelbox RequestClient for executing queries project_id: Project to import labels into name: Name of the import job. Can be used to reference the task later path: Path to ndjson file containing annotations @@ -794,7 +796,7 @@ def create_from_file( @classmethod def create_from_objects( cls, - client: "labelbox.Client", + client: "RequestClient", project_id: str, name: str, labels: Union[List[Dict[str, Any]], List["Label"]], @@ -803,7 +805,7 @@ def create_from_objects( Create a label import job from an in memory dictionary Args: - client: Labelbox Client for executing queries + client: Labelbox RequestClient for executing queries project_id: Project to import labels into name: Name of the import job. Can be used to reference the task later labels: List of labels @@ -826,14 +828,14 @@ def create_from_objects( @classmethod def create_from_url( - cls, client: "labelbox.Client", project_id: str, name: str, url: str + cls, client: "RequestClient", project_id: str, name: str, url: str ) -> "LabelImport": """ Create a label annotation import job from a url The url must point to a file containing label annotations. Args: - client: Labelbox Client for executing queries + client: Labelbox RequestClient for executing queries project_id: Project to import labels into name: Name of the import job. Can be used to reference the task later url: Url pointing to file to upload @@ -859,7 +861,7 @@ def create_from_url( @classmethod def from_name( cls, - client: "labelbox.Client", + client: "RequestClient", project_id: str, name: str, as_json: bool = False, @@ -868,7 +870,7 @@ def from_name( Retrieves an label import job. Args: - client: Labelbox Client for executing queries + client: Labelbox RequestClient for executing queries project_id: ID used for querying import jobs name: Name of the import job. Returns: @@ -912,7 +914,7 @@ def _get_file_mutation(cls) -> str: @classmethod def _create_label_import_from_bytes( cls, - client: "labelbox.Client", + client: "RequestClient", project_id: str, name: str, bytes_data: BinaryIO, diff --git a/libs/labelbox/src/labelbox/schema/catalog.py b/libs/labelbox/src/labelbox/schema/catalog.py index 567bbd777..ff52be76d 100644 --- a/libs/labelbox/src/labelbox/schema/catalog.py +++ b/libs/labelbox/src/labelbox/schema/catalog.py @@ -1,7 +1,6 @@ -from typing import Any, Dict, List, Optional, Tuple, Union -from labelbox.orm.db_object import experimental -from labelbox.schema.export_filters import CatalogExportFilters, build_filters +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from labelbox.schema.export_filters import CatalogExportFilters, build_filters from labelbox.schema.export_params import ( CatalogExportParams, validate_catalog_export_params, @@ -9,16 +8,14 @@ from labelbox.schema.export_task import ExportTask from labelbox.schema.task import Task -from typing import TYPE_CHECKING - if TYPE_CHECKING: - from labelbox import Client + from labelbox.request_client import RequestClient class Catalog: - client: "Client" + client: "RequestClient" - def __init__(self, client: "Client"): + def __init__(self, client: "RequestClient"): self.client = client def export_v2( diff --git a/libs/labelbox/src/labelbox/schema/data_row.py b/libs/labelbox/src/labelbox/schema/data_row.py index 8987a00f0..aa7b40944 100644 --- a/libs/labelbox/src/labelbox/schema/data_row.py +++ b/libs/labelbox/src/labelbox/schema/data_row.py @@ -1,18 +1,19 @@ +import json import logging from enum import Enum -from typing import TYPE_CHECKING, List, Optional, Tuple, Union, Any -import json +from typing import TYPE_CHECKING, List, Optional, Tuple, Union from labelbox.orm import query from labelbox.orm.db_object import ( + BulkDeletable, DbObject, Updateable, - BulkDeletable, - experimental, ) from labelbox.orm.model import Entity, Field, Relationship from labelbox.schema.asset_attachment import AttachmentType -from labelbox.schema.data_row_metadata import DataRowMetadataField # type: ignore +from labelbox.schema.data_row_metadata import ( + DataRowMetadataField, # type: ignore +) from labelbox.schema.export_filters import ( DatarowExportFilters, build_filters, @@ -25,8 +26,10 @@ from labelbox.schema.export_task import ExportTask from labelbox.schema.task import Task +from ..request_client import RequestClient + if TYPE_CHECKING: - from labelbox import AssetAttachment, Client + from labelbox import AssetAttachment, RequestClient logger = logging.getLogger(__name__) @@ -213,7 +216,7 @@ def create_attachment( @staticmethod def export( - client: "Client", + client: "RequestClient", data_rows: Optional[List[Union[str, "DataRow"]]] = None, global_keys: Optional[List[str]] = None, task_name: Optional[str] = None, @@ -222,7 +225,7 @@ def export( """ Creates a data rows export task with the given list, params and returns the task. Args: - client (Client): client to use to make the export request + client (RequestClient): client to use to make the export request data_rows (list of DataRow or str): list of data row objects or data row ids to export task_name (str): name of remote task params (CatalogExportParams): export params @@ -248,7 +251,7 @@ def export( @staticmethod def export_v2( - client: "Client", + client: "RequestClient", data_rows: Optional[List[Union[str, "DataRow"]]] = None, global_keys: Optional[List[str]] = None, task_name: Optional[str] = None, @@ -257,7 +260,7 @@ def export_v2( """ Creates a data rows export task with the given list, params and returns the task. Args: - client (Client): client to use to make the export request + client (RequestClient): client to use to make the export request data_rows (list of DataRow or str): list of data row objects or data row ids to export task_name (str): name of remote task params (CatalogExportParams): export params @@ -286,7 +289,7 @@ def export_v2( @staticmethod def _export( - client: "Client", + client: "RequestClient", data_rows: Optional[List[Union[str, "DataRow"]]] = None, global_keys: Optional[List[str]] = None, task_name: Optional[str] = None, diff --git a/libs/labelbox/src/labelbox/schema/dataset.py b/libs/labelbox/src/labelbox/schema/dataset.py index 04877c885..a58ac4daa 100644 --- a/libs/labelbox/src/labelbox/schema/dataset.py +++ b/libs/labelbox/src/labelbox/schema/dataset.py @@ -1,57 +1,48 @@ -from datetime import datetime -from typing import Dict, Generator, List, Optional, Any, Final, Tuple, Union -import os import json import logging -from collections.abc import Iterable -from string import Template -import time +import os import warnings - -from labelbox import parser -from itertools import islice - from concurrent.futures import ThreadPoolExecutor, as_completed -from io import StringIO -import requests +from itertools import islice +from string import Template +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +import labelbox.schema.internal.data_row_uploader as data_row_uploader from labelbox.exceptions import ( InvalidQueryError, LabelboxError, - ResourceNotFoundError, ResourceCreationError, + ResourceNotFoundError, ) +from labelbox.orm import query from labelbox.orm.comparison import Comparison -from labelbox.orm.db_object import DbObject, Updateable, Deletable, experimental +from labelbox.orm.db_object import DbObject, Deletable, Updateable from labelbox.orm.model import Entity, Field, Relationship -from labelbox.orm import query -from labelbox.exceptions import MalformedQueryException from labelbox.pagination import PaginatedCollection from labelbox.schema.data_row import DataRow -from labelbox.schema.embedding import EmbeddingVector from labelbox.schema.export_filters import DatasetExportFilters, build_filters from labelbox.schema.export_params import ( CatalogExportParams, validate_catalog_export_params, ) from labelbox.schema.export_task import ExportTask -from labelbox.schema.identifiable import UniqueId, GlobalKey -from labelbox.schema.task import Task, DataUpsertTask -from labelbox.schema.user import User from labelbox.schema.iam_integration import IAMIntegration +from labelbox.schema.identifiable import GlobalKey, UniqueId from labelbox.schema.internal.data_row_upsert_item import ( + DataRowCreateItem, DataRowItemBase, DataRowUpsertItem, - DataRowCreateItem, -) -import labelbox.schema.internal.data_row_uploader as data_row_uploader -from labelbox.schema.internal.descriptor_file_creator import ( - DescriptorFileCreator, ) from labelbox.schema.internal.datarow_upload_constants import ( FILE_UPLOAD_THREAD_COUNT, UPSERT_CHUNK_SIZE_BYTES, ) +from labelbox.schema.task import DataUpsertTask, Task + +from ..client import get_data_row, get_user + +if TYPE_CHECKING: + pass logger = logging.getLogger(__name__) @@ -177,7 +168,7 @@ def create_data_row(self, items=None, **kwargs) -> "DataRow": f"Data row upload did not complete, task status {completed_task.status} task id {completed_task.uid}" ) - return self.client.get_data_row(res[0]["id"]) + return get_data_row(self.client, res[0]["id"]) def create_data_rows_sync( self, items, file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT @@ -359,7 +350,7 @@ def data_row_for_external_id(self, external_id) -> "DataRow": ) if len(data_rows) > 1: logger.warning( - f"More than one data_row has the provided external_id : `%s`. Use function data_rows_for_external_id to fetch all", + "More than one data_row has the provided external_id : `%s`. Use function data_rows_for_external_id to fetch all", external_id, ) return data_rows[0] @@ -582,7 +573,7 @@ def _exec_upsert_data_rows( res = self.client.execute(query_str, {"manifestUri": manifest_uri}) res = res["upsertDataRows"] task = DataUpsertTask(self.client, res) - task._user = self.client.get_user() + task._user = get_user(self.client) return task def add_iam_integration( diff --git a/libs/labelbox/src/labelbox/schema/export_task.py b/libs/labelbox/src/labelbox/schema/export_task.py index a144f4c76..ba90092de 100644 --- a/libs/labelbox/src/labelbox/schema/export_task.py +++ b/libs/labelbox/src/labelbox/schema/export_task.py @@ -1,11 +1,16 @@ +import json +import os +import tempfile +import warnings from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum from functools import lru_cache from io import TextIOWrapper -import json from pathlib import Path from typing import ( + TYPE_CHECKING, + Any, Callable, Generic, Iterator, @@ -14,22 +19,20 @@ Tuple, TypeVar, Union, - TYPE_CHECKING, overload, - Any, ) import requests -import warnings -import tempfile -import os +from pydantic import BaseModel from labelbox.schema.task import Task from labelbox.utils import _CamelCaseMixin -from pydantic import BaseModel, Field, AliasChoices + +from ..client import get_organization +from ..request_client import RequestClient if TYPE_CHECKING: - from labelbox import Client + pass OutputT = TypeVar("OutputT") @@ -61,7 +64,7 @@ class _MetadataFileInfo(_CamelCaseMixin, BaseModel): # pylint: disable=too-few- @dataclass class _TaskContext: - client: "Client" + client: "RequestClient" task_id: str stream_type: StreamType metadata_header: _MetadataHeader @@ -733,7 +736,7 @@ def result_url(self): + "/export-results/" + self._task.uid + "/" - + self._task.client.get_organization().uid + + get_organization(self._task.client).uid ) @property diff --git a/libs/labelbox/src/labelbox/schema/internal/descriptor_file_creator.py b/libs/labelbox/src/labelbox/schema/internal/descriptor_file_creator.py index ce3ce4b35..0cd7c8774 100644 --- a/libs/labelbox/src/labelbox/schema/internal/descriptor_file_creator.py +++ b/libs/labelbox/src/labelbox/schema/internal/descriptor_file_creator.py @@ -1,28 +1,22 @@ import json import os -import sys from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import TYPE_CHECKING, Generator, Iterable, List -from typing import Iterable, List, Generator - -from labelbox.exceptions import InvalidQueryError -from labelbox.exceptions import InvalidAttributeError -from labelbox.exceptions import MalformedQueryException -from labelbox.orm.model import Entity -from labelbox.orm.model import Field +from labelbox.exceptions import InvalidAttributeError, InvalidQueryError +from labelbox.orm.model import Entity, Field from labelbox.schema.embedding import EmbeddingVector -from labelbox.schema.internal.datarow_upload_constants import ( - FILE_UPLOAD_THREAD_COUNT, -) from labelbox.schema.internal.data_row_upsert_item import ( DataRowItemBase, - DataRowUpsertItem, +) +from labelbox.schema.internal.datarow_upload_constants import ( + FILE_UPLOAD_THREAD_COUNT, ) -from typing import TYPE_CHECKING +from ...data_uploader import DataUploader if TYPE_CHECKING: - from labelbox import Client + from labelbox.request_client import RequestClient class DescriptorFileCreator: @@ -32,12 +26,12 @@ class DescriptorFileCreator: upload the files to gcs in parallel, and return a list of urls Args: - client (Client): The client object + client (RequestClient): The client object max_chunk_size_bytes (int): The maximum size of the file in bytes """ - def __init__(self, client: "Client"): - self.client = client + def __init__(self, client: "RequestClient"): + self.data_uploader = DataUploader(client) def create(self, items, max_chunk_size_bytes=None) -> List[str]: is_upsert = True # This class will only support upsert use cases @@ -46,7 +40,7 @@ def create(self, items, max_chunk_size_bytes=None) -> List[str]: with ThreadPoolExecutor(FILE_UPLOAD_THREAD_COUNT) as executor: futures = [ executor.submit( - self.client.upload_data, + self.data_uploader.upload_data, chunk, "application/json", "json_import.json", @@ -61,7 +55,7 @@ def create_one(self, items) -> List[str]: ) # Prepare and upload the descriptor file data = json.dumps(items) - return self.client.upload_data( + return self.data_uploader.upload_data( data, content_type="application/json", filename="json_import.json" ) @@ -115,7 +109,7 @@ def upload_if_necessary(item): return item row_data = item["row_data"] if isinstance(row_data, str) and os.path.exists(row_data): - item_url = self.client.upload_file(row_data) + item_url = self.data_uploader.upload_file(row_data) item["row_data"] = item_url if "external_id" not in item: # Default `external_id` to local file name @@ -161,7 +155,7 @@ def check_message_keys(message): ] ) for key in message.keys(): - if not key in accepted_message_keys: + if key not in accepted_message_keys: raise KeyError( f"Invalid {key} key found! Accepted keys in messages list is {accepted_message_keys}" ) @@ -178,7 +172,7 @@ def check_message_keys(message): def parse_metadata_fields(item): metadata_fields = item.get("metadata_fields") if metadata_fields: - mdo = self.client.get_data_row_metadata_ontology() + mdo = self.data_uploader.get_data_row_metadata_ontology() item["metadata_fields"] = mdo.parse_upsert_metadata( metadata_fields ) diff --git a/libs/labelbox/src/labelbox/schema/invite.py b/libs/labelbox/src/labelbox/schema/invite.py index c89a8b08c..130b9373b 100644 --- a/libs/labelbox/src/labelbox/schema/invite.py +++ b/libs/labelbox/src/labelbox/schema/invite.py @@ -4,6 +4,8 @@ from labelbox.orm.model import Field from labelbox.schema.role import ProjectRole, format_role +from ..client import get_project, get_roles + @dataclass class InviteLimit: @@ -33,8 +35,8 @@ def __init__(self, client, invite_response): self.project_roles = [ ProjectRole( - project=client.get_project(r["projectId"]), - role=client.get_roles()[format_role(r["projectRoleName"])], + project=get_project(self.client, r["projectId"]), + role=get_roles(self.client)[format_role(r["projectRoleName"])], ) for r in project_roles ] diff --git a/libs/labelbox/src/labelbox/schema/label.py b/libs/labelbox/src/labelbox/schema/label.py index 371193a13..e31d1e6e3 100644 --- a/libs/labelbox/src/labelbox/schema/label.py +++ b/libs/labelbox/src/labelbox/schema/label.py @@ -1,9 +1,13 @@ from typing import TYPE_CHECKING +from libs.labelbox.src.labelbox import request_client + from labelbox.orm import query -from labelbox.orm.db_object import DbObject, Updateable, BulkDeletable +from labelbox.orm.db_object import BulkDeletable, DbObject, Updateable from labelbox.orm.model import Entity, Field, Relationship +from ..client import create_entity + if TYPE_CHECKING: from labelbox import Benchmark, Review """ Client-side object type definitions. """ @@ -61,7 +65,7 @@ def create_review(self, **kwargs) -> "Review": """ kwargs[Entity.Review.label.name] = self kwargs[Entity.Review.project.name] = self.project() - return self.client._create(Entity.Review, kwargs) + return create_entity(request_client, Entity.Review, kwargs) def create_benchmark(self) -> "Benchmark": """Creates a Benchmark for this Label. diff --git a/libs/labelbox/src/labelbox/schema/model_run.py b/libs/labelbox/src/labelbox/schema/model_run.py index bc9971174..f1bff9215 100644 --- a/libs/labelbox/src/labelbox/schema/model_run.py +++ b/libs/labelbox/src/labelbox/schema/model_run.py @@ -2,25 +2,22 @@ import logging import os import time -import warnings from enum import Enum from pathlib import Path from typing import ( TYPE_CHECKING, + Any, Dict, Iterable, - Union, - Tuple, List, Optional, - Any, + Tuple, + Union, ) -import requests - -from labelbox import parser +from labelbox.client import Client from labelbox.orm.db_object import DbObject, experimental -from labelbox.orm.model import Field, Relationship, Entity +from labelbox.orm.model import Entity, Field, Relationship from labelbox.orm.query import results_query_part from labelbox.pagination import PaginatedCollection from labelbox.schema.conflict_resolution_strategy import ( @@ -28,7 +25,7 @@ ) from labelbox.schema.export_params import ModelRunExportParams from labelbox.schema.export_task import ExportTask -from labelbox.schema.identifiables import UniqueIds, GlobalKeys, DataRowIds +from labelbox.schema.identifiables import DataRowIds, GlobalKeys from labelbox.schema.send_to_annotate_params import ( SendToAnnotateFromModelParams, build_destination_task_queue_input, @@ -200,7 +197,7 @@ def _wait_until_done(self, status_fn, timeout_seconds=120, sleep_time=5): if res["status"] == "COMPLETE": return True elif res["status"] == "FAILED": - raise Exception(f"Job failed.") + raise Exception("Job failed.") timeout_seconds -= sleep_time if timeout_seconds <= 0: raise TimeoutError( @@ -633,7 +630,7 @@ def send_to_annotate_from_model( destination_task_queue = build_destination_task_queue_input( task_queue_id ) - data_rows_query = self.client.build_catalog_query(data_rows) + data_rows_query = Client.build_catalog_query(data_rows) predictions_ontology_mapping = params.get( "predictions_ontology_mapping", None diff --git a/libs/labelbox/src/labelbox/schema/project.py b/libs/labelbox/src/labelbox/schema/project.py index 88153e48f..755e25ed2 100644 --- a/libs/labelbox/src/labelbox/schema/project.py +++ b/libs/labelbox/src/labelbox/schema/project.py @@ -1,10 +1,10 @@ import json import logging -from string import Template import time import warnings from collections import namedtuple from datetime import datetime, timezone +from string import Template from typing import ( TYPE_CHECKING, Any, @@ -16,19 +16,13 @@ overload, ) -from labelbox.schema.labeling_service import ( - LabelingService, - LabelingServiceStatus, -) -from labelbox.schema.labeling_service_dashboard import LabelingServiceDashboard - from labelbox import utils -from labelbox.exceptions import error_message_for_unparsed_graphql_error from labelbox.exceptions import ( InvalidQueryError, LabelboxError, ProcessingWaitTimeout, ResourceNotFoundError, + error_message_for_unparsed_graphql_error, ) from labelbox.orm import query from labelbox.orm.db_object import DbObject, Deletable, Updateable, experimental @@ -46,21 +40,28 @@ from labelbox.schema.id_type import IdType from labelbox.schema.identifiable import DataRowIdentifier, GlobalKey, UniqueId from labelbox.schema.identifiables import DataRowIdentifiers, UniqueIds +from labelbox.schema.labeling_service import ( + LabelingService, + LabelingServiceStatus, +) +from labelbox.schema.labeling_service_dashboard import LabelingServiceDashboard from labelbox.schema.media_type import MediaType from labelbox.schema.model_config import ModelConfig -from labelbox.schema.project_model_config import ProjectModelConfig -from labelbox.schema.queue_mode import QueueMode -from labelbox.schema.resource_tag import ResourceTag -from labelbox.schema.task import Task -from labelbox.schema.task_queue import TaskQueue from labelbox.schema.ontology_kind import ( EditorTaskType, UploadType, ) +from labelbox.schema.project_model_config import ProjectModelConfig from labelbox.schema.project_overview import ( ProjectOverview, ProjectOverviewDetailed, ) +from labelbox.schema.queue_mode import QueueMode +from labelbox.schema.resource_tag import ResourceTag +from labelbox.schema.task import Task +from labelbox.schema.task_queue import TaskQueue + +from ..client import get_batch, get_labeling_frontends if TYPE_CHECKING: pass @@ -729,7 +730,7 @@ def _connect_default_labeling_front_end(self, ontology_as_dict: dict): ): # Chat evaluation projects are automatically set up via the same api that creates a project warnings.warn("Connecting default labeling editor for the project.") labeling_frontend = next( - self.client.get_labeling_frontends( + get_labeling_frontends(self.client, where=Entity.LabelingFrontend.name == "Editor" ) ) @@ -1083,7 +1084,7 @@ def _create_batch_async( + json.dumps(task.errors) ) - return self.client.get_batch(self.uid, batch_id) + return get_batch(self.client, self.uid, batch_id) def _update_queue_mode(self, mode: "QueueMode") -> "QueueMode": """ diff --git a/libs/labelbox/src/labelbox/schema/role.py b/libs/labelbox/src/labelbox/schema/role.py index 47cd753e9..cc4cc7f86 100644 --- a/libs/labelbox/src/labelbox/schema/role.py +++ b/libs/labelbox/src/labelbox/schema/role.py @@ -1,16 +1,17 @@ from dataclasses import dataclass -from typing import Dict, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Dict, Optional -from labelbox.orm.model import Field, Entity from labelbox.orm.db_object import DbObject +from labelbox.orm.model import Field if TYPE_CHECKING: - from labelbox import Client, Project + from labelbox import Project + from labelbox.request_client import RequestClient _ROLES: Optional[Dict[str, "Role"]] = None -def get_roles(client: "Client") -> Dict[str, "Role"]: +def get_roles(client: "RequestClient") -> Dict[str, "Role"]: global _ROLES if _ROLES is None: query_str = """query GetAvailableUserRolesPyApi { roles { id name } }""" diff --git a/libs/labelbox/src/labelbox/schema/user_group.py b/libs/labelbox/src/labelbox/schema/user_group.py index 9d506bf92..51e709964 100644 --- a/libs/labelbox/src/labelbox/schema/user_group.py +++ b/libs/labelbox/src/labelbox/schema/user_group.py @@ -1,20 +1,22 @@ -from enum import Enum -from typing import Set, Iterator from collections import defaultdict +from enum import Enum +from typing import Iterator, Set + +from pydantic import BaseModel, ConfigDict -from labelbox import Client -from labelbox.exceptions import ResourceCreationError -from labelbox.schema.user import User -from labelbox.schema.project import Project from labelbox.exceptions import ( - UnprocessableEntityError, MalformedQueryException, + ResourceCreationError, ResourceNotFoundError, + UnprocessableEntityError, ) -from labelbox.schema.queue_mode import QueueMode -from labelbox.schema.ontology_kind import EditorTaskType from labelbox.schema.media_type import MediaType -from pydantic import BaseModel, ConfigDict +from labelbox.schema.ontology_kind import EditorTaskType +from labelbox.schema.project import Project +from labelbox.schema.queue_mode import QueueMode +from labelbox.schema.user import User + +from ..request_client import RequestClient class UserGroupColor(Enum): @@ -54,15 +56,15 @@ class UserGroup(BaseModel): color (UserGroupColor): The color of the user group. users (Set[UserGroupUser]): The set of users in the user group. projects (Set[UserGroupProject]): The set of projects associated with the user group. - client (Client): The Labelbox client object. + client (RequestClient): The Labelbox client object. Methods: - __init__(self, client: Client) + __init__(self, client: RequestClient) get(self) -> "UserGroup" update(self) -> "UserGroup" create(self) -> "UserGroup" delete(self) -> bool - get_user_groups(client: Client) -> Iterator["UserGroup"] + get_user_groups(client: RequestClient) -> Iterator["UserGroup"] """ id: str @@ -70,12 +72,12 @@ class UserGroup(BaseModel): color: UserGroupColor users: Set[User] projects: Set[Project] - client: Client + client: RequestClient model_config = ConfigDict(arbitrary_types_allowed=True) def __init__( self, - client: Client, + client: RequestClient, id: str = "", name: str = "", color: UserGroupColor = UserGroupColor.BLUE, @@ -86,7 +88,7 @@ def __init__( Initializes a UserGroup object. Args: - client (Client): The Labelbox client object. + client (RequestClient): The Labelbox client object. id (str, optional): The ID of the user group. Defaults to an empty string. name (str, optional): The name of the user group. Defaults to an empty string. color (UserGroupColor, optional): The color of the user group. Defaults to UserGroupColor.BLUE. @@ -329,7 +331,7 @@ def get_user_groups(self) -> Iterator["UserGroup"]: Gets all user groups in Labelbox. Args: - client (Client): The Labelbox client. + client (RequestClient): The Labelbox client. Returns: Iterator[UserGroup]: An iterator over the user groups. diff --git a/libs/labelbox/src/labelbox/schema/webhook.py b/libs/labelbox/src/labelbox/schema/webhook.py index 0eebe157e..b84aee6bd 100644 --- a/libs/labelbox/src/labelbox/schema/webhook.py +++ b/libs/labelbox/src/labelbox/schema/webhook.py @@ -1,6 +1,6 @@ import logging from enum import Enum -from typing import Iterable, List +from typing import Iterable from labelbox.orm import query from labelbox.orm.db_object import DbObject, Updateable @@ -56,7 +56,7 @@ def create(client, topics, url, secret, project) -> "Webhook": """Creates a Webhook. Args: - client (Client): The Labelbox client used to connect + client (RequestClient): The Labelbox client used to connect to the server. topics (list of str): A list of topics this Webhook should get notifications for. Must be one of Webhook.Topic diff --git a/libs/labelbox/tests/data/annotation_import/test_ndjson_validation.py b/libs/labelbox/tests/data/annotation_import/test_ndjson_validation.py index 0ec742333..9e8963a26 100644 --- a/libs/labelbox/tests/data/annotation_import/test_ndjson_validation.py +++ b/libs/labelbox/tests/data/annotation_import/test_ndjson_validation.py @@ -1,21 +1,8 @@ -from labelbox.schema.media_type import MediaType import pytest - -from pytest_cases import parametrize, fixture_ref +from pytest_cases import fixture_ref, parametrize from labelbox.exceptions import MALValidationError -from labelbox.schema.bulk_import_request import ( - NDChecklist, - NDClassification, - NDMask, - NDPolygon, - NDPolyline, - NDRectangle, - NDText, - NDTextEntity, - NDTool, - _validate_ndjson, -) +from labelbox.schema.media_type import MediaType """ - These NDlabels are apart of bulkImportReqeust and should be removed once bulk import request is removed diff --git a/libs/lbox-clients/.gitignore b/libs/lbox-clients/.gitignore new file mode 100644 index 000000000..ae8554dec --- /dev/null +++ b/libs/lbox-clients/.gitignore @@ -0,0 +1,10 @@ +# python generated files +__pycache__/ +*.py[oc] +build/ +dist/ +wheels/ +*.egg-info + +# venv +.venv diff --git a/libs/lbox-clients/Dockerfile b/libs/lbox-clients/Dockerfile new file mode 100644 index 000000000..2ee61ab7e --- /dev/null +++ b/libs/lbox-clients/Dockerfile @@ -0,0 +1,44 @@ +# https://github.com/ucyo/python-package-template/blob/master/Dockerfile +FROM python:3.8-slim as rye + +ENV LANG="C.UTF-8" \ + LC_ALL="C.UTF-8" \ + PATH="/home/python/.local/bin:/home/python/.rye/shims:$PATH" \ + PIP_NO_CACHE_DIR="false" \ + RYE_VERSION="0.34.0" \ + RYE_INSTALL_OPTION="--yes" \ + LABELBOX_TEST_ENVIRON="prod" + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + ca-certificates \ + curl \ + inotify-tools \ + make \ + # cv2 + libsm6 \ + libxext6 \ + ffmpeg \ + libfontconfig1 \ + libxrender1 \ + libgl1-mesa-glx \ + libgeos-dev \ + gcc \ + && rm -rf /var/lib/apt/lists/* + +RUN groupadd --gid 1000 python && \ + useradd --uid 1000 --gid python --shell /bin/bash --create-home python + +USER 1000 +WORKDIR /home/python/ + +RUN curl -sSf https://rye.astral.sh/get | bash - + +COPY --chown=python:python . /home/python/labelbox-python/ +WORKDIR /home/python/labelbox-python + +RUN rye config --set-bool behavior.global-python=true && \ + rye config --set-bool behavior.use-uv=true && \ + rye pin 3.8 && \ + rye sync + +CMD rye run unit && rye integration \ No newline at end of file diff --git a/libs/lbox-clients/README.md b/libs/lbox-clients/README.md new file mode 100644 index 000000000..a7bf3e94f --- /dev/null +++ b/libs/lbox-clients/README.md @@ -0,0 +1,9 @@ +# lbox-example + +This is an example module which can be cloned and reused to develop modules under the `lbox` namespace. + +## Module Status: Experimental + +**TLDR: This module may be removed or altered at any given time and there is no offical support.** + +Please see [here](https://docs.labelbox.com/docs/product-release-phases) for the formal definition of `Experimental`. \ No newline at end of file diff --git a/libs/lbox-clients/pyproject.toml b/libs/lbox-clients/pyproject.toml new file mode 100644 index 000000000..9d88584bb --- /dev/null +++ b/libs/lbox-clients/pyproject.toml @@ -0,0 +1,62 @@ +[project] +name = "lbox-example" +version = "0.1.0" +description = "This module contains client sdk uses to conntect to the Labelbox API and backends" +authors = [ + { name = "Labelbox", email = "engineering@labelbox.com" } +] +dependencies = [ + "requests>=2.22.0", + "google-api-core>=1.22.1", +] +readme = "README.md" +requires-python = ">= 3.8" + +classifiers=[ + # How mature is this project? + "Development Status :: 2 - Alpha", + # Indicate who your project is intended for + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development :: Libraries", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "Intended Audience :: Education", + # Pick your license as you wish + "License :: OSI Approved :: Apache Software License", + # Specify the Python versions you support here. + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", +] +keywords = ["ml", "ai", "labelbox", "labeling", "llm", "machinelearning", "edu"] + +[project.urls] +Homepage = "https://labelbox.com/" +Documentation = "https://labelbox-python.readthedocs.io/en/latest/" +Repository = "https://github.com/Labelbox/labelbox-python" +Issues = "https://github.com/Labelbox/labelbox-python/issues" +Changelog = "https://github.com/Labelbox/labelbox-python/blob/develop/libs/labelbox/CHANGELOG.md" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.rye] +managed = true +dev-dependencies = [] + +[tool.rye.scripts] +unit = "pytest tests/unit" +integration = "python -c \"import sys; sys.exit(0)\"" + +[tool.hatch.metadata] +allow-direct-references = true + +[tool.hatch.build.targets.wheel] +packages = ["src/lbox"] + +[tool.pytest.ini_options] +addopts = "-rP -vvv --durations=20 --cov=lbox.example --import-mode=importlib" \ No newline at end of file diff --git a/libs/lbox-clients/src/lbox/exceptions.py b/libs/lbox-clients/src/lbox/exceptions.py new file mode 100644 index 000000000..e69de29bb diff --git a/libs/lbox-clients/src/lbox/request_client.py b/libs/lbox-clients/src/lbox/request_client.py new file mode 100644 index 000000000..920e8b9b5 --- /dev/null +++ b/libs/lbox-clients/src/lbox/request_client.py @@ -0,0 +1,360 @@ +import json +import logging +import os +import sys +from datetime import datetime, timezone +from types import MappingProxyType + +import labelbox.exceptions +import requests +import requests.exceptions +from google.api_core import retry + +logger = logging.getLogger(__name__) + +_LABELBOX_API_KEY = "LABELBOX_API_KEY" + + +def python_version_info(): + version_info = sys.version_info + + return f"{version_info.major}.{version_info.minor}.{version_info.micro}-{version_info.releaselevel}" + + +class RequestClient: + """A Labelbox request client. + + Contains info necessary for connecting to a Labelbox server (URL, + authentication key). + """ + + def __init__( + self, + sdk_version, + api_key=None, + endpoint="https://api.labelbox.com/graphql", + enable_experimental=False, + app_url="https://app.labelbox.com", + rest_endpoint="https://api.labelbox.com/api/v1", + ): + """Creates and initializes a RequestClient. + This class executes graphql and rest requests to the Labelbox server. + + Args: + api_key (str): API key. If None, the key is obtained from the "LABELBOX_API_KEY" environment variable. + endpoint (str): URL of the Labelbox server to connect to. + enable_experimental (bool): Indicates whether or not to use experimental features + app_url (str) : host url for all links to the web app + Raises: + labelbox.exceptions.AuthenticationError: If no `api_key` + is provided as an argument or via the environment + variable. + """ + if api_key is None: + if _LABELBOX_API_KEY not in os.environ: + raise labelbox.exceptions.AuthenticationError( + "Labelbox API key not provided" + ) + api_key = os.environ[_LABELBOX_API_KEY] + self.api_key = api_key + + self.enable_experimental = enable_experimental + if enable_experimental: + logger.info("Experimental features have been enabled") + + logger.info("Initializing Labelbox client at '%s'", endpoint) + self.app_url = app_url + self.endpoint = endpoint + self.rest_endpoint = rest_endpoint + self._connection: requests.Session = self._init_connection() + self.sdk_version = sdk_version + + def _init_connection(self) -> requests.Session: + connection = ( + requests.Session() + ) # using default connection pool size of 10 + connection.headers.update(self._default_headers()) + + return connection + + @property + def headers(self) -> MappingProxyType: + return self._connection.headers + + def _default_headers(self): + return { + "Authorization": "Bearer %s" % self.api_key, + "Accept": "application/json", + "Content-Type": "application/json", + "X-User-Agent": f"python-sdk {self.sdk_version}", + "X-Python-Version": f"{python_version_info()}", + } + + @retry.Retry( + predicate=retry.if_exception_type( + labelbox.exceptions.InternalServerError, + labelbox.exceptions.TimeoutError, + ) + ) + def execute( + self, + query=None, + params=None, + data=None, + files=None, + timeout=60.0, + experimental=False, + error_log_key="message", + raise_return_resource_not_found=False, + ): + """Sends a request to the server for the execution of the + given query. + + Checks the response for errors and wraps errors + in appropriate `labelbox.exceptions.LabelboxError` subtypes. + + Args: + query (str): The query to execute. + params (dict): Query parameters referenced within the query. + data (str): json string containing the query to execute + files (dict): file arguments for request + timeout (float): Max allowed time for query execution, + in seconds. + Returns: + dict, parsed JSON response. + Raises: + labelbox.exceptions.AuthenticationError: If authentication + failed. + labelbox.exceptions.InvalidQueryError: If `query` is not + syntactically or semantically valid (checked server-side). + labelbox.exceptions.ApiLimitError: If the server API limit was + exceeded. See "How to import data" in the online documentation + to see API limits. + labelbox.exceptions.TimeoutError: If response was not received + in `timeout` seconds. + labelbox.exceptions.NetworkError: If an unknown error occurred + most likely due to connection issues. + labelbox.exceptions.LabelboxError: If an unknown error of any + kind occurred. + ValueError: If query and data are both None. + """ + logger.debug("Query: %s, params: %r, data %r", query, params, data) + + # Convert datetimes to UTC strings. + def convert_value(value): + if isinstance(value, datetime): + value = value.astimezone(timezone.utc) + value = value.strftime("%Y-%m-%dT%H:%M:%SZ") + return value + + if query is not None: + if params is not None: + params = { + key: convert_value(value) for key, value in params.items() + } + data = json.dumps({"query": query, "variables": params}).encode( + "utf-8" + ) + elif data is None: + raise ValueError("query and data cannot both be none") + + endpoint = ( + self.endpoint + if not experimental + else self.endpoint.replace("/graphql", "/_gql") + ) + + try: + headers = self._connection.headers.copy() + if files: + del headers["Content-Type"] + del headers["Accept"] + request = requests.Request( + "POST", + endpoint, + headers=headers, + data=data, + files=files if files else None, + ) + + prepped: requests.PreparedRequest = request.prepare() + + response = self._connection.send(prepped, timeout=timeout) + logger.debug("Response: %s", response.text) + except requests.exceptions.Timeout as e: + raise labelbox.exceptions.TimeoutError(str(e)) + except requests.exceptions.RequestException as e: + logger.error("Unknown error: %s", str(e)) + raise labelbox.exceptions.NetworkError(e) + except Exception as e: + raise labelbox.exceptions.LabelboxError( + "Unknown error during Client.query(): " + str(e), e + ) + + if ( + 200 <= response.status_code < 300 + or response.status_code < 500 + or response.status_code >= 600 + ): + try: + r_json = response.json() + except Exception: + raise labelbox.exceptions.LabelboxError( + "Failed to parse response as JSON: %s" % response.text + ) + else: + if ( + "upstream connect error or disconnect/reset before headers" + in response.text + ): + raise labelbox.exceptions.InternalServerError( + "Connection reset" + ) + elif response.status_code == 502: + error_502 = "502 Bad Gateway" + raise labelbox.exceptions.InternalServerError(error_502) + elif 500 <= response.status_code < 600: + error_500 = f"Internal server http error {response.status_code}" + raise labelbox.exceptions.InternalServerError(error_500) + + errors = r_json.get("errors", []) + + def check_errors(keywords, *path): + """Helper that looks for any of the given `keywords` in any of + current errors on paths (like error[path][component][to][keyword]). + """ + for error in errors: + obj = error + for path_elem in path: + obj = obj.get(path_elem, {}) + if obj in keywords: + return error + return None + + def get_error_status_code(error: dict) -> int: + try: + return int(error["extensions"].get("exception").get("status")) + except: + return 500 + + if ( + check_errors(["AUTHENTICATION_ERROR"], "extensions", "code") + is not None + ): + raise labelbox.exceptions.AuthenticationError("Invalid API key") + + authorization_error = check_errors( + ["AUTHORIZATION_ERROR"], "extensions", "code" + ) + if authorization_error is not None: + raise labelbox.exceptions.AuthorizationError( + authorization_error["message"] + ) + + validation_error = check_errors( + ["GRAPHQL_VALIDATION_FAILED"], "extensions", "code" + ) + + if validation_error is not None: + message = validation_error["message"] + if message == "Query complexity limit exceeded": + raise labelbox.exceptions.ValidationFailedError(message) + else: + raise labelbox.exceptions.InvalidQueryError(message) + + graphql_error = check_errors( + ["GRAPHQL_PARSE_FAILED"], "extensions", "code" + ) + if graphql_error is not None: + raise labelbox.exceptions.InvalidQueryError( + graphql_error["message"] + ) + + # Check if API limit was exceeded + response_msg = r_json.get("message", "") + + if response_msg.startswith("You have exceeded"): + raise labelbox.exceptions.ApiLimitError(response_msg) + + resource_not_found_error = check_errors( + ["RESOURCE_NOT_FOUND"], "extensions", "code" + ) + if resource_not_found_error is not None: + if raise_return_resource_not_found: + raise labelbox.exceptions.ResourceNotFoundError( + message=resource_not_found_error["message"] + ) + else: + # Return None and let the caller methods raise an exception + # as they already know which resource type and ID was requested + return None + + resource_conflict_error = check_errors( + ["RESOURCE_CONFLICT"], "extensions", "code" + ) + if resource_conflict_error is not None: + raise labelbox.exceptions.ResourceConflict( + resource_conflict_error["message"] + ) + + malformed_request_error = check_errors( + ["MALFORMED_REQUEST"], "extensions", "code" + ) + if malformed_request_error is not None: + raise labelbox.exceptions.MalformedQueryException( + malformed_request_error[error_log_key] + ) + + # A lot of different error situations are now labeled serverside + # as INTERNAL_SERVER_ERROR, when they are actually client errors. + # TODO: fix this in the server API + internal_server_error = check_errors( + ["INTERNAL_SERVER_ERROR"], "extensions", "code" + ) + if internal_server_error is not None: + message = internal_server_error.get("message") + error_status_code = get_error_status_code(internal_server_error) + if error_status_code == 400: + raise labelbox.exceptions.InvalidQueryError(message) + elif error_status_code == 422: + raise labelbox.exceptions.UnprocessableEntityError(message) + elif error_status_code == 426: + raise labelbox.exceptions.OperationNotAllowedException(message) + elif error_status_code == 500: + raise labelbox.exceptions.LabelboxError(message) + else: + raise labelbox.exceptions.InternalServerError(message) + + not_allowed_error = check_errors( + ["OPERATION_NOT_ALLOWED"], "extensions", "code" + ) + if not_allowed_error is not None: + message = not_allowed_error.get("message") + raise labelbox.exceptions.OperationNotAllowedException(message) + + if len(errors) > 0: + logger.warning("Unparsed errors on query execution: %r", errors) + messages = list( + map( + lambda x: { + "message": x["message"], + "code": x["extensions"]["code"], + }, + errors, + ) + ) + raise labelbox.exceptions.LabelboxError( + "Unknown error: %s" % str(messages) + ) + + # if we do return a proper error code, and didn't catch this above + # reraise + # this mainly catches a 401 for API access disabled for free tier + # TODO: need to unify API errors to handle things more uniformly + # in the SDK + if response.status_code != requests.codes.ok: + message = f"{response.status_code} {response.reason}" + cause = r_json.get("message") + raise labelbox.exceptions.LabelboxError(message, cause) + + return r_json["data"] diff --git a/libs/lbox-clients/tests/unit/lbox/test_client.py b/libs/lbox-clients/tests/unit/lbox/test_client.py new file mode 100644 index 000000000..b1f55bca0 --- /dev/null +++ b/libs/lbox-clients/tests/unit/lbox/test_client.py @@ -0,0 +1,10 @@ + + +# @patch.dict(os.environ, {'LABELBOX_API_KEY': 'bar'}) +def test_headers(): + client = RequestClient(api_key="api_key", endpoint="http://localhost:8080/_gql") + assert client.headers + assert client.headers["Authorization"] == "Bearer api_key" + assert client.headers["Content-Type"] == "application/json" + assert client.headers["User-Agent"] + assert client.headers["X-Python-Version"]