diff --git a/.github/actions/lbox-matrix/index.js b/.github/actions/lbox-matrix/index.js index ac3975488..574a56381 100644 --- a/.github/actions/lbox-matrix/index.js +++ b/.github/actions/lbox-matrix/index.js @@ -26811,11 +26811,6 @@ const core = __nccwpck_require__(8611); try { const files = JSON.parse(core.getInput('files-changed')); const startingMatrix = [ - { - "python-version": "3.8", - "api-key": "STAGING_LABELBOX_API_KEY_2", - "da-test-key": "DA_GCP_LABELBOX_API_KEY" - }, { "python-version": "3.9", "api-key": "STAGING_LABELBOX_API_KEY_3", diff --git a/.github/workflows/lbox-develop.yml b/.github/workflows/lbox-develop.yml index ba1e4f34e..309ea2969 100644 --- a/.github/workflows/lbox-develop.yml +++ b/.github/workflows/lbox-develop.yml @@ -90,7 +90,7 @@ jobs: - uses: ./.github/actions/python-package-shared-setup with: rye-version: ${{ vars.RYE_VERSION }} - python-version: '3.8' + python-version: '3.9' - name: Create build id: create-build working-directory: libs/${{ matrix.package }} diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index d64123c60..4fcb10257 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -71,9 +71,6 @@ jobs: fail-fast: false matrix: include: - - python-version: 3.8 - prod-key: PROD_LABELBOX_API_KEY_2 - da-test-key: DA_GCP_LABELBOX_API_KEY - python-version: 3.9 prod-key: PROD_LABELBOX_API_KEY_3 da-test-key: DA_GCP_LABELBOX_API_KEY diff --git a/.github/workflows/python-package-develop.yml b/.github/workflows/python-package-develop.yml index 05eff5dc4..65dd02872 100644 --- a/.github/workflows/python-package-develop.yml +++ b/.github/workflows/python-package-develop.yml @@ -44,7 +44,7 @@ jobs: - name: Get Latest SDK versions id: get_sdk_versions run: | - sdk_versions=$(git tag --list --sort=-version:refname "v.*" | head -n 4 | jq -R -s -c 'split("\n")[:-1]') + sdk_versions=$(git tag --list --sort=-version:refname "v.*" | head -n 3 | jq -R -s -c 'split("\n")[:-1]') if [ -z "$sdk_versions" ]; then echo "No tags found" exit 1 @@ -58,10 +58,6 @@ jobs: fail-fast: false matrix: include: - - python-version: 3.8 - api-key: STAGING_LABELBOX_API_KEY_2 - da-test-key: DA_GCP_LABELBOX_API_KEY - sdk-version: ${{ fromJson(needs.get_sdk_versions.outputs.sdk_versions)[3] }} - python-version: 3.9 api-key: STAGING_LABELBOX_API_KEY_3 da-test-key: DA_GCP_LABELBOX_API_KEY @@ -103,7 +99,7 @@ jobs: - uses: ./.github/actions/python-package-shared-setup with: rye-version: ${{ vars.RYE_VERSION }} - python-version: '3.8' + python-version: '3.9' - name: Create build id: create-build working-directory: libs/labelbox diff --git a/.github/workflows/python-package-prod.yml b/.github/workflows/python-package-prod.yml index 1936af132..c0e24536f 100644 --- a/.github/workflows/python-package-prod.yml +++ b/.github/workflows/python-package-prod.yml @@ -13,9 +13,6 @@ jobs: fail-fast: false matrix: include: - - python-version: 3.8 - api-key: PROD_LABELBOX_API_KEY_2 - da-test-key: DA_GCP_LABELBOX_API_KEY - python-version: 3.9 api-key: PROD_LABELBOX_API_KEY_3 da-test-key: DA_GCP_LABELBOX_API_KEY diff --git a/.python-version b/.python-version index 9ad6380c1..43077b246 100644 --- a/.python-version +++ b/.python-version @@ -1 +1 @@ -3.8.18 +3.9.18 diff --git a/docs/labelbox/exceptions.rst b/docs/labelbox/exceptions.rst index 3082bc081..96ea0f6d5 100644 --- a/docs/labelbox/exceptions.rst +++ b/docs/labelbox/exceptions.rst @@ -1,6 +1,6 @@ Exceptions =============================================================================================== -.. automodule:: labelbox.exceptions +.. automodule:: lbox.exceptions :members: :show-inheritance: \ No newline at end of file diff --git a/docs/labelbox/index.rst b/docs/labelbox/index.rst index 35118f56f..fcfa2f6b5 100644 --- a/docs/labelbox/index.rst +++ b/docs/labelbox/index.rst @@ -41,6 +41,7 @@ Labelbox Python SDK Documentation project project-model-config quality-mode + request-client resource-tag review search-filters diff --git a/docs/labelbox/request-client.rst b/docs/labelbox/request-client.rst new file mode 100644 index 000000000..fcfea7f97 --- /dev/null +++ b/docs/labelbox/request-client.rst @@ -0,0 +1,6 @@ +Request Client +=============================================================================================== + +.. automodule:: lbox.request_client + :members: + :show-inheritance: \ No newline at end of file diff --git a/examples/README.md b/examples/README.md index 1bc102947..e02b7c64f 100644 --- a/examples/README.md +++ b/examples/README.md @@ -16,16 +16,6 @@ - - Basics - Open In Github - Open In Colab - - - Projects - Open In Github - Open In Colab - Ontologies Open In Github @@ -42,25 +32,35 @@ Open In Colab - Data Row Metadata - Open In Github - Open In Colab - - - User Management - Open In Github - Open In Colab + Basics + Open In Github + Open In Colab Batches Open In Github Open In Colab + + Projects + Open In Github + Open In Colab + + + Data Row Metadata + Open In Github + Open In Colab + Custom Embeddings Open In Github Open In Colab + + User Management + Open In Github + Open In Colab + @@ -75,26 +75,26 @@ - - Exporting to CSV - Open In Github - Open In Colab - - - Export Data - Open In Github - Open In Colab - Export V1 to V2 Migration Support Open In Github Open In Colab + + Exporting to CSV + Open In Github + Open In Colab + Composite Mask Export Open In Github Open In Colab + + Export Data + Open In Github + Open In Colab + @@ -109,16 +109,16 @@ - - Project Setup - Open In Github - Open In Colab - Queue Management Open In Github Open In Colab + + Project Setup + Open In Github + Open In Colab + Webhooks Open In Github @@ -143,30 +143,20 @@ + + DICOM + Open In Github + Open In Colab + Tiled Open In Github Open In Colab - Conversational LLM - Open In Github - Open In Colab - - - HTML - Open In Github - Open In Colab - - - Conversational LLM Data Generation - Open In Github - Open In Colab - - - Image - Open In Github - Open In Colab + Text + Open In Github + Open In Colab PDF @@ -174,14 +164,9 @@ Open In Colab - DICOM - Open In Github - Open In Colab - - - Text - Open In Github - Open In Colab + Video + Open In Github + Open In Colab Audio @@ -194,9 +179,24 @@ Open In Colab - Video - Open In Github - Open In Colab + HTML + Open In Github + Open In Colab + + + Conversational LLM Data Generation + Open In Github + Open In Colab + + + Image + Open In Github + Open In Colab + + + Conversational LLM + Open In Github + Open In Colab @@ -213,9 +213,9 @@ - Import YOLOv8 Annotations - Open In Github - Open In Colab + Langchain + Open In Github + Open In Colab Meta SAM Video @@ -228,9 +228,9 @@ Open In Colab - Langchain - Open In Github - Open In Colab + Import YOLOv8 Annotations + Open In Github + Open In Colab Huggingface Custom Embeddings @@ -251,25 +251,25 @@ + + Model Predictions to Project + Open In Github + Open In Colab + Custom Metrics Demo Open In Github Open In Colab - - Model Slices - Open In Github - Open In Colab - Custom Metrics Basics Open In Github Open In Colab - Model Predictions to Project - Open In Github - Open In Colab + Model Slices + Open In Github + Open In Colab @@ -285,20 +285,20 @@ - - Video Predictions - Open In Github - Open In Colab - HTML Predictions Open In Github Open In Colab - Geospatial Predictions - Open In Github - Open In Colab + Text Predictions + Open In Github + Open In Colab + + + Video Predictions + Open In Github + Open In Colab Conversational Predictions @@ -306,14 +306,9 @@ Open In Colab - Text Predictions - Open In Github - Open In Colab - - - Conversational LLM Predictions - Open In Github - Open In Colab + Geospatial Predictions + Open In Github + Open In Colab PDF Predictions @@ -325,6 +320,11 @@ Open In Github Open In Colab + + Conversational LLM Predictions + Open In Github + Open In Colab + diff --git a/examples/pyproject.toml b/examples/pyproject.toml index 969454ecc..b620680e0 100644 --- a/examples/pyproject.toml +++ b/examples/pyproject.toml @@ -5,7 +5,7 @@ description = "Labelbox Python Example Notebooks" authors = [{ name = "Labelbox", email = "docs@labelbox.com" }] readme = "README.md" # Python version matches labelbox SDK -requires-python = ">=3.8" +requires-python = ">=3.9" dependencies = [] [project.urls] @@ -22,8 +22,8 @@ dev-dependencies = [ "yapf>=0.40.2", "black[jupyter]>=24.4.2", "databooks>=1.3.10", - # higher versions dont support python 3.8 - "pandas>=2.0.3", + # higher versions dont support python 3.9 + "pandas>=2.2.3", ] [tool.rye.scripts] diff --git a/libs/labelbox/Dockerfile b/libs/labelbox/Dockerfile index ddf451b14..386bdbc63 100644 --- a/libs/labelbox/Dockerfile +++ b/libs/labelbox/Dockerfile @@ -1,5 +1,5 @@ # https://github.com/ucyo/python-package-template/blob/master/Dockerfile -FROM python:3.8-slim as rye +FROM python:3.9-slim as rye ENV LANG="C.UTF-8" \ LC_ALL="C.UTF-8" \ @@ -38,7 +38,7 @@ WORKDIR /home/python/labelbox-python RUN rye config --set-bool behavior.global-python=true && \ rye config --set-bool behavior.use-uv=true && \ - rye pin 3.8 && \ + rye pin 3.9 && \ rye sync CMD cd libs/labelbox && rye run integration && rye sync -f --features labelbox/data && rye run unit && rye run data diff --git a/libs/labelbox/mypy.ini b/libs/labelbox/mypy.ini index 48135600e..5129fbe0f 100644 --- a/libs/labelbox/mypy.ini +++ b/libs/labelbox/mypy.ini @@ -7,4 +7,7 @@ ignore_missing_imports = True ignore_errors = True [mypy-labelbox] -ignore_errors = True \ No newline at end of file +ignore_errors = True + +[mypy-lbox.exceptions] +ignore_missing_imports = True diff --git a/libs/labelbox/pyproject.toml b/libs/labelbox/pyproject.toml index ee2f9b859..f58dba890 100644 --- a/libs/labelbox/pyproject.toml +++ b/libs/labelbox/pyproject.toml @@ -12,9 +12,10 @@ dependencies = [ "tqdm>=4.66.2", "geojson>=3.1.0", "mypy==1.10.1", + "lbox-clients==1.1.0", ] readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.9,<3.13" classifiers = [ # How mature is this project? "Development Status :: 5 - Production/Stable", @@ -28,7 +29,6 @@ classifiers = [ "License :: OSI Approved :: Apache Software License", # Specify the Python versions you support here. "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", @@ -46,8 +46,7 @@ Changelog = "https://github.com/Labelbox/labelbox-python/blob/develop/libs/label [project.optional-dependencies] data = [ "shapely>=2.0.3", - # numpy v2 breaks package since it only supports python >3.9 - "numpy>=1.24.4, <2.0.0", + "numpy>=1.25.0", "pillow>=10.2.0", "typeguard>=4.1.5", "imagesize>=1.4.1", @@ -74,6 +73,10 @@ dev-dependencies = [ [tool.ruff] line-length = 80 +[tool.ruff.lint] +ignore = ["F", "E722"] +exclude = ["**/__init__.py"] + [tool.rye.scripts] unit = "pytest tests/unit" # https://github.com/Labelbox/labelbox-python/blob/7c84fdffbc14fd1f69d2a6abdcc0087dc557fa4e/Makefile @@ -89,9 +92,11 @@ unit = "pytest tests/unit" # LABELBOX_TEST_BASE_URL="http://host.docker.internal:8080" \ integration = { cmd = "pytest tests/integration" } data = { cmd = "pytest tests/data" } +rye-lint = "rye lint" rye-fmt-check = "rye fmt --check" +MYPYPATH = "../lbox-clients/src/" mypy-lint = "mypy src --pretty --show-error-codes --non-interactive --install-types" -lint = { chain = ["mypy-lint", "rye-fmt-check"] } +lint = { chain = ["rye-fmt-check", "mypy-lint", "rye-lint"] } test = { chain = ["lint", "unit", "integration"] } [tool.hatch.metadata] diff --git a/libs/labelbox/src/labelbox/__init__.py b/libs/labelbox/src/labelbox/__init__.py index b0f4ebe11..9b5d6f715 100644 --- a/libs/labelbox/src/labelbox/__init__.py +++ b/libs/labelbox/src/labelbox/__init__.py @@ -12,7 +12,6 @@ from labelbox.schema.asset_attachment import AssetAttachment from labelbox.schema.batch import Batch from labelbox.schema.benchmark import Benchmark -from labelbox.schema.bulk_import_request import BulkImportRequest from labelbox.schema.catalog import Catalog from labelbox.schema.data_row import DataRow from labelbox.schema.data_row_metadata import ( @@ -26,10 +25,7 @@ from labelbox.schema.export_task import ( BufferedJsonConverterOutput, ExportTask, - FileConverter, - FileConverterOutput, - JsonConverter, - JsonConverterOutput, + BufferedJsonConverterOutput, StreamType, ) from labelbox.schema.iam_integration import IAMIntegration @@ -59,6 +55,28 @@ ResponseOption, Tool, ) +from labelbox.schema.ontology import PromptResponseClassification +from labelbox.schema.ontology import ResponseOption +from labelbox.schema.role import Role, ProjectRole +from labelbox.schema.invite import Invite, InviteLimit +from labelbox.schema.data_row_metadata import ( + DataRowMetadataOntology, + DataRowMetadataField, + DataRowMetadata, + DeleteDataRowMetadata, +) +from labelbox.schema.model_run import ModelRun, DataSplit +from labelbox.schema.benchmark import Benchmark +from labelbox.schema.iam_integration import IAMIntegration +from labelbox.schema.resource_tag import ResourceTag +from labelbox.schema.project_model_config import ProjectModelConfig +from labelbox.schema.project_resource_tag import ProjectResourceTag +from labelbox.schema.media_type import MediaType +from labelbox.schema.slice import Slice, CatalogSlice, ModelSlice +from labelbox.schema.task_queue import TaskQueue +from labelbox.schema.label_score import LabelScore +from labelbox.schema.identifiables import UniqueIds, GlobalKeys, DataRowIds +from labelbox.schema.identifiable import UniqueId, GlobalKey from labelbox.schema.ontology_kind import OntologyKind from labelbox.schema.organization import Organization from labelbox.schema.project import Project @@ -68,7 +86,6 @@ ProjectOverviewDetailed, ) from labelbox.schema.project_resource_tag import ProjectResourceTag -from labelbox.schema.queue_mode import QueueMode from labelbox.schema.resource_tag import ResourceTag from labelbox.schema.review import Review from labelbox.schema.role import ProjectRole, Role diff --git a/libs/labelbox/src/labelbox/adv_client.py b/libs/labelbox/src/labelbox/adv_client.py index 626ac0279..20766a9a4 100644 --- a/libs/labelbox/src/labelbox/adv_client.py +++ b/libs/labelbox/src/labelbox/adv_client.py @@ -1,12 +1,12 @@ import io import json import logging -from typing import Dict, Any, Optional, List, Callable +from typing import Any, Callable, Dict, List, Optional from urllib.parse import urlparse -from labelbox.exceptions import LabelboxError import requests -from requests import Session, Response +from lbox.exceptions import LabelboxError +from requests import Response, Session logger = logging.getLogger(__name__) diff --git a/libs/labelbox/src/labelbox/annotated_types.py b/libs/labelbox/src/labelbox/annotated_types.py index d0fe93ef8..de2e1e01f 100644 --- a/libs/labelbox/src/labelbox/annotated_types.py +++ b/libs/labelbox/src/labelbox/annotated_types.py @@ -1,6 +1,5 @@ -from typing_extensions import Annotated +from typing import Annotated from pydantic import Field - Cuid = Annotated[str, Field(min_length=25, max_length=25)] diff --git a/libs/labelbox/src/labelbox/client.py b/libs/labelbox/src/labelbox/client.py index 671f8b8cc..bcf29665e 100644 --- a/libs/labelbox/src/labelbox/client.py +++ b/libs/labelbox/src/labelbox/client.py @@ -4,20 +4,25 @@ import mimetypes import os import random -import sys import time import urllib.parse import warnings from collections import defaultdict from datetime import datetime, timezone from types import MappingProxyType -from typing import Any, Dict, List, Optional, Union, overload +from typing import Any, Callable, Dict, List, Optional, Set, Union, overload import requests import requests.exceptions from google.api_core import retry +from lbox.exceptions import ( + InternalServerError, + LabelboxError, + ResourceNotFoundError, + TimeoutError, +) +from lbox.request_client import RequestClient -import labelbox.exceptions from labelbox import __version__ as SDK_VERSION from labelbox import utils from labelbox.adv_client import AdvClient @@ -25,6 +30,7 @@ from labelbox.orm.db_object import DbObject from labelbox.orm.model import Entity, Field from labelbox.pagination import PaginatedCollection +from labelbox.project_validation import _CoreProjectInput from labelbox.schema import role from labelbox.schema.catalog import Catalog from labelbox.schema.data_row import DataRow @@ -67,7 +73,6 @@ CONSENSUS_AUTO_AUDIT_PERCENTAGE, QualityMode, ) -from labelbox.schema.queue_mode import QueueMode from labelbox.schema.role import Role from labelbox.schema.search_filters import SearchFilter from labelbox.schema.send_to_annotate_params import ( @@ -82,20 +87,11 @@ logger = logging.getLogger(__name__) -_LABELBOX_API_KEY = "LABELBOX_API_KEY" - - -def python_version_info(): - version_info = sys.version_info - - return f"{version_info.major}.{version_info.minor}.{version_info.micro}-{version_info.releaselevel}" - class Client: """A Labelbox client. - Contains info necessary for connecting to a Labelbox server (URL, - authentication key). Provides functions for querying and creating + Provides functions for querying and creating top-level data objects (Projects, Datasets). """ @@ -121,57 +117,45 @@ def __init__( enable_experimental (bool): Indicates whether or not to use experimental features app_url (str) : host url for all links to the web app Raises: - labelbox.exceptions.AuthenticationError: If no `api_key` + AuthenticationError: If no `api_key` is provided as an argument or via the environment variable. """ - if api_key is None: - if _LABELBOX_API_KEY not in os.environ: - raise labelbox.exceptions.AuthenticationError( - "Labelbox API key not provided" - ) - api_key = os.environ[_LABELBOX_API_KEY] - self.api_key = api_key - - self.enable_experimental = enable_experimental - if enable_experimental: - logger.info("Experimental features have been enabled") - - logger.info("Initializing Labelbox client at '%s'", endpoint) - self.app_url = app_url - self.endpoint = endpoint - self.rest_endpoint = rest_endpoint self._data_row_metadata_ontology = None + self._request_client = RequestClient( + sdk_version=SDK_VERSION, + api_key=api_key, + endpoint=endpoint, + enable_experimental=enable_experimental, + app_url=app_url, + rest_endpoint=rest_endpoint, + ) self._adv_client = AdvClient.factory(rest_endpoint, api_key) - self._connection: requests.Session = self._init_connection() - def _init_connection(self) -> requests.Session: - connection = ( - requests.Session() - ) # using default connection pool size of 10 - connection.headers.update(self._default_headers()) + @property + def headers(self) -> MappingProxyType: + return self._request_client.headers - return connection + @property + def connection(self) -> requests.Session: + return self._request_client._connection @property - def headers(self) -> MappingProxyType: - return self._connection.headers - - def _default_headers(self): - return { - "Authorization": "Bearer %s" % self.api_key, - "Accept": "application/json", - "Content-Type": "application/json", - "X-User-Agent": f"python-sdk {SDK_VERSION}", - "X-Python-Version": f"{python_version_info()}", - } + def endpoint(self) -> str: + return self._request_client.endpoint + + @property + def rest_endpoint(self) -> str: + return self._request_client.rest_endpoint + + @property + def enable_experimental(self) -> bool: + return self._request_client.enable_experimental + + @property + def app_url(self) -> str: + return self._request_client.app_url - @retry.Retry( - predicate=retry.if_exception_type( - labelbox.exceptions.InternalServerError, - labelbox.exceptions.TimeoutError, - ) - ) def execute( self, query=None, @@ -182,264 +166,34 @@ def execute( experimental=False, error_log_key="message", raise_return_resource_not_found=False, - ): - """Sends a request to the server for the execution of the - given query. - - Checks the response for errors and wraps errors - in appropriate `labelbox.exceptions.LabelboxError` subtypes. + error_handlers: Optional[ + Dict[str, Callable[[requests.models.Response], None]] + ] = None, + ) -> Dict[str, Any]: + """Executes a GraphQL query. Args: query (str): The query to execute. - params (dict): Query parameters referenced within the query. - data (str): json string containing the query to execute - files (dict): file arguments for request - timeout (float): Max allowed time for query execution, - in seconds. - Returns: - dict, parsed JSON response. - Raises: - labelbox.exceptions.AuthenticationError: If authentication - failed. - labelbox.exceptions.InvalidQueryError: If `query` is not - syntactically or semantically valid (checked server-side). - labelbox.exceptions.ApiLimitError: If the server API limit was - exceeded. See "How to import data" in the online documentation - to see API limits. - labelbox.exceptions.TimeoutError: If response was not received - in `timeout` seconds. - labelbox.exceptions.NetworkError: If an unknown error occurred - most likely due to connection issues. - labelbox.exceptions.LabelboxError: If an unknown error of any - kind occurred. - ValueError: If query and data are both None. - """ - logger.debug("Query: %s, params: %r, data %r", query, params, data) - - # Convert datetimes to UTC strings. - def convert_value(value): - if isinstance(value, datetime): - value = value.astimezone(timezone.utc) - value = value.strftime("%Y-%m-%dT%H:%M:%SZ") - return value - - if query is not None: - if params is not None: - params = { - key: convert_value(value) for key, value in params.items() - } - data = json.dumps({"query": query, "variables": params}).encode( - "utf-8" - ) - elif data is None: - raise ValueError("query and data cannot both be none") - - endpoint = ( - self.endpoint - if not experimental - else self.endpoint.replace("/graphql", "/_gql") - ) - - try: - headers = self._connection.headers.copy() - if files: - del headers["Content-Type"] - del headers["Accept"] - request = requests.Request( - "POST", - endpoint, - headers=headers, - data=data, - files=files if files else None, - ) - - prepped: requests.PreparedRequest = request.prepare() - - settings = self._connection.merge_environment_settings( - prepped.url, {}, None, None, None - ) - - response = self._connection.send( - prepped, timeout=timeout, **settings - ) - logger.debug("Response: %s", response.text) - except requests.exceptions.Timeout as e: - raise labelbox.exceptions.TimeoutError(str(e)) - except requests.exceptions.RequestException as e: - logger.error("Unknown error: %s", str(e)) - raise labelbox.exceptions.NetworkError(e) - except Exception as e: - raise labelbox.exceptions.LabelboxError( - "Unknown error during Client.query(): " + str(e), e - ) - - if ( - 200 <= response.status_code < 300 - or response.status_code < 500 - or response.status_code >= 600 - ): - try: - r_json = response.json() - except Exception: - raise labelbox.exceptions.LabelboxError( - "Failed to parse response as JSON: %s" % response.text - ) - else: - if ( - "upstream connect error or disconnect/reset before headers" - in response.text - ): - raise labelbox.exceptions.InternalServerError( - "Connection reset" - ) - elif response.status_code == 502: - error_502 = "502 Bad Gateway" - raise labelbox.exceptions.InternalServerError(error_502) - elif 500 <= response.status_code < 600: - error_500 = f"Internal server http error {response.status_code}" - raise labelbox.exceptions.InternalServerError(error_500) - - errors = r_json.get("errors", []) - - def check_errors(keywords, *path): - """Helper that looks for any of the given `keywords` in any of - current errors on paths (like error[path][component][to][keyword]). - """ - for error in errors: - obj = error - for path_elem in path: - obj = obj.get(path_elem, {}) - if obj in keywords: - return error - return None - - def get_error_status_code(error: dict) -> int: - try: - return int(error["extensions"].get("exception").get("status")) - except: - return 500 - - if ( - check_errors(["AUTHENTICATION_ERROR"], "extensions", "code") - is not None - ): - raise labelbox.exceptions.AuthenticationError("Invalid API key") - - authorization_error = check_errors( - ["AUTHORIZATION_ERROR"], "extensions", "code" - ) - if authorization_error is not None: - raise labelbox.exceptions.AuthorizationError( - authorization_error["message"] - ) - - validation_error = check_errors( - ["GRAPHQL_VALIDATION_FAILED"], "extensions", "code" - ) - - if validation_error is not None: - message = validation_error["message"] - if message == "Query complexity limit exceeded": - raise labelbox.exceptions.ValidationFailedError(message) - else: - raise labelbox.exceptions.InvalidQueryError(message) - - graphql_error = check_errors( - ["GRAPHQL_PARSE_FAILED"], "extensions", "code" - ) - if graphql_error is not None: - raise labelbox.exceptions.InvalidQueryError( - graphql_error["message"] - ) - - # Check if API limit was exceeded - response_msg = r_json.get("message", "") - - if response_msg.startswith("You have exceeded"): - raise labelbox.exceptions.ApiLimitError(response_msg) - - resource_not_found_error = check_errors( - ["RESOURCE_NOT_FOUND"], "extensions", "code" - ) - if resource_not_found_error is not None: - if raise_return_resource_not_found: - raise labelbox.exceptions.ResourceNotFoundError( - message=resource_not_found_error["message"] - ) - else: - # Return None and let the caller methods raise an exception - # as they already know which resource type and ID was requested - return None - - resource_conflict_error = check_errors( - ["RESOURCE_CONFLICT"], "extensions", "code" - ) - if resource_conflict_error is not None: - raise labelbox.exceptions.ResourceConflict( - resource_conflict_error["message"] - ) + variables (dict): Variables to pass to the query. + raise_return_resource_not_found (bool): If True, raise a + ResourceNotFoundError if the query returns None. + error_handlers (dict): A dictionary mapping graphql error code to handler functions. + Allows a caller to handle specific errors reporting in a custom way or produce more user-friendly readable messages - malformed_request_error = check_errors( - ["MALFORMED_REQUEST"], "extensions", "code" - ) - if malformed_request_error is not None: - raise labelbox.exceptions.MalformedQueryException( - malformed_request_error[error_log_key] - ) - - # A lot of different error situations are now labeled serverside - # as INTERNAL_SERVER_ERROR, when they are actually client errors. - # TODO: fix this in the server API - internal_server_error = check_errors( - ["INTERNAL_SERVER_ERROR"], "extensions", "code" - ) - if internal_server_error is not None: - message = internal_server_error.get("message") - error_status_code = get_error_status_code(internal_server_error) - if error_status_code == 400: - raise labelbox.exceptions.InvalidQueryError(message) - elif error_status_code == 422: - raise labelbox.exceptions.UnprocessableEntityError(message) - elif error_status_code == 426: - raise labelbox.exceptions.OperationNotAllowedException(message) - elif error_status_code == 500: - raise labelbox.exceptions.LabelboxError(message) - else: - raise labelbox.exceptions.InternalServerError(message) - - not_allowed_error = check_errors( - ["OPERATION_NOT_ALLOWED"], "extensions", "code" + Returns: + dict: The response from the server. + """ + return self._request_client.execute( + query, + params, + data=data, + files=files, + timeout=timeout, + experimental=experimental, + error_log_key=error_log_key, + raise_return_resource_not_found=raise_return_resource_not_found, + error_handlers=error_handlers, ) - if not_allowed_error is not None: - message = not_allowed_error.get("message") - raise labelbox.exceptions.OperationNotAllowedException(message) - - if len(errors) > 0: - logger.warning("Unparsed errors on query execution: %r", errors) - messages = list( - map( - lambda x: { - "message": x["message"], - "code": x["extensions"]["code"], - }, - errors, - ) - ) - raise labelbox.exceptions.LabelboxError( - "Unknown error: %s" % str(messages) - ) - - # if we do return a proper error code, and didn't catch this above - # reraise - # this mainly catches a 401 for API access disabled for free tier - # TODO: need to unify API errors to handle things more uniformly - # in the SDK - if response.status_code != requests.codes.ok: - message = f"{response.status_code} {response.reason}" - cause = r_json.get("message") - raise labelbox.exceptions.LabelboxError(message, cause) - - return r_json["data"] def upload_file(self, path: str) -> str: """Uploads given path to local file. @@ -451,7 +205,7 @@ def upload_file(self, path: str) -> str: Returns: str, the URL of uploaded data. Raises: - labelbox.exceptions.LabelboxError: If upload failed. + LabelboxError: If upload failed. """ content_type, _ = mimetypes.guess_type(path) filename = os.path.basename(path) @@ -460,11 +214,7 @@ def upload_file(self, path: str) -> str: content=f.read(), filename=filename, content_type=content_type ) - @retry.Retry( - predicate=retry.if_exception_type( - labelbox.exceptions.InternalServerError - ) - ) + @retry.Retry(predicate=retry.if_exception_type(InternalServerError)) def upload_data( self, content: bytes, @@ -484,7 +234,7 @@ def upload_data( str, the URL of uploaded data. Raises: - labelbox.exceptions.LabelboxError: If upload failed. + LabelboxError: If upload failed. """ request_data = { @@ -509,7 +259,7 @@ def upload_data( if (filename and content_type) else content } - headers = self._connection.headers.copy() + headers = self.connection.headers.copy() headers.pop("Content-Type", None) request = requests.Request( "POST", @@ -521,22 +271,20 @@ def upload_data( prepped: requests.PreparedRequest = request.prepare() - response = self._connection.send(prepped) + response = self.connection.send(prepped) if response.status_code == 502: error_502 = "502 Bad Gateway" - raise labelbox.exceptions.InternalServerError(error_502) + raise InternalServerError(error_502) elif response.status_code == 503: - raise labelbox.exceptions.InternalServerError(response.text) + raise InternalServerError(response.text) elif response.status_code == 520: - raise labelbox.exceptions.InternalServerError(response.text) + raise InternalServerError(response.text) try: file_data = response.json().get("data", None) except ValueError as e: # response is not valid JSON - raise labelbox.exceptions.LabelboxError( - "Failed to upload, unknown cause", e - ) + raise LabelboxError("Failed to upload, unknown cause", e) if not file_data or not file_data.get("uploadFile", None): try: @@ -546,9 +294,7 @@ def upload_data( ) except Exception: error_msg = "Unknown error" - raise labelbox.exceptions.LabelboxError( - "Failed to upload, message: %s" % error_msg - ) + raise LabelboxError("Failed to upload, message: %s" % error_msg) return file_data["uploadFile"]["url"] @@ -561,7 +307,7 @@ def _get_single(self, db_object_type, uid): Returns: Object of `db_object_type`. Raises: - labelbox.exceptions.ResourceNotFoundError: If there is no object + ResourceNotFoundError: If there is no object of the given type for the given ID. """ query_str, params = query.get_single(db_object_type, uid) @@ -569,9 +315,7 @@ def _get_single(self, db_object_type, uid): res = self.execute(query_str, params) res = res and res.get(utils.camel_case(db_object_type.type_name())) if res is None: - raise labelbox.exceptions.ResourceNotFoundError( - db_object_type, params - ) + raise ResourceNotFoundError(db_object_type, params) else: return db_object_type(self, res) @@ -585,7 +329,7 @@ def get_project(self, project_id) -> Project: Returns: The sought Project. Raises: - labelbox.exceptions.ResourceNotFoundError: If there is no + ResourceNotFoundError: If there is no Project with the given ID. """ return self._get_single(Entity.Project, project_id) @@ -600,7 +344,7 @@ def get_dataset(self, dataset_id) -> Dataset: Returns: The sought Dataset. Raises: - labelbox.exceptions.ResourceNotFoundError: If there is no + ResourceNotFoundError: If there is no Dataset with the given ID. """ return self._get_single(Entity.Dataset, dataset_id) @@ -630,7 +374,7 @@ def _get_all(self, db_object_type, where, filter_deleted=True): An iterable of `db_object_type` instances. """ if filter_deleted: - not_deleted = db_object_type.deleted == False + not_deleted = db_object_type.deleted == False # noqa: E712 Needed for bit operator to combine comparisons where = not_deleted if where is None else where & not_deleted query_str, params = query.get_all(db_object_type, where) @@ -724,13 +468,11 @@ def _create(self, db_object_type, data, extra_params={}): res = self.execute( query_string, params, raise_return_resource_not_found=True ) - if not res: - raise labelbox.exceptions.LabelboxError( + raise LabelboxError( "Failed to create %s" % db_object_type.type_name() ) res = res["create%s" % db_object_type.type_name()] - return db_object_type(self, res) def create_model_config( @@ -784,9 +526,7 @@ def delete_model_config(self, id: str) -> bool: params = {"id": id} result = self.execute(query, params) if not result: - raise labelbox.exceptions.ResourceNotFoundError( - Entity.ModelConfig, params - ) + raise ResourceNotFoundError(Entity.ModelConfig, params) return result["deleteModelConfig"]["success"] def create_dataset( @@ -853,7 +593,18 @@ def create_dataset( raise e return dataset - def create_project(self, **kwargs) -> Project: + def create_project( + self, + name: str, + media_type: MediaType, + description: Optional[str] = None, + quality_modes: Optional[Set[QualityMode]] = { + QualityMode.Benchmark, + QualityMode.Consensus, + }, + is_benchmark_enabled: Optional[bool] = None, + is_consensus_enabled: Optional[bool] = None, + ) -> Project: """Creates a Project object on the server. Attribute values are passed as keyword arguments. @@ -862,71 +613,44 @@ def create_project(self, **kwargs) -> Project: name="", description="", media_type=MediaType.Image, - queue_mode=QueueMode.Batch ) Args: name (str): A name for the project description (str): A short summary for the project media_type (MediaType): The type of assets that this project will accept - queue_mode (Optional[QueueMode]): The queue mode to use - quality_mode (Optional[QualityMode]): The quality mode to use (e.g. Benchmark, Consensus). Defaults to - Benchmark quality_modes (Optional[List[QualityMode]]): The quality modes to use (e.g. Benchmark, Consensus). Defaults to Benchmark. + is_benchmark_enabled (Optional[bool]): Whether the project supports benchmark. Defaults to None. + is_consensus_enabled (Optional[bool]): Whether the project supports consensus. Defaults to None. Returns: A new Project object. Raises: - InvalidAttributeError: If the Project type does not contain - any of the attribute names given in kwargs. - - NOTE: the following attributes are used only in chat model evaluation projects: - dataset_name_or_id, append_to_existing_dataset, data_row_count, editor_task_type - They are not used for general projects and not supported in this method + ValueError: If inputs are invalid. """ - # The following arguments are not supported for general projects, only for chat model evaluation projects - kwargs.pop("dataset_name_or_id", None) - kwargs.pop("append_to_existing_dataset", None) - kwargs.pop("data_row_count", None) - kwargs.pop("editor_task_type", None) - return self._create_project(**kwargs) - - @overload - def create_model_evaluation_project( - self, - dataset_name: str, - dataset_id: str = None, - data_row_count: int = 100, - **kwargs, - ) -> Project: - pass - - @overload - def create_model_evaluation_project( - self, - dataset_id: str, - dataset_name: str = None, - data_row_count: int = 100, - **kwargs, - ) -> Project: - pass - - @overload - def create_model_evaluation_project( - self, - dataset_id: Optional[str] = None, - dataset_name: Optional[str] = None, - data_row_count: Optional[int] = None, - **kwargs, - ) -> Project: - pass + input = { + "name": name, + "description": description, + "media_type": media_type, + "quality_modes": quality_modes, + "is_benchmark_enabled": is_benchmark_enabled, + "is_consensus_enabled": is_consensus_enabled, + } + return self._create_project(_CoreProjectInput(**input)) def create_model_evaluation_project( self, + name: str, + description: Optional[str] = None, + quality_modes: Optional[Set[QualityMode]] = { + QualityMode.Benchmark, + QualityMode.Consensus, + }, + is_benchmark_enabled: Optional[bool] = None, + is_consensus_enabled: Optional[bool] = None, dataset_id: Optional[str] = None, dataset_name: Optional[str] = None, data_row_count: Optional[int] = None, - **kwargs, ) -> Project: """ Use this method exclusively to create a chat model evaluation project. @@ -934,12 +658,12 @@ def create_model_evaluation_project( dataset_name: When creating a new dataset, pass the name dataset_id: When using an existing dataset, pass the id data_row_count: The number of data row assets to use for the project - **kwargs: Additional parameters to pass to the the create_project method + See create_project for additional parameters Returns: Project: The created project Examples: - >>> client.create_model_evaluation_project(name=project_name, dataset_name="new data set") + >>> client.create_model_evaluation_project(name=project_name, media_type=dataset_name="new data set") >>> This creates a new dataset with a default number of rows (100), creates new project and assigns a batch of the newly created datarows to the project. >>> client.create_model_evaluation_project(name=project_name, dataset_name="new data set", data_row_count=10) @@ -959,51 +683,75 @@ def create_model_evaluation_project( append_to_existing_dataset = bool(dataset_id) if dataset_name_or_id: - kwargs["dataset_name_or_id"] = dataset_name_or_id - kwargs["append_to_existing_dataset"] = append_to_existing_dataset if data_row_count is None: data_row_count = 100 - if data_row_count < 0: - raise ValueError("data_row_count must be a positive integer.") - kwargs["data_row_count"] = data_row_count warnings.warn( "Automatic generation of data rows of live model evaluation projects is deprecated. dataset_name_or_id, append_to_existing_dataset, data_row_count will be removed in a future version.", DeprecationWarning, ) - kwargs["media_type"] = MediaType.Conversational - kwargs["editor_task_type"] = EditorTaskType.ModelChatEvaluation.value + media_type = MediaType.Conversational + editor_task_type = EditorTaskType.ModelChatEvaluation - return self._create_project(**kwargs) + input = { + "name": name, + "description": description, + "media_type": media_type, + "quality_modes": quality_modes, + "is_benchmark_enabled": is_benchmark_enabled, + "is_consensus_enabled": is_consensus_enabled, + "dataset_name_or_id": dataset_name_or_id, + "append_to_existing_dataset": append_to_existing_dataset, + "data_row_count": data_row_count, + "editor_task_type": editor_task_type, + } + return self._create_project(_CoreProjectInput(**input)) - def create_offline_model_evaluation_project(self, **kwargs) -> Project: + def create_offline_model_evaluation_project( + self, + name: str, + description: Optional[str] = None, + quality_modes: Optional[Set[QualityMode]] = { + QualityMode.Benchmark, + QualityMode.Consensus, + }, + is_benchmark_enabled: Optional[bool] = None, + is_consensus_enabled: Optional[bool] = None, + ) -> Project: """ Creates a project for offline model evaluation. Args: - **kwargs: Additional parameters to pass see the create_project method + See create_project for parameters Returns: Project: The created project """ - kwargs["media_type"] = ( - MediaType.Conversational - ) # Only Conversational is supported - kwargs["editor_task_type"] = ( - EditorTaskType.OfflineModelChatEvaluation.value - ) # Special editor task type for offline model evaluation - - # The following arguments are not supported for offline model evaluation - kwargs.pop("dataset_name_or_id", None) - kwargs.pop("append_to_existing_dataset", None) - kwargs.pop("data_row_count", None) - - return self._create_project(**kwargs) + input = { + "name": name, + "description": description, + "media_type": MediaType.Conversational, + "quality_modes": quality_modes, + "is_benchmark_enabled": is_benchmark_enabled, + "is_consensus_enabled": is_consensus_enabled, + "editor_task_type": EditorTaskType.OfflineModelChatEvaluation, + } + return self._create_project(_CoreProjectInput(**input)) def create_prompt_response_generation_project( self, + name: str, + media_type: MediaType, + description: Optional[str] = None, + auto_audit_percentage: Optional[float] = None, + auto_audit_number_of_labels: Optional[int] = None, + quality_modes: Optional[Set[QualityMode]] = { + QualityMode.Benchmark, + QualityMode.Consensus, + }, + is_benchmark_enabled: Optional[bool] = None, + is_consensus_enabled: Optional[bool] = None, dataset_id: Optional[str] = None, dataset_name: Optional[str] = None, data_row_count: int = 100, - **kwargs, ) -> Project: """ Use this method exclusively to create a prompt and response generation project. @@ -1012,7 +760,8 @@ def create_prompt_response_generation_project( dataset_name: When creating a new dataset, pass the name dataset_id: When using an existing dataset, pass the id data_row_count: The number of data row assets to use for the project - **kwargs: Additional parameters to pass see the create_project method + media_type: The type of assets that this project will accept. Limited to LLMPromptCreation and LLMPromptResponseCreation + See create_project for additional parameters Returns: Project: The created project @@ -1042,9 +791,6 @@ def create_prompt_response_generation_project( "Only provide a dataset_name or dataset_id, not both." ) - if data_row_count <= 0: - raise ValueError("data_row_count must be a positive integer.") - if dataset_id: append_to_existing_dataset = True dataset_name_or_id = dataset_id @@ -1052,7 +798,7 @@ def create_prompt_response_generation_project( append_to_existing_dataset = False dataset_name_or_id = dataset_name - if "media_type" in kwargs and kwargs.get("media_type") not in [ + if media_type not in [ MediaType.LLMPromptCreation, MediaType.LLMPromptResponseCreation, ]: @@ -1060,133 +806,52 @@ def create_prompt_response_generation_project( "media_type must be either LLMPromptCreation or LLMPromptResponseCreation" ) - kwargs["dataset_name_or_id"] = dataset_name_or_id - kwargs["append_to_existing_dataset"] = append_to_existing_dataset - kwargs["data_row_count"] = data_row_count - - kwargs.pop("editor_task_type", None) - - return self._create_project(**kwargs) + input = { + "name": name, + "description": description, + "media_type": media_type, + "auto_audit_percentage": auto_audit_percentage, + "auto_audit_number_of_labels": auto_audit_number_of_labels, + "quality_modes": quality_modes, + "is_benchmark_enabled": is_benchmark_enabled, + "is_consensus_enabled": is_consensus_enabled, + "dataset_name_or_id": dataset_name_or_id, + "append_to_existing_dataset": append_to_existing_dataset, + "data_row_count": data_row_count, + } + return self._create_project(_CoreProjectInput(**input)) - def create_response_creation_project(self, **kwargs) -> Project: + def create_response_creation_project( + self, + name: str, + description: Optional[str] = None, + quality_modes: Optional[Set[QualityMode]] = { + QualityMode.Benchmark, + QualityMode.Consensus, + }, + is_benchmark_enabled: Optional[bool] = None, + is_consensus_enabled: Optional[bool] = None, + ) -> Project: """ Creates a project for response creation. Args: - **kwargs: Additional parameters to pass see the create_project method + See create_project for parameters Returns: Project: The created project """ - kwargs["media_type"] = MediaType.Text # Only Text is supported - kwargs["editor_task_type"] = ( - EditorTaskType.ResponseCreation.value - ) # Special editor task type for response creation projects - - # The following arguments are not supported for response creation projects - kwargs.pop("dataset_name_or_id", None) - kwargs.pop("append_to_existing_dataset", None) - kwargs.pop("data_row_count", None) - - return self._create_project(**kwargs) - - def _create_project(self, **kwargs) -> Project: - auto_audit_percentage = kwargs.get("auto_audit_percentage") - auto_audit_number_of_labels = kwargs.get("auto_audit_number_of_labels") - if ( - auto_audit_percentage is not None - or auto_audit_number_of_labels is not None - ): - raise ValueError( - "quality_modes must be set instead of auto_audit_percentage or auto_audit_number_of_labels." - ) - - name = kwargs.get("name") - if name is None or not name.strip(): - raise ValueError("project name must be a valid string.") - - queue_mode = kwargs.get("queue_mode") - if queue_mode is QueueMode.Dataset: - raise ValueError( - "Dataset queue mode is deprecated. Please prefer Batch queue mode." - ) - elif queue_mode is QueueMode.Batch: - logger.warning( - "Passing a queue mode of batch is redundant and will soon no longer be supported." - ) - - media_type = kwargs.get("media_type") - if media_type and MediaType.is_supported(media_type): - media_type_value = media_type.value - elif media_type: - raise TypeError( - f"{media_type} is not a valid media type. Use" - f" any of {MediaType.get_supported_members()}" - " from MediaType. Example: MediaType.Image." - ) - else: - logger.warning( - "Creating a project without specifying media_type" - " through this method will soon no longer be supported." - ) - media_type_value = None - - quality_modes = kwargs.get("quality_modes") - quality_mode = kwargs.get("quality_mode") - if quality_mode: - logger.warning( - "Passing quality_mode is deprecated and will soon no longer be supported. Use quality_modes instead." - ) - - if quality_modes and quality_mode: - raise ValueError( - "Cannot use both quality_modes and quality_mode at the same time. Use one or the other." - ) - - if not quality_modes and not quality_mode: - logger.info("Defaulting quality modes to Benchmark and Consensus.") - - data = kwargs - data.pop("quality_modes", None) - data.pop("quality_mode", None) - - # check if quality_modes is a set, if not, convert to set - quality_modes_set = quality_modes - if quality_modes and not isinstance(quality_modes, set): - quality_modes_set = set(quality_modes) - if quality_mode: - quality_modes_set = {quality_mode} - - if ( - quality_modes_set is None - or len(quality_modes_set) == 0 - or quality_modes_set - == {QualityMode.Benchmark, QualityMode.Consensus} - ): - data["auto_audit_number_of_labels"] = ( - CONSENSUS_AUTO_AUDIT_NUMBER_OF_LABELS - ) - data["auto_audit_percentage"] = CONSENSUS_AUTO_AUDIT_PERCENTAGE - data["is_benchmark_enabled"] = True - data["is_consensus_enabled"] = True - elif quality_modes_set == {QualityMode.Benchmark}: - data["auto_audit_number_of_labels"] = ( - BENCHMARK_AUTO_AUDIT_NUMBER_OF_LABELS - ) - data["auto_audit_percentage"] = BENCHMARK_AUTO_AUDIT_PERCENTAGE - data["is_benchmark_enabled"] = True - elif quality_modes_set == {QualityMode.Consensus}: - data["auto_audit_number_of_labels"] = ( - CONSENSUS_AUTO_AUDIT_NUMBER_OF_LABELS - ) - data["auto_audit_percentage"] = CONSENSUS_AUTO_AUDIT_PERCENTAGE - data["is_consensus_enabled"] = True - else: - raise ValueError( - f"{quality_modes_set} is not a valid quality modes set. Allowed values are [Benchmark, Consensus]" - ) + input = { + "name": name, + "description": description, + "media_type": MediaType.Text, # Only Text is supported + "quality_modes": quality_modes, + "is_benchmark_enabled": is_benchmark_enabled, + "is_consensus_enabled": is_consensus_enabled, + "editor_task_type": EditorTaskType.ResponseCreation.value, # Special editor task type for response creation projects + } + return self._create_project(_CoreProjectInput(**input)) - params = {**data} - if media_type_value: - params["media_type"] = media_type_value + def _create_project(self, input: _CoreProjectInput) -> Project: + params = input.model_dump(exclude_none=True) extra_params = { Field.String("dataset_name_or_id"): params.pop( @@ -1222,7 +887,7 @@ def get_data_row_by_global_key(self, global_key: str) -> DataRow: """ res = self.get_data_row_ids_for_global_keys([global_key]) if res["status"] != "SUCCESS": - raise labelbox.exceptions.ResourceNotFoundError( + raise ResourceNotFoundError( Entity.DataRow, {global_key: global_key} ) data_row_id = res["results"][0] @@ -1250,7 +915,7 @@ def get_model(self, model_id) -> Model: Returns: The sought Model. Raises: - labelbox.exceptions.ResourceNotFoundError: If there is no + ResourceNotFoundError: If there is no Model with the given ID. """ return self._get_single(Entity.Model, model_id) @@ -1493,10 +1158,10 @@ def delete_unused_feature_schema(self, feature_schema_id: str) -> None: + "/feature-schemas/" + urllib.parse.quote(feature_schema_id) ) - response = self._connection.delete(endpoint) + response = self.connection.delete(endpoint) if response.status_code != requests.codes.no_content: - raise labelbox.exceptions.LabelboxError( + raise LabelboxError( "Failed to delete the feature schema, message: " + str(response.json()["message"]) ) @@ -1514,10 +1179,10 @@ def delete_unused_ontology(self, ontology_id: str) -> None: + "/ontologies/" + urllib.parse.quote(ontology_id) ) - response = self._connection.delete(endpoint) + response = self.connection.delete(endpoint) if response.status_code != requests.codes.no_content: - raise labelbox.exceptions.LabelboxError( + raise LabelboxError( "Failed to delete the ontology, message: " + str(response.json()["message"]) ) @@ -1542,12 +1207,12 @@ def update_feature_schema_title( + urllib.parse.quote(feature_schema_id) + "/definition" ) - response = self._connection.patch(endpoint, json={"title": title}) + response = self.connection.patch(endpoint, json={"title": title}) if response.status_code == requests.codes.ok: return self.get_feature_schema(feature_schema_id) else: - raise labelbox.exceptions.LabelboxError( + raise LabelboxError( "Failed to update the feature schema, message: " + str(response.json()["message"]) ) @@ -1576,14 +1241,14 @@ def upsert_feature_schema(self, feature_schema: Dict) -> FeatureSchema: + "/feature-schemas/" + urllib.parse.quote(feature_schema_id) ) - response = self._connection.put( + response = self.connection.put( endpoint, json={"normalized": json.dumps(feature_schema)} ) if response.status_code == requests.codes.ok: return self.get_feature_schema(response.json()["schemaId"]) else: - raise labelbox.exceptions.LabelboxError( + raise LabelboxError( "Failed to upsert the feature schema, message: " + str(response.json()["message"]) ) @@ -1609,9 +1274,9 @@ def insert_feature_schema_into_ontology( + "/feature-schemas/" + urllib.parse.quote(feature_schema_id) ) - response = self._connection.post(endpoint, json={"position": position}) + response = self.connection.post(endpoint, json={"position": position}) if response.status_code != requests.codes.created: - raise labelbox.exceptions.LabelboxError( + raise LabelboxError( "Failed to insert the feature schema into the ontology, message: " + str(response.json()["message"]) ) @@ -1631,12 +1296,12 @@ def get_unused_ontologies(self, after: str = None) -> List[str]: """ endpoint = self.rest_endpoint + "/ontologies/unused" - response = self._connection.get(endpoint, json={"after": after}) + response = self.connection.get(endpoint, json={"after": after}) if response.status_code == requests.codes.ok: return response.json() else: - raise labelbox.exceptions.LabelboxError( + raise LabelboxError( "Failed to get unused ontologies, message: " + str(response.json()["message"]) ) @@ -1656,12 +1321,12 @@ def get_unused_feature_schemas(self, after: str = None) -> List[str]: """ endpoint = self.rest_endpoint + "/feature-schemas/unused" - response = self._connection.get(endpoint, json={"after": after}) + response = self.connection.get(endpoint, json={"after": after}) if response.status_code == requests.codes.ok: return response.json() else: - raise labelbox.exceptions.LabelboxError( + raise LabelboxError( "Failed to get unused feature schemas, message: " + str(response.json()["message"]) ) @@ -1957,12 +1622,12 @@ def _format_failed_rows( elif ( res["assignGlobalKeysToDataRowsResult"]["jobStatus"] == "FAILED" ): - raise labelbox.exceptions.LabelboxError( + raise LabelboxError( "Job assign_global_keys_to_data_rows failed." ) current_time = time.time() if current_time - start_time > timeout_seconds: - raise labelbox.exceptions.TimeoutError( + raise TimeoutError( "Timed out waiting for assign_global_keys_to_data_rows job to complete." ) time.sleep(sleep_time) @@ -1973,10 +1638,6 @@ def get_data_row_ids_for_global_keys( """ Gets data row ids for a list of global keys. - Deprecation Notice: This function will soon no longer return 'Deleted Data Rows' - as part of the 'results'. Global keys for deleted data rows will soon be placed - under 'Data Row not found' portion. - Args: A list of global keys Returns: @@ -2066,12 +1727,10 @@ def _format_failed_rows( return {"status": status, "results": results, "errors": errors} elif res["dataRowsForGlobalKeysResult"]["jobStatus"] == "FAILED": - raise labelbox.exceptions.LabelboxError( - "Job dataRowsForGlobalKeys failed." - ) + raise LabelboxError("Job dataRowsForGlobalKeys failed.") current_time = time.time() if current_time - start_time > timeout_seconds: - raise labelbox.exceptions.TimeoutError( + raise TimeoutError( "Timed out waiting for get_data_rows_for_global_keys job to complete." ) time.sleep(sleep_time) @@ -2170,12 +1829,10 @@ def _format_failed_rows( return {"status": status, "results": results, "errors": errors} elif res["clearGlobalKeysResult"]["jobStatus"] == "FAILED": - raise labelbox.exceptions.LabelboxError( - "Job clearGlobalKeys failed." - ) + raise LabelboxError("Job clearGlobalKeys failed.") current_time = time.time() if current_time - start_time > timeout_seconds: - raise labelbox.exceptions.TimeoutError( + raise TimeoutError( "Timed out waiting for clear_global_keys job to complete." ) time.sleep(sleep_time) @@ -2224,7 +1881,7 @@ def is_feature_schema_archived( + "/ontologies/" + urllib.parse.quote(ontology_id) ) - response = self._connection.get(ontology_endpoint) + response = self.connection.get(ontology_endpoint) if response.status_code == requests.codes.ok: feature_schema_nodes = response.json()["featureSchemaNodes"] @@ -2240,16 +1897,14 @@ def is_feature_schema_archived( if filtered_feature_schema_nodes: return bool(filtered_feature_schema_nodes[0]["archived"]) else: - raise labelbox.exceptions.LabelboxError( + raise LabelboxError( "The specified feature schema was not in the ontology." ) elif response.status_code == 404: - raise labelbox.exceptions.ResourceNotFoundError( - Ontology, ontology_id - ) + raise ResourceNotFoundError(Ontology, ontology_id) else: - raise labelbox.exceptions.LabelboxError( + raise LabelboxError( "Failed to get the feature schema archived status." ) @@ -2276,9 +1931,7 @@ def get_model_slice(self, slice_id) -> ModelSlice: """ res = self.execute(query_str, {"id": slice_id}) if res is None or res["getSavedQuery"] is None: - raise labelbox.exceptions.ResourceNotFoundError( - ModelSlice, slice_id - ) + raise ResourceNotFoundError(ModelSlice, slice_id) return Entity.ModelSlice(self, res["getSavedQuery"]) @@ -2308,15 +1961,15 @@ def delete_feature_schema_from_ontology( + "/feature-schemas/" + urllib.parse.quote(feature_schema_id) ) - response = self._connection.delete(ontology_endpoint) + response = self.connection.delete(ontology_endpoint) if response.status_code == requests.codes.ok: response_json = response.json() - if response_json["archived"] == True: + if response_json["archived"] is True: logger.info( "Feature schema was archived from the ontology because it had associated labels." ) - elif response_json["deleted"] == True: + elif response_json["deleted"] is True: logger.info( "Feature schema was successfully removed from the ontology" ) @@ -2325,7 +1978,7 @@ def delete_feature_schema_from_ontology( result.deleted = bool(response_json["deleted"]) return result else: - raise labelbox.exceptions.LabelboxError( + raise LabelboxError( "Failed to remove feature schema from ontology, message: " + str(response.json()["message"]) ) @@ -2350,14 +2003,12 @@ def unarchive_feature_schema_node( + urllib.parse.quote(root_feature_schema_id) + "/unarchive" ) - response = self._connection.patch(ontology_endpoint) + response = self.connection.patch(ontology_endpoint) if response.status_code == requests.codes.ok: if not bool(response.json()["unarchived"]): - raise labelbox.exceptions.LabelboxError( - "Failed unarchive the feature schema." - ) + raise LabelboxError("Failed unarchive the feature schema.") else: - raise labelbox.exceptions.LabelboxError( + raise LabelboxError( "Failed unarchive the feature schema node, message: ", response.text, ) @@ -2586,9 +2237,7 @@ def get_embedding_by_name(self, name: str) -> Embedding: for e in embeddings: if e.name == name: return e - raise labelbox.exceptions.ResourceNotFoundError( - Embedding, dict(name=name) - ) + raise ResourceNotFoundError(Embedding, dict(name=name)) def upsert_label_feedback( self, label_id: str, feedback: str, scores: Dict[str, float] @@ -2635,8 +2284,7 @@ def upsert_label_feedback( scores_raw = res["upsertAutoQaLabelFeedback"]["scores"] return [ - labelbox.LabelScore(name=x["name"], score=x["score"]) - for x in scores_raw + LabelScore(name=x["name"], score=x["score"]) for x in scores_raw ] def get_labeling_service_dashboards( @@ -2712,7 +2360,7 @@ def get_task_by_id(self, task_id: str) -> Union[Task, DataUpsertTask]: result = self.execute(query, {"userId": user.uid, "taskId": task_id}) data = result.get("user", {}).get("createdTasks", []) if not data: - raise labelbox.exceptions.ResourceNotFoundError( + raise ResourceNotFoundError( message=f"The task {task_id} does not exist." ) task_data = data[0] diff --git a/libs/labelbox/src/labelbox/data/annotation_types/__init__.py b/libs/labelbox/src/labelbox/data/annotation_types/__init__.py index 7908bc242..1a78127e1 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/__init__.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/__init__.py @@ -32,18 +32,8 @@ from .classification import Radio from .classification import Text -from .data import AudioData -from .data import ConversationData -from .data import DicomData -from .data import DocumentData -from .data import HTMLData -from .data import ImageData +from .data import GenericDataRowData from .data import MaskData -from .data import TextData -from .data import VideoData -from .data import LlmPromptResponseCreationData -from .data import LlmPromptCreationData -from .data import LlmResponseCreationData from .label import Label from .collection import LabelGenerator diff --git a/libs/labelbox/src/labelbox/data/annotation_types/collection.py b/libs/labelbox/src/labelbox/data/annotation_types/collection.py index a636a3b3a..42d2a1184 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/collection.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/collection.py @@ -21,62 +21,6 @@ def __init__(self, data: Generator[Label, None, None], *args, **kwargs): self._fns = {} super().__init__(data, *args, **kwargs) - def assign_feature_schema_ids( - self, ontology_builder: "ontology.OntologyBuilder" - ) -> "LabelGenerator": - def _assign_ids(label: Label): - label.assign_feature_schema_ids(ontology_builder) - return label - - warnings.warn( - "This method is deprecated and will be " - "removed in a future release. Feature schema ids" - " are no longer required for importing." - ) - self._fns["assign_feature_schema_ids"] = _assign_ids - return self - - def add_url_to_data( - self, signer: Callable[[bytes], str] - ) -> "LabelGenerator": - """ - Creates signed urls for the data - Only uploads url if one doesn't already exist. - - Args: - signer: A function that accepts bytes and returns a signed url. - Returns: - LabelGenerator that signs urls as data is accessed - """ - - def _add_url_to_data(label: Label): - label.add_url_to_data(signer) - return label - - self._fns["add_url_to_data"] = _add_url_to_data - return self - - def add_to_dataset( - self, dataset: "Entity.Dataset", signer: Callable[[bytes], str] - ) -> "LabelGenerator": - """ - Creates data rows from each labels data object and attaches the data to the given dataset. - Updates the label's data object to have the same external_id and uid as the data row. - - Args: - dataset: labelbox dataset object to add the new data row to - signer: A function that accepts bytes and returns a signed url. - Returns: - LabelGenerator that updates references to the new data rows as data is accessed - """ - - def _add_to_dataset(label: Label): - label.create_data_row(dataset, signer) - return label - - self._fns["assign_datarow_ids"] = _add_to_dataset - return self - def add_url_to_masks( self, signer: Callable[[bytes], str] ) -> "LabelGenerator": diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/__init__.py b/libs/labelbox/src/labelbox/data/annotation_types/data/__init__.py index 2522b2741..8d5e7289b 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/data/__init__.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/data/__init__.py @@ -1,12 +1,2 @@ -from .audio import AudioData -from .conversation import ConversationData -from .dicom import DicomData -from .document import DocumentData -from .html import HTMLData -from .raster import ImageData from .raster import MaskData -from .text import TextData -from .video import VideoData -from .llm_prompt_response_creation import LlmPromptResponseCreationData -from .llm_prompt_creation import LlmPromptCreationData -from .llm_response_creation import LlmResponseCreationData +from .generic_data_row_data import GenericDataRowData diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/audio.py b/libs/labelbox/src/labelbox/data/annotation_types/data/audio.py deleted file mode 100644 index 916fca99d..000000000 --- a/libs/labelbox/src/labelbox/data/annotation_types/data/audio.py +++ /dev/null @@ -1,7 +0,0 @@ -from labelbox.typing_imports import Literal -from labelbox.utils import _NoCoercionMixin -from .base_data import BaseData - - -class AudioData(BaseData, _NoCoercionMixin): - class_name: Literal["AudioData"] = "AudioData" diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/conversation.py b/libs/labelbox/src/labelbox/data/annotation_types/data/conversation.py deleted file mode 100644 index ef6507dca..000000000 --- a/libs/labelbox/src/labelbox/data/annotation_types/data/conversation.py +++ /dev/null @@ -1,7 +0,0 @@ -from labelbox.typing_imports import Literal -from labelbox.utils import _NoCoercionMixin -from .base_data import BaseData - - -class ConversationData(BaseData, _NoCoercionMixin): - class_name: Literal["ConversationData"] = "ConversationData" diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/dicom.py b/libs/labelbox/src/labelbox/data/annotation_types/data/dicom.py deleted file mode 100644 index ae4c377dc..000000000 --- a/libs/labelbox/src/labelbox/data/annotation_types/data/dicom.py +++ /dev/null @@ -1,7 +0,0 @@ -from labelbox.typing_imports import Literal -from labelbox.utils import _NoCoercionMixin -from .base_data import BaseData - - -class DicomData(BaseData, _NoCoercionMixin): - class_name: Literal["DicomData"] = "DicomData" diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/document.py b/libs/labelbox/src/labelbox/data/annotation_types/data/document.py deleted file mode 100644 index 810a3ed3e..000000000 --- a/libs/labelbox/src/labelbox/data/annotation_types/data/document.py +++ /dev/null @@ -1,7 +0,0 @@ -from labelbox.typing_imports import Literal -from labelbox.utils import _NoCoercionMixin -from .base_data import BaseData - - -class DocumentData(BaseData, _NoCoercionMixin): - class_name: Literal["DocumentData"] = "DocumentData" diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/html.py b/libs/labelbox/src/labelbox/data/annotation_types/data/html.py deleted file mode 100644 index 7a78fcb7b..000000000 --- a/libs/labelbox/src/labelbox/data/annotation_types/data/html.py +++ /dev/null @@ -1,7 +0,0 @@ -from labelbox.typing_imports import Literal -from labelbox.utils import _NoCoercionMixin -from .base_data import BaseData - - -class HTMLData(BaseData, _NoCoercionMixin): - class_name: Literal["HTMLData"] = "HTMLData" diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/llm_prompt_creation.py b/libs/labelbox/src/labelbox/data/annotation_types/data/llm_prompt_creation.py deleted file mode 100644 index a1b0450bc..000000000 --- a/libs/labelbox/src/labelbox/data/annotation_types/data/llm_prompt_creation.py +++ /dev/null @@ -1,7 +0,0 @@ -from labelbox.typing_imports import Literal -from labelbox.utils import _NoCoercionMixin -from .base_data import BaseData - - -class LlmPromptCreationData(BaseData, _NoCoercionMixin): - class_name: Literal["LlmPromptCreationData"] = "LlmPromptCreationData" diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/llm_prompt_response_creation.py b/libs/labelbox/src/labelbox/data/annotation_types/data/llm_prompt_response_creation.py deleted file mode 100644 index a8dfce894..000000000 --- a/libs/labelbox/src/labelbox/data/annotation_types/data/llm_prompt_response_creation.py +++ /dev/null @@ -1,9 +0,0 @@ -from labelbox.typing_imports import Literal -from labelbox.utils import _NoCoercionMixin -from .base_data import BaseData - - -class LlmPromptResponseCreationData(BaseData, _NoCoercionMixin): - class_name: Literal["LlmPromptResponseCreationData"] = ( - "LlmPromptResponseCreationData" - ) diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/llm_response_creation.py b/libs/labelbox/src/labelbox/data/annotation_types/data/llm_response_creation.py deleted file mode 100644 index a8963ed3f..000000000 --- a/libs/labelbox/src/labelbox/data/annotation_types/data/llm_response_creation.py +++ /dev/null @@ -1,7 +0,0 @@ -from labelbox.typing_imports import Literal -from labelbox.utils import _NoCoercionMixin -from .base_data import BaseData - - -class LlmResponseCreationData(BaseData, _NoCoercionMixin): - class_name: Literal["LlmResponseCreationData"] = "LlmResponseCreationData" diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/raster.py b/libs/labelbox/src/labelbox/data/annotation_types/data/raster.py index ba4c6485f..8702b8aeb 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/data/raster.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/data/raster.py @@ -1,17 +1,15 @@ from abc import ABC from io import BytesIO -from typing import Callable, Optional, Union -from typing_extensions import Literal +from typing import Callable, Literal, Optional, Union -from PIL import Image +import numpy as np +import requests from google.api_core import retry +from lbox.exceptions import InternalServerError +from PIL import Image +from pydantic import BaseModel, ConfigDict, model_validator from requests.exceptions import ConnectTimeout -import requests -import numpy as np -from pydantic import BaseModel, model_validator, ConfigDict -from labelbox.exceptions import InternalServerError -from .base_data import BaseData from ..types import TypedArray @@ -172,7 +170,7 @@ def validate_args(self, values): uid = self.uid global_key = self.global_key if ( - uid == file_path == im_bytes == url == global_key == None + uid == file_path == im_bytes == url == global_key is None and arr is None ): raise ValueError( @@ -191,7 +189,9 @@ def validate_args(self, values): return self def __repr__(self) -> str: - symbol_or_none = lambda data: "..." if data is not None else None + def symbol_or_none(data): + return "..." if data is not None else None + return ( f"{self.__class__.__name__}(im_bytes={symbol_or_none(self.im_bytes)}," f"file_path={self.file_path}," @@ -220,6 +220,3 @@ class MaskData(RasterData): url: Optional[str] = None arr: Optional[TypedArray[Literal['uint8']]] = None """ - - -class ImageData(RasterData, BaseData): ... diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/text.py b/libs/labelbox/src/labelbox/data/annotation_types/data/text.py deleted file mode 100644 index fe4c222d3..000000000 --- a/libs/labelbox/src/labelbox/data/annotation_types/data/text.py +++ /dev/null @@ -1,115 +0,0 @@ -from typing import Callable, Optional - -import requests -from requests.exceptions import ConnectTimeout -from google.api_core import retry - -from pydantic import ConfigDict, model_validator -from labelbox.exceptions import InternalServerError -from labelbox.typing_imports import Literal -from labelbox.utils import _NoCoercionMixin -from .base_data import BaseData - - -class TextData(BaseData, _NoCoercionMixin): - """ - Represents text data. Requires arg file_path, text, or url - - >>> TextData(text="") - - Args: - file_path (str) - text (str) - url (str) - """ - - class_name: Literal["TextData"] = "TextData" - file_path: Optional[str] = None - text: Optional[str] = None - url: Optional[str] = None - model_config = ConfigDict(extra="forbid") - - @property - def value(self) -> str: - """ - Property that unifies the data access pattern for all references to the text. - - Returns: - string representation of the text - """ - if self.text: - return self.text - elif self.file_path: - with open(self.file_path, "r") as file: - text = file.read() - self.text = text - return text - elif self.url: - text = self.fetch_remote() - self.text = text - return text - else: - raise ValueError("Must set either url, file_path or im_bytes") - - def set_fetch_fn(self, fn): - object.__setattr__(self, "fetch_remote", lambda: fn(self)) - - @retry.Retry( - deadline=15.0, - predicate=retry.if_exception_type(ConnectTimeout, InternalServerError), - ) - def fetch_remote(self) -> str: - """ - Method for accessing url. - - If url is not publicly accessible or requires another access pattern - simply override this function - """ - response = requests.get(self.url) - if response.status_code in [500, 502, 503, 504]: - raise InternalServerError(response.text) - response.raise_for_status() - return response.text - - @retry.Retry(deadline=15.0) - def create_url(self, signer: Callable[[bytes], str]) -> None: - """ - Utility for creating a url from any of the other text references. - - Args: - signer: A function that accepts bytes and returns a signed url. - Returns: - url for the text - """ - if self.url is not None: - return self.url - elif self.file_path is not None: - with open(self.file_path, "rb") as file: - self.url = signer(file.read()) - elif self.text is not None: - self.url = signer(self.text.encode()) - else: - raise ValueError( - "One of url, im_bytes, file_path, numpy must not be None." - ) - return self.url - - @model_validator(mode="after") - def validate_date(self, values): - file_path = self.file_path - text = self.text - url = self.url - uid = self.uid - global_key = self.global_key - if uid == file_path == text == url == global_key == None: - raise ValueError( - "One of `file_path`, `text`, `uid`, `global_key` or `url` required." - ) - return self - - def __repr__(self) -> str: - return ( - f"TextData(file_path={self.file_path}," - f"text={self.text[:30] + '...' if self.text is not None else None}," - f"url={self.url})" - ) diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/video.py b/libs/labelbox/src/labelbox/data/annotation_types/data/video.py deleted file mode 100644 index 581801036..000000000 --- a/libs/labelbox/src/labelbox/data/annotation_types/data/video.py +++ /dev/null @@ -1,173 +0,0 @@ -import logging -import os -import urllib.request -from typing import Callable, Dict, Generator, Optional, Tuple -from typing_extensions import Literal -from uuid import uuid4 - -import cv2 -import numpy as np -from google.api_core import retry - -from .base_data import BaseData -from ..types import TypedArray - -from pydantic import ConfigDict, model_validator - -logger = logging.getLogger(__name__) - - -class VideoData(BaseData): - """ - Represents video - """ - - file_path: Optional[str] = None - url: Optional[str] = None - frames: Optional[Dict[int, TypedArray[Literal["uint8"]]]] = None - # Required for discriminating between data types - model_config = ConfigDict(extra="forbid") - - def load_frames(self, overwrite: bool = False) -> None: - """ - Loads all frames into memory at once in order to access in non-sequential order. - This will use a lot of memory, especially for longer videos - - Args: - overwrite: Replace existing frames - """ - if self.frames and not overwrite: - return - - for count, frame in self.frame_generator(): - if self.frames is None: - self.frames = {} - self.frames[count] = frame - - @property - def value(self): - return self.frame_generator() - - def frame_generator( - self, cache_frames=False, download_dir="/tmp" - ) -> Generator[Tuple[int, np.ndarray], None, None]: - """ - A generator for accessing individual frames in a video. - - Args: - cache_frames (bool): Whether or not to cache frames while iterating through the video. - download_dir (str): Directory to save the video to. Defaults to `/tmp` dir - """ - if self.frames is not None: - for idx, frame in self.frames.items(): - yield idx, frame - return - elif self.url and not self.file_path: - file_path = os.path.join(download_dir, f"{uuid4()}.mp4") - logger.info("Downloading the video locally to %s", file_path) - self.fetch_remote(file_path) - self.file_path = file_path - - vidcap = cv2.VideoCapture(self.file_path) - - success, frame = vidcap.read() - count = 0 - if cache_frames: - self.frames = {} - while success: - frame = frame[:, :, ::-1] - yield count, frame - if cache_frames: - self.frames[count] = frame - success, frame = vidcap.read() - count += 1 - - def __getitem__(self, idx: int) -> np.ndarray: - if self.frames is None: - raise ValueError( - "Cannot select by index without iterating over the entire video or loading all frames." - ) - return self.frames[idx] - - def set_fetch_fn(self, fn): - object.__setattr__(self, "fetch_remote", lambda: fn(self)) - - @retry.Retry(deadline=15.0) - def fetch_remote(self, local_path) -> None: - """ - Method for downloading data from self.url - - If url is not publicly accessible or requires another access pattern - simply override this function - - Args: - local_path: Where to save the thing too. - """ - urllib.request.urlretrieve(self.url, local_path) - - @retry.Retry(deadline=15.0) - def create_url(self, signer: Callable[[bytes], str]) -> None: - """ - Utility for creating a url from any of the other video references. - - Args: - signer: A function that accepts bytes and returns a signed url. - Returns: - url for the video - """ - if self.url is not None: - return self.url - elif self.file_path is not None: - with open(self.file_path, "rb") as file: - self.url = signer(file.read()) - elif self.frames is not None: - self.file_path = self.frames_to_video(self.frames) - self.url = self.create_url(signer) - else: - raise ValueError("One of url, file_path, frames must not be None.") - return self.url - - def frames_to_video( - self, frames: Dict[int, np.ndarray], fps=20, save_dir="/tmp" - ) -> str: - """ - Compresses the data by converting a set of individual frames to a single video. - - """ - file_path = os.path.join(save_dir, f"{uuid4()}.mp4") - out = None - for key in frames.keys(): - frame = frames[key] - if out is None: - out = cv2.VideoWriter( - file_path, - cv2.VideoWriter_fourcc(*"MP4V"), - fps, - frame.shape[:2], - ) - out.write(frame) - if out is None: - return - out.release() - return file_path - - @model_validator(mode="after") - def validate_data(self): - file_path = self.file_path - url = self.url - frames = self.frames - uid = self.uid - global_key = self.global_key - - if uid == file_path == frames == url == global_key == None: - raise ValueError( - "One of `file_path`, `frames`, `uid`, `global_key` or `url` required." - ) - return self - - def __repr__(self) -> str: - return ( - f"VideoData(file_path={self.file_path}," - f"frames={'...' if self.frames is not None else None}," - f"url={self.url})" - ) diff --git a/libs/labelbox/src/labelbox/data/annotation_types/geometry/mask.py b/libs/labelbox/src/labelbox/data/annotation_types/geometry/mask.py index 0d870f24f..03e1dd62c 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/geometry/mask.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/geometry/mask.py @@ -1,15 +1,13 @@ -from typing import Callable, Optional, Tuple, Union, Dict, List +from typing import Callable, Dict, List, Optional, Tuple, Union -import numpy as np import cv2 - +import numpy as np +from pydantic import field_validator from shapely.geometry import MultiPolygon, Polygon from ..data import MaskData from .geometry import Geometry -from pydantic import field_validator - class Mask(Geometry): """Mask used to represent a single class in a larger segmentation mask @@ -91,7 +89,7 @@ def draw( as opposed to the mask that this object references which might have multiple objects determined by colors """ mask = self.mask.value - mask = np.alltrue(mask == self.color, axis=2).astype(np.uint8) + mask = np.all(mask == self.color, axis=2).astype(np.uint8) if height is not None or width is not None: mask = cv2.resize( diff --git a/libs/labelbox/src/labelbox/data/annotation_types/label.py b/libs/labelbox/src/labelbox/data/annotation_types/label.py index 7eef43f31..2f835b23c 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/label.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/label.py @@ -3,10 +3,7 @@ import warnings import labelbox -from labelbox.data.annotation_types.data.generic_data_row_data import ( - GenericDataRowData, -) -from labelbox.data.annotation_types.data.tiled_image import TiledImageData +from labelbox.data.annotation_types.data import GenericDataRowData, MaskData from labelbox.schema import ontology from ...annotated_types import Cuid @@ -14,42 +11,13 @@ from .relationship import RelationshipAnnotation from .llm_prompt_response.prompt import PromptClassificationAnnotation from .classification import ClassificationAnswer -from .data import ( - AudioData, - ConversationData, - DicomData, - DocumentData, - HTMLData, - ImageData, - TextData, - VideoData, - LlmPromptCreationData, - LlmPromptResponseCreationData, - LlmResponseCreationData, -) from .geometry import Mask from .metrics import ScalarMetric, ConfusionMatrixMetric from .video import VideoClassificationAnnotation from .video import VideoObjectAnnotation, VideoMaskAnnotation from .mmc import MessageEvaluationTaskAnnotation from ..ontology import get_feature_schema_lookup -from pydantic import BaseModel, field_validator, model_serializer - -DataType = Union[ - VideoData, - ImageData, - TextData, - TiledImageData, - AudioData, - ConversationData, - DicomData, - DocumentData, - HTMLData, - LlmPromptCreationData, - LlmPromptResponseCreationData, - LlmResponseCreationData, - GenericDataRowData, -] +from pydantic import BaseModel, field_validator class Label(BaseModel): @@ -67,14 +35,13 @@ class Label(BaseModel): Args: uid: Optional Label Id in Labelbox - data: Data of Label, Image, Video, Text or dict with a single key uid | global_key | external_id. - Note use of classes as data is deprecated. Use GenericDataRowData or dict with a single key instead. + data: GenericDataRowData or dict with a single key uid | global_key | external_id. annotations: List of Annotations in the label extra: additional context """ uid: Optional[Cuid] = None - data: DataType + data: Union[GenericDataRowData, MaskData] annotations: List[ Union[ ClassificationAnnotation, @@ -94,13 +61,6 @@ class Label(BaseModel): def validate_data(cls, data): if isinstance(data, Dict): return GenericDataRowData(**data) - elif isinstance(data, GenericDataRowData): - return data - else: - warnings.warn( - f"Using {type(data).__name__} class for label.data is deprecated. " - "Use a dict or an instance of GenericDataRowData instead." - ) return data def object_annotations(self) -> List[ObjectAnnotation]: @@ -128,19 +88,6 @@ def frame_annotations( frame_dict[annotation.frame].append(annotation) return frame_dict - def add_url_to_data(self, signer) -> "Label": - """ - Creates signed urls for the data - Only uploads url if one doesn't already exist. - - Args: - signer: A function that accepts bytes and returns a signed url. - Returns: - Label with updated references to new data url - """ - self.data.create_url(signer) - return self - def add_url_to_masks(self, signer) -> "Label": """ Creates signed urls for all masks in the Label. @@ -189,42 +136,6 @@ def create_data_row( self.data.external_id = data_row.external_id return self - def assign_feature_schema_ids( - self, ontology_builder: ontology.OntologyBuilder - ) -> "Label": - """ - Adds schema ids to all FeatureSchema objects in the Labels. - - Args: - ontology_builder: The ontology that matches the feature names assigned to objects in this dataset - Returns: - Label. useful for chaining these modifying functions - - Note: You can now import annotations using names directly without having to lookup schema_ids - """ - warnings.warn( - "This method is deprecated and will be " - "removed in a future release. Feature schema ids" - " are no longer required for importing." - ) - tool_lookup, classification_lookup = get_feature_schema_lookup( - ontology_builder - ) - for annotation in self.annotations: - if isinstance(annotation, ClassificationAnnotation): - self._assign_or_raise(annotation, classification_lookup) - self._assign_option(annotation, classification_lookup) - elif isinstance(annotation, ObjectAnnotation): - self._assign_or_raise(annotation, tool_lookup) - for classification in annotation.classifications: - self._assign_or_raise(classification, classification_lookup) - self._assign_option(classification, classification_lookup) - else: - raise TypeError( - f"Unexpected type found for annotation. {type(annotation)}" - ) - return self - def _assign_or_raise(self, annotation, lookup: Dict[str, str]) -> None: if annotation.feature_schema_id is not None: return diff --git a/libs/labelbox/src/labelbox/data/annotation_types/metrics/scalar.py b/libs/labelbox/src/labelbox/data/annotation_types/metrics/scalar.py index 13d0e9748..1434be427 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/metrics/scalar.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/metrics/scalar.py @@ -1,11 +1,10 @@ -from typing import Dict, Optional, Union -from typing_extensions import Annotated from enum import Enum +from typing import Annotated, Dict, Optional, Union from pydantic import field_validator from pydantic.types import confloat -from .base import ConfidenceValue, BaseMetric +from .base import BaseMetric, ConfidenceValue ScalarMetricValue = Annotated[float, confloat(ge=0, le=100_000_000)] ScalarMetricConfidenceValue = Dict[ConfidenceValue, ScalarMetricValue] diff --git a/libs/labelbox/src/labelbox/data/annotation_types/types.py b/libs/labelbox/src/labelbox/data/annotation_types/types.py index 9bb86a4b9..4a9198fea 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/types.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/types.py @@ -1,43 +1,20 @@ -import sys -from typing import Generic, TypeVar, Any - -from typing_extensions import Annotated -from packaging import version +from typing import Generic, TypeVar import numpy as np - -from pydantic import StringConstraints, Field +from pydantic_core import core_schema DType = TypeVar("DType") DShape = TypeVar("DShape") -class _TypedArray(np.ndarray, Generic[DType, DShape]): +class TypedArray(np.ndarray, Generic[DType, DShape]): @classmethod - def __get_validators__(cls): - yield cls.validate + def __get_pydantic_core_schema__( + cls, _source_type: type, _model: type + ) -> core_schema.CoreSchema: + return core_schema.no_info_plain_validator_function(cls.validate) @classmethod - def validate(cls, val, field: Field): + def validate(cls, val): if not isinstance(val, np.ndarray): raise TypeError(f"Expected numpy array. Found {type(val)}") return val - - -if version.parse(np.__version__) >= version.parse("1.25.0"): - from typing import GenericAlias - - TypedArray = GenericAlias(_TypedArray, (Any, DType)) -elif version.parse(np.__version__) >= version.parse("1.23.0"): - from numpy._typing import _GenericAlias - - TypedArray = _GenericAlias(_TypedArray, (Any, DType)) -elif ( - version.parse("1.22.0") - <= version.parse(np.__version__) - < version.parse("1.23.0") -): - from numpy.typing import _GenericAlias - - TypedArray = _GenericAlias(_TypedArray, (Any, DType)) -else: - TypedArray = _TypedArray[Any, DType] diff --git a/libs/labelbox/src/labelbox/data/annotation_types/video.py b/libs/labelbox/src/labelbox/data/annotation_types/video.py index cfebd7a1f..5a93704c8 100644 --- a/libs/labelbox/src/labelbox/data/annotation_types/video.py +++ b/libs/labelbox/src/labelbox/data/annotation_types/video.py @@ -125,7 +125,7 @@ class MaskFrame(_CamelCaseMixin, BaseModel): def validate_args(self, values): im_bytes = self.im_bytes instance_uri = self.instance_uri - if im_bytes == instance_uri == None: + if im_bytes == instance_uri is None: raise ValueError("One of `instance_uri`, `im_bytes` required.") return self diff --git a/libs/labelbox/src/labelbox/data/metrics/confusion_matrix/calculation.py b/libs/labelbox/src/labelbox/data/metrics/confusion_matrix/calculation.py index 938e17f65..83410a540 100644 --- a/libs/labelbox/src/labelbox/data/metrics/confusion_matrix/calculation.py +++ b/libs/labelbox/src/labelbox/data/metrics/confusion_matrix/calculation.py @@ -130,7 +130,7 @@ def classification_confusion_matrix( prediction, ground_truth = predictions[0], ground_truths[0] - if type(prediction) != type(ground_truth): + if type(prediction) is not type(ground_truth): raise TypeError( "Classification features must be the same type to compute agreement. " f"Found `{type(prediction)}` and `{type(ground_truth)}`" diff --git a/libs/labelbox/src/labelbox/data/metrics/group.py b/libs/labelbox/src/labelbox/data/metrics/group.py index 88f4eae8b..9c4104c29 100644 --- a/libs/labelbox/src/labelbox/data/metrics/group.py +++ b/libs/labelbox/src/labelbox/data/metrics/group.py @@ -16,10 +16,10 @@ try: from typing import Literal except ImportError: - from typing_extensions import Literal + from typing import Literal +from ..annotation_types import ClassificationAnnotation, Label, ObjectAnnotation from ..annotation_types.feature import FeatureSchema -from ..annotation_types import ObjectAnnotation, ClassificationAnnotation, Label def get_identifying_key( diff --git a/libs/labelbox/src/labelbox/data/metrics/iou/calculation.py b/libs/labelbox/src/labelbox/data/metrics/iou/calculation.py index 2a376d3fe..d0237963c 100644 --- a/libs/labelbox/src/labelbox/data/metrics/iou/calculation.py +++ b/libs/labelbox/src/labelbox/data/metrics/iou/calculation.py @@ -209,7 +209,7 @@ def classification_miou( prediction, ground_truth = predictions[0], ground_truths[0] - if type(prediction) != type(ground_truth): + if type(prediction) is not type(ground_truth): raise TypeError( "Classification features must be the same type to compute agreement. " f"Found `{type(prediction)}` and `{type(ground_truth)}`" diff --git a/libs/labelbox/src/labelbox/data/mixins.py b/libs/labelbox/src/labelbox/data/mixins.py index 4440c8a72..fd5ecaa42 100644 --- a/libs/labelbox/src/labelbox/data/mixins.py +++ b/libs/labelbox/src/labelbox/data/mixins.py @@ -1,13 +1,10 @@ -from typing import Optional, List +from typing import List, Optional -from pydantic import BaseModel, field_validator, model_serializer - -from labelbox.exceptions import ( +from lbox.exceptions import ( ConfidenceNotSupportedException, CustomMetricsNotSupportedException, ) - -from warnings import warn +from pydantic import BaseModel, field_validator class ConfidenceMixin(BaseModel): diff --git a/libs/labelbox/src/labelbox/data/serialization/__init__.py b/libs/labelbox/src/labelbox/data/serialization/__init__.py index 71a9b3443..38cb5edff 100644 --- a/libs/labelbox/src/labelbox/data/serialization/__init__.py +++ b/libs/labelbox/src/labelbox/data/serialization/__init__.py @@ -1,2 +1 @@ from .ndjson import NDJsonConverter -from .coco import COCOConverter diff --git a/libs/labelbox/src/labelbox/data/serialization/coco/__init__.py b/libs/labelbox/src/labelbox/data/serialization/coco/__init__.py deleted file mode 100644 index 4511e89ee..000000000 --- a/libs/labelbox/src/labelbox/data/serialization/coco/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .converter import COCOConverter diff --git a/libs/labelbox/src/labelbox/data/serialization/coco/annotation.py b/libs/labelbox/src/labelbox/data/serialization/coco/annotation.py deleted file mode 100644 index e387cb7d9..000000000 --- a/libs/labelbox/src/labelbox/data/serialization/coco/annotation.py +++ /dev/null @@ -1,78 +0,0 @@ -from typing import Any, Tuple, List, Union -from pathlib import Path -from collections import defaultdict -import warnings - -from ...annotation_types.relationship import RelationshipAnnotation -from ...annotation_types.metrics.confusion_matrix import ConfusionMatrixMetric -from ...annotation_types.metrics.scalar import ScalarMetric -from ...annotation_types.video import VideoMaskAnnotation -from ...annotation_types.annotation import ObjectAnnotation -from ...annotation_types.classification.classification import ( - ClassificationAnnotation, -) - -import numpy as np - -from .path import PathSerializerMixin -from pydantic import BaseModel - - -def rle_decoding(rle_arr: List[int], w: int, h: int) -> np.ndarray: - indices = [] - for idx, cnt in zip(rle_arr[0::2], rle_arr[1::2]): - indices.extend( - list(range(idx - 1, idx + cnt - 1)) - ) # RLE is 1-based index - mask = np.zeros(h * w, dtype=np.uint8) - mask[indices] = 1 - return mask.reshape((w, h)).T - - -def get_annotation_lookup(annotations): - """Get annotations from Label.annotations objects - - Args: - annotations (Label.annotations): Annotations attached to labelbox Label object used as private method - """ - annotation_lookup = defaultdict(list) - for annotation in annotations: - # Provide a default value of None if the attribute doesn't exist - attribute_value = getattr(annotation, "image_id", None) or getattr( - annotation, "name", None - ) - annotation_lookup[attribute_value].append(annotation) - return annotation_lookup - - -class SegmentInfo(BaseModel): - id: int - category_id: int - area: Union[float, int] - bbox: Tuple[float, float, float, float] # [x,y,w,h], - iscrowd: int = 0 - - -class RLE(BaseModel): - counts: List[int] - size: Tuple[int, int] # h,w or w,h? - - -class COCOObjectAnnotation(BaseModel): - # All segmentations for a particular class in an image... - # So each image will have one of these for each class present in the image.. - # Annotations only exist if there is data.. - id: int - image_id: int - category_id: int - segmentation: Union[RLE, List[List[float]]] # [[x1,y1,x2,y2,x3,y3...]] - area: float - bbox: Tuple[float, float, float, float] # [x,y,w,h], - iscrowd: int = 0 - - -class PanopticAnnotation(PathSerializerMixin): - # One to one relationship between image and panoptic annotation - image_id: int - file_name: Path - segments_info: List[SegmentInfo] diff --git a/libs/labelbox/src/labelbox/data/serialization/coco/categories.py b/libs/labelbox/src/labelbox/data/serialization/coco/categories.py deleted file mode 100644 index 60ba30fce..000000000 --- a/libs/labelbox/src/labelbox/data/serialization/coco/categories.py +++ /dev/null @@ -1,17 +0,0 @@ -import sys -from hashlib import md5 - -from pydantic import BaseModel - - -class Categories(BaseModel): - id: int - name: str - supercategory: str - isthing: int = 1 - - -def hash_category_name(name: str) -> int: - return int.from_bytes( - md5(name.encode("utf-8")).hexdigest().encode("utf-8"), "little" - ) diff --git a/libs/labelbox/src/labelbox/data/serialization/coco/converter.py b/libs/labelbox/src/labelbox/data/serialization/coco/converter.py deleted file mode 100644 index e270b7573..000000000 --- a/libs/labelbox/src/labelbox/data/serialization/coco/converter.py +++ /dev/null @@ -1,170 +0,0 @@ -from typing import Dict, Any, Union -from pathlib import Path -import os -import warnings - -from ...annotation_types.collection import LabelCollection, LabelGenerator -from ...serialization.coco.instance_dataset import CocoInstanceDataset -from ...serialization.coco.panoptic_dataset import CocoPanopticDataset - - -def create_path_if_not_exists( - path: Union[Path, str], ignore_existing_data=False -): - path = Path(path) - if not path.exists(): - path.mkdir(parents=True, exist_ok=True) - elif not ignore_existing_data and os.listdir(path): - raise ValueError( - f"Directory `{path}`` must be empty. Or set `ignore_existing_data=True`" - ) - return path - - -def validate_path(path: Union[Path, str], name: str): - path = Path(path) - if not path.exists(): - raise ValueError(f"{name} `{path}` must exist") - return path - - -class COCOConverter: - """ - Class for converting between coco and labelbox formats - Note that this class is only compatible with image data. - - Subclasses are currently ignored. - To use subclasses, manually flatten them before using the converter. - """ - - @staticmethod - def serialize_instances( - labels: LabelCollection, - image_root: Union[Path, str], - ignore_existing_data=False, - max_workers=8, - ) -> Dict[str, Any]: - """ - Convert a Labelbox LabelCollection into an mscoco dataset. - This function will only convert masks, polygons, and rectangles. - Masks will be converted into individual instances. - Use deserialize_panoptic to prevent masks from being split apart. - - Args: - labels: A collection of labels to convert - image_root: Where to save images to - ignore_existing_data: Whether or not to raise an exception if images already exist. - This exists only to support detectons panoptic fpn model which requires two mscoco payloads for the same images. - max_workers : Number of workers to process dataset with. A value of 0 will process all data in the main process - Returns: - A dictionary containing labels in the coco object format. - """ - - warnings.warn( - "You are currently utilizing COCOconverter for this action, which will be deprecated in a later release.", - DeprecationWarning, - stacklevel=2, - ) - - image_root = create_path_if_not_exists(image_root, ignore_existing_data) - return CocoInstanceDataset.from_common( - labels=labels, image_root=image_root, max_workers=max_workers - ).model_dump() - - @staticmethod - def serialize_panoptic( - labels: LabelCollection, - image_root: Union[Path, str], - mask_root: Union[Path, str], - all_stuff: bool = False, - ignore_existing_data=False, - max_workers: int = 8, - ) -> Dict[str, Any]: - """ - Convert a Labelbox LabelCollection into an mscoco dataset. - This function will only convert masks, polygons, and rectangles. - Masks will be converted into individual instances. - Use deserialize_panoptic to prevent masks from being split apart. - - Args: - labels: A collection of labels to convert - image_root: Where to save images to - mask_root: Where to save segmentation masks to - all_stuff: If rectangle or polygon annotations are encountered, they will be treated as instances. - To convert them to stuff class set `all_stuff=True`. - ignore_existing_data: Whether or not to raise an exception if images already exist. - This exists only to support detectons panoptic fpn model which requires two mscoco payloads for the same images. - max_workers : Number of workers to process dataset with. A value of 0 will process all data in the main process. - Returns: - A dictionary containing labels in the coco panoptic format. - """ - - warnings.warn( - "You are currently utilizing COCOconverter for this action, which will be deprecated in a later release.", - DeprecationWarning, - stacklevel=2, - ) - - image_root = create_path_if_not_exists(image_root, ignore_existing_data) - mask_root = create_path_if_not_exists(mask_root, ignore_existing_data) - return CocoPanopticDataset.from_common( - labels=labels, - image_root=image_root, - mask_root=mask_root, - all_stuff=all_stuff, - max_workers=max_workers, - ).model_dump() - - @staticmethod - def deserialize_panoptic( - json_data: Dict[str, Any], - image_root: Union[Path, str], - mask_root: Union[Path, str], - ) -> LabelGenerator: - """ - Convert coco panoptic data into the labelbox format (as a LabelGenerator). - - Args: - json_data: panoptic data as a dict - image_root: Path to local images that are referenced by the panoptic json - mask_root: Path to local segmentation masks that are referenced by the panoptic json - Returns: - LabelGenerator - """ - - warnings.warn( - "You are currently utilizing COCOconverter for this action, which will be deprecated in a later release.", - DeprecationWarning, - stacklevel=2, - ) - - image_root = validate_path(image_root, "image_root") - mask_root = validate_path(mask_root, "mask_root") - objs = CocoPanopticDataset(**json_data) - gen = objs.to_common(image_root, mask_root) - return LabelGenerator(data=gen) - - @staticmethod - def deserialize_instances( - json_data: Dict[str, Any], image_root: Path - ) -> LabelGenerator: - """ - Convert coco object data into the labelbox format (as a LabelGenerator). - - Args: - json_data: coco object data as a dict - image_root: Path to local images that are referenced by the coco object json - Returns: - LabelGenerator - """ - - warnings.warn( - "You are currently utilizing COCOconverter for this action, which will be deprecated in a later release.", - DeprecationWarning, - stacklevel=2, - ) - - image_root = validate_path(image_root, "image_root") - objs = CocoInstanceDataset(**json_data) - gen = objs.to_common(image_root) - return LabelGenerator(data=gen) diff --git a/libs/labelbox/src/labelbox/data/serialization/coco/image.py b/libs/labelbox/src/labelbox/data/serialization/coco/image.py deleted file mode 100644 index cef173377..000000000 --- a/libs/labelbox/src/labelbox/data/serialization/coco/image.py +++ /dev/null @@ -1,52 +0,0 @@ -from pathlib import Path - -from typing import Optional, Tuple -from PIL import Image -import imagesize - -from .path import PathSerializerMixin -from ...annotation_types import Label - - -class CocoImage(PathSerializerMixin): - id: int - width: int - height: int - file_name: Path - license: Optional[int] = None - flickr_url: Optional[str] = None - coco_url: Optional[str] = None - - -def get_image_id(label: Label, idx: int) -> int: - if label.data.file_path is not None: - file_name = label.data.file_path.replace(".jpg", "") - if file_name.isdecimal(): - return file_name - return idx - - -def get_image(label: Label, image_root: Path, image_id: str) -> CocoImage: - path = Path(image_root, f"{image_id}.jpg") - if not path.exists(): - im = Image.fromarray(label.data.value) - im.save(path) - w, h = im.size - else: - w, h = imagesize.get(str(path)) - return CocoImage(id=image_id, width=w, height=h, file_name=Path(path.name)) - - -def id_to_rgb(id: int) -> Tuple[int, int, int]: - digits = [] - for _ in range(3): - digits.append(id % 256) - id //= 256 - return digits - - -def rgb_to_id(red: int, green: int, blue: int) -> int: - id = blue * 256 * 256 - id += green * 256 - id += red - return id diff --git a/libs/labelbox/src/labelbox/data/serialization/coco/instance_dataset.py b/libs/labelbox/src/labelbox/data/serialization/coco/instance_dataset.py deleted file mode 100644 index 5241e596f..000000000 --- a/libs/labelbox/src/labelbox/data/serialization/coco/instance_dataset.py +++ /dev/null @@ -1,266 +0,0 @@ -# https://cocodataset.org/#format-data - -from concurrent.futures import ProcessPoolExecutor, as_completed -from typing import Any, Dict, List, Tuple, Optional -from pathlib import Path - -import numpy as np -from tqdm import tqdm - -from ...annotation_types import ( - ImageData, - MaskData, - Mask, - ObjectAnnotation, - Label, - Polygon, - Point, - Rectangle, -) -from ...annotation_types.collection import LabelCollection -from .categories import Categories, hash_category_name -from .annotation import ( - COCOObjectAnnotation, - RLE, - get_annotation_lookup, - rle_decoding, -) -from .image import CocoImage, get_image, get_image_id -from pydantic import BaseModel - - -def mask_to_coco_object_annotation( - annotation: ObjectAnnotation, - annot_idx: int, - image_id: int, - category_id: int, -) -> Optional[COCOObjectAnnotation]: - # This is going to fill any holes into the multipolygon - # If you need to support holes use the panoptic data format - shapely = annotation.value.shapely.simplify(1).buffer(0) - if shapely.is_empty: - return - - xmin, ymin, xmax, ymax = shapely.bounds - # Iterate over polygon once or multiple polygon for each item - area = shapely.area - - return COCOObjectAnnotation( - id=annot_idx, - image_id=image_id, - category_id=category_id, - segmentation=[ - np.array(s.exterior.coords).ravel().tolist() - for s in ([shapely] if shapely.type == "Polygon" else shapely.geoms) - ], - area=area, - bbox=[xmin, ymin, xmax - xmin, ymax - ymin], - iscrowd=0, - ) - - -def vector_to_coco_object_annotation( - annotation: ObjectAnnotation, - annot_idx: int, - image_id: int, - category_id: int, -) -> COCOObjectAnnotation: - shapely = annotation.value.shapely - xmin, ymin, xmax, ymax = shapely.bounds - segmentation = [] - if isinstance(annotation.value, Polygon): - for point in annotation.value.points: - segmentation.extend([point.x, point.y]) - else: - box = annotation.value - segmentation.extend( - [ - box.start.x, - box.start.y, - box.end.x, - box.start.y, - box.end.x, - box.end.y, - box.start.x, - box.end.y, - ] - ) - - return COCOObjectAnnotation( - id=annot_idx, - image_id=image_id, - category_id=category_id, - segmentation=[segmentation], - area=shapely.area, - bbox=[xmin, ymin, xmax - xmin, ymax - ymin], - iscrowd=0, - ) - - -def rle_to_common( - class_annotations: COCOObjectAnnotation, class_name: str -) -> ObjectAnnotation: - mask = rle_decoding( - class_annotations.segmentation.counts, - *class_annotations.segmentation.size[::-1], - ) - return ObjectAnnotation( - name=class_name, - value=Mask(mask=MaskData.from_2D_arr(mask), color=[1, 1, 1]), - ) - - -def segmentations_to_common( - class_annotations: COCOObjectAnnotation, class_name: str -) -> List[ObjectAnnotation]: - # Technically it is polygons. But the key in coco is called segmentations.. - annotations = [] - for points in class_annotations.segmentation: - annotations.append( - ObjectAnnotation( - name=class_name, - value=Polygon( - points=[ - Point(x=points[i], y=points[i + 1]) - for i in range(0, len(points), 2) - ] - ), - ) - ) - return annotations - - -def object_annotation_to_coco( - annotation: ObjectAnnotation, - annot_idx: int, - image_id: int, - category_id: int, -) -> Optional[COCOObjectAnnotation]: - if isinstance(annotation.value, Mask): - return mask_to_coco_object_annotation( - annotation, annot_idx, image_id, category_id - ) - elif isinstance(annotation.value, (Polygon, Rectangle)): - return vector_to_coco_object_annotation( - annotation, annot_idx, image_id, category_id - ) - else: - return None - - -def process_label( - label: Label, idx: int, image_root: str, max_annotations_per_image=10000 -) -> Tuple[np.ndarray, List[COCOObjectAnnotation], Dict[str, str]]: - annot_idx = idx * max_annotations_per_image - image_id = get_image_id(label, idx) - image = get_image(label, image_root, image_id) - coco_annotations = [] - annotation_lookup = get_annotation_lookup(label.annotations) - categories = {} - for class_name in annotation_lookup: - for annotation in annotation_lookup[class_name]: - category_id = categories.get(annotation.name) or hash_category_name( - annotation.name - ) - coco_annotation = object_annotation_to_coco( - annotation, annot_idx, image_id, category_id - ) - if coco_annotation is not None: - coco_annotations.append(coco_annotation) - if annotation.name not in categories: - categories[annotation.name] = category_id - annot_idx += 1 - - return image, coco_annotations, categories - - -class CocoInstanceDataset(BaseModel): - info: Dict[str, Any] = {} - images: List[CocoImage] - annotations: List[COCOObjectAnnotation] - categories: List[Categories] - - @classmethod - def from_common( - cls, labels: LabelCollection, image_root: Path, max_workers=8 - ): - all_coco_annotations = [] - categories = {} - images = [] - futures = [] - coco_categories = {} - - if max_workers: - with ProcessPoolExecutor(max_workers=max_workers) as exc: - futures = [ - exc.submit(process_label, label, idx, image_root) - for idx, label in enumerate(labels) - ] - results = [ - future.result() for future in tqdm(as_completed(futures)) - ] - else: - results = [ - process_label(label, idx, image_root) - for idx, label in enumerate(labels) - ] - - for result in results: - images.append(result[0]) - all_coco_annotations.extend(result[1]) - coco_categories.update(result[2]) - - category_mapping = { - category_id: idx + 1 - for idx, category_id in enumerate(coco_categories.values()) - } - categories = [ - Categories( - id=category_mapping[idx], - name=name, - supercategory="all", - isthing=1, - ) - for name, idx in coco_categories.items() - ] - for annot in all_coco_annotations: - annot.category_id = category_mapping[annot.category_id] - - return CocoInstanceDataset( - info={"image_root": image_root}, - images=images, - annotations=all_coco_annotations, - categories=categories, - ) - - def to_common(self, image_root): - category_lookup = { - category.id: category for category in self.categories - } - annotation_lookup = get_annotation_lookup(self.annotations) - - for image in self.images: - im_path = Path(image_root, image.file_name) - if not im_path.exists(): - raise ValueError( - f"Cannot find file {im_path}. Make sure `image_root` is set properly" - ) - - data = ImageData(file_path=str(im_path)) - annotations = [] - for class_annotations in annotation_lookup[image.id]: - if isinstance(class_annotations.segmentation, RLE): - annotations.append( - rle_to_common( - class_annotations, - category_lookup[class_annotations.category_id].name, - ) - ) - elif isinstance(class_annotations.segmentation, list): - annotations.extend( - segmentations_to_common( - class_annotations, - category_lookup[class_annotations.category_id].name, - ) - ) - yield Label(data=data, annotations=annotations) diff --git a/libs/labelbox/src/labelbox/data/serialization/coco/panoptic_dataset.py b/libs/labelbox/src/labelbox/data/serialization/coco/panoptic_dataset.py deleted file mode 100644 index cbb410548..000000000 --- a/libs/labelbox/src/labelbox/data/serialization/coco/panoptic_dataset.py +++ /dev/null @@ -1,242 +0,0 @@ -from concurrent.futures import ProcessPoolExecutor, as_completed -from typing import Dict, Any, List, Union -from pathlib import Path - -from tqdm import tqdm -import numpy as np -from PIL import Image - -from ...annotation_types.geometry import Polygon, Rectangle -from ...annotation_types import Label -from ...annotation_types.geometry.mask import Mask -from ...annotation_types.annotation import ObjectAnnotation -from ...annotation_types.data.raster import MaskData, ImageData -from ...annotation_types.collection import LabelCollection -from .categories import Categories, hash_category_name -from .image import CocoImage, get_image, get_image_id, id_to_rgb -from .annotation import PanopticAnnotation, SegmentInfo, get_annotation_lookup -from pydantic import BaseModel - - -def vector_to_coco_segment_info( - canvas: np.ndarray, - annotation: ObjectAnnotation, - annotation_idx: int, - image: CocoImage, - category_id: int, -): - shapely = annotation.value.shapely - if shapely.is_empty: - return - - xmin, ymin, xmax, ymax = shapely.bounds - canvas = annotation.value.draw( - height=image.height, - width=image.width, - canvas=canvas, - color=id_to_rgb(annotation_idx), - ) - - return SegmentInfo( - id=annotation_idx, - category_id=category_id, - area=shapely.area, - bbox=[xmin, ymin, xmax - xmin, ymax - ymin], - ), canvas - - -def mask_to_coco_segment_info( - canvas: np.ndarray, annotation, annotation_idx: int, category_id -): - color = id_to_rgb(annotation_idx) - mask = annotation.value.draw(color=color) - shapely = annotation.value.shapely - if shapely.is_empty: - return - - xmin, ymin, xmax, ymax = shapely.bounds - canvas = np.where(canvas == (0, 0, 0), mask, canvas) - return SegmentInfo( - id=annotation_idx, - category_id=category_id, - area=shapely.area, - bbox=[xmin, ymin, xmax - xmin, ymax - ymin], - ), canvas - - -def process_label( - label: Label, idx: Union[int, str], image_root, mask_root, all_stuff=False -): - """ - Masks become stuff - Polygon and rectangle become thing - """ - annotations = get_annotation_lookup(label.annotations) - image_id = get_image_id(label, idx) - image = get_image(label, image_root, image_id) - canvas = np.zeros((image.height, image.width, 3)) - - segments = [] - categories = {} - is_thing = {} - - for class_idx, class_name in enumerate(annotations): - for annotation_idx, annotation in enumerate(annotations[class_name]): - categories[annotation.name] = hash_category_name(annotation.name) - if isinstance(annotation.value, Mask): - coco_segment_info = mask_to_coco_segment_info( - canvas, - annotation, - class_idx + 1, - categories[annotation.name], - ) - - if coco_segment_info is None: - # Filter out empty masks - continue - - segment, canvas = coco_segment_info - segments.append(segment) - is_thing[annotation.name] = 0 - - elif isinstance(annotation.value, (Polygon, Rectangle)): - coco_vector_info = vector_to_coco_segment_info( - canvas, - annotation, - annotation_idx=(class_idx if all_stuff else annotation_idx) - + 1, - image=image, - category_id=categories[annotation.name], - ) - - if coco_vector_info is None: - # Filter out empty annotations - continue - - segment, canvas = coco_vector_info - segments.append(segment) - is_thing[annotation.name] = 1 - int(all_stuff) - - mask_file = str(image.file_name).replace(".jpg", ".png") - mask_file = Path(mask_root, mask_file) - Image.fromarray(canvas.astype(np.uint8)).save(mask_file) - return ( - image, - PanopticAnnotation( - image_id=image_id, - file_name=Path(mask_file.name), - segments_info=segments, - ), - categories, - is_thing, - ) - - -class CocoPanopticDataset(BaseModel): - info: Dict[str, Any] = {} - images: List[CocoImage] - annotations: List[PanopticAnnotation] - categories: List[Categories] - - @classmethod - def from_common( - cls, - labels: LabelCollection, - image_root, - mask_root, - all_stuff, - max_workers=8, - ): - all_coco_annotations = [] - coco_categories = {} - coco_things = {} - images = [] - - if max_workers: - with ProcessPoolExecutor(max_workers=max_workers) as exc: - futures = [ - exc.submit( - process_label, - label, - idx, - image_root, - mask_root, - all_stuff, - ) - for idx, label in enumerate(labels) - ] - results = [ - future.result() for future in tqdm(as_completed(futures)) - ] - else: - results = [ - process_label(label, idx, image_root, mask_root, all_stuff) - for idx, label in enumerate(labels) - ] - - for result in results: - images.append(result[0]) - all_coco_annotations.append(result[1]) - coco_categories.update(result[2]) - coco_things.update(result[3]) - - category_mapping = { - category_id: idx + 1 - for idx, category_id in enumerate(coco_categories.values()) - } - categories = [ - Categories( - id=category_mapping[idx], - name=name, - supercategory="all", - isthing=coco_things.get(name, 1), - ) - for name, idx in coco_categories.items() - ] - - for annot in all_coco_annotations: - for segment in annot.segments_info: - segment.category_id = category_mapping[segment.category_id] - - return CocoPanopticDataset( - info={"image_root": image_root, "mask_root": mask_root}, - images=images, - annotations=all_coco_annotations, - categories=categories, - ) - - def to_common(self, image_root: Path, mask_root: Path): - category_lookup = { - category.id: category for category in self.categories - } - annotation_lookup = { - annotation.image_id: annotation for annotation in self.annotations - } - for image in self.images: - annotations = [] - annotation = annotation_lookup[image.id] - - im_path = Path(image_root, image.file_name) - if not im_path.exists(): - raise ValueError( - f"Cannot find file {im_path}. Make sure `image_root` is set properly" - ) - if not str(annotation.file_name).endswith(".png"): - raise ValueError( - f"COCO masks must be stored as png files and their extension must be `.png`. Found {annotation.file_name}" - ) - mask = MaskData( - file_path=str(Path(mask_root, annotation.file_name)) - ) - - for segmentation in annotation.segments_info: - category = category_lookup[segmentation.category_id] - annotations.append( - ObjectAnnotation( - name=category.name, - value=Mask(mask=mask, color=id_to_rgb(segmentation.id)), - ) - ) - data = ImageData(file_path=str(im_path)) - yield Label(data=data, annotations=annotations) - del annotation_lookup[image.id] diff --git a/libs/labelbox/src/labelbox/data/serialization/coco/path.py b/libs/labelbox/src/labelbox/data/serialization/coco/path.py deleted file mode 100644 index c3be84f31..000000000 --- a/libs/labelbox/src/labelbox/data/serialization/coco/path.py +++ /dev/null @@ -1,9 +0,0 @@ -from pathlib import Path -from pydantic import BaseModel, model_serializer - - -class PathSerializerMixin(BaseModel): - @model_serializer(mode="wrap") - def serialize_model(self, handler): - res = handler(self) - return {k: str(v) if isinstance(v, Path) else v for k, v in res.items()} diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/base.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/base.py index 75ebdc100..d8d8cd36f 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/base.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/base.py @@ -8,18 +8,6 @@ from ....annotated_types import Cuid -subclass_registry = {} - - -class _SubclassRegistryBase(BaseModel): - model_config = ConfigDict(extra="allow") - - def __init_subclass__(cls, **kwargs): - super().__init_subclass__(**kwargs) - if cls.__name__ != "NDAnnotation": - with threading.Lock(): - subclass_registry[cls.__name__] = cls - class DataRow(_CamelCaseMixin): id: Optional[str] = None diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py index b127c4a90..fedf4d91b 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List, Union, Optional -from labelbox.data.annotation_types import ImageData, TextData, VideoData +from labelbox.data.annotation_types import GenericDataRowData from labelbox.data.mixins import ( ConfidenceMixin, CustomMetric, @@ -30,7 +30,6 @@ model_serializer, ) from pydantic.alias_generators import to_camel -from .base import _SubclassRegistryBase class NDAnswer(ConfidenceMixin, CustomMetricsMixin): @@ -224,7 +223,7 @@ def from_common( # ====== End of subclasses -class NDText(NDAnnotation, NDTextSubclass, _SubclassRegistryBase): +class NDText(NDAnnotation, NDTextSubclass): @classmethod def from_common( cls, @@ -233,7 +232,7 @@ def from_common( name: str, feature_schema_id: Cuid, extra: Dict[str, Any], - data: Union[TextData, ImageData], + data: GenericDataRowData, message_id: str, confidence: Optional[float] = None, ) -> "NDText": @@ -249,9 +248,7 @@ def from_common( ) -class NDChecklist( - NDAnnotation, NDChecklistSubclass, VideoSupported, _SubclassRegistryBase -): +class NDChecklist(NDAnnotation, NDChecklistSubclass, VideoSupported): @model_serializer(mode="wrap") def serialize_model(self, handler): res = handler(self) @@ -267,7 +264,7 @@ def from_common( name: str, feature_schema_id: Cuid, extra: Dict[str, Any], - data: Union[VideoData, TextData, ImageData], + data: GenericDataRowData, message_id: str, confidence: Optional[float] = None, custom_metrics: Optional[List[CustomMetric]] = None, @@ -298,9 +295,7 @@ def from_common( ) -class NDRadio( - NDAnnotation, NDRadioSubclass, VideoSupported, _SubclassRegistryBase -): +class NDRadio(NDAnnotation, NDRadioSubclass, VideoSupported): @classmethod def from_common( cls, @@ -309,7 +304,7 @@ def from_common( name: str, feature_schema_id: Cuid, extra: Dict[str, Any], - data: Union[VideoData, TextData, ImageData], + data: GenericDataRowData, message_id: str, confidence: Optional[float] = None, ) -> "NDRadio": @@ -343,7 +338,7 @@ def serialize_model(self, handler): return res -class NDPromptText(NDAnnotation, NDPromptTextSubclass, _SubclassRegistryBase): +class NDPromptText(NDAnnotation, NDPromptTextSubclass): @classmethod def from_common( cls, @@ -432,7 +427,7 @@ def from_common( annotation: Union[ ClassificationAnnotation, VideoClassificationAnnotation ], - data: Union[VideoData, TextData, ImageData], + data: GenericDataRowData, ) -> Union[NDTextSubclass, NDChecklistSubclass, NDRadioSubclass]: classify_obj = cls.lookup_classification(annotation) if classify_obj is None: @@ -480,7 +475,7 @@ def to_common( def from_common( cls, annotation: Union[PromptClassificationAnnotation], - data: Union[VideoData, TextData, ImageData], + data: GenericDataRowData, ) -> Union[NDTextSubclass, NDChecklistSubclass, NDRadioSubclass]: return NDPromptText.from_common( str(annotation._uuid), diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/converter.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/converter.py index 01ab8454a..8176d7862 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/converter.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/converter.py @@ -26,20 +26,6 @@ class NDJsonConverter: - @staticmethod - def deserialize(json_data: Iterable[Dict[str, Any]]) -> LabelGenerator: - """ - Converts ndjson data (prediction import format) into the common labelbox format. - - Args: - json_data: An iterable representing the ndjson data - Returns: - LabelGenerator containing the ndjson data. - """ - data = NDLabel(**{"annotations": copy.copy(json_data)}) - res = data.to_common() - return res - @staticmethod def serialize( labels: LabelCollection, diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py index 18134a228..ffaefb4d7 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py @@ -14,7 +14,6 @@ ) from ...annotation_types.video import VideoObjectAnnotation, VideoMaskAnnotation from ...annotation_types.collection import LabelCollection, LabelGenerator -from ...annotation_types.data import DicomData, ImageData, TextData, VideoData from ...annotation_types.data.generic_data_row_data import GenericDataRowData from ...annotation_types.label import Label from ...annotation_types.ner import TextEntity, ConversationEntity @@ -46,7 +45,6 @@ from .relationship import NDRelationship from .base import DataRow from pydantic import BaseModel, ValidationError -from .base import subclass_registry, _SubclassRegistryBase from pydantic_core import PydanticUndefined from contextlib import suppress @@ -67,68 +65,7 @@ class NDLabel(BaseModel): - annotations: List[_SubclassRegistryBase] - - def __init__(self, **kwargs): - # NOTE: Deserialization of subclasses in pydantic is difficult, see here https://blog.devgenius.io/deserialize-child-classes-with-pydantic-that-gonna-work-784230e1cf83 - # Below implements the subclass registry as mentioned in the article. The python dicts we pass in can be missing certain fields - # we essentially have to infer the type against all sub classes that have the _SubclasssRegistryBase inheritance. - # It works by checking if the keys of our annotations we are missing in matches any required subclass. - # More keys are prioritized over less keys (closer match). This is used when importing json to our base models not a lot of customer workflows - # depend on this method but this works for all our existing tests with the bonus of added validation. (no subclass found it throws an error) - - for index, annotation in enumerate(kwargs["annotations"]): - if isinstance(annotation, dict): - item_annotation_keys = annotation.keys() - key_subclass_combos = defaultdict(list) - for subclass in subclass_registry.values(): - # Get all required keys from subclass - annotation_keys = [] - for k, field in subclass.model_fields.items(): - if field.default == PydanticUndefined and k != "uuid": - if ( - hasattr(field, "alias") - and field.alias in item_annotation_keys - ): - annotation_keys.append(field.alias) - elif ( - hasattr(field, "validation_alias") - and field.validation_alias - in item_annotation_keys - ): - annotation_keys.append(field.validation_alias) - else: - annotation_keys.append(k) - - key_subclass_combos[subclass].extend(annotation_keys) - - # Sort by subclass that has the most keys i.e. the one with the most keys that matches is most likely our subclass - key_subclass_combos = dict( - sorted( - key_subclass_combos.items(), - key=lambda x: len(x[1]), - reverse=True, - ) - ) - - for subclass, key_subclass_combo in key_subclass_combos.items(): - # Choose the keys from our dict we supplied that matches the required keys of a subclass - check_required_keys = all( - key in list(item_annotation_keys) - for key in key_subclass_combo - ) - if check_required_keys: - # Keep trying subclasses until we find one that has valid values (does not throw an validation error) - with suppress(ValidationError): - annotation = subclass(**annotation) - break - if isinstance(annotation, dict): - raise ValueError( - f"Could not find subclass for fields: {item_annotation_keys}" - ) - - kwargs["annotations"][index] = annotation - super().__init__(**kwargs) + annotations: AnnotationType class _Relationship(BaseModel): """This object holds information about the relationship""" @@ -276,46 +213,9 @@ def _generate_annotations( yield Label( annotations=annotations, - data=self._infer_media_type(group.data_row, annotations), + data=GenericDataRowData, ) - def _infer_media_type( - self, - data_row: DataRow, - annotations: List[ - Union[ - TextEntity, - ConversationEntity, - VideoClassificationAnnotation, - DICOMObjectAnnotation, - VideoObjectAnnotation, - ObjectAnnotation, - ClassificationAnnotation, - ScalarMetric, - ConfusionMatrixMetric, - ] - ], - ) -> Union[TextData, VideoData, ImageData]: - if len(annotations) == 0: - raise ValueError("Missing annotations while inferring media type") - - types = {type(annotation) for annotation in annotations} - data = GenericDataRowData - if (TextEntity in types) or (ConversationEntity in types): - data = TextData - elif ( - VideoClassificationAnnotation in types - or VideoObjectAnnotation in types - ): - data = VideoData - elif DICOMObjectAnnotation in types: - data = DicomData - - if data_row.id: - return data(uid=data_row.id) - else: - return data(global_key=data_row.global_key) - @staticmethod def _get_consecutive_frames( frames_indices: List[int], diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/metric.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/metric.py index 60d538b19..f8b522ab5 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/metric.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/metric.py @@ -1,6 +1,6 @@ from typing import Optional, Union, Type -from labelbox.data.annotation_types.data import ImageData, TextData +from labelbox.data.annotation_types.data import GenericDataRowData from labelbox.data.serialization.ndjson.base import DataRow, NDJsonBase from labelbox.data.annotation_types.metrics.scalar import ( ScalarMetric, @@ -15,7 +15,6 @@ ConfusionMatrixMetricConfidenceValue, ) from pydantic import ConfigDict, model_serializer -from .base import _SubclassRegistryBase class BaseNDMetric(NDJsonBase): @@ -33,7 +32,7 @@ def serialize_model(self, handler): return res -class NDConfusionMatrixMetric(BaseNDMetric, _SubclassRegistryBase): +class NDConfusionMatrixMetric(BaseNDMetric): metric_value: Union[ ConfusionMatrixMetricValue, ConfusionMatrixMetricConfidenceValue ] @@ -52,7 +51,7 @@ def to_common(self) -> ConfusionMatrixMetric: @classmethod def from_common( - cls, metric: ConfusionMatrixMetric, data: Union[TextData, ImageData] + cls, metric: ConfusionMatrixMetric, data: GenericDataRowData ) -> "NDConfusionMatrixMetric": return cls( uuid=metric.extra.get("uuid"), @@ -65,7 +64,7 @@ def from_common( ) -class NDScalarMetric(BaseNDMetric, _SubclassRegistryBase): +class NDScalarMetric(BaseNDMetric): metric_value: Union[ScalarMetricValue, ScalarMetricConfidenceValue] metric_name: Optional[str] = None aggregation: Optional[ScalarMetricAggregation] = ( @@ -84,7 +83,7 @@ def to_common(self) -> ScalarMetric: @classmethod def from_common( - cls, metric: ScalarMetric, data: Union[TextData, ImageData] + cls, metric: ScalarMetric, data: GenericDataRowData ) -> "NDScalarMetric": return cls( uuid=metric.extra.get("uuid"), @@ -108,7 +107,7 @@ def to_common( def from_common( cls, annotation: Union[ScalarMetric, ConfusionMatrixMetric], - data: Union[TextData, ImageData], + data: GenericDataRowData, ) -> Union[NDScalarMetric, NDConfusionMatrixMetric]: obj = cls.lookup_object(annotation) return obj.from_common(annotation, data) diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/mmc.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/mmc.py index 4be24f683..b2dcfb5b4 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/mmc.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/mmc.py @@ -2,13 +2,14 @@ from labelbox.utils import _CamelCaseMixin -from .base import _SubclassRegistryBase, DataRow, NDAnnotation +from .base import DataRow, NDAnnotation from ...annotation_types.mmc import ( MessageSingleSelectionTask, MessageMultiSelectionTask, MessageRankingTask, MessageEvaluationTaskAnnotation, ) +from ...annotation_types import GenericDataRowData class MessageTaskData(_CamelCaseMixin): @@ -20,7 +21,7 @@ class MessageTaskData(_CamelCaseMixin): ] -class NDMessageTask(NDAnnotation, _SubclassRegistryBase): +class NDMessageTask(NDAnnotation): message_evaluation_task: MessageTaskData def to_common(self) -> MessageEvaluationTaskAnnotation: @@ -35,7 +36,7 @@ def to_common(self) -> MessageEvaluationTaskAnnotation: def from_common( cls, annotation: MessageEvaluationTaskAnnotation, - data: Any, # Union[ImageData, TextData], + data: GenericDataRowData, ) -> "NDMessageTask": return cls( uuid=str(annotation._uuid), diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/objects.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/objects.py index a1465fa06..1bcba7a89 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/objects.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/objects.py @@ -2,6 +2,7 @@ from typing import Any, Dict, List, Tuple, Union, Optional import base64 +from labelbox.data.annotation_types.data.raster import MaskData from labelbox.data.annotation_types.ner.conversation_entity import ( ConversationEntity, ) @@ -21,9 +22,9 @@ from PIL import Image from labelbox.data.annotation_types import feature -from labelbox.data.annotation_types.data.video import VideoData +from labelbox.data.annotation_types.data import GenericDataRowData -from ...annotation_types.data import ImageData, TextData, MaskData +from ...annotation_types.data import GenericDataRowData from ...annotation_types.ner import ( DocumentEntity, DocumentTextSelection, @@ -52,7 +53,7 @@ NDSubclassification, NDSubclassificationType, ) -from .base import DataRow, NDAnnotation, NDJsonBase, _SubclassRegistryBase +from .base import DataRow, NDAnnotation, NDJsonBase from pydantic import BaseModel @@ -81,9 +82,7 @@ class Bbox(BaseModel): width: float -class NDPoint( - NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase -): +class NDPoint(NDBaseObject, ConfidenceMixin, CustomMetricsMixin): point: _Point def to_common(self) -> Point: @@ -98,7 +97,7 @@ def from_common( name: str, feature_schema_id: Cuid, extra: Dict[str, Any], - data: Union[ImageData, TextData], + data: GenericDataRowData, confidence: Optional[float] = None, custom_metrics: Optional[List[CustomMetric]] = None, ) -> "NDPoint": @@ -114,7 +113,7 @@ def from_common( ) -class NDFramePoint(VideoSupported, _SubclassRegistryBase): +class NDFramePoint(VideoSupported): point: _Point classifications: List[NDSubclassificationType] = [] @@ -148,9 +147,7 @@ def from_common( ) -class NDLine( - NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase -): +class NDLine(NDBaseObject, ConfidenceMixin, CustomMetricsMixin): line: List[_Point] def to_common(self) -> Line: @@ -165,7 +162,7 @@ def from_common( name: str, feature_schema_id: Cuid, extra: Dict[str, Any], - data: Union[ImageData, TextData], + data: GenericDataRowData, confidence: Optional[float] = None, custom_metrics: Optional[List[CustomMetric]] = None, ) -> "NDLine": @@ -181,7 +178,7 @@ def from_common( ) -class NDFrameLine(VideoSupported, _SubclassRegistryBase): +class NDFrameLine(VideoSupported): line: List[_Point] classifications: List[NDSubclassificationType] = [] @@ -215,7 +212,7 @@ def from_common( ) -class NDDicomLine(NDFrameLine, _SubclassRegistryBase): +class NDDicomLine(NDFrameLine): def to_common( self, name: str, @@ -234,9 +231,7 @@ def to_common( ) -class NDPolygon( - NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase -): +class NDPolygon(NDBaseObject, ConfidenceMixin, CustomMetricsMixin): polygon: List[_Point] def to_common(self) -> Polygon: @@ -251,7 +246,7 @@ def from_common( name: str, feature_schema_id: Cuid, extra: Dict[str, Any], - data: Union[ImageData, TextData], + data: GenericDataRowData, confidence: Optional[float] = None, custom_metrics: Optional[List[CustomMetric]] = None, ) -> "NDPolygon": @@ -267,9 +262,7 @@ def from_common( ) -class NDRectangle( - NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase -): +class NDRectangle(NDBaseObject, ConfidenceMixin, CustomMetricsMixin): bbox: Bbox def to_common(self) -> Rectangle: @@ -290,7 +283,7 @@ def from_common( name: str, feature_schema_id: Cuid, extra: Dict[str, Any], - data: Union[ImageData, TextData], + data: GenericDataRowData, confidence: Optional[float] = None, custom_metrics: Optional[List[CustomMetric]] = None, ) -> "NDRectangle": @@ -313,7 +306,7 @@ def from_common( ) -class NDDocumentRectangle(NDRectangle, _SubclassRegistryBase): +class NDDocumentRectangle(NDRectangle): page: int unit: str @@ -337,7 +330,7 @@ def from_common( name: str, feature_schema_id: Cuid, extra: Dict[str, Any], - data: Union[ImageData, TextData], + data: GenericDataRowData, confidence: Optional[float] = None, custom_metrics: Optional[List[CustomMetric]] = None, ) -> "NDRectangle": @@ -360,7 +353,7 @@ def from_common( ) -class NDFrameRectangle(VideoSupported, _SubclassRegistryBase): +class NDFrameRectangle(VideoSupported): bbox: Bbox classifications: List[NDSubclassificationType] = [] @@ -496,7 +489,7 @@ def to_common( ] -class NDSegments(NDBaseObject, _SubclassRegistryBase): +class NDSegments(NDBaseObject): segments: List[NDSegment] def to_common(self, name: str, feature_schema_id: Cuid): @@ -516,7 +509,7 @@ def to_common(self, name: str, feature_schema_id: Cuid): def from_common( cls, segments: List[VideoObjectAnnotation], - data: VideoData, + data: GenericDataRowData, name: str, feature_schema_id: Cuid, extra: Dict[str, Any], @@ -532,7 +525,7 @@ def from_common( ) -class NDDicomSegments(NDBaseObject, DicomSupported, _SubclassRegistryBase): +class NDDicomSegments(NDBaseObject, DicomSupported): segments: List[NDDicomSegment] def to_common(self, name: str, feature_schema_id: Cuid): @@ -553,7 +546,7 @@ def to_common(self, name: str, feature_schema_id: Cuid): def from_common( cls, segments: List[DICOMObjectAnnotation], - data: VideoData, + data: GenericDataRowData, name: str, feature_schema_id: Cuid, extra: Dict[str, Any], @@ -580,9 +573,7 @@ class _PNGMask(BaseModel): png: str -class NDMask( - NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase -): +class NDMask(NDBaseObject, ConfidenceMixin, CustomMetricsMixin): mask: Union[_URIMask, _PNGMask] def to_common(self) -> Mask: @@ -611,7 +602,7 @@ def from_common( name: str, feature_schema_id: Cuid, extra: Dict[str, Any], - data: Union[ImageData, TextData], + data: GenericDataRowData, confidence: Optional[float] = None, custom_metrics: Optional[List[CustomMetric]] = None, ) -> "NDMask": @@ -646,7 +637,6 @@ class NDVideoMasks( NDJsonBase, ConfidenceMixin, CustomMetricsNotSupportedMixin, - _SubclassRegistryBase, ): masks: NDVideoMasksFramesInstances @@ -678,7 +668,7 @@ def from_common(cls, annotation, data): ) -class NDDicomMasks(NDVideoMasks, DicomSupported, _SubclassRegistryBase): +class NDDicomMasks(NDVideoMasks, DicomSupported): def to_common(self) -> DICOMMaskAnnotation: return DICOMMaskAnnotation( frames=self.masks.frames, @@ -702,9 +692,7 @@ class Location(BaseModel): end: int -class NDTextEntity( - NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase -): +class NDTextEntity(NDBaseObject, ConfidenceMixin, CustomMetricsMixin): location: Location def to_common(self) -> TextEntity: @@ -719,7 +707,7 @@ def from_common( name: str, feature_schema_id: Cuid, extra: Dict[str, Any], - data: Union[ImageData, TextData], + data: GenericDataRowData, confidence: Optional[float] = None, custom_metrics: Optional[List[CustomMetric]] = None, ) -> "NDTextEntity": @@ -738,9 +726,7 @@ def from_common( ) -class NDDocumentEntity( - NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase -): +class NDDocumentEntity(NDBaseObject, ConfidenceMixin, CustomMetricsMixin): name: str text_selections: List[DocumentTextSelection] @@ -758,7 +744,7 @@ def from_common( name: str, feature_schema_id: Cuid, extra: Dict[str, Any], - data: Union[ImageData, TextData], + data: GenericDataRowData, confidence: Optional[float] = None, custom_metrics: Optional[List[CustomMetric]] = None, ) -> "NDDocumentEntity": @@ -774,7 +760,7 @@ def from_common( ) -class NDConversationEntity(NDTextEntity, _SubclassRegistryBase): +class NDConversationEntity(NDTextEntity): message_id: str def to_common(self) -> ConversationEntity: @@ -793,7 +779,7 @@ def from_common( name: str, feature_schema_id: Cuid, extra: Dict[str, Any], - data: Union[ImageData, TextData], + data: GenericDataRowData, confidence: Optional[float] = None, custom_metrics: Optional[List[CustomMetric]] = None, ) -> "NDConversationEntity": @@ -851,7 +837,7 @@ def from_common( List[List[VideoObjectAnnotation]], VideoMaskAnnotation, ], - data: Union[ImageData, TextData], + data: GenericDataRowData, ) -> Union[ NDLine, NDPoint, diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/relationship.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/relationship.py index fbea7e477..d558ac244 100644 --- a/libs/labelbox/src/labelbox/data/serialization/ndjson/relationship.py +++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/relationship.py @@ -1,11 +1,11 @@ from typing import Union from pydantic import BaseModel from .base import NDAnnotation, DataRow -from ...annotation_types.data import ImageData, TextData +from ...annotation_types.data import GenericDataRowData from ...annotation_types.relationship import RelationshipAnnotation from ...annotation_types.relationship import Relationship from .objects import NDObjectType -from .base import DataRow, _SubclassRegistryBase +from .base import DataRow SUPPORTED_ANNOTATIONS = NDObjectType @@ -16,7 +16,7 @@ class _Relationship(BaseModel): type: str -class NDRelationship(NDAnnotation, _SubclassRegistryBase): +class NDRelationship(NDAnnotation): relationship: _Relationship @staticmethod @@ -40,7 +40,7 @@ def to_common( def from_common( cls, annotation: RelationshipAnnotation, - data: Union[ImageData, TextData], + data: GenericDataRowData, ) -> "NDRelationship": relationship = annotation.value return cls( diff --git a/libs/labelbox/src/labelbox/orm/db_object.py b/libs/labelbox/src/labelbox/orm/db_object.py index b210a8a5b..a1c2bde38 100644 --- a/libs/labelbox/src/labelbox/orm/db_object.py +++ b/libs/labelbox/src/labelbox/orm/db_object.py @@ -1,17 +1,17 @@ -from dataclasses import dataclass +import json +import logging from datetime import datetime, timezone from functools import wraps -import logging -import json -from labelbox import utils -from labelbox.exceptions import ( - InvalidQueryError, +from lbox.exceptions import ( InvalidAttributeError, + InvalidQueryError, OperationNotSupportedException, ) + +from labelbox import utils from labelbox.orm import query -from labelbox.orm.model import Field, Relationship, Entity +from labelbox.orm.model import Entity, Field, Relationship from labelbox.pagination import PaginatedCollection logger = logging.getLogger(__name__) @@ -177,7 +177,7 @@ def _to_many(self, where=None, order_by=None): ) if rel.filter_deleted: - not_deleted = rel.destination_type.deleted == False + not_deleted = rel.destination_type.deleted == False # noqa: E712 Needed for bit operator to combine comparisons where = not_deleted if where is None else where & not_deleted query_string, params = query.relationship( diff --git a/libs/labelbox/src/labelbox/orm/model.py b/libs/labelbox/src/labelbox/orm/model.py index 84dcac774..535ab0f7d 100644 --- a/libs/labelbox/src/labelbox/orm/model.py +++ b/libs/labelbox/src/labelbox/orm/model.py @@ -1,10 +1,11 @@ from dataclasses import dataclass from enum import Enum, auto -from typing import Dict, List, Union, Any, Type, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, List, Type, Union + +from lbox.exceptions import InvalidAttributeError import labelbox from labelbox import utils -from labelbox.exceptions import InvalidAttributeError from labelbox.orm.comparison import Comparison """ Defines Field, Relationship and Entity. These classes are building @@ -386,7 +387,6 @@ class Entity(metaclass=EntityMeta): Review: Type[labelbox.Review] User: Type[labelbox.User] LabelingFrontend: Type[labelbox.LabelingFrontend] - BulkImportRequest: Type[labelbox.BulkImportRequest] Benchmark: Type[labelbox.Benchmark] IAMIntegration: Type[labelbox.IAMIntegration] LabelingFrontendOptions: Type[labelbox.LabelingFrontendOptions] diff --git a/libs/labelbox/src/labelbox/orm/query.py b/libs/labelbox/src/labelbox/orm/query.py index 8fa9fea00..8c019fb5e 100644 --- a/libs/labelbox/src/labelbox/orm/query.py +++ b/libs/labelbox/src/labelbox/orm/query.py @@ -1,14 +1,15 @@ from itertools import chain from typing import Any, Dict -from labelbox import utils -from labelbox.exceptions import ( - InvalidQueryError, +from lbox.exceptions import ( InvalidAttributeError, + InvalidQueryError, MalformedQueryException, ) -from labelbox.orm.comparison import LogicalExpression, Comparison -from labelbox.orm.model import Field, Relationship, Entity + +from labelbox import utils +from labelbox.orm.comparison import Comparison, LogicalExpression +from labelbox.orm.model import Entity, Field, Relationship """ Common query creation functionality. """ diff --git a/libs/labelbox/src/labelbox/project_validation.py b/libs/labelbox/src/labelbox/project_validation.py new file mode 100644 index 000000000..2a6db9e2a --- /dev/null +++ b/libs/labelbox/src/labelbox/project_validation.py @@ -0,0 +1,87 @@ +from typing import Annotated, Optional, Set + +from pydantic import BaseModel, ConfigDict, Field, model_validator + +from labelbox.schema.media_type import MediaType +from labelbox.schema.ontology_kind import EditorTaskType +from labelbox.schema.quality_mode import ( + BENCHMARK_AUTO_AUDIT_NUMBER_OF_LABELS, + BENCHMARK_AUTO_AUDIT_PERCENTAGE, + CONSENSUS_AUTO_AUDIT_NUMBER_OF_LABELS, + CONSENSUS_AUTO_AUDIT_PERCENTAGE, + QualityMode, +) + +PositiveInt = Annotated[int, Field(gt=0)] + + +class _CoreProjectInput(BaseModel): + name: str + description: Optional[str] = None + media_type: MediaType + auto_audit_percentage: Optional[float] = None + auto_audit_number_of_labels: Optional[int] = None + quality_modes: Optional[Set[QualityMode]] = Field( + default={QualityMode.Benchmark, QualityMode.Consensus}, exclude=True + ) + is_benchmark_enabled: Optional[bool] = None + is_consensus_enabled: Optional[bool] = None + dataset_name_or_id: Optional[str] = None + append_to_existing_dataset: Optional[bool] = None + data_row_count: Optional[PositiveInt] = None + editor_task_type: Optional[EditorTaskType] = None + + model_config = ConfigDict(extra="forbid", use_enum_values=True) + + @model_validator(mode="after") + def validate_fields(self): + if ( + self.auto_audit_percentage is not None + or self.auto_audit_number_of_labels is not None + ): + raise ValueError( + "quality_modes must be set instead of auto_audit_percentage or auto_audit_number_of_labels." + ) + + if not self.name.strip(): + raise ValueError("project name must be a valid string.") + + if self.quality_modes == { + QualityMode.Benchmark, + QualityMode.Consensus, + }: + self._set_quality_mode_attributes( + CONSENSUS_AUTO_AUDIT_NUMBER_OF_LABELS, + CONSENSUS_AUTO_AUDIT_PERCENTAGE, + is_benchmark_enabled=True, + is_consensus_enabled=True, + ) + elif self.quality_modes == {QualityMode.Benchmark}: + self._set_quality_mode_attributes( + BENCHMARK_AUTO_AUDIT_NUMBER_OF_LABELS, + BENCHMARK_AUTO_AUDIT_PERCENTAGE, + is_benchmark_enabled=True, + ) + elif self.quality_modes == {QualityMode.Consensus}: + self._set_quality_mode_attributes( + number_of_labels=CONSENSUS_AUTO_AUDIT_NUMBER_OF_LABELS, + percentage=CONSENSUS_AUTO_AUDIT_PERCENTAGE, + is_consensus_enabled=True, + ) + + if self.data_row_count is not None and self.data_row_count < 0: + raise ValueError("data_row_count must be a positive integer.") + + return self + + def _set_quality_mode_attributes( + self, + number_of_labels, + percentage, + is_benchmark_enabled=False, + is_consensus_enabled=False, + ): + self.auto_audit_number_of_labels = number_of_labels + self.auto_audit_percentage = percentage + self.is_benchmark_enabled = is_benchmark_enabled + self.is_consensus_enabled = is_consensus_enabled diff --git a/libs/labelbox/src/labelbox/schema/__init__.py b/libs/labelbox/src/labelbox/schema/__init__.py index 03327e0d1..d6b74de68 100644 --- a/libs/labelbox/src/labelbox/schema/__init__.py +++ b/libs/labelbox/src/labelbox/schema/__init__.py @@ -1,5 +1,4 @@ import labelbox.schema.asset_attachment -import labelbox.schema.bulk_import_request import labelbox.schema.annotation_import import labelbox.schema.benchmark import labelbox.schema.data_row diff --git a/libs/labelbox/src/labelbox/schema/annotation_import.py b/libs/labelbox/src/labelbox/schema/annotation_import.py index df7f272a3..497ac899d 100644 --- a/libs/labelbox/src/labelbox/schema/annotation_import.py +++ b/libs/labelbox/src/labelbox/schema/annotation_import.py @@ -3,33 +3,34 @@ import logging import os import time +from collections import defaultdict from typing import ( + TYPE_CHECKING, Any, BinaryIO, Dict, List, Optional, Union, - TYPE_CHECKING, cast, ) -from collections import defaultdict -from google.api_core import retry -from labelbox import parser import requests +from google.api_core import retry +from lbox.exceptions import ApiLimitError, NetworkError, ResourceNotFoundError from tqdm import tqdm # type: ignore import labelbox +from labelbox import parser from labelbox.orm import query from labelbox.orm.db_object import DbObject from labelbox.orm.model import Field, Relationship -from labelbox.utils import is_exactly_one_set from labelbox.schema.confidence_presence_checker import ( LabelsConfidencePresenceChecker, ) from labelbox.schema.enums import AnnotationImportState from labelbox.schema.serialization import serialize_labels +from labelbox.utils import is_exactly_one_set if TYPE_CHECKING: from labelbox.types import Label @@ -140,9 +141,9 @@ def wait_until_done( @retry.Retry( predicate=retry.if_exception_type( - labelbox.exceptions.ApiLimitError, - labelbox.exceptions.TimeoutError, - labelbox.exceptions.NetworkError, + ApiLimitError, + TimeoutError, + NetworkError, ) ) def __backoff_refresh(self) -> None: @@ -435,9 +436,7 @@ def from_name( } response = client.execute(query_str, params) if response is None: - raise labelbox.exceptions.ResourceNotFoundError( - MEAPredictionImport, params - ) + raise ResourceNotFoundError(MEAPredictionImport, params) response = response["modelErrorAnalysisPredictionImport"] if as_json: return response @@ -560,9 +559,7 @@ def from_name( } response = client.execute(query_str, params) if response is None: - raise labelbox.exceptions.ResourceNotFoundError( - MALPredictionImport, params - ) + raise ResourceNotFoundError(MALPredictionImport, params) response = response["meaToMalPredictionImport"] if as_json: return response @@ -709,9 +706,7 @@ def from_name( } response = client.execute(query_str, params) if response is None: - raise labelbox.exceptions.ResourceNotFoundError( - MALPredictionImport, params - ) + raise ResourceNotFoundError(MALPredictionImport, params) response = response["modelAssistedLabelingPredictionImport"] if as_json: return response @@ -885,7 +880,7 @@ def from_name( } response = client.execute(query_str, params) if response is None: - raise labelbox.exceptions.ResourceNotFoundError(LabelImport, params) + raise ResourceNotFoundError(LabelImport, params) response = response["labelImport"] if as_json: return response diff --git a/libs/labelbox/src/labelbox/schema/asset_attachment.py b/libs/labelbox/src/labelbox/schema/asset_attachment.py index 0d5598c84..9a56dbb72 100644 --- a/libs/labelbox/src/labelbox/schema/asset_attachment.py +++ b/libs/labelbox/src/labelbox/schema/asset_attachment.py @@ -7,15 +7,6 @@ class AttachmentType(str, Enum): - @classmethod - def __missing__(cls, value: object): - if str(value) == "TEXT": - warnings.warn( - "The TEXT attachment type is deprecated. Use RAW_TEXT instead." - ) - return cls.RAW_TEXT - return value - VIDEO = "VIDEO" IMAGE = "IMAGE" IMAGE_OVERLAY = "IMAGE_OVERLAY" @@ -30,7 +21,7 @@ class AssetAttachment(DbObject): """Asset attachment provides extra context about an asset while labeling. Attributes: - attachment_type (str): IMAGE, VIDEO, IMAGE_OVERLAY, HTML, RAW_TEXT, TEXT_URL, or PDF_URL. TEXT attachment type is deprecated. + attachment_type (str): IMAGE, VIDEO, IMAGE_OVERLAY, HTML, RAW_TEXT, TEXT_URL, or PDF_URL. attachment_value (str): URL to an external file or a string of text attachment_name (str): The name of the attachment """ diff --git a/libs/labelbox/src/labelbox/schema/batch.py b/libs/labelbox/src/labelbox/schema/batch.py index 316732a67..99e4c4908 100644 --- a/libs/labelbox/src/labelbox/schema/batch.py +++ b/libs/labelbox/src/labelbox/schema/batch.py @@ -1,15 +1,11 @@ -from typing import Generator, TYPE_CHECKING +import logging +from typing import TYPE_CHECKING + +from lbox.exceptions import ResourceNotFoundError -from labelbox.orm.db_object import DbObject, experimental from labelbox.orm import query +from labelbox.orm.db_object import DbObject from labelbox.orm.model import Entity, Field, Relationship -from labelbox.exceptions import LabelboxError, ResourceNotFoundError -from io import StringIO -from labelbox import parser -import requests -import logging -import time -import warnings if TYPE_CHECKING: from labelbox import Project diff --git a/libs/labelbox/src/labelbox/schema/bulk_import_request.py b/libs/labelbox/src/labelbox/schema/bulk_import_request.py deleted file mode 100644 index 8e11f3261..000000000 --- a/libs/labelbox/src/labelbox/schema/bulk_import_request.py +++ /dev/null @@ -1,1004 +0,0 @@ -import json -import time -from uuid import UUID, uuid4 -import functools - -import logging -from pathlib import Path -from google.api_core import retry -from labelbox import parser -import requests -from pydantic import ( - ValidationError, - BaseModel, - Field, - field_validator, - model_validator, - ConfigDict, - StringConstraints, -) -from typing_extensions import Literal, Annotated -from typing import ( - Any, - List, - Optional, - BinaryIO, - Dict, - Iterable, - Tuple, - Union, - Type, - Set, - TYPE_CHECKING, -) - -from labelbox import exceptions as lb_exceptions -from labelbox import utils -from labelbox.orm import query -from labelbox.orm.db_object import DbObject -from labelbox.orm.model import Relationship -from labelbox.schema.enums import BulkImportRequestState -from labelbox.schema.serialization import serialize_labels -from labelbox.orm.model import Field as lb_Field - -if TYPE_CHECKING: - from labelbox import Project - from labelbox.types import Label - -NDJSON_MIME_TYPE = "application/x-ndjson" -logger = logging.getLogger(__name__) - -# TODO: Deprecate this library in place of labelimport and malprediction import library. - - -def _determinants(parent_cls: Any) -> List[str]: - return [ - k - for k, v in parent_cls.model_fields.items() - if v.json_schema_extra and "determinant" in v.json_schema_extra - ] - - -def _make_file_name(project_id: str, name: str) -> str: - return f"{project_id}__{name}.ndjson" - - -# TODO(gszpak): move it to client.py -def _make_request_data( - project_id: str, name: str, content_length: int, file_name: str -) -> dict: - query_str = """mutation createBulkImportRequestFromFilePyApi( - $projectId: ID!, $name: String!, $file: Upload!, $contentLength: Int!) { - createBulkImportRequest(data: { - projectId: $projectId, - name: $name, - filePayload: { - file: $file, - contentLength: $contentLength - } - }) { - %s - } - } - """ % query.results_query_part(BulkImportRequest) - variables = { - "projectId": project_id, - "name": name, - "file": None, - "contentLength": content_length, - } - operations = json.dumps({"variables": variables, "query": query_str}) - - return { - "operations": operations, - "map": (None, json.dumps({file_name: ["variables.file"]})), - } - - -def _send_create_file_command( - client, - request_data: dict, - file_name: str, - file_data: Tuple[str, Union[bytes, BinaryIO], str], -) -> dict: - response = client.execute(data=request_data, files={file_name: file_data}) - - if not response.get("createBulkImportRequest", None): - raise lb_exceptions.LabelboxError( - "Failed to create BulkImportRequest, message: %s" - % response.get("errors", None) - or response.get("error", None) - ) - - return response - - -class BulkImportRequest(DbObject): - """Represents the import job when importing annotations. - - Attributes: - name (str) - state (Enum): FAILED, RUNNING, or FINISHED (Refers to the whole import job) - input_file_url (str): URL to your web-hosted NDJSON file - error_file_url (str): NDJSON that contains error messages for failed annotations - status_file_url (str): NDJSON that contains status for each annotation - created_at (datetime): UTC timestamp for date BulkImportRequest was created - - project (Relationship): `ToOne` relationship to Project - created_by (Relationship): `ToOne` relationship to User - """ - - name = lb_Field.String("name") - state = lb_Field.Enum(BulkImportRequestState, "state") - input_file_url = lb_Field.String("input_file_url") - error_file_url = lb_Field.String("error_file_url") - status_file_url = lb_Field.String("status_file_url") - created_at = lb_Field.DateTime("created_at") - - project = Relationship.ToOne("Project") - created_by = Relationship.ToOne("User", False, "created_by") - - @property - def inputs(self) -> List[Dict[str, Any]]: - """ - Inputs for each individual annotation uploaded. - This should match the ndjson annotations that you have uploaded. - - Returns: - Uploaded ndjson. - - * This information will expire after 24 hours. - """ - return self._fetch_remote_ndjson(self.input_file_url) - - @property - def errors(self) -> List[Dict[str, Any]]: - """ - Errors for each individual annotation uploaded. This is a subset of statuses - - Returns: - List of dicts containing error messages. Empty list means there were no errors - See `BulkImportRequest.statuses` for more details. - - * This information will expire after 24 hours. - """ - self.wait_until_done() - return self._fetch_remote_ndjson(self.error_file_url) - - @property - def statuses(self) -> List[Dict[str, Any]]: - """ - Status for each individual annotation uploaded. - - Returns: - A status for each annotation if the upload is done running. - See below table for more details - - .. list-table:: - :widths: 15 150 - :header-rows: 1 - - * - Field - - Description - * - uuid - - Specifies the annotation for the status row. - * - dataRow - - JSON object containing the Labelbox data row ID for the annotation. - * - status - - Indicates SUCCESS or FAILURE. - * - errors - - An array of error messages included when status is FAILURE. Each error has a name, message and optional (key might not exist) additional_info. - - * This information will expire after 24 hours. - """ - self.wait_until_done() - return self._fetch_remote_ndjson(self.status_file_url) - - @functools.lru_cache() - def _fetch_remote_ndjson(self, url: str) -> List[Dict[str, Any]]: - """ - Fetches the remote ndjson file and caches the results. - - Args: - url (str): Can be any url pointing to an ndjson file. - Returns: - ndjson as a list of dicts. - """ - response = requests.get(url) - response.raise_for_status() - return parser.loads(response.text) - - def refresh(self) -> None: - """Synchronizes values of all fields with the database.""" - query_str, params = query.get_single(BulkImportRequest, self.uid) - res = self.client.execute(query_str, params) - res = res[utils.camel_case(BulkImportRequest.type_name())] - self._set_field_values(res) - - def wait_till_done(self, sleep_time_seconds: int = 5) -> None: - self.wait_until_done(sleep_time_seconds) - - def wait_until_done(self, sleep_time_seconds: int = 5) -> None: - """Blocks import job until certain conditions are met. - - Blocks until the BulkImportRequest.state changes either to - `BulkImportRequestState.FINISHED` or `BulkImportRequestState.FAILED`, - periodically refreshing object's state. - - Args: - sleep_time_seconds (str): a time to block between subsequent API calls - """ - while self.state == BulkImportRequestState.RUNNING: - logger.info(f"Sleeping for {sleep_time_seconds} seconds...") - time.sleep(sleep_time_seconds) - self.__exponential_backoff_refresh() - - @retry.Retry( - predicate=retry.if_exception_type( - lb_exceptions.ApiLimitError, - lb_exceptions.TimeoutError, - lb_exceptions.NetworkError, - ) - ) - def __exponential_backoff_refresh(self) -> None: - self.refresh() - - @classmethod - def from_name( - cls, client, project_id: str, name: str - ) -> "BulkImportRequest": - """Fetches existing BulkImportRequest. - - Args: - client (Client): a Labelbox client - project_id (str): BulkImportRequest's project id - name (str): name of BulkImportRequest - Returns: - BulkImportRequest object - - """ - query_str = """query getBulkImportRequestPyApi( - $projectId: ID!, $name: String!) { - bulkImportRequest(where: { - projectId: $projectId, - name: $name - }) { - %s - } - } - """ % query.results_query_part(cls) - params = {"projectId": project_id, "name": name} - response = client.execute(query_str, params=params) - return cls(client, response["bulkImportRequest"]) - - @classmethod - def create_from_url( - cls, client, project_id: str, name: str, url: str, validate=True - ) -> "BulkImportRequest": - """ - Creates a BulkImportRequest from a publicly accessible URL - to an ndjson file with predictions. - - Args: - client (Client): a Labelbox client - project_id (str): id of project for which predictions will be imported - name (str): name of BulkImportRequest - url (str): publicly accessible URL pointing to ndjson file containing predictions - validate (bool): a flag indicating if there should be a validation - if `url` is valid ndjson - Returns: - BulkImportRequest object - """ - if validate: - logger.warn( - "Validation is turned on. The file will be downloaded locally and processed before uploading." - ) - res = requests.get(url) - data = parser.loads(res.text) - _validate_ndjson(data, client.get_project(project_id)) - - query_str = """mutation createBulkImportRequestPyApi( - $projectId: ID!, $name: String!, $fileUrl: String!) { - createBulkImportRequest(data: { - projectId: $projectId, - name: $name, - fileUrl: $fileUrl - }) { - %s - } - } - """ % query.results_query_part(cls) - params = {"projectId": project_id, "name": name, "fileUrl": url} - bulk_import_request_response = client.execute(query_str, params=params) - return cls( - client, bulk_import_request_response["createBulkImportRequest"] - ) - - @classmethod - def create_from_objects( - cls, - client, - project_id: str, - name: str, - predictions: Union[Iterable[Dict], Iterable["Label"]], - validate=True, - ) -> "BulkImportRequest": - """ - Creates a `BulkImportRequest` from an iterable of dictionaries. - - Conforms to JSON predictions format, e.g.: - ``{ - "uuid": "9fd9a92e-2560-4e77-81d4-b2e955800092", - "schemaId": "ckappz7d700gn0zbocmqkwd9i", - "dataRow": { - "id": "ck1s02fqxm8fi0757f0e6qtdc" - }, - "bbox": { - "top": 48, - "left": 58, - "height": 865, - "width": 1512 - } - }`` - - Args: - client (Client): a Labelbox client - project_id (str): id of project for which predictions will be imported - name (str): name of BulkImportRequest - predictions (Iterable[dict]): iterable of dictionaries representing predictions - validate (bool): a flag indicating if there should be a validation - if `predictions` is valid ndjson - Returns: - BulkImportRequest object - """ - if not isinstance(predictions, list): - raise TypeError( - f"annotations must be in a form of Iterable. Found {type(predictions)}" - ) - ndjson_predictions = serialize_labels(predictions) - - if validate: - _validate_ndjson(ndjson_predictions, client.get_project(project_id)) - - data_str = parser.dumps(ndjson_predictions) - if not data_str: - raise ValueError("annotations cannot be empty") - - data = data_str.encode("utf-8") - file_name = _make_file_name(project_id, name) - request_data = _make_request_data( - project_id, name, len(data_str), file_name - ) - file_data = (file_name, data, NDJSON_MIME_TYPE) - response_data = _send_create_file_command( - client, - request_data=request_data, - file_name=file_name, - file_data=file_data, - ) - - return cls(client, response_data["createBulkImportRequest"]) - - @classmethod - def create_from_local_file( - cls, client, project_id: str, name: str, file: Path, validate_file=True - ) -> "BulkImportRequest": - """ - Creates a BulkImportRequest from a local ndjson file with predictions. - - Args: - client (Client): a Labelbox client - project_id (str): id of project for which predictions will be imported - name (str): name of BulkImportRequest - file (Path): local ndjson file with predictions - validate_file (bool): a flag indicating if there should be a validation - if `file` is a valid ndjson file - Returns: - BulkImportRequest object - - """ - file_name = _make_file_name(project_id, name) - content_length = file.stat().st_size - request_data = _make_request_data( - project_id, name, content_length, file_name - ) - - with file.open("rb") as f: - if validate_file: - reader = parser.reader(f) - # ensure that the underlying json load call is valid - # https://github.com/rhgrant10/ndjson/blob/ff2f03c56b21f28f7271b27da35ca4a8bf9a05d0/ndjson/api.py#L53 - # by iterating through the file so we only store - # each line in memory rather than the entire file - try: - _validate_ndjson(reader, client.get_project(project_id)) - except ValueError: - raise ValueError(f"{file} is not a valid ndjson file") - else: - f.seek(0) - file_data = (file.name, f, NDJSON_MIME_TYPE) - response_data = _send_create_file_command( - client, request_data, file_name, file_data - ) - return cls(client, response_data["createBulkImportRequest"]) - - def delete(self) -> None: - """Deletes the import job and also any annotations created by this import. - - Returns: - None - """ - id_param = "bulk_request_id" - query_str = """mutation deleteBulkImportRequestPyApi($%s: ID!) { - deleteBulkImportRequest(where: {id: $%s}) { - id - name - } - }""" % (id_param, id_param) - self.client.execute(query_str, {id_param: self.uid}) - - -def _validate_ndjson( - lines: Iterable[Dict[str, Any]], project: "Project" -) -> None: - """ - Client side validation of an ndjson object. - - Does not guarentee that an upload will succeed for the following reasons: - * We are not checking the data row types which will cause the following errors to slip through - * Missing frame indices will not causes an error for videos - * Uploaded annotations for the wrong data type will pass (Eg. entity on images) - * We are not checking bounds of an asset (Eg. frame index, image height, text location) - - Args: - lines (Iterable[Dict[str,Any]]): An iterable of ndjson lines - project (Project): id of project for which predictions will be imported - - Raises: - MALValidationError: Raise for invalid NDJson - UuidError: Duplicate UUID in upload - """ - feature_schemas_by_id, feature_schemas_by_name = get_mal_schemas( - project.ontology() - ) - uids: Set[str] = set() - for idx, line in enumerate(lines): - try: - annotation = NDAnnotation(**line) - annotation.validate_instance( - feature_schemas_by_id, feature_schemas_by_name - ) - uuid = str(annotation.uuid) - if uuid in uids: - raise lb_exceptions.UuidError( - f"{uuid} already used in this import job, " - "must be unique for the project." - ) - uids.add(uuid) - except (ValidationError, ValueError, TypeError, KeyError) as e: - raise lb_exceptions.MALValidationError( - f"Invalid NDJson on line {idx}" - ) from e - - -# The rest of this file contains objects for MAL validation -def parse_classification(tool): - """ - Parses a classification from an ontology. Only radio, checklist, and text are supported for mal - - Args: - tool (dict) - - Returns: - dict - """ - if tool["type"] in ["radio", "checklist"]: - option_schema_ids = [r["featureSchemaId"] for r in tool["options"]] - option_names = [r["value"] for r in tool["options"]] - return { - "tool": tool["type"], - "featureSchemaId": tool["featureSchemaId"], - "name": tool["name"], - "options": [*option_schema_ids, *option_names], - } - elif tool["type"] == "text": - return { - "tool": tool["type"], - "name": tool["name"], - "featureSchemaId": tool["featureSchemaId"], - } - - -def get_mal_schemas(ontology): - """ - Converts a project ontology to a dict for easier lookup during ndjson validation - - Args: - ontology (Ontology) - Returns: - Dict, Dict : Useful for looking up a tool from a given feature schema id or name - """ - - valid_feature_schemas_by_schema_id = {} - valid_feature_schemas_by_name = {} - for tool in ontology.normalized["tools"]: - classifications = [ - parse_classification(classification_tool) - for classification_tool in tool["classifications"] - ] - classifications_by_schema_id = { - v["featureSchemaId"]: v for v in classifications - } - classifications_by_name = {v["name"]: v for v in classifications} - valid_feature_schemas_by_schema_id[tool["featureSchemaId"]] = { - "tool": tool["tool"], - "classificationsBySchemaId": classifications_by_schema_id, - "classificationsByName": classifications_by_name, - "name": tool["name"], - } - valid_feature_schemas_by_name[tool["name"]] = { - "tool": tool["tool"], - "classificationsBySchemaId": classifications_by_schema_id, - "classificationsByName": classifications_by_name, - "name": tool["name"], - } - for tool in ontology.normalized["classifications"]: - valid_feature_schemas_by_schema_id[tool["featureSchemaId"]] = ( - parse_classification(tool) - ) - valid_feature_schemas_by_name[tool["name"]] = parse_classification(tool) - return valid_feature_schemas_by_schema_id, valid_feature_schemas_by_name - - -class Bbox(BaseModel): - top: float - left: float - height: float - width: float - - -class Point(BaseModel): - x: float - y: float - - -class FrameLocation(BaseModel): - end: int - start: int - - -class VideoSupported(BaseModel): - # Note that frames are only allowed as top level inferences for video - frames: Optional[List[FrameLocation]] = None - - -# Base class for a special kind of union. -class SpecialUnion: - def __new__(cls, **kwargs): - return cls.build(kwargs) - - @classmethod - def __get_validators__(cls): - yield cls.build - - @classmethod - def get_union_types(cls): - if not issubclass(cls, SpecialUnion): - raise TypeError("{} must be a subclass of SpecialUnion") - - union_types = [x for x in cls.__orig_bases__ if hasattr(x, "__args__")] - if len(union_types) < 1: - raise TypeError( - "Class {cls} should inherit from a union of objects to build" - ) - if len(union_types) > 1: - raise TypeError( - f"Class {cls} should inherit from exactly one union of objects to build. Found {union_types}" - ) - return union_types[0].__args__[0].__args__ - - @classmethod - def build(cls: Any, data: Union[dict, BaseModel]) -> "NDBase": - """ - Checks through all objects in the union to see which matches the input data. - Args: - data (Union[dict, BaseModel]) : The data for constructing one of the objects in the union - raises: - KeyError: data does not contain the determinant fields for any of the types supported by this SpecialUnion - ValidationError: Error while trying to construct a specific object in the union - - """ - if isinstance(data, BaseModel): - data = data.model_dump() - - top_level_fields = [] - max_match = 0 - matched = None - - for type_ in cls.get_union_types(): - determinate_fields = _determinants(type_) - top_level_fields.append(determinate_fields) - matches = sum([val in determinate_fields for val in data]) - if matches == len(determinate_fields) and matches > max_match: - max_match = matches - matched = type_ - - if matched is not None: - # These two have the exact same top level keys - if matched in [NDRadio, NDText]: - if isinstance(data["answer"], dict): - matched = NDRadio - elif isinstance(data["answer"], str): - matched = NDText - else: - raise TypeError( - f"Unexpected type for answer field. Found {data['answer']}. Expected a string or a dict" - ) - return matched(**data) - else: - raise KeyError( - f"Invalid annotation. Must have one of the following keys : {top_level_fields}. Found {data}." - ) - - @classmethod - def schema(cls): - results = {"definitions": {}} - for cl in cls.get_union_types(): - schema = cl.schema() - results["definitions"].update(schema.pop("definitions")) - results[cl.__name__] = schema - return results - - -class DataRow(BaseModel): - id: str - - -class NDFeatureSchema(BaseModel): - schemaId: Optional[str] = None - name: Optional[str] = None - - @model_validator(mode="after") - def most_set_one(self): - if self.schemaId is None and self.name is None: - raise ValueError( - "Must set either schemaId or name for all feature schemas" - ) - return self - - -class NDBase(NDFeatureSchema): - ontology_type: str - uuid: UUID - dataRow: DataRow - model_config = ConfigDict(extra="forbid") - - def validate_feature_schemas( - self, valid_feature_schemas_by_id, valid_feature_schemas_by_name - ): - if self.name: - if self.name not in valid_feature_schemas_by_name: - raise ValueError( - f"Name {self.name} is not valid for the provided project's ontology." - ) - - if ( - self.ontology_type - != valid_feature_schemas_by_name[self.name]["tool"] - ): - raise ValueError( - f"Name {self.name} does not map to the assigned tool {valid_feature_schemas_by_name[self.name]['tool']}" - ) - - if self.schemaId: - if self.schemaId not in valid_feature_schemas_by_id: - raise ValueError( - f"Schema id {self.schemaId} is not valid for the provided project's ontology." - ) - - if ( - self.ontology_type - != valid_feature_schemas_by_id[self.schemaId]["tool"] - ): - raise ValueError( - f"Schema id {self.schemaId} does not map to the assigned tool {valid_feature_schemas_by_id[self.schemaId]['tool']}" - ) - - def validate_instance( - self, valid_feature_schemas_by_id, valid_feature_schemas_by_name - ): - self.validate_feature_schemas( - valid_feature_schemas_by_id, valid_feature_schemas_by_name - ) - - -###### Classifications ###### - - -class NDText(NDBase): - ontology_type: Literal["text"] = "text" - answer: str = Field(json_schema_extra={"determinant": True}) - # No feature schema to check - - -class NDChecklist(VideoSupported, NDBase): - ontology_type: Literal["checklist"] = "checklist" - answers: List[NDFeatureSchema] = Field( - json_schema_extra={"determinant": True} - ) - - @field_validator("answers", mode="before") - def validate_answers(cls, value, field): - # constr not working with mypy. - if not len(value): - raise ValueError("Checklist answers should not be empty") - return value - - def validate_feature_schemas( - self, valid_feature_schemas_by_id, valid_feature_schemas_by_name - ): - # Test top level feature schema for this tool - super(NDChecklist, self).validate_feature_schemas( - valid_feature_schemas_by_id, valid_feature_schemas_by_name - ) - # Test the feature schemas provided to the answer field - if len( - set([answer.name or answer.schemaId for answer in self.answers]) - ) != len(self.answers): - raise ValueError( - f"Duplicated featureSchema found for checklist {self.uuid}" - ) - for answer in self.answers: - options = ( - valid_feature_schemas_by_name[self.name]["options"] - if self.name - else valid_feature_schemas_by_id[self.schemaId]["options"] - ) - if answer.name not in options and answer.schemaId not in options: - raise ValueError( - f"Feature schema provided to {self.ontology_type} invalid. Expected on of {options}. Found {answer}" - ) - - -class NDRadio(VideoSupported, NDBase): - ontology_type: Literal["radio"] = "radio" - answer: NDFeatureSchema = Field(json_schema_extra={"determinant": True}) - - def validate_feature_schemas( - self, valid_feature_schemas_by_id, valid_feature_schemas_by_name - ): - super(NDRadio, self).validate_feature_schemas( - valid_feature_schemas_by_id, valid_feature_schemas_by_name - ) - options = ( - valid_feature_schemas_by_name[self.name]["options"] - if self.name - else valid_feature_schemas_by_id[self.schemaId]["options"] - ) - if ( - self.answer.name not in options - and self.answer.schemaId not in options - ): - raise ValueError( - f"Feature schema provided to {self.ontology_type} invalid. Expected on of {options}. Found {self.answer.name or self.answer.schemaId}" - ) - - -# A union with custom construction logic to improve error messages -class NDClassification( - SpecialUnion, - Type[Union[NDText, NDRadio, NDChecklist]], # type: ignore -): ... - - -###### Tools ###### - - -class NDBaseTool(NDBase): - classifications: List[NDClassification] = [] - - # This is indepdent of our problem - def validate_feature_schemas( - self, valid_feature_schemas_by_id, valid_feature_schemas_by_name - ): - super(NDBaseTool, self).validate_feature_schemas( - valid_feature_schemas_by_id, valid_feature_schemas_by_name - ) - for classification in self.classifications: - classification.validate_feature_schemas( - valid_feature_schemas_by_name[self.name][ - "classificationsBySchemaId" - ] - if self.name - else valid_feature_schemas_by_id[self.schemaId][ - "classificationsBySchemaId" - ], - valid_feature_schemas_by_name[self.name][ - "classificationsByName" - ] - if self.name - else valid_feature_schemas_by_id[self.schemaId][ - "classificationsByName" - ], - ) - - @field_validator("classifications", mode="before") - def validate_subclasses(cls, value, field): - # Create uuid and datarow id so we don't have to define classification objects twice - # This is caused by the fact that we require these ids for top level classifications but not for subclasses - results = [] - dummy_id = "child".center(25, "_") - for row in value: - results.append( - {**row, "dataRow": {"id": dummy_id}, "uuid": str(uuid4())} - ) - return results - - -class NDPolygon(NDBaseTool): - ontology_type: Literal["polygon"] = "polygon" - polygon: List[Point] = Field(json_schema_extra={"determinant": True}) - - @field_validator("polygon") - def is_geom_valid(cls, v): - if len(v) < 3: - raise ValueError( - f"A polygon must have at least 3 points to be valid. Found {v}" - ) - return v - - -class NDPolyline(NDBaseTool): - ontology_type: Literal["line"] = "line" - line: List[Point] = Field(json_schema_extra={"determinant": True}) - - @field_validator("line") - def is_geom_valid(cls, v): - if len(v) < 2: - raise ValueError( - f"A line must have at least 2 points to be valid. Found {v}" - ) - return v - - -class NDRectangle(NDBaseTool): - ontology_type: Literal["rectangle"] = "rectangle" - bbox: Bbox = Field(json_schema_extra={"determinant": True}) - # Could check if points are positive - - -class NDPoint(NDBaseTool): - ontology_type: Literal["point"] = "point" - point: Point = Field(json_schema_extra={"determinant": True}) - # Could check if points are positive - - -class EntityLocation(BaseModel): - start: int - end: int - - -class NDTextEntity(NDBaseTool): - ontology_type: Literal["named-entity"] = "named-entity" - location: EntityLocation = Field(json_schema_extra={"determinant": True}) - - @field_validator("location") - def is_valid_location(cls, v): - if isinstance(v, BaseModel): - v = v.model_dump() - - if len(v) < 2: - raise ValueError( - f"A line must have at least 2 points to be valid. Found {v}" - ) - if v["start"] < 0: - raise ValueError(f"Text location must be positive. Found {v}") - if v["start"] > v["end"]: - raise ValueError( - f"Text start location must be less or equal than end. Found {v}" - ) - return v - - -class RLEMaskFeatures(BaseModel): - counts: List[int] - size: List[int] - - @field_validator("counts") - def validate_counts(cls, counts): - if not all([count >= 0 for count in counts]): - raise ValueError( - "Found negative value for counts. They should all be zero or positive" - ) - return counts - - @field_validator("size") - def validate_size(cls, size): - if len(size) != 2: - raise ValueError( - f"Mask `size` should have two ints representing height and with. Found : {size}" - ) - if not all([count > 0 for count in size]): - raise ValueError( - f"Mask `size` should be a postitive int. Found : {size}" - ) - return size - - -class PNGMaskFeatures(BaseModel): - # base64 encoded png bytes - png: str - - -class URIMaskFeatures(BaseModel): - instanceURI: str - colorRGB: Union[List[int], Tuple[int, int, int]] - - @field_validator("colorRGB") - def validate_color(cls, colorRGB): - # Does the dtype matter? Can it be a float? - if not isinstance(colorRGB, (tuple, list)): - raise ValueError( - f"Received color that is not a list or tuple. Found : {colorRGB}" - ) - elif len(colorRGB) != 3: - raise ValueError( - f"Must provide RGB values for segmentation colors. Found : {colorRGB}" - ) - elif not all([0 <= color <= 255 for color in colorRGB]): - raise ValueError( - f"All rgb colors must be between 0 and 255. Found : {colorRGB}" - ) - return colorRGB - - -class NDMask(NDBaseTool): - ontology_type: Literal["superpixel"] = "superpixel" - mask: Union[URIMaskFeatures, PNGMaskFeatures, RLEMaskFeatures] = Field( - json_schema_extra={"determinant": True} - ) - - -# A union with custom construction logic to improve error messages -class NDTool( - SpecialUnion, - Type[ # type: ignore - Union[ - NDMask, - NDTextEntity, - NDPoint, - NDRectangle, - NDPolyline, - NDPolygon, - ] - ], -): ... - - -class NDAnnotation( - SpecialUnion, - Type[Union[NDTool, NDClassification]], # type: ignore -): - @classmethod - def build(cls: Any, data) -> "NDBase": - if not isinstance(data, dict): - raise ValueError("value must be dict") - errors = [] - for cl in cls.get_union_types(): - try: - return cl(**data) - except KeyError as e: - errors.append(f"{cl.__name__}: {e}") - - raise ValueError( - "Unable to construct any annotation.\n{}".format("\n".join(errors)) - ) - - @classmethod - def schema(cls): - data = {"definitions": {}} - for type_ in cls.get_union_types(): - schema_ = type_.schema() - data["definitions"].update(schema_.pop("definitions")) - data[type_.__name__] = schema_ - return data diff --git a/libs/labelbox/src/labelbox/schema/catalog.py b/libs/labelbox/src/labelbox/schema/catalog.py index 567bbd777..8d9646779 100644 --- a/libs/labelbox/src/labelbox/schema/catalog.py +++ b/libs/labelbox/src/labelbox/schema/catalog.py @@ -1,4 +1,5 @@ from typing import Any, Dict, List, Optional, Tuple, Union +import warnings from labelbox.orm.db_object import experimental from labelbox.schema.export_filters import CatalogExportFilters, build_filters @@ -45,6 +46,13 @@ def export_v2( >>> task.wait_till_done() >>> task.result """ + + warnings.warn( + "You are currently utilizing export_v2 for this action, which will be removed in 7.0. Please refer to our docs for export alternatives. https://docs.labelbox.com/reference/export-overview#export-methods", + DeprecationWarning, + stacklevel=2, + ) + task, is_streamable = self._export(task_name, filters, params) if is_streamable: return ExportTask(task, True) diff --git a/libs/labelbox/src/labelbox/schema/data_row.py b/libs/labelbox/src/labelbox/schema/data_row.py index 8987a00f0..cb0e99b22 100644 --- a/libs/labelbox/src/labelbox/schema/data_row.py +++ b/libs/labelbox/src/labelbox/schema/data_row.py @@ -2,6 +2,7 @@ from enum import Enum from typing import TYPE_CHECKING, List, Optional, Tuple, Union, Any import json +import warnings from labelbox.orm import query from labelbox.orm.db_object import ( @@ -277,6 +278,13 @@ def export_v2( >>> task.wait_till_done() >>> task.result """ + + warnings.warn( + "You are currently utilizing export_v2 for this action, which will be removed in 7.0. Please refer to our docs for export alternatives. https://docs.labelbox.com/reference/export-overview#export-methods", + DeprecationWarning, + stacklevel=2, + ) + task, is_streamable = DataRow._export( client, data_rows, global_keys, task_name, params ) diff --git a/libs/labelbox/src/labelbox/schema/data_row_metadata.py b/libs/labelbox/src/labelbox/schema/data_row_metadata.py index 288459a89..cb45ef57f 100644 --- a/libs/labelbox/src/labelbox/schema/data_row_metadata.py +++ b/libs/labelbox/src/labelbox/schema/data_row_metadata.py @@ -1,34 +1,35 @@ # type: ignore -from datetime import datetime +import warnings from copy import deepcopy +from datetime import datetime from enum import Enum from itertools import chain -import warnings - from typing import ( + Annotated, + Any, + Callable, + Dict, + Generator, List, Optional, - Dict, - Union, - Callable, Type, - Any, - Generator, + Union, overload, ) -from typing_extensions import Annotated -from labelbox.schema.identifiables import DataRowIdentifiers, UniqueIds -from labelbox.schema.identifiable import UniqueId, GlobalKey from pydantic import ( BaseModel, + BeforeValidator, + ConfigDict, Field, StringConstraints, conlist, - ConfigDict, model_serializer, ) +from typing_extensions import Annotated +from labelbox.schema.identifiable import GlobalKey, UniqueId +from labelbox.schema.identifiables import DataRowIdentifiers, UniqueIds from labelbox.schema.ontology import SchemaId from labelbox.utils import ( _CamelCaseMixin, @@ -36,6 +37,12 @@ format_iso_from_string, ) +Name = Annotated[ + str, + BeforeValidator(lambda x: str.strip(str(x))), + Field(min_length=1, max_length=100), +] + class DataRowMetadataKind(Enum): number = "CustomMetadataNumber" @@ -49,7 +56,7 @@ class DataRowMetadataKind(Enum): # Metadata schema class DataRowMetadataSchema(BaseModel): uid: SchemaId - name: str = Field(strip_whitespace=True, min_length=1, max_length=100) + name: Name reserved: bool kind: DataRowMetadataKind options: Optional[List["DataRowMetadataSchema"]] = None @@ -417,7 +424,7 @@ def update_enum_option( schema = self._validate_custom_schema_by_name(name) if schema.kind != DataRowMetadataKind.enum: raise ValueError( - f"Updating Enum option is only supported for Enum metadata schema" + "Updating Enum option is only supported for Enum metadata schema" ) valid_options: List[str] = [o.name for o in schema.options] @@ -666,10 +673,8 @@ def bulk_delete( if not len(deletes): raise ValueError("The 'deletes' list cannot be empty.") - passed_strings = False for i, delete in enumerate(deletes): if isinstance(delete.data_row_id, str): - passed_strings = True deletes[i] = DeleteDataRowMetadata( data_row_id=UniqueId(delete.data_row_id), fields=delete.fields, @@ -683,12 +688,6 @@ def bulk_delete( f"Invalid data row identifier type '{type(delete.data_row_id)}' for '{delete.data_row_id}'" ) - if passed_strings: - warnings.warn( - "Using string for data row id will be deprecated. Please use " - "UniqueId instead." - ) - def _batch_delete( deletes: List[_DeleteBatchDataRowMetadata], ) -> List[DataRowMetadataBatchResponse]: @@ -751,10 +750,6 @@ def bulk_export(self, data_row_ids) -> List[DataRowMetadata]: and isinstance(data_row_ids[0], str) ): data_row_ids = UniqueIds(data_row_ids) - warnings.warn( - "Using data row ids will be deprecated. Please use " - "UniqueIds or GlobalKeys instead." - ) def _bulk_export( _data_row_ids: DataRowIdentifiers, @@ -803,13 +798,13 @@ def _convert_metadata_field(metadata_field): if isinstance(metadata_field, DataRowMetadataField): return metadata_field elif isinstance(metadata_field, dict): - if not "value" in metadata_field: + if "value" not in metadata_field: raise ValueError( f"Custom metadata field '{metadata_field}' must have a 'value' key" ) if ( - not "schema_id" in metadata_field - and not "name" in metadata_field + "schema_id" not in metadata_field + and "name" not in metadata_field ): raise ValueError( f"Custom metadata field '{metadata_field}' must have either 'schema_id' or 'name' key" @@ -954,9 +949,8 @@ def _validate_custom_schema_by_name( def _batch_items(iterable: List[Any], size: int) -> Generator[Any, None, None]: - l = len(iterable) - for ndx in range(0, l, size): - yield iterable[ndx : min(ndx + size, l)] + for ndx in range(0, len(iterable), size): + yield iterable[ndx : min(ndx + size, len(iterable))] def _batch_operations( diff --git a/libs/labelbox/src/labelbox/schema/dataset.py b/libs/labelbox/src/labelbox/schema/dataset.py index 16c993dfa..107f3f50b 100644 --- a/libs/labelbox/src/labelbox/schema/dataset.py +++ b/libs/labelbox/src/labelbox/schema/dataset.py @@ -7,13 +7,14 @@ from string import Template from typing import Any, Dict, List, Optional, Tuple, Union -import labelbox.schema.internal.data_row_uploader as data_row_uploader -from labelbox.exceptions import ( +from lbox.exceptions import ( InvalidQueryError, LabelboxError, ResourceCreationError, ResourceNotFoundError, -) +) # type: ignore + +import labelbox.schema.internal.data_row_uploader as data_row_uploader from labelbox.orm import query from labelbox.orm.comparison import Comparison from labelbox.orm.db_object import DbObject, Deletable, Updateable @@ -165,49 +166,9 @@ def create_data_row(self, items=None, **kwargs) -> "DataRow": return self.client.get_data_row(res[0]["id"]) - def create_data_rows_sync( - self, items, file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT - ) -> None: - """Synchronously bulk upload data rows. - - Use this instead of `Dataset.create_data_rows` for smaller batches of data rows that need to be uploaded quickly. - Cannot use this for uploads containing more than 1000 data rows. - Each data row is also limited to 5 attachments. - - Args: - items (iterable of (dict or str)): - See the docstring for `Dataset._create_descriptor_file` for more information. - Returns: - None. If the function doesn't raise an exception then the import was successful. - - Raises: - ResourceCreationError: If the `items` parameter does not conform to - the specification in Dataset._create_descriptor_file or if the server did not accept the - DataRow creation request (unknown reason). - InvalidAttributeError: If there are fields in `items` not valid for - a DataRow. - ValueError: When the upload parameters are invalid - """ - warnings.warn( - "This method is deprecated and will be " - "removed in a future release. Please use create_data_rows instead." - ) - - self._create_data_rows_sync( - items, file_upload_thread_count=file_upload_thread_count - ) - - return None # Return None if no exception is raised - def _create_data_rows_sync( self, items, file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT ) -> "DataUpsertTask": - max_data_rows_supported = 1000 - if len(items) > max_data_rows_supported: - raise ValueError( - f"Dataset.create_data_rows_sync() supports a max of {max_data_rows_supported} data rows." - " For larger imports use the async function Dataset.create_data_rows()" - ) if file_upload_thread_count < 1: raise ValueError( "file_upload_thread_count must be a positive integer" @@ -234,8 +195,6 @@ def create_data_rows( ) -> "DataUpsertTask": """Asynchronously bulk upload data rows - Use this instead of `Dataset.create_data_rows_sync` uploads for batches that contain more than 1000 data rows. - Args: items (iterable of (dict or str)) @@ -310,7 +269,7 @@ def data_rows_for_external_id( A list of `DataRow` with the given ID. Raises: - labelbox.exceptions.ResourceNotFoundError: If there is no `DataRow` + lbox.exceptions.ResourceNotFoundError: If there is no `DataRow` in this `DataSet` with the given external ID, or if there are multiple `DataRows` for it. """ @@ -336,7 +295,7 @@ def data_row_for_external_id(self, external_id) -> "DataRow": A single `DataRow` with the given ID. Raises: - labelbox.exceptions.ResourceNotFoundError: If there is no `DataRow` + lbox.exceptions.ResourceNotFoundError: If there is no `DataRow` in this `DataSet` with the given external ID, or if there are multiple `DataRows` for it. """ @@ -399,6 +358,13 @@ def export_v2( >>> task.wait_till_done() >>> task.result """ + + warnings.warn( + "You are currently utilizing export_v2 for this action, which will be removed in 7.0. Please refer to our docs for export alternatives. https://docs.labelbox.com/reference/export-overview#export-methods", + DeprecationWarning, + stacklevel=2, + ) + task, is_streamable = self._export(task_name, filters, params) if is_streamable: return ExportTask(task, True) diff --git a/libs/labelbox/src/labelbox/schema/enums.py b/libs/labelbox/src/labelbox/schema/enums.py index 6f8aebc58..dfc87c8a4 100644 --- a/libs/labelbox/src/labelbox/schema/enums.py +++ b/libs/labelbox/src/labelbox/schema/enums.py @@ -1,31 +1,6 @@ from enum import Enum -class BulkImportRequestState(Enum): - """State of the import job when importing annotations (RUNNING, FAILED, or FINISHED). - - If you are not usinig MEA continue using BulkImportRequest. - AnnotationImports are in beta and will change soon. - - .. list-table:: - :widths: 15 150 - :header-rows: 1 - - * - State - - Description - * - RUNNING - - Indicates that the import job is not done yet. - * - FAILED - - Indicates the import job failed. Check `BulkImportRequest.errors` for more information - * - FINISHED - - Indicates the import job is no longer running. Check `BulkImportRequest.statuses` for more information - """ - - RUNNING = "RUNNING" - FAILED = "FAILED" - FINISHED = "FINISHED" - - class AnnotationImportState(Enum): """State of the import job when importing annotations (RUNNING, FAILED, or FINISHED). diff --git a/libs/labelbox/src/labelbox/schema/export_filters.py b/libs/labelbox/src/labelbox/schema/export_filters.py index 641adc011..b0d0284c4 100644 --- a/libs/labelbox/src/labelbox/schema/export_filters.py +++ b/libs/labelbox/src/labelbox/schema/export_filters.py @@ -1,13 +1,5 @@ -import sys - from datetime import datetime, timezone -from typing import Collection, Dict, Tuple, List, Optional -from labelbox.typing_imports import Literal - -if sys.version_info >= (3, 8): - from typing import TypedDict -else: - from typing_extensions import TypedDict +from typing import Collection, Dict, List, Literal, Optional, Tuple, TypedDict SEARCH_LIMIT_PER_EXPORT_V2 = 2_000 ISO_8061_FORMAT = "%Y-%m-%dT%H:%M:%S%z" diff --git a/libs/labelbox/src/labelbox/schema/export_params.py b/libs/labelbox/src/labelbox/schema/export_params.py index b15bc2828..f921cdd31 100644 --- a/libs/labelbox/src/labelbox/schema/export_params.py +++ b/libs/labelbox/src/labelbox/schema/export_params.py @@ -1,15 +1,8 @@ -import sys - -from typing import Optional, List - -EXPORT_LIMIT = 30 +from typing import List, Optional, TypedDict from labelbox.schema.media_type import MediaType -if sys.version_info >= (3, 8): - from typing import TypedDict -else: - from typing_extensions import TypedDict +EXPORT_LIMIT = 30 class DataRowParams(TypedDict): diff --git a/libs/labelbox/src/labelbox/schema/export_task.py b/libs/labelbox/src/labelbox/schema/export_task.py index 76fd8a739..7e78fc3e9 100644 --- a/libs/labelbox/src/labelbox/schema/export_task.py +++ b/libs/labelbox/src/labelbox/schema/export_task.py @@ -6,20 +6,16 @@ from dataclasses import dataclass from enum import Enum from functools import lru_cache -from io import TextIOWrapper -from pathlib import Path from typing import ( TYPE_CHECKING, Any, Callable, Generic, Iterator, - List, Optional, Tuple, TypeVar, Union, - overload, ) import requests @@ -100,122 +96,6 @@ def convert(self, input_args: ConverterInputArgs) -> Iterator[OutputT]: """ -@dataclass -class JsonConverterOutput: - """Output with the JSON string.""" - - current_offset: int - current_line: int - json_str: str - - -class JsonConverter(Converter[JsonConverterOutput]): # pylint: disable=too-few-public-methods - """Converts JSON data. - - Deprecated: This converter is deprecated and will be removed in a future release. - """ - - def __init__(self) -> None: - warnings.warn( - "JSON converter is deprecated and will be removed in a future release" - ) - super().__init__() - - def _find_json_object_offsets(self, data: str) -> List[Tuple[int, int]]: - object_offsets: List[Tuple[int, int]] = [] - stack = [] - current_object_start = None - - for index, char in enumerate(data): - if char == "{": - stack.append(char) - if len(stack) == 1: - current_object_start = index - # we need to account for scenarios where data lands in the middle of an object - # and the object is not the last one in the data - if ( - index > 0 - and data[index - 1] == "\n" - and not object_offsets - ): - object_offsets.append((0, index - 1)) - elif char == "}" and stack: - stack.pop() - # this covers cases where the last object is either followed by a newline or - # it is missing - if ( - len(stack) == 0 - and (len(data) == index + 1 or data[index + 1] == "\n") - and current_object_start is not None - ): - object_offsets.append((current_object_start, index + 1)) - current_object_start = None - - # we also need to account for scenarios where data lands in the middle of the last object - return object_offsets if object_offsets else [(0, len(data) - 1)] - - def convert( - self, input_args: Converter.ConverterInputArgs - ) -> Iterator[JsonConverterOutput]: - current_offset, current_line, raw_data = ( - input_args.file_info.offsets.start, - input_args.file_info.lines.start, - input_args.raw_data, - ) - offsets = self._find_json_object_offsets(raw_data) - for line, (offset_start, offset_end) in enumerate(offsets): - yield JsonConverterOutput( - current_offset=current_offset + offset_start, - current_line=current_line + line, - json_str=raw_data[offset_start : offset_end + 1].strip(), - ) - - -@dataclass -class FileConverterOutput: - """Output with statistics about the written file.""" - - file_path: Path - total_size: int - total_lines: int - current_offset: int - current_line: int - bytes_written: int - - -class FileConverter(Converter[FileConverterOutput]): - """Converts data to a file.""" - - def __init__(self, file_path: str) -> None: - super().__init__() - self._file: Optional[TextIOWrapper] = None - self._file_path = file_path - - def __enter__(self): - self._file = open(self._file_path, "w", encoding="utf-8") - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - if self._file: - self._file.close() - return False - - def convert( - self, input_args: Converter.ConverterInputArgs - ) -> Iterator[FileConverterOutput]: - # appends data to the file - assert self._file is not None - self._file.write(input_args.raw_data) - yield FileConverterOutput( - file_path=Path(self._file_path), - total_size=input_args.ctx.metadata_header.total_size, - total_lines=input_args.ctx.metadata_header.total_lines, - current_offset=input_args.file_info.offsets.start, - current_line=input_args.file_info.lines.start, - bytes_written=len(input_args.raw_data), - ) - - class FileRetrieverStrategy(ABC): # pylint: disable=too-few-public-methods """Abstract class for retrieving files.""" @@ -252,147 +132,6 @@ def _get_file_content( return file_info, response.text -class FileRetrieverByOffset(FileRetrieverStrategy): # pylint: disable=too-few-public-methods - """Retrieves files by offset.""" - - def __init__( - self, - ctx: _TaskContext, - offset: int, - ) -> None: - super().__init__(ctx) - self._current_offset = offset - self._current_line: Optional[int] = None - if self._current_offset >= self._ctx.metadata_header.total_size: - raise ValueError( - f"offset is out of range, max offset is {self._ctx.metadata_header.total_size - 1}" - ) - - def _find_line_at_offset( - self, file_content: str, target_offset: int - ) -> int: - # TODO: Remove this, incorrect parsing of JSON to find braces - stack = [] - line_number = 0 - - for index, char in enumerate(file_content): - if char == "{": - stack.append(char) - if len(stack) == 1 and index > 0: - line_number += 1 - elif char == "}" and stack: - stack.pop() - - if index == target_offset: - break - - return line_number - - def get_next_chunk(self) -> Optional[Tuple[_MetadataFileInfo, str]]: - if self._current_offset >= self._ctx.metadata_header.total_size: - return None - query = ( - f"query GetExportFileFromOffsetPyApi" - f"($where: WhereUniqueIdInput, $streamType: TaskStreamType!, $offset: UInt64!)" - f"{{task(where: $where)" - f"{{{'exportFileFromOffset'}(streamType: $streamType, offset: $offset)" - f"{{offsets {{start end}} lines {{start end}} file}}" - f"}}}}" - ) - variables = { - "where": {"id": self._ctx.task_id}, - "streamType": self._ctx.stream_type.value, - "offset": str(self._current_offset), - } - file_info, file_content = self._get_file_content( - query, variables, "exportFileFromOffset" - ) - if self._current_line is None: - self._current_line = self._find_line_at_offset( - file_content, self._current_offset - file_info.offsets.start - ) - self._current_line += file_info.lines.start - file_content = file_content[ - self._current_offset - file_info.offsets.start : - ] - file_info.offsets.start = self._current_offset - file_info.lines.start = self._current_line - self._current_offset = file_info.offsets.end + 1 - self._current_line = file_info.lines.end + 1 - return file_info, file_content - - -class FileRetrieverByLine(FileRetrieverStrategy): # pylint: disable=too-few-public-methods - """Retrieves files by line.""" - - def __init__( - self, - ctx: _TaskContext, - line: int, - ) -> None: - super().__init__(ctx) - self._current_line = line - self._current_offset: Optional[int] = None - if self._current_line >= self._ctx.metadata_header.total_lines: - raise ValueError( - f"line is out of range, max line is {self._ctx.metadata_header.total_lines - 1}" - ) - - def _find_offset_of_line(self, file_content: str, target_line: int): - # TODO: Remove this, incorrect parsing of JSON to find braces - start_offset = None - stack = [] - line_number = 0 - - for index, char in enumerate(file_content): - if char == "{": - stack.append(char) - if len(stack) == 1: - if line_number == target_line: - start_offset = index - line_number += 1 - elif char == "}" and stack: - stack.pop() - - if line_number > target_line: - break - - return start_offset - - def get_next_chunk(self) -> Optional[Tuple[_MetadataFileInfo, str]]: - if self._current_line >= self._ctx.metadata_header.total_lines: - return None - query = ( - f"query GetExportFileFromLinePyApi" - f"($where: WhereUniqueIdInput, $streamType: TaskStreamType!, $line: UInt64!)" - f"{{task(where: $where)" - f"{{{'exportFileFromLine'}(streamType: $streamType, line: $line)" - f"{{offsets {{start end}} lines {{start end}} file}}" - f"}}}}" - ) - variables = { - "where": {"id": self._ctx.task_id}, - "streamType": self._ctx.stream_type.value, - "line": self._current_line, - } - file_info, file_content = self._get_file_content( - query, variables, "exportFileFromLine" - ) - if self._current_offset is None: - self._current_offset = self._find_offset_of_line( - file_content, self._current_line - file_info.lines.start - ) - self._current_offset += file_info.offsets.start - file_content = file_content[ - self._current_offset - file_info.offsets.start : - ] - file_info.offsets.start = self._current_offset - file_info.lines.start = self._current_line - self._current_offset = file_info.offsets.end + 1 - self._current_line = file_info.lines.end + 1 - return file_info, file_content - - class _Reader(ABC): # pylint: disable=too-few-public-methods """Abstract class for reading data from a source.""" @@ -405,94 +144,6 @@ def read(self) -> Iterator[Tuple[_MetadataFileInfo, str]]: """Reads data from the source.""" -class _MultiGCSFileReader(_Reader): # pylint: disable=too-few-public-methods - """Reads data from multiple GCS files in a seamless way. - - Deprecated: This reader is deprecated and will be removed in a future release. - """ - - def __init__(self): - warnings.warn( - "_MultiGCSFileReader is deprecated and will be removed in a future release" - ) - super().__init__() - self._retrieval_strategy = None - - def set_retrieval_strategy(self, strategy: FileRetrieverStrategy) -> None: - """Sets the retrieval strategy.""" - self._retrieval_strategy = strategy - - def read(self) -> Iterator[Tuple[_MetadataFileInfo, str]]: - if not self._retrieval_strategy: - raise ValueError("retrieval strategy not set") - result = self._retrieval_strategy.get_next_chunk() - while result: - file_info, raw_data = result - yield file_info, raw_data - result = self._retrieval_strategy.get_next_chunk() - - -class Stream(Generic[OutputT]): - """Streams data from a Reader.""" - - def __init__( - self, - ctx: _TaskContext, - reader: _Reader, - converter: Converter, - ): - self._ctx = ctx - self._reader = reader - self._converter = converter - # default strategy is to retrieve files by offset, starting from 0 - self.with_offset(0) - - def __iter__(self): - yield from self._fetch() - - def _fetch( - self, - ) -> Iterator[OutputT]: - """Fetches the result data. - Returns an iterator that yields the offset and the data. - """ - if self._ctx.metadata_header.total_size is None: - return - - stream = self._reader.read() - with self._converter as converter: - for file_info, raw_data in stream: - for output in converter.convert( - Converter.ConverterInputArgs(self._ctx, file_info, raw_data) - ): - yield output - - def with_offset(self, offset: int) -> "Stream[OutputT]": - """Sets the offset for the stream.""" - self._reader.set_retrieval_strategy( - FileRetrieverByOffset(self._ctx, offset) - ) - return self - - def with_line(self, line: int) -> "Stream[OutputT]": - """Sets the line number for the stream.""" - self._reader.set_retrieval_strategy( - FileRetrieverByLine(self._ctx, line) - ) - return self - - def start( - self, stream_handler: Optional[Callable[[OutputT], None]] = None - ) -> None: - """Starts streaming the result data. - Calls the stream_handler for each result. - """ - # this calls the __iter__ method, which in turn calls the _fetch method - for output in self: - if stream_handler: - stream_handler(output) - - class _BufferedFileRetrieverByOffset(FileRetrieverStrategy): # pylint: disable=too-few-public-methods """Retrieves files by offset.""" @@ -926,53 +577,6 @@ def get_buffered_stream( ), ) - @overload - def get_stream( - self, - converter: JsonConverter, - stream_type: StreamType = StreamType.RESULT, - ) -> Stream[JsonConverterOutput]: - """Overload for getting the right typing hints when using a JsonConverter.""" - - @overload - def get_stream( - self, - converter: FileConverter, - stream_type: StreamType = StreamType.RESULT, - ) -> Stream[FileConverterOutput]: - """Overload for getting the right typing hints when using a FileConverter.""" - - def get_stream( - self, - converter: Optional[Converter] = None, - stream_type: StreamType = StreamType.RESULT, - ) -> Stream: - warnings.warn( - "get_stream is deprecated and will be removed in a future release, use get_buffered_stream" - ) - if converter is None: - converter = JsonConverter() - """Returns the result of the task.""" - if self._task.status == "FAILED": - raise ExportTask.ExportTaskException("Task failed") - if self._task.status != "COMPLETE": - raise ExportTask.ExportTaskException("Task is not ready yet") - - metadata_header = self._get_metadata_header( - self._task.client, self._task.uid, stream_type - ) - if metadata_header is None: - raise ValueError( - f"Task {self._task.uid} does not have a {stream_type.value} stream" - ) - return Stream( - _TaskContext( - self._task.client, self._task.uid, stream_type, metadata_header - ), - _MultiGCSFileReader(), - converter, - ) - @staticmethod def get_task(client, task_id): """Returns the task with the given id.""" diff --git a/libs/labelbox/src/labelbox/schema/foundry/foundry_client.py b/libs/labelbox/src/labelbox/schema/foundry/foundry_client.py index 914a363c7..315fc831c 100644 --- a/libs/labelbox/src/labelbox/schema/foundry/foundry_client.py +++ b/libs/labelbox/src/labelbox/schema/foundry/foundry_client.py @@ -1,6 +1,8 @@ from typing import Union -from labelbox import exceptions -from labelbox.schema.foundry.app import App, APP_FIELD_NAMES + +from lbox import exceptions # type: ignore + +from labelbox.schema.foundry.app import APP_FIELD_NAMES, App from labelbox.schema.identifiables import DataRowIds, GlobalKeys, IdType from labelbox.schema.task import Task @@ -51,7 +53,7 @@ def _get_app(self, id: str) -> App: try: response = self.client.execute(query_str, params) - except exceptions.InvalidQueryError as e: + except exceptions.InvalidQueryError: raise exceptions.ResourceNotFoundError(App, params) except Exception as e: raise exceptions.LabelboxError(f"Unable to get app with id {id}", e) diff --git a/libs/labelbox/src/labelbox/schema/internal/descriptor_file_creator.py b/libs/labelbox/src/labelbox/schema/internal/descriptor_file_creator.py index ce3ce4b35..a9c1c250c 100644 --- a/libs/labelbox/src/labelbox/schema/internal/descriptor_file_creator.py +++ b/libs/labelbox/src/labelbox/schema/internal/descriptor_file_creator.py @@ -1,25 +1,21 @@ import json import os -import sys from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import TYPE_CHECKING, Generator, Iterable, List -from typing import Iterable, List, Generator +from lbox.exceptions import ( + InvalidAttributeError, + InvalidQueryError, +) # type: ignore -from labelbox.exceptions import InvalidQueryError -from labelbox.exceptions import InvalidAttributeError -from labelbox.exceptions import MalformedQueryException -from labelbox.orm.model import Entity -from labelbox.orm.model import Field +from labelbox.orm.model import Entity, Field from labelbox.schema.embedding import EmbeddingVector -from labelbox.schema.internal.datarow_upload_constants import ( - FILE_UPLOAD_THREAD_COUNT, -) from labelbox.schema.internal.data_row_upsert_item import ( DataRowItemBase, - DataRowUpsertItem, ) - -from typing import TYPE_CHECKING +from labelbox.schema.internal.datarow_upload_constants import ( + FILE_UPLOAD_THREAD_COUNT, +) if TYPE_CHECKING: from labelbox import Client @@ -161,7 +157,7 @@ def check_message_keys(message): ] ) for key in message.keys(): - if not key in accepted_message_keys: + if key not in accepted_message_keys: raise KeyError( f"Invalid {key} key found! Accepted keys in messages list is {accepted_message_keys}" ) diff --git a/libs/labelbox/src/labelbox/schema/labeling_service.py b/libs/labelbox/src/labelbox/schema/labeling_service.py index 0fbd15bd1..0b7dab6bd 100644 --- a/libs/labelbox/src/labelbox/schema/labeling_service.py +++ b/libs/labelbox/src/labelbox/schema/labeling_service.py @@ -1,13 +1,12 @@ +import json from datetime import datetime from typing import Any -from typing_extensions import Annotated -from pydantic import BaseModel, Field +from lbox.exceptions import LabelboxError, ResourceNotFoundError -from labelbox.exceptions import ResourceNotFoundError -from labelbox.utils import _CamelCaseMixin from labelbox.schema.labeling_service_dashboard import LabelingServiceDashboard from labelbox.schema.labeling_service_status import LabelingServiceStatus +from labelbox.utils import _CamelCaseMixin from ..annotated_types import Cuid @@ -107,12 +106,24 @@ def request(self) -> "LabelingService": query_str, {"projectId": self.project_id}, raise_return_resource_not_found=True, + error_handlers={"MALFORMED_REQUEST": self._raise_readable_errors}, ) success = result["validateAndRequestProjectBoostWorkforce"]["success"] if not success: raise Exception("Failed to start labeling service") return LabelingService.get(self.client, self.project_id) + def _raise_readable_errors(self, response): + errors = response.json().get("errors", []) + if errors: + message = errors[0].get( + "errors", json.dumps([{"error": "Unknown error"}]) + ) + error_messages = [error["error"] for error in message] + else: + error_messages = ["Uknown error"] + raise LabelboxError(". ".join(error_messages)) + @classmethod def getOrCreate(cls, client, project_id: Cuid) -> "LabelingService": """ diff --git a/libs/labelbox/src/labelbox/schema/labeling_service_dashboard.py b/libs/labelbox/src/labelbox/schema/labeling_service_dashboard.py index c5e1fa11e..2f91af7af 100644 --- a/libs/labelbox/src/labelbox/schema/labeling_service_dashboard.py +++ b/libs/labelbox/src/labelbox/schema/labeling_service_dashboard.py @@ -1,16 +1,17 @@ -from string import Template from datetime import datetime +from string import Template from typing import Any, Dict, List, Optional, Union -from labelbox.exceptions import ResourceNotFoundError +from lbox.exceptions import ResourceNotFoundError +from pydantic import BaseModel, Field, model_validator, model_serializer + from labelbox.pagination import PaginatedCollection -from pydantic import BaseModel, model_validator, Field +from labelbox.schema.labeling_service_status import LabelingServiceStatus +from labelbox.schema.media_type import MediaType from labelbox.schema.search_filters import SearchFilter, build_search_filter -from labelbox.utils import _CamelCaseMixin +from labelbox.utils import _CamelCaseMixin, sentence_case + from .ontology_kind import EditorTaskType -from labelbox.schema.media_type import MediaType -from labelbox.schema.labeling_service_status import LabelingServiceStatus -from labelbox.utils import sentence_case GRAPHQL_QUERY_SELECTIONS = """ id @@ -49,7 +50,7 @@ class LabelingServiceDashboard(_CamelCaseMixin): Represent labeling service data for a project NOTE on tasks vs data rows. A task is a unit of work that is assigned to a user. A data row is a unit of data that needs to be labeled. - In the current implementation a task reprsents a single data row. However tasks only exists when a labeler start labeling a data row. + In the current implementation a task represents a single data row. However tasks only exists when a labeler start labeling a data row. So if a data row is not labeled, it will not have a task associated with it. Therefore the number of tasks can be less than the number of data rows. Attributes: @@ -220,8 +221,9 @@ def convert_boost_data(cls, data): return data - def dict(self, *args, **kwargs): - row = super().dict(*args, **kwargs) + @model_serializer() + def ser_model(self): + row = self row.pop("client") row["service_type"] = self.service_type return row diff --git a/libs/labelbox/src/labelbox/schema/model_run.py b/libs/labelbox/src/labelbox/schema/model_run.py index bc9971174..dcdbdf0e8 100644 --- a/libs/labelbox/src/labelbox/schema/model_run.py +++ b/libs/labelbox/src/labelbox/schema/model_run.py @@ -539,6 +539,12 @@ def export_v2( >>> export_task = export_v2("my_export_task", params={"media_attributes": True}) """ + + warnings.warn( + "You are currently utilizing export_v2 for this action, which will be removed in 7.0. Please refer to our docs for export alternatives. https://docs.labelbox.com/reference/export-overview#export-methods", + DeprecationWarning, + stacklevel=2, + ) task, is_streamable = self._export(task_name, params) if is_streamable: return ExportTask(task, True) diff --git a/libs/labelbox/src/labelbox/schema/ontology.py b/libs/labelbox/src/labelbox/schema/ontology.py index efe32611b..a3b388ef2 100644 --- a/libs/labelbox/src/labelbox/schema/ontology.py +++ b/libs/labelbox/src/labelbox/schema/ontology.py @@ -1,17 +1,17 @@ # type: ignore import colorsys +import json +import warnings from dataclasses import dataclass, field from enum import Enum -from typing import Any, Dict, List, Optional, Union, Type -from typing_extensions import Annotated -import warnings +from typing import Annotated, Any, Dict, List, Optional, Type, Union + +from lbox.exceptions import InconsistentOntologyException +from pydantic import StringConstraints -from labelbox.exceptions import InconsistentOntologyException from labelbox.orm.db_object import DbObject from labelbox.orm.model import Field, Relationship -import json -from pydantic import StringConstraints FeatureSchemaId: Type[str] = Annotated[ str, StringConstraints(min_length=25, max_length=25) @@ -561,14 +561,18 @@ class OntologyBuilder: There are no required instantiation arguments. To create an ontology, use the asdict() method after fully building your - ontology within this class, and inserting it into project.setup() as the - "labeling_frontend_options" parameter. + ontology within this class, and inserting it into client.create_ontology() as the + "normalized" parameter. Example: - builder = OntologyBuilder() - ... - frontend = list(client.get_labeling_frontends())[0] - project.setup(frontend, builder.asdict()) + >>> builder = OntologyBuilder() + >>> ... + >>> ontology = client.create_ontology( + >>> "Ontology from new features", + >>> ontology_builder.asdict(), + >>> media_type=lb.MediaType.Image, + >>> ) + >>> project.connect_ontology(ontology) attributes: tools: (list) diff --git a/libs/labelbox/src/labelbox/schema/ontology_kind.py b/libs/labelbox/src/labelbox/schema/ontology_kind.py index 3171b811e..79ef7d7a3 100644 --- a/libs/labelbox/src/labelbox/schema/ontology_kind.py +++ b/libs/labelbox/src/labelbox/schema/ontology_kind.py @@ -53,7 +53,7 @@ def evaluate_ontology_kind_with_media_type( return media_type -class EditorTaskType(Enum): +class EditorTaskType(str, Enum): ModelChatEvaluation = "MODEL_CHAT_EVALUATION" ResponseCreation = "RESPONSE_CREATION" OfflineModelChatEvaluation = "OFFLINE_MODEL_CHAT_EVALUATION" diff --git a/libs/labelbox/src/labelbox/schema/organization.py b/libs/labelbox/src/labelbox/schema/organization.py index 71e715f11..bd416e997 100644 --- a/libs/labelbox/src/labelbox/schema/organization.py +++ b/libs/labelbox/src/labelbox/schema/organization.py @@ -1,21 +1,21 @@ -import json -from typing import TYPE_CHECKING, List, Optional, Dict +from typing import TYPE_CHECKING, Dict, List, Optional + +from lbox.exceptions import LabelboxError -from labelbox.exceptions import LabelboxError from labelbox import utils -from labelbox.orm.db_object import DbObject, query, Entity +from labelbox.orm.db_object import DbObject, Entity, query from labelbox.orm.model import Field, Relationship from labelbox.schema.invite import InviteLimit from labelbox.schema.resource_tag import ResourceTag if TYPE_CHECKING: from labelbox import ( - Role, - User, - ProjectRole, + IAMIntegration, Invite, InviteLimit, - IAMIntegration, + ProjectRole, + Role, + User, ) diff --git a/libs/labelbox/src/labelbox/schema/project.py b/libs/labelbox/src/labelbox/schema/project.py index 3d5f8ca92..a6f2dfe28 100644 --- a/libs/labelbox/src/labelbox/schema/project.py +++ b/libs/labelbox/src/labelbox/schema/project.py @@ -4,29 +4,27 @@ import warnings from collections import namedtuple from datetime import datetime, timezone -from pathlib import Path from string import Template from typing import ( TYPE_CHECKING, Any, Dict, - Iterable, List, Optional, Tuple, Union, - overload, + get_args, ) -from urllib.parse import urlparse -from labelbox import utils -from labelbox.exceptions import ( +from lbox.exceptions import ( InvalidQueryError, LabelboxError, ProcessingWaitTimeout, ResourceNotFoundError, error_message_for_unparsed_graphql_error, -) +) # type: ignore + +from labelbox import utils from labelbox.orm import query from labelbox.orm.db_object import DbObject, Deletable, Updateable, experimental from labelbox.orm.model import Entity, Field, Relationship @@ -42,7 +40,11 @@ from labelbox.schema.export_task import ExportTask from labelbox.schema.id_type import IdType from labelbox.schema.identifiable import DataRowIdentifier, GlobalKey, UniqueId -from labelbox.schema.identifiables import DataRowIdentifiers, UniqueIds +from labelbox.schema.identifiables import ( + DataRowIdentifiers, + GlobalKeys, + UniqueIds, +) from labelbox.schema.labeling_service import ( LabelingService, LabelingServiceStatus, @@ -59,19 +61,16 @@ ProjectOverview, ProjectOverviewDetailed, ) -from labelbox.schema.queue_mode import QueueMode from labelbox.schema.resource_tag import ResourceTag from labelbox.schema.task import Task from labelbox.schema.task_queue import TaskQueue if TYPE_CHECKING: - from labelbox import BulkImportRequest + pass DataRowPriority = int -LabelingParameterOverrideInput = Tuple[ - Union[DataRow, DataRowIdentifier], DataRowPriority -] +LabelingParameterOverrideInput = Tuple[DataRowIdentifier, DataRowPriority] logger = logging.getLogger(__name__) MAX_SYNC_BATCH_ROW_COUNT = 1_000 @@ -81,23 +80,18 @@ def validate_labeling_parameter_overrides( data: List[LabelingParameterOverrideInput], ) -> None: for idx, row in enumerate(data): - if len(row) < 2: - raise TypeError( - f"Data must be a list of tuples each containing two elements: a DataRow or a DataRowIdentifier and priority (int). Found {len(row)} items. Index: {idx}" - ) data_row_identifier = row[0] priority = row[1] - valid_types = (Entity.DataRow, UniqueId, GlobalKey) - if not isinstance(data_row_identifier, valid_types): + if not isinstance(data_row_identifier, get_args(DataRowIdentifier)): raise TypeError( - f"Data row identifier should be be of type DataRow, UniqueId or GlobalKey. Found {type(data_row_identifier)} for data_row_identifier {data_row_identifier}" + f"Data row identifier should be of type DataRowIdentifier. Found {type(data_row_identifier)}." + ) + if len(row) < 2: + raise TypeError( + f"Data must be a list of tuples each containing two elements: a DataRowIdentifier and priority (int). Found {len(row)} items. Index: {idx}" ) - if not isinstance(priority, int): - if isinstance(data_row_identifier, Entity.DataRow): - id = data_row_identifier.uid - else: - id = data_row_identifier + id = data_row_identifier.key raise TypeError( f"Priority must be an int. Found {type(priority)} for data_row_identifier {id}" ) @@ -114,7 +108,6 @@ class Project(DbObject, Updateable, Deletable): created_at (datetime) setup_complete (datetime) last_activity_time (datetime) - queue_mode (string) auto_audit_number_of_labels (int) auto_audit_percentage (float) is_benchmark_enabled (bool) @@ -137,7 +130,6 @@ class Project(DbObject, Updateable, Deletable): created_at = Field.DateTime("created_at") setup_complete = Field.DateTime("setup_complete") last_activity_time = Field.DateTime("last_activity_time") - queue_mode = Field.Enum(QueueMode, "queue_mode") auto_audit_number_of_labels = Field.Int("auto_audit_number_of_labels") auto_audit_percentage = Field.Float("auto_audit_percentage") # Bind data_type and allowedMediaTYpe using the GraphQL type MediaType @@ -423,6 +415,13 @@ def export_v2( >>> task.wait_till_done() >>> task.result """ + + warnings.warn( + "You are currently utilizing export_v2 for this action, which will be removed in 7.0. Please refer to our docs for export alternatives. https://docs.labelbox.com/reference/export-overview#export-methods", + DeprecationWarning, + stacklevel=2, + ) + task, is_streamable = self._export(task_name, filters, params) if is_streamable: return ExportTask(task, True) @@ -659,21 +658,11 @@ def review_metrics(self, net_score) -> int: res = self.client.execute(query_str, {id_param: self.uid}) return res["project"]["reviewMetrics"]["labelAggregate"]["count"] - def setup_editor(self, ontology) -> None: - """ - Sets up the project using the Pictor editor. - - Args: - ontology (Ontology): The ontology to attach to the project - """ - warnings.warn("This method is deprecated use connect_ontology instead.") - self.connect_ontology(ontology) - def connect_ontology(self, ontology) -> None: """ Connects the ontology to the project. If an editor is not setup, it will be connected as well. - Note: For live chat model evaluation projects, the editor setup is skipped becase it is automatically setup when the project is created. + Note: For live chat model evaluation projects, the editor setup is skipped because it is automatically setup when the project is created. Args: ontology (Ontology): The ontology to attach to the project @@ -696,34 +685,6 @@ def connect_ontology(self, ontology) -> None: timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") self.update(setup_complete=timestamp) - def setup(self, labeling_frontend, labeling_frontend_options) -> None: - """This method will associate default labeling frontend with the project and create an ontology based on labeling_frontend_options. - - Args: - labeling_frontend (LabelingFrontend): Do not use, this parameter is deprecated. We now associate the default labeling frontend with the project. - labeling_frontend_options (dict or str): Labeling frontend options, - a.k.a. project ontology. If given a `dict` it will be converted - to `str` using `json.dumps`. - """ - - warnings.warn("This method is deprecated use connect_ontology instead.") - if labeling_frontend is not None: - warnings.warn( - "labeling_frontend parameter will not be used to create a new labeling frontend." - ) - - if self.is_chat_evaluation() or self.is_prompt_response(): - warnings.warn(""" - This project is a live chat evaluation project or prompt and response generation project. - Editor was setup automatically. - """) - return - - self._connect_default_labeling_front_end(labeling_frontend_options) - - timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ") - self.update(setup_complete=timestamp) - def _connect_default_labeling_front_end(self, ontology_as_dict: dict): labeling_frontend = self.labeling_frontend() if ( @@ -775,11 +736,8 @@ def create_batch( Returns: the created batch Raises: - labelbox.exceptions.ValueError if a project is not batch mode, if the project is auto data generation, if the batch exceeds 100k data rows + lbox.exceptions.ValueError if a project is not batch mode, if the project is auto data generation, if the batch exceeds 100k data rows """ - # @TODO: make this automatic? - if self.queue_mode != QueueMode.Batch: - raise ValueError("Project must be in batch mode") if ( self.is_auto_data_generation() and not self.is_chat_evaluation() @@ -861,9 +819,6 @@ def create_batches( Returns: a task for the created batches """ - if self.queue_mode != QueueMode.Batch: - raise ValueError("Project must be in batch mode") - dr_ids = [] if data_rows is not None: for dr in data_rows: @@ -946,9 +901,6 @@ def create_batches_from_dataset( Returns: a task for the created batches """ - if self.queue_mode != QueueMode.Batch: - raise ValueError("Project must be in batch mode") - if consensus_settings: consensus_settings = ConsensusSettings( **consensus_settings @@ -1088,57 +1040,6 @@ def _create_batch_async( return self.client.get_batch(self.uid, batch_id) - def _update_queue_mode(self, mode: "QueueMode") -> "QueueMode": - """ - Updates the queueing mode of this project. - - Deprecation notice: This method is deprecated. Going forward, projects must - go through a migration to have the queue mode changed. Users should specify the - queue mode for a project during creation if a non-default mode is desired. - - For more information, visit https://docs.labelbox.com/reference/migrating-to-workflows#upcoming-changes - - Args: - mode: the specified queue mode - - Returns: the updated queueing mode of this project - - """ - - logger.warning( - "Updating the queue_mode for a project will soon no longer be supported." - ) - - if self.queue_mode == mode: - return mode - - if mode == QueueMode.Batch: - status = "ENABLED" - elif mode == QueueMode.Dataset: - status = "DISABLED" - else: - raise ValueError( - "Must provide either `BATCH` or `DATASET` as a mode" - ) - - query_str = ( - """mutation %s($projectId: ID!, $status: TagSetStatusInput!) { - project(where: {id: $projectId}) { - setTagSetStatus(input: {tagSetStatus: $status}) { - tagSetStatus - } - } - } - """ - % "setTagSetStatusPyApi" - ) - - self.client.execute( - query_str, {"projectId": self.uid, "status": status} - ) - - return mode - def get_label_count(self) -> int: """ Returns: the total number of labels in this project. @@ -1153,46 +1054,6 @@ def get_label_count(self) -> int: res = self.client.execute(query_str, {"projectId": self.uid}) return res["project"]["labelCount"] - def get_queue_mode(self) -> "QueueMode": - """ - Provides the queue mode used for this project. - - Deprecation notice: This method is deprecated and will be removed in - a future version. To obtain the queue mode of a project, simply refer - to the queue_mode attribute of a Project. - - For more information, visit https://docs.labelbox.com/reference/migrating-to-workflows#upcoming-changes - - Returns: the QueueMode for this project - - """ - - logger.warning( - "Obtaining the queue_mode for a project through this method will soon" - " no longer be supported." - ) - - query_str = ( - """query %s($projectId: ID!) { - project(where: {id: $projectId}) { - tagSetStatus - } - } - """ - % "GetTagSetStatusPyApi" - ) - - status = self.client.execute(query_str, {"projectId": self.uid})[ - "project" - ]["tagSetStatus"] - - if status == "ENABLED": - return QueueMode.Batch - elif status == "DISABLED": - return QueueMode.Dataset - else: - raise ValueError("Status not known") - def add_model_config(self, model_config_id: str) -> str: """Adds a model config to this project. @@ -1285,18 +1146,13 @@ def set_labeling_parameter_overrides( See information on priority here: https://docs.labelbox.com/en/configure-editor/queue-system#reservation-system - >>> project.set_labeling_parameter_overrides([ - >>> (data_row_id1, 2), (data_row_id2, 1)]) - or >>> project.set_labeling_parameter_overrides([ >>> (data_row_gk1, 2), (data_row_gk2, 1)]) Args: data (iterable): An iterable of tuples. Each tuple must contain - either (DataRow, DataRowPriority) - or (DataRowIdentifier, priority) for the new override. + (DataRowIdentifier, priority) for the new override. DataRowIdentifier is an object representing a data row id or a global key. A DataIdentifier object can be a UniqueIds or GlobalKeys class. - NOTE - passing whole DatRow is deprecated. Please use a DataRowIdentifier instead. Priority: * Data will be labeled in priority order. @@ -1325,16 +1181,7 @@ def set_labeling_parameter_overrides( data_rows_with_identifiers = "" for data_row, priority in data: - if isinstance(data_row, DataRow): - data_rows_with_identifiers += f'{{dataRowIdentifier: {{id: "{data_row.uid}", idType: {IdType.DataRowId}}}, priority: {priority}}},' - elif isinstance(data_row, UniqueId) or isinstance( - data_row, GlobalKey - ): - data_rows_with_identifiers += f'{{dataRowIdentifier: {{id: "{data_row.key}", idType: {data_row.id_type}}}, priority: {priority}}},' - else: - raise TypeError( - f"Data row identifier should be be of type DataRow or Data Row Identifier. Found {type(data_row)}." - ) + data_rows_with_identifiers += f'{{dataRowIdentifier: {{id: "{data_row.key}", idType: {data_row.id_type}}}, priority: {priority}}},' query_str = template.substitute( dataWithDataRowIdentifiers=data_rows_with_identifiers @@ -1342,26 +1189,10 @@ def set_labeling_parameter_overrides( res = self.client.execute(query_str, {"projectId": self.uid}) return res["project"]["setLabelingParameterOverrides"]["success"] - @overload def update_data_row_labeling_priority( self, data_rows: DataRowIdentifiers, priority: int, - ) -> bool: - pass - - @overload - def update_data_row_labeling_priority( - self, - data_rows: List[str], - priority: int, - ) -> bool: - pass - - def update_data_row_labeling_priority( - self, - data_rows, - priority: int, ) -> bool: """ Updates labeling parameter overrides to this project in bulk. This method allows up to 1 million data rows to be @@ -1371,7 +1202,7 @@ def update_data_row_labeling_priority( https://docs.labelbox.com/en/configure-editor/queue-system#reservation-system Args: - data_rows: a list of data row ids to update priorities for. This can be a list of strings or a DataRowIdentifiers object + data_rows: data row identifiers object to update priorities. DataRowIdentifier objects are lists of ids or global keys. A DataIdentifier object can be a UniqueIds or GlobalKeys class. priority (int): Priority for the new override. See above for more information. @@ -1379,12 +1210,8 @@ def update_data_row_labeling_priority( bool, indicates if the operation was a success. """ - if isinstance(data_rows, list): - data_rows = UniqueIds(data_rows) - warnings.warn( - "Using data row ids will be deprecated. Please use " - "UniqueIds or GlobalKeys instead." - ) + if not isinstance(data_rows, get_args(DataRowIdentifiers)): + raise TypeError("data_rows must be a DataRowIdentifiers object") method = "createQueuePriorityUpdateTask" priority_param = "priority" @@ -1482,33 +1309,6 @@ def enable_model_assisted_labeling(self, toggle: bool = True) -> bool: "showingPredictionsToLabelers" ] - def bulk_import_requests(self) -> PaginatedCollection: - """Returns bulk import request objects which are used in model-assisted labeling. - These are returned with the oldest first, and most recent last. - """ - - id_param = "project_id" - query_str = """query ListAllImportRequestsPyApi($%s: ID!) { - bulkImportRequests ( - where: { projectId: $%s } - skip: %%d - first: %%d - ) { - %s - } - }""" % ( - id_param, - id_param, - query.results_query_part(Entity.BulkImportRequest), - ) - return PaginatedCollection( - self.client, - query_str, - {id_param: str(self.uid)}, - ["bulkImportRequests"], - Entity.BulkImportRequest, - ) - def batches(self) -> PaginatedCollection: """Fetch all batches that belong to this project @@ -1554,25 +1354,15 @@ def task_queues(self) -> List[TaskQueue]: for field_values in task_queue_values ] - @overload def move_data_rows_to_task_queue( self, data_row_ids: DataRowIdentifiers, task_queue_id: str ): - pass - - @overload - def move_data_rows_to_task_queue( - self, data_row_ids: List[str], task_queue_id: str - ): - pass - - def move_data_rows_to_task_queue(self, data_row_ids, task_queue_id: str): """ Moves data rows to the specified task queue. Args: - data_row_ids: a list of data row ids to be moved. This can be a list of strings or a DataRowIdentifiers object + data_row_ids: a list of data row ids to be moved. This should be a DataRowIdentifiers object DataRowIdentifier objects are lists of ids or global keys. A DataIdentifier object can be a UniqueIds or GlobalKeys class. task_queue_id: the task queue id to be moved to, or None to specify the "Done" queue @@ -1580,12 +1370,9 @@ def move_data_rows_to_task_queue(self, data_row_ids, task_queue_id: str): None if successful, or a raised error on failure """ - if isinstance(data_row_ids, list): - data_row_ids = UniqueIds(data_row_ids) - warnings.warn( - "Using data row ids will be deprecated. Please use " - "UniqueIds or GlobalKeys instead." - ) + + if not isinstance(data_row_ids, get_args(DataRowIdentifiers)): + raise TypeError("data_rows must be a DataRowIdentifiers object") method = "createBulkAddRowsToQueueTask" query_str = ( @@ -1633,77 +1420,6 @@ def _wait_for_task(self, task_id: str) -> Task: return task - def upload_annotations( - self, - name: str, - annotations: Union[str, Path, Iterable[Dict]], - validate: bool = False, - ) -> "BulkImportRequest": # type: ignore - """Uploads annotations to a new Editor project. - - Args: - name (str): name of the BulkImportRequest job - annotations (str or Path or Iterable): - url that is publicly accessible by Labelbox containing an - ndjson file - OR local path to an ndjson file - OR iterable of annotation rows - validate (bool): - Whether or not to validate the payload before uploading. - Returns: - BulkImportRequest - """ - - if isinstance(annotations, str) or isinstance(annotations, Path): - - def _is_url_valid(url: Union[str, Path]) -> bool: - """Verifies that the given string is a valid url. - - Args: - url: string to be checked - Returns: - True if the given url is valid otherwise False - - """ - if isinstance(url, Path): - return False - parsed = urlparse(url) - return bool(parsed.scheme) and bool(parsed.netloc) - - if _is_url_valid(annotations): - return Entity.BulkImportRequest.create_from_url( - client=self.client, - project_id=self.uid, - name=name, - url=str(annotations), - validate=validate, - ) - else: - path = Path(annotations) - if not path.exists(): - raise FileNotFoundError( - f"{annotations} is not a valid url nor existing local file" - ) - return Entity.BulkImportRequest.create_from_local_file( - client=self.client, - project_id=self.uid, - name=name, - file=path, - validate_file=validate, - ) - elif isinstance(annotations, Iterable): - return Entity.BulkImportRequest.create_from_objects( - client=self.client, - project_id=self.uid, - name=name, - predictions=annotations, # type: ignore - validate=validate, - ) - else: - raise ValueError( - f"Invalid annotations given of type: {type(annotations)}" - ) - def _wait_until_data_rows_are_processed( self, data_row_ids: Optional[List[str]] = None, diff --git a/libs/labelbox/src/labelbox/schema/project_model_config.py b/libs/labelbox/src/labelbox/schema/project_model_config.py index 9b6d8a0bb..c8773abf9 100644 --- a/libs/labelbox/src/labelbox/schema/project_model_config.py +++ b/libs/labelbox/src/labelbox/schema/project_model_config.py @@ -1,10 +1,11 @@ -from labelbox.orm.db_object import DbObject -from labelbox.orm.model import Field, Relationship -from labelbox.exceptions import ( +from lbox.exceptions import ( LabelboxError, error_message_for_unparsed_graphql_error, ) +from labelbox.orm.db_object import DbObject +from labelbox.orm.model import Field, Relationship + class ProjectModelConfig(DbObject): """A ProjectModelConfig represents an association between a project and a single model config. diff --git a/libs/labelbox/src/labelbox/schema/project_overview.py b/libs/labelbox/src/labelbox/schema/project_overview.py index cee195c10..8be20fbda 100644 --- a/libs/labelbox/src/labelbox/schema/project_overview.py +++ b/libs/labelbox/src/labelbox/schema/project_overview.py @@ -1,6 +1,7 @@ from typing import Dict, List -from typing_extensions import TypedDict + from pydantic import BaseModel +from typing_extensions import TypedDict class ProjectOverview(BaseModel): diff --git a/libs/labelbox/src/labelbox/schema/queue_mode.py b/libs/labelbox/src/labelbox/schema/queue_mode.py deleted file mode 100644 index 333e92987..000000000 --- a/libs/labelbox/src/labelbox/schema/queue_mode.py +++ /dev/null @@ -1,12 +0,0 @@ -from enum import Enum - - -class QueueMode(str, Enum): - Batch = "BATCH" - Dataset = "DATA_SET" - - @classmethod - def _missing_(cls, value): - # Parses the deprecated "CATALOG" value back to QueueMode.Batch. - if value == "CATALOG": - return QueueMode.Batch diff --git a/libs/labelbox/src/labelbox/schema/search_filters.py b/libs/labelbox/src/labelbox/schema/search_filters.py index 13b158678..52330492d 100644 --- a/libs/labelbox/src/labelbox/schema/search_filters.py +++ b/libs/labelbox/src/labelbox/schema/search_filters.py @@ -1,11 +1,15 @@ import datetime from enum import Enum -from typing import List, Union -from pydantic import PlainSerializer, BaseModel, Field +from typing import Annotated, List, Union -from typing_extensions import Annotated +from pydantic import ( + BaseModel, + ConfigDict, + Field, + PlainSerializer, + field_validator, +) -from pydantic import BaseModel, Field, field_validator from labelbox.schema.labeling_service_status import LabelingServiceStatus from labelbox.utils import format_iso_datetime @@ -15,8 +19,7 @@ class BaseSearchFilter(BaseModel): Shared code for all search filters """ - class Config: - use_enum_values = True + model_config = ConfigDict(use_enum_values=True) class OperationTypeEnum(Enum): diff --git a/libs/labelbox/src/labelbox/schema/send_to_annotate_params.py b/libs/labelbox/src/labelbox/schema/send_to_annotate_params.py index 18bd26637..fe09cf4c0 100644 --- a/libs/labelbox/src/labelbox/schema/send_to_annotate_params.py +++ b/libs/labelbox/src/labelbox/schema/send_to_annotate_params.py @@ -1,18 +1,11 @@ -import sys +from typing import Dict, Optional, TypedDict -from typing import Optional, Dict +from pydantic import BaseModel, model_validator from labelbox.schema.conflict_resolution_strategy import ( ConflictResolutionStrategy, ) -if sys.version_info >= (3, 8): - from typing import TypedDict -else: - from typing_extensions import TypedDict - -from pydantic import BaseModel, model_validator - class SendToAnnotateFromCatalogParams(BaseModel): """ diff --git a/libs/labelbox/src/labelbox/schema/slice.py b/libs/labelbox/src/labelbox/schema/slice.py index 624731024..a640ebc1d 100644 --- a/libs/labelbox/src/labelbox/schema/slice.py +++ b/libs/labelbox/src/labelbox/schema/slice.py @@ -53,43 +53,6 @@ class CatalogSlice(Slice): Represents a Slice used for filtering data rows in Catalog. """ - def get_data_row_ids(self) -> PaginatedCollection: - """ - Fetches all data row ids that match this Slice - - Returns: - A PaginatedCollection of mapping of data row ids to global keys - """ - - warnings.warn( - "get_data_row_ids will be deprecated. Use get_data_row_identifiers instead" - ) - - query_str = """ - query getDataRowIdsBySavedQueryPyApi($id: ID!, $from: String, $first: Int!) { - getDataRowIdsBySavedQuery(input: { - savedQueryId: $id, - after: $from - first: $first - }) { - totalCount - nodes - pageInfo { - endCursor - hasNextPage - } - } - } - """ - return PaginatedCollection( - client=self.client, - query=query_str, - params={"id": str(self.uid)}, - dereferencing=["getDataRowIdsBySavedQuery", "nodes"], - obj_class=lambda _, data_row_id: data_row_id, - cursor_path=["getDataRowIdsBySavedQuery", "pageInfo", "endCursor"], - ) - def get_data_row_identifiers(self) -> PaginatedCollection: """ Fetches all data row ids and global keys (where defined) that match this Slice @@ -164,6 +127,13 @@ def export_v2( >>> task.wait_till_done() >>> task.result """ + + warnings.warn( + "You are currently utilizing export_v2 for this action, which will be removed in 7.0. Please refer to our docs for export alternatives. https://docs.labelbox.com/reference/export-overview#export-methods", + DeprecationWarning, + stacklevel=2, + ) + task, is_streamable = self._export(task_name, params) if is_streamable: return ExportTask(task, True) diff --git a/libs/labelbox/src/labelbox/schema/task.py b/libs/labelbox/src/labelbox/schema/task.py index 9d7a26e1d..f996ae05d 100644 --- a/libs/labelbox/src/labelbox/schema/task.py +++ b/libs/labelbox/src/labelbox/schema/task.py @@ -1,14 +1,14 @@ import json import logging -import requests import time -from typing import TYPE_CHECKING, Callable, Optional, Dict, Any, List, Union -from labelbox import parser +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union -from labelbox.exceptions import ResourceNotFoundError -from labelbox.orm.db_object import DbObject -from labelbox.orm.model import Field, Relationship, Entity +import requests +from lbox.exceptions import ResourceNotFoundError +from labelbox import parser +from labelbox.orm.db_object import DbObject +from labelbox.orm.model import Entity, Field, Relationship from labelbox.pagination import PaginatedCollection from labelbox.schema.internal.datarow_upload_constants import ( DOWNLOAD_RESULT_PAGE_SIZE, diff --git a/libs/labelbox/src/labelbox/schema/user_group.py b/libs/labelbox/src/labelbox/schema/user_group.py index 9d506bf92..2e93b4376 100644 --- a/libs/labelbox/src/labelbox/schema/user_group.py +++ b/libs/labelbox/src/labelbox/schema/user_group.py @@ -1,21 +1,21 @@ -from enum import Enum -from typing import Set, Iterator from collections import defaultdict +from enum import Enum +from typing import Iterator, Set -from labelbox import Client -from labelbox.exceptions import ResourceCreationError -from labelbox.schema.user import User -from labelbox.schema.project import Project -from labelbox.exceptions import ( - UnprocessableEntityError, +from lbox.exceptions import ( MalformedQueryException, + ResourceCreationError, ResourceNotFoundError, + UnprocessableEntityError, ) -from labelbox.schema.queue_mode import QueueMode -from labelbox.schema.ontology_kind import EditorTaskType -from labelbox.schema.media_type import MediaType from pydantic import BaseModel, ConfigDict +from labelbox import Client +from labelbox.schema.media_type import MediaType +from labelbox.schema.ontology_kind import EditorTaskType +from labelbox.schema.project import Project +from labelbox.schema.user import User + class UserGroupColor(Enum): """ @@ -92,9 +92,6 @@ def __init__( color (UserGroupColor, optional): The color of the user group. Defaults to UserGroupColor.BLUE. users (Set[User], optional): The set of users in the user group. Defaults to an empty set. projects (Set[Project], optional): The set of projects associated with the user group. Defaults to an empty set. - - Raises: - RuntimeError: If the experimental feature is not enabled in the client. """ super().__init__( client=client, @@ -104,10 +101,6 @@ def __init__( users=users, projects=projects, ) - if not self.client.enable_experimental: - raise RuntimeError( - "Please enable experimental in client to use UserGroups" - ) def get(self) -> "UserGroup": """ @@ -284,7 +277,7 @@ def create(self) -> "UserGroup": except Exception as e: error = e if not result or error: - # this is client side only, server doesn't have an equivalent error + # This is client side only, server doesn't have an equivalent error raise ResourceCreationError( f"Failed to create user group, either user group name is in use currently, or provided user or projects don't exist server error: {error}" ) @@ -417,7 +410,6 @@ def _get_projects_set(self, project_nodes): project_values = defaultdict(lambda: None) project_values["id"] = project["id"] project_values["name"] = project["name"] - project_values["queueMode"] = QueueMode.Batch.value project_values["editorTaskType"] = EditorTaskType.Missing.value project_values["mediaType"] = MediaType.Image.value projects.add(Project(self.client, project_values)) diff --git a/libs/labelbox/src/labelbox/typing_imports.py b/libs/labelbox/src/labelbox/typing_imports.py deleted file mode 100644 index 6edfb9bef..000000000 --- a/libs/labelbox/src/labelbox/typing_imports.py +++ /dev/null @@ -1,11 +0,0 @@ -""" -This module imports types that differ across python versions, so other modules -don't have to worry about where they should be imported from. -""" - -import sys - -if sys.version_info >= (3, 8): - from typing import Literal -else: - from typing_extensions import Literal diff --git a/libs/labelbox/src/labelbox/utils.py b/libs/labelbox/src/labelbox/utils.py index c76ce188f..dcf51be82 100644 --- a/libs/labelbox/src/labelbox/utils.py +++ b/libs/labelbox/src/labelbox/utils.py @@ -87,8 +87,8 @@ class _NoCoercionMixin: when serializing the object. Example: - class ConversationData(BaseData, _NoCoercionMixin): - class_name: Literal["ConversationData"] = "ConversationData" + class GenericDataRowData(BaseData, _NoCoercionMixin): + class_name: Literal["GenericDataRowData"] = "GenericDataRowData" """ diff --git a/libs/labelbox/tests/conftest.py b/libs/labelbox/tests/conftest.py index 49eab165d..3e9f0b491 100644 --- a/libs/labelbox/tests/conftest.py +++ b/libs/labelbox/tests/conftest.py @@ -12,6 +12,7 @@ import pytest import requests +from lbox.exceptions import LabelboxError from labelbox import ( Classification, @@ -24,7 +25,6 @@ Option, Tool, ) -from labelbox.exceptions import LabelboxError from labelbox.orm import query from labelbox.pagination import PaginatedCollection from labelbox.schema.annotation_import import LabelImport @@ -33,7 +33,6 @@ from labelbox.schema.ontology import Ontology from labelbox.schema.project import Project from labelbox.schema.quality_mode import QualityMode -from labelbox.schema.queue_mode import QueueMode IMG_URL = "https://picsum.photos/200/300.jpg" MASKABLE_IMG_URL = "https://storage.googleapis.com/labelbox-datasets/image_sample_data/2560px-Kitano_Street_Kobe01s5s4110.jpeg" @@ -445,7 +444,6 @@ def conversation_entity_data_row(client, rand_gen): def project(client, rand_gen): project = client.create_project( name=rand_gen(str), - queue_mode=QueueMode.Batch, media_type=MediaType.Image, ) yield project @@ -456,8 +454,7 @@ def project(client, rand_gen): def consensus_project(client, rand_gen): project = client.create_project( name=rand_gen(str), - quality_mode=QualityMode.Consensus, - queue_mode=QueueMode.Batch, + quality_modes={QualityMode.Consensus}, media_type=MediaType.Image, ) yield project @@ -647,7 +644,6 @@ def configured_project_with_label( """ project = client.create_project( name=rand_gen(str), - queue_mode=QueueMode.Batch, media_type=MediaType.Image, ) project._wait_until_data_rows_are_processed( @@ -661,7 +657,7 @@ def configured_project_with_label( [data_row.uid], # sample of data row objects 5, # priority between 1(Highest) - 5(lowest) ) - ontology = _setup_ontology(project) + ontology = _setup_ontology(project, client) label = _create_label( project, data_row, ontology, wait_for_label_processing ) @@ -704,20 +700,19 @@ def create_label(): return label -def _setup_ontology(project): - editor = list( - project.client.get_labeling_frontends( - where=LabelingFrontend.name == "editor" - ) - )[0] +def _setup_ontology(project: Project, client: Client): ontology_builder = OntologyBuilder( tools=[ Tool(tool=Tool.Type.BBOX, name="test-bbox-class"), ] ) - project.setup(editor, ontology_builder.asdict()) - # TODO: ontology may not be synchronous after setup. remove sleep when api is more consistent - time.sleep(2) + ontology = client.create_ontology( + name="ontology with features", + media_type=MediaType.Image, + normalized=ontology_builder.asdict(), + ) + project.connect_ontology(ontology) + return OntologyBuilder.from_project(project) @@ -750,7 +745,6 @@ def configured_batch_project_with_label( """ project = client.create_project( name=rand_gen(str), - queue_mode=QueueMode.Batch, media_type=MediaType.Image, ) data_rows = [dr.uid for dr in list(dataset.data_rows())] @@ -760,7 +754,7 @@ def configured_batch_project_with_label( project.create_batch("test-batch", data_rows) project.data_row_ids = data_rows - ontology = _setup_ontology(project) + ontology = _setup_ontology(project, client) label = _create_label( project, data_row, ontology, wait_for_label_processing ) @@ -785,7 +779,6 @@ def configured_batch_project_with_multiple_datarows( """ project = client.create_project( name=rand_gen(str), - queue_mode=QueueMode.Batch, media_type=MediaType.Image, ) global_keys = [dr.global_key for dr in data_rows] @@ -793,7 +786,7 @@ def configured_batch_project_with_multiple_datarows( batch_name = f"batch {uuid.uuid4()}" project.create_batch(batch_name, global_keys=global_keys) - ontology = _setup_ontology(project) + ontology = _setup_ontology(project, client) for datarow in data_rows: _create_label(project, datarow, ontology, wait_for_label_processing) @@ -1030,11 +1023,11 @@ def _upload_invalid_data_rows_for_dataset(dataset: Dataset): @pytest.fixture def configured_project( - project_with_empty_ontology, initial_dataset, rand_gen, image_url + project_with_one_feature_ontology, initial_dataset, rand_gen, image_url ): dataset = initial_dataset data_row_id = dataset.create_data_row(row_data=image_url).uid - project = project_with_empty_ontology + project = project_with_one_feature_ontology batch = project.create_batch( rand_gen(str), @@ -1049,24 +1042,24 @@ def configured_project( @pytest.fixture -def project_with_empty_ontology(project): - editor = list( - project.client.get_labeling_frontends( - where=LabelingFrontend.name == "editor" - ) - )[0] - empty_ontology = {"tools": [], "classifications": []} - project.setup(editor, empty_ontology) +def project_with_one_feature_ontology(project, client: Client): + tools = [ + Tool(tool=Tool.Type.BBOX, name="test-bbox-class").asdict(), + ] + empty_ontology = {"tools": tools, "classifications": []} + ontology = client.create_ontology( + "empty ontology", empty_ontology, MediaType.Image + ) + project.connect_ontology(ontology) yield project @pytest.fixture def configured_project_with_complex_ontology( - client, initial_dataset, rand_gen, image_url, teardown_helpers + client: Client, initial_dataset, rand_gen, image_url, teardown_helpers ): project = client.create_project( name=rand_gen(str), - queue_mode=QueueMode.Batch, media_type=MediaType.Image, ) dataset = initial_dataset @@ -1080,19 +1073,12 @@ def configured_project_with_complex_ontology( ) project.data_row_ids = data_row_ids - editor = list( - project.client.get_labeling_frontends( - where=LabelingFrontend.name == "editor" - ) - )[0] - ontology = OntologyBuilder() tools = [ Tool(tool=Tool.Type.BBOX, name="test-bbox-class"), Tool(tool=Tool.Type.LINE, name="test-line-class"), Tool(tool=Tool.Type.POINT, name="test-point-class"), Tool(tool=Tool.Type.POLYGON, name="test-polygon-class"), - Tool(tool=Tool.Type.NER, name="test-ner-class"), ] options = [ @@ -1124,7 +1110,11 @@ def configured_project_with_complex_ontology( for c in classifications: ontology.add_classification(c) - project.setup(editor, ontology.asdict()) + ontology = client.create_ontology( + "complex image ontology", ontology.asdict(), MediaType.Image + ) + + project.connect_ontology(ontology) yield [project, data_row] teardown_helpers.teardown_project_labels_ontology_feature_schemas(project) diff --git a/libs/labelbox/tests/data/annotation_import/test_bulk_import_request.py b/libs/labelbox/tests/data/annotation_import/test_bulk_import_request.py deleted file mode 100644 index 9abae1422..000000000 --- a/libs/labelbox/tests/data/annotation_import/test_bulk_import_request.py +++ /dev/null @@ -1,258 +0,0 @@ -from unittest.mock import patch -import uuid -from labelbox import parser, Project -from labelbox.data.annotation_types.data.generic_data_row_data import ( - GenericDataRowData, -) -import pytest -import random -from labelbox.data.annotation_types.annotation import ObjectAnnotation -from labelbox.data.annotation_types.classification.classification import ( - Checklist, - ClassificationAnnotation, - ClassificationAnswer, - Radio, -) -from labelbox.data.annotation_types.data.video import VideoData -from labelbox.data.annotation_types.geometry.point import Point -from labelbox.data.annotation_types.geometry.rectangle import ( - Rectangle, - RectangleUnit, -) -from labelbox.data.annotation_types.label import Label -from labelbox.data.annotation_types.data.text import TextData -from labelbox.data.annotation_types.ner import ( - DocumentEntity, - DocumentTextSelection, -) -from labelbox.data.annotation_types.video import VideoObjectAnnotation - -from labelbox.data.serialization import NDJsonConverter -from labelbox.exceptions import MALValidationError, UuidError -from labelbox.schema.bulk_import_request import BulkImportRequest -from labelbox.schema.enums import BulkImportRequestState -from labelbox.schema.annotation_import import LabelImport, MALPredictionImport -from labelbox.schema.media_type import MediaType - -""" -- Here we only want to check that the uploads are calling the validation -- Then with unit tests we can check the types of errors raised -""" -# TODO: remove library once bulk import requests are removed - - -@pytest.mark.order(1) -def test_create_from_url(module_project): - name = str(uuid.uuid4()) - url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson" - - bulk_import_request = module_project.upload_annotations( - name=name, annotations=url, validate=False - ) - - assert bulk_import_request.project() == module_project - assert bulk_import_request.name == name - assert bulk_import_request.input_file_url == url - assert bulk_import_request.error_file_url is None - assert bulk_import_request.status_file_url is None - assert bulk_import_request.state == BulkImportRequestState.RUNNING - - -def test_validate_file(module_project): - name = str(uuid.uuid4()) - url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson" - with pytest.raises(MALValidationError): - module_project.upload_annotations( - name=name, annotations=url, validate=True - ) - # Schema ids shouldn't match - - -def test_create_from_objects( - module_project: Project, predictions, annotation_import_test_helpers -): - name = str(uuid.uuid4()) - - bulk_import_request = module_project.upload_annotations( - name=name, annotations=predictions - ) - - assert bulk_import_request.project() == module_project - assert bulk_import_request.name == name - assert bulk_import_request.error_file_url is None - assert bulk_import_request.status_file_url is None - assert bulk_import_request.state == BulkImportRequestState.RUNNING - annotation_import_test_helpers.assert_file_content( - bulk_import_request.input_file_url, predictions - ) - - -def test_create_from_label_objects( - module_project, predictions, annotation_import_test_helpers -): - name = str(uuid.uuid4()) - - labels = list(NDJsonConverter.deserialize(predictions)) - bulk_import_request = module_project.upload_annotations( - name=name, annotations=labels - ) - - assert bulk_import_request.project() == module_project - assert bulk_import_request.name == name - assert bulk_import_request.error_file_url is None - assert bulk_import_request.status_file_url is None - assert bulk_import_request.state == BulkImportRequestState.RUNNING - normalized_predictions = list(NDJsonConverter.serialize(labels)) - annotation_import_test_helpers.assert_file_content( - bulk_import_request.input_file_url, normalized_predictions - ) - - -def test_create_from_local_file( - tmp_path, predictions, module_project, annotation_import_test_helpers -): - name = str(uuid.uuid4()) - file_name = f"{name}.ndjson" - file_path = tmp_path / file_name - with file_path.open("w") as f: - parser.dump(predictions, f) - - bulk_import_request = module_project.upload_annotations( - name=name, annotations=str(file_path), validate=False - ) - - assert bulk_import_request.project() == module_project - assert bulk_import_request.name == name - assert bulk_import_request.error_file_url is None - assert bulk_import_request.status_file_url is None - assert bulk_import_request.state == BulkImportRequestState.RUNNING - annotation_import_test_helpers.assert_file_content( - bulk_import_request.input_file_url, predictions - ) - - -def test_get(client, module_project): - name = str(uuid.uuid4()) - url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson" - module_project.upload_annotations( - name=name, annotations=url, validate=False - ) - - bulk_import_request = BulkImportRequest.from_name( - client, project_id=module_project.uid, name=name - ) - - assert bulk_import_request.project() == module_project - assert bulk_import_request.name == name - assert bulk_import_request.input_file_url == url - assert bulk_import_request.error_file_url is None - assert bulk_import_request.status_file_url is None - assert bulk_import_request.state == BulkImportRequestState.RUNNING - - -def test_validate_ndjson(tmp_path, module_project): - file_name = f"broken.ndjson" - file_path = tmp_path / file_name - with file_path.open("w") as f: - f.write("test") - - with pytest.raises(ValueError): - module_project.upload_annotations( - name="name", validate=True, annotations=str(file_path) - ) - - -def test_validate_ndjson_uuid(tmp_path, module_project, predictions): - file_name = f"repeat_uuid.ndjson" - file_path = tmp_path / file_name - repeat_uuid = predictions.copy() - uid = str(uuid.uuid4()) - repeat_uuid[0]["uuid"] = uid - repeat_uuid[1]["uuid"] = uid - - with file_path.open("w") as f: - parser.dump(repeat_uuid, f) - - with pytest.raises(UuidError): - module_project.upload_annotations( - name="name", validate=True, annotations=str(file_path) - ) - - with pytest.raises(UuidError): - module_project.upload_annotations( - name="name", validate=True, annotations=repeat_uuid - ) - - -@pytest.mark.skip( - "Slow test and uses a deprecated api endpoint for annotation imports" -) -def test_wait_till_done(rectangle_inference, project): - name = str(uuid.uuid4()) - url = project.client.upload_data( - content=parser.dumps(rectangle_inference), sign=True - ) - bulk_import_request = project.upload_annotations( - name=name, annotations=url, validate=False - ) - - assert len(bulk_import_request.inputs) == 1 - bulk_import_request.wait_until_done() - assert bulk_import_request.state == BulkImportRequestState.FINISHED - - # Check that the status files are being returned as expected - assert len(bulk_import_request.errors) == 0 - assert len(bulk_import_request.inputs) == 1 - assert bulk_import_request.inputs[0]["uuid"] == rectangle_inference["uuid"] - assert len(bulk_import_request.statuses) == 1 - assert bulk_import_request.statuses[0]["status"] == "SUCCESS" - assert ( - bulk_import_request.statuses[0]["uuid"] == rectangle_inference["uuid"] - ) - - -def test_project_bulk_import_requests(module_project, predictions): - result = module_project.bulk_import_requests() - assert len(list(result)) == 0 - - name = str(uuid.uuid4()) - bulk_import_request = module_project.upload_annotations( - name=name, annotations=predictions - ) - bulk_import_request.wait_until_done() - - name = str(uuid.uuid4()) - bulk_import_request = module_project.upload_annotations( - name=name, annotations=predictions - ) - bulk_import_request.wait_until_done() - - name = str(uuid.uuid4()) - bulk_import_request = module_project.upload_annotations( - name=name, annotations=predictions - ) - bulk_import_request.wait_until_done() - - result = module_project.bulk_import_requests() - assert len(list(result)) == 3 - - -def test_delete(module_project, predictions): - name = str(uuid.uuid4()) - - bulk_import_requests = module_project.bulk_import_requests() - [ - bulk_import_request.delete() - for bulk_import_request in bulk_import_requests - ] - - bulk_import_request = module_project.upload_annotations( - name=name, annotations=predictions - ) - bulk_import_request.wait_until_done() - all_import_requests = module_project.bulk_import_requests() - assert len(list(all_import_requests)) == 1 - - bulk_import_request.delete() - all_import_requests = module_project.bulk_import_requests() - assert len(list(all_import_requests)) == 0 diff --git a/libs/labelbox/tests/data/annotation_import/test_data_types.py b/libs/labelbox/tests/data/annotation_import/test_data_types.py deleted file mode 100644 index 1e45295ef..000000000 --- a/libs/labelbox/tests/data/annotation_import/test_data_types.py +++ /dev/null @@ -1,83 +0,0 @@ -import pytest - -from labelbox.data.annotation_types.data import ( - AudioData, - ConversationData, - DocumentData, - HTMLData, - ImageData, - TextData, -) -from labelbox.data.serialization import NDJsonConverter -from labelbox.data.annotation_types.data.video import VideoData - -import labelbox.types as lb_types -from labelbox.schema.media_type import MediaType - -# Unit test for label based on data type. -# TODO: Dicom removed it is unstable when you deserialize and serialize on label import. If we intend to keep this library this needs add generic data types tests work with this data type. -# TODO: add MediaType.LLMPromptResponseCreation(data gen) once supported and llm human preference once media type is added - - -@pytest.mark.parametrize( - "media_type, data_type_class", - [ - (MediaType.Audio, AudioData), - (MediaType.Html, HTMLData), - (MediaType.Image, ImageData), - (MediaType.Text, TextData), - (MediaType.Video, VideoData), - (MediaType.Conversational, ConversationData), - (MediaType.Document, DocumentData), - ], -) -def test_data_row_type_by_data_row_id( - media_type, - data_type_class, - annotations_by_media_type, - hardcoded_datarow_id, -): - annotations_ndjson = annotations_by_media_type[media_type] - annotations_ndjson = [annotation[0] for annotation in annotations_ndjson] - - label = list(NDJsonConverter.deserialize(annotations_ndjson))[0] - - data_label = lb_types.Label( - data=data_type_class(uid=hardcoded_datarow_id()), - annotations=label.annotations, - ) - - assert data_label.data.uid == label.data.uid - assert label.annotations == data_label.annotations - - -@pytest.mark.parametrize( - "media_type, data_type_class", - [ - (MediaType.Audio, AudioData), - (MediaType.Html, HTMLData), - (MediaType.Image, ImageData), - (MediaType.Text, TextData), - (MediaType.Video, VideoData), - (MediaType.Conversational, ConversationData), - (MediaType.Document, DocumentData), - ], -) -def test_data_row_type_by_global_key( - media_type, - data_type_class, - annotations_by_media_type, - hardcoded_global_key, -): - annotations_ndjson = annotations_by_media_type[media_type] - annotations_ndjson = [annotation[0] for annotation in annotations_ndjson] - - label = list(NDJsonConverter.deserialize(annotations_ndjson))[0] - - data_label = lb_types.Label( - data=data_type_class(global_key=hardcoded_global_key()), - annotations=label.annotations, - ) - - assert data_label.data.global_key == label.data.global_key - assert label.annotations == data_label.annotations diff --git a/libs/labelbox/tests/data/annotation_import/test_generic_data_types.py b/libs/labelbox/tests/data/annotation_import/test_generic_data_types.py index 1cc5538d9..921e98c9d 100644 --- a/libs/labelbox/tests/data/annotation_import/test_generic_data_types.py +++ b/libs/labelbox/tests/data/annotation_import/test_generic_data_types.py @@ -29,78 +29,6 @@ def validate_iso_format(date_string: str): assert parsed_t.second is not None -@pytest.mark.parametrize( - "media_type, data_type_class", - [ - (MediaType.Audio, GenericDataRowData), - (MediaType.Html, GenericDataRowData), - (MediaType.Image, GenericDataRowData), - (MediaType.Text, GenericDataRowData), - (MediaType.Video, GenericDataRowData), - (MediaType.Conversational, GenericDataRowData), - (MediaType.Document, GenericDataRowData), - (MediaType.LLMPromptResponseCreation, GenericDataRowData), - (MediaType.LLMPromptCreation, GenericDataRowData), - (OntologyKind.ResponseCreation, GenericDataRowData), - (OntologyKind.ModelEvaluation, GenericDataRowData), - ], -) -def test_generic_data_row_type_by_data_row_id( - media_type, - data_type_class, - annotations_by_media_type, - hardcoded_datarow_id, -): - annotations_ndjson = annotations_by_media_type[media_type] - annotations_ndjson = [annotation[0] for annotation in annotations_ndjson] - - label = list(NDJsonConverter.deserialize(annotations_ndjson))[0] - - data_label = Label( - data=data_type_class(uid=hardcoded_datarow_id()), - annotations=label.annotations, - ) - - assert data_label.data.uid == label.data.uid - assert label.annotations == data_label.annotations - - -@pytest.mark.parametrize( - "media_type, data_type_class", - [ - (MediaType.Audio, GenericDataRowData), - (MediaType.Html, GenericDataRowData), - (MediaType.Image, GenericDataRowData), - (MediaType.Text, GenericDataRowData), - (MediaType.Video, GenericDataRowData), - (MediaType.Conversational, GenericDataRowData), - (MediaType.Document, GenericDataRowData), - # (MediaType.LLMPromptResponseCreation, GenericDataRowData), - # (MediaType.LLMPromptCreation, GenericDataRowData), - (OntologyKind.ResponseCreation, GenericDataRowData), - (OntologyKind.ModelEvaluation, GenericDataRowData), - ], -) -def test_generic_data_row_type_by_global_key( - media_type, - data_type_class, - annotations_by_media_type, - hardcoded_global_key, -): - annotations_ndjson = annotations_by_media_type[media_type] - annotations_ndjson = [annotation[0] for annotation in annotations_ndjson] - - label = list(NDJsonConverter.deserialize(annotations_ndjson))[0] - - data_label = Label( - data=data_type_class(global_key=hardcoded_global_key()), - annotations=label.annotations, - ) - - assert data_label.data.global_key == label.data.global_key - assert label.annotations == data_label.annotations - - @pytest.mark.parametrize( "configured_project, media_type", [ diff --git a/libs/labelbox/tests/data/annotation_import/test_mea_prediction_import.py b/libs/labelbox/tests/data/annotation_import/test_mea_prediction_import.py index fccca2a3f..5f47975ad 100644 --- a/libs/labelbox/tests/data/annotation_import/test_mea_prediction_import.py +++ b/libs/labelbox/tests/data/annotation_import/test_mea_prediction_import.py @@ -1,5 +1,19 @@ import uuid from labelbox import parser +from labelbox.data.annotation_types.annotation import ObjectAnnotation +from labelbox.data.annotation_types.classification.classification import ( + ClassificationAnnotation, + ClassificationAnswer, + Radio, +) +from labelbox.data.annotation_types.data.generic_data_row_data import ( + GenericDataRowData, +) +from labelbox.data.annotation_types.geometry.line import Line +from labelbox.data.annotation_types.geometry.point import Point +from labelbox.data.annotation_types.geometry.polygon import Polygon +from labelbox.data.annotation_types.geometry.rectangle import Rectangle +from labelbox.data.annotation_types.label import Label import pytest from labelbox import ModelRun @@ -193,14 +207,60 @@ def test_create_from_label_objects( annotation_import_test_helpers, ): name = str(uuid.uuid4()) - use_data_row_ids = [ + use_data_row_id = [ p["dataRow"]["id"] for p in object_predictions_for_annotation_import ] - model_run_with_data_rows.upsert_data_rows(use_data_row_ids) - predictions = list( - NDJsonConverter.deserialize(object_predictions_for_annotation_import) - ) + model_run_with_data_rows.upsert_data_rows(use_data_row_id) + + predictions = [] + for data_row_id in use_data_row_id: + predictions.append( + Label( + data=GenericDataRowData( + uid=data_row_id, + ), + annotations=[ + ObjectAnnotation( + name="polygon", + extra={ + "uuid": "6d10fa30-3ea0-4e6c-bbb1-63f5c29fe3e4", + }, + value=Polygon( + points=[ + Point(x=147.692, y=118.154), + Point(x=142.769, y=104.923), + Point(x=57.846, y=118.769), + Point(x=28.308, y=169.846), + Point(x=147.692, y=118.154), + ], + ), + ), + ObjectAnnotation( + name="bbox", + extra={ + "uuid": "15b7138f-4bbc-42c5-ae79-45d87b0a3b2a", + }, + value=Rectangle( + start=Point(x=58.0, y=48.0), + end=Point(x=70.0, y=113.0), + ), + ), + ObjectAnnotation( + name="polyline", + extra={ + "uuid": "cf4c6df9-c39c-4fbc-9541-470f6622978a", + }, + value=Line( + points=[ + Point(x=147.692, y=118.154), + Point(x=150.692, y=160.154), + ], + ), + ), + ], + ), + ) annotation_import = model_run_with_data_rows.add_predictions( name=name, predictions=predictions diff --git a/libs/labelbox/tests/data/annotation_import/test_model.py b/libs/labelbox/tests/data/annotation_import/test_model.py index dcfe9ef2c..56b2e07c4 100644 --- a/libs/labelbox/tests/data/annotation_import/test_model.py +++ b/libs/labelbox/tests/data/annotation_import/test_model.py @@ -1,7 +1,7 @@ import pytest +from lbox.exceptions import ResourceNotFoundError from labelbox import Model -from labelbox.exceptions import ResourceNotFoundError def test_model(client, configured_project, rand_gen): diff --git a/libs/labelbox/tests/data/annotation_import/test_ndjson_validation.py b/libs/labelbox/tests/data/annotation_import/test_ndjson_validation.py deleted file mode 100644 index a0df559fc..000000000 --- a/libs/labelbox/tests/data/annotation_import/test_ndjson_validation.py +++ /dev/null @@ -1,230 +0,0 @@ -from labelbox.schema.media_type import MediaType -from labelbox.schema.project import Project -import pytest - -from labelbox import parser -from pytest_cases import parametrize, fixture_ref - -from labelbox.exceptions import MALValidationError -from labelbox.schema.bulk_import_request import ( - NDChecklist, - NDClassification, - NDMask, - NDPolygon, - NDPolyline, - NDRadio, - NDRectangle, - NDText, - NDTextEntity, - NDTool, - _validate_ndjson, -) - -""" -- These NDlabels are apart of bulkImportReqeust and should be removed once bulk import request is removed -""" - - -def test_classification_construction(checklist_inference, text_inference): - checklist = NDClassification.build(checklist_inference[0]) - assert isinstance(checklist, NDChecklist) - text = NDClassification.build(text_inference[0]) - assert isinstance(text, NDText) - - -@parametrize( - "inference, expected_type", - [ - (fixture_ref("polygon_inference"), NDPolygon), - (fixture_ref("rectangle_inference"), NDRectangle), - (fixture_ref("line_inference"), NDPolyline), - (fixture_ref("entity_inference"), NDTextEntity), - (fixture_ref("segmentation_inference"), NDMask), - (fixture_ref("segmentation_inference_rle"), NDMask), - (fixture_ref("segmentation_inference_png"), NDMask), - ], -) -def test_tool_construction(inference, expected_type): - assert isinstance(NDTool.build(inference[0]), expected_type) - - -def no_tool(text_inference, module_project): - pred = text_inference[0].copy() - # Missing key - del pred["answer"] - with pytest.raises(MALValidationError): - _validate_ndjson([pred], module_project) - - -@pytest.mark.parametrize("configured_project", [MediaType.Text], indirect=True) -def test_invalid_text(text_inference, configured_project): - # and if it is not a string - pred = text_inference[0].copy() - # Extra and wrong key - del pred["answer"] - pred["answers"] = [] - with pytest.raises(MALValidationError): - _validate_ndjson([pred], configured_project) - del pred["answers"] - - # Invalid type - pred["answer"] = [] - with pytest.raises(MALValidationError): - _validate_ndjson([pred], configured_project) - - # Invalid type - pred["answer"] = None - with pytest.raises(MALValidationError): - _validate_ndjson([pred], configured_project) - - -def test_invalid_checklist_item(checklist_inference, module_project): - # Only two points - pred = checklist_inference[0].copy() - pred["answers"] = [pred["answers"][0], pred["answers"][0]] - # Duplicate schema ids - with pytest.raises(MALValidationError): - _validate_ndjson([pred], module_project) - - pred["answers"] = [{"name": "asdfg"}] - with pytest.raises(MALValidationError): - _validate_ndjson([pred], module_project) - - pred["answers"] = [{"schemaId": "1232132132"}] - with pytest.raises(MALValidationError): - _validate_ndjson([pred], module_project) - - pred["answers"] = [{}] - with pytest.raises(MALValidationError): - _validate_ndjson([pred], module_project) - - pred["answers"] = [] - with pytest.raises(MALValidationError): - _validate_ndjson([pred], module_project) - - del pred["answers"] - with pytest.raises(MALValidationError): - _validate_ndjson([pred], module_project) - - -def test_invalid_polygon(polygon_inference, module_project): - # Only two points - pred = polygon_inference[0].copy() - pred["polygon"] = [{"x": 100, "y": 100}, {"x": 200, "y": 200}] - with pytest.raises(MALValidationError): - _validate_ndjson([pred], module_project) - - -@pytest.mark.parametrize("configured_project", [MediaType.Text], indirect=True) -def test_incorrect_entity(entity_inference, configured_project): - entity = entity_inference[0].copy() - # Location cannot be a list - entity["location"] = [0, 10] - with pytest.raises(MALValidationError): - _validate_ndjson([entity], configured_project) - - entity["location"] = {"start": -1, "end": 5} - with pytest.raises(MALValidationError): - _validate_ndjson([entity], configured_project) - - entity["location"] = {"start": 15, "end": 5} - with pytest.raises(MALValidationError): - _validate_ndjson([entity], configured_project) - - -@pytest.mark.skip( - "Test wont work/fails randomly since projects have to have a media type and could be missing features from prediction list" -) -def test_all_validate_json(module_project, predictions): - # Predictions contains one of each type of prediction. - # These should be properly formatted and pass. - _validate_ndjson(predictions[0], module_project) - - -def test_incorrect_line(line_inference, module_project): - line = line_inference[0].copy() - line["line"] = [line["line"][0]] # Just one point - with pytest.raises(MALValidationError): - _validate_ndjson([line], module_project) - - -def test_incorrect_rectangle(rectangle_inference, module_project): - del rectangle_inference[0]["bbox"]["top"] - with pytest.raises(MALValidationError): - _validate_ndjson([rectangle_inference], module_project) - - -def test_duplicate_tools(rectangle_inference, module_project): - pred = rectangle_inference[0].copy() - pred["polygon"] = [{"x": 100, "y": 100}, {"x": 200, "y": 200}] - with pytest.raises(MALValidationError): - _validate_ndjson([pred], module_project) - - -def test_invalid_feature_schema(module_project, rectangle_inference): - pred = rectangle_inference[0].copy() - pred["schemaId"] = "blahblah" - with pytest.raises(MALValidationError): - _validate_ndjson([pred], module_project) - - -def test_name_only_feature_schema(module_project, rectangle_inference): - pred = rectangle_inference[0].copy() - _validate_ndjson([pred], module_project) - - -def test_schema_id_only_feature_schema(module_project, rectangle_inference): - pred = rectangle_inference[0].copy() - del pred["name"] - ontology = module_project.ontology().normalized["tools"] - for tool in ontology: - if tool["name"] == "bbox": - feature_schema_id = tool["featureSchemaId"] - pred["schemaId"] = feature_schema_id - _validate_ndjson([pred], module_project) - - -def test_missing_feature_schema(module_project, rectangle_inference): - pred = rectangle_inference[0].copy() - del pred["name"] - with pytest.raises(MALValidationError): - _validate_ndjson([pred], module_project) - - -def test_validate_ndjson(tmp_path, configured_project): - file_name = f"broken.ndjson" - file_path = tmp_path / file_name - with file_path.open("w") as f: - f.write("test") - - with pytest.raises(ValueError): - configured_project.upload_annotations( - name="name", annotations=str(file_path), validate=True - ) - - -def test_validate_ndjson_uuid(tmp_path, configured_project, predictions): - file_name = f"repeat_uuid.ndjson" - file_path = tmp_path / file_name - repeat_uuid = predictions.copy() - repeat_uuid[0]["uuid"] = "test_uuid" - repeat_uuid[1]["uuid"] = "test_uuid" - - with file_path.open("w") as f: - parser.dump(repeat_uuid, f) - - with pytest.raises(MALValidationError): - configured_project.upload_annotations( - name="name", validate=True, annotations=str(file_path) - ) - - with pytest.raises(MALValidationError): - configured_project.upload_annotations( - name="name", validate=True, annotations=repeat_uuid - ) - - -@pytest.mark.parametrize("configured_project", [MediaType.Video], indirect=True) -def test_video_upload(video_checklist_inference, configured_project): - pred = video_checklist_inference[0].copy() - _validate_ndjson([pred], configured_project) diff --git a/libs/labelbox/tests/data/annotation_types/data/test_raster.py b/libs/labelbox/tests/data/annotation_types/data/test_raster.py index 6bc8f2bbf..209419aed 100644 --- a/libs/labelbox/tests/data/annotation_types/data/test_raster.py +++ b/libs/labelbox/tests/data/annotation_types/data/test_raster.py @@ -5,34 +5,28 @@ import pytest from PIL import Image -from labelbox.data.annotation_types.data import ImageData +from labelbox.data.annotation_types.data import GenericDataRowData, MaskData from pydantic import ValidationError def test_validate_schema(): with pytest.raises(ValidationError): - data = ImageData() + MaskData() def test_im_bytes(): data = (np.random.random((32, 32, 3)) * 255).astype(np.uint8) im_bytes = BytesIO() Image.fromarray(data).save(im_bytes, format="PNG") - raster_data = ImageData(im_bytes=im_bytes.getvalue()) + raster_data = MaskData(im_bytes=im_bytes.getvalue()) data_ = raster_data.value assert np.all(data == data_) def test_im_url(): - raster_data = ImageData(url="https://picsum.photos/id/829/200/300") - data_ = raster_data.value - assert data_.shape == (300, 200, 3) - - -def test_im_path(): - img_path = "/tmp/img.jpg" - urllib.request.urlretrieve("https://picsum.photos/id/829/200/300", img_path) - raster_data = ImageData(file_path=img_path) + raster_data = MaskData( + uid="test", url="https://picsum.photos/id/829/200/300" + ) data_ = raster_data.value assert data_.shape == (300, 200, 3) @@ -42,14 +36,11 @@ def test_ref(): uid = "uid" metadata = [] media_attributes = {} - data = ImageData( - im_bytes=b"", - external_id=external_id, + data = GenericDataRowData( uid=uid, metadata=metadata, media_attributes=media_attributes, ) - assert data.external_id == external_id assert data.uid == uid assert data.media_attributes == media_attributes assert data.metadata == metadata diff --git a/libs/labelbox/tests/data/annotation_types/data/test_text.py b/libs/labelbox/tests/data/annotation_types/data/test_text.py deleted file mode 100644 index 865f93e65..000000000 --- a/libs/labelbox/tests/data/annotation_types/data/test_text.py +++ /dev/null @@ -1,55 +0,0 @@ -import os - -import pytest - -from labelbox.data.annotation_types import TextData -from pydantic import ValidationError - - -def test_validate_schema(): - with pytest.raises(ValidationError): - data = TextData() - - -def test_text(): - text = "hello world" - metadata = [] - media_attributes = {} - text_data = TextData( - text=text, metadata=metadata, media_attributes=media_attributes - ) - assert text_data.text == text - - -def test_url(): - url = "https://storage.googleapis.com/lb-artifacts-testing-public/sdk_integration_test/sample3.txt" - text_data = TextData(url=url) - text = text_data.value - assert len(text) == 3541 - - -def test_file(tmpdir): - content = "foo bar baz" - file = "hello.txt" - dir = tmpdir.mkdir("data") - dir.join(file).write(content) - text_data = TextData(file_path=os.path.join(dir.strpath, file)) - assert len(text_data.value) == len(content) - - -def test_ref(): - external_id = "external_id" - uid = "uid" - metadata = [] - media_attributes = {} - data = TextData( - text="hello world", - external_id=external_id, - uid=uid, - metadata=metadata, - media_attributes=media_attributes, - ) - assert data.external_id == external_id - assert data.uid == uid - assert data.media_attributes == media_attributes - assert data.metadata == metadata diff --git a/libs/labelbox/tests/data/annotation_types/data/test_video.py b/libs/labelbox/tests/data/annotation_types/data/test_video.py deleted file mode 100644 index 5fd77c2c8..000000000 --- a/libs/labelbox/tests/data/annotation_types/data/test_video.py +++ /dev/null @@ -1,73 +0,0 @@ -import numpy as np -import pytest - -from labelbox.data.annotation_types import VideoData -from pydantic import ValidationError - - -def test_validate_schema(): - with pytest.raises(ValidationError): - data = VideoData() - - -def test_frames(): - data = { - x: (np.random.random((32, 32, 3)) * 255).astype(np.uint8) - for x in range(5) - } - video_data = VideoData(frames=data) - for idx, frame in video_data.frame_generator(): - assert idx in data - assert np.all(frame == data[idx]) - - -def test_file_path(): - path = "tests/integration/media/cat.mp4" - raster_data = VideoData(file_path=path) - - with pytest.raises(ValueError): - raster_data[0] - - raster_data.load_frames() - raster_data[0] - - frame_indices = list(raster_data.frames.keys()) - # 29 frames - assert set(frame_indices) == set(list(range(28))) - - -def test_file_url(): - url = "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/ForBiggerMeltdowns.mp4" - raster_data = VideoData(url=url) - - with pytest.raises(ValueError): - raster_data[0] - - raster_data.load_frames() - raster_data[0] - - frame_indices = list(raster_data.frames.keys()) - # 362 frames - assert set(frame_indices) == set(list(range(361))) - - -def test_ref(): - external_id = "external_id" - uid = "uid" - data = { - x: (np.random.random((32, 32, 3)) * 255).astype(np.uint8) - for x in range(5) - } - metadata = [] - media_attributes = {} - data = VideoData( - frames=data, - external_id=external_id, - uid=uid, - metadata=metadata, - media_attributes=media_attributes, - ) - assert data.external_id == external_id - assert data.uid == uid - assert data.media_attributes == media_attributes - assert data.metadata == metadata diff --git a/libs/labelbox/tests/data/annotation_types/test_annotation.py b/libs/labelbox/tests/data/annotation_types/test_annotation.py index 8cdeac9ba..01547bd56 100644 --- a/libs/labelbox/tests/data/annotation_types/test_annotation.py +++ b/libs/labelbox/tests/data/annotation_types/test_annotation.py @@ -1,18 +1,20 @@ import pytest +from lbox.exceptions import ConfidenceNotSupportedException +from pydantic import ValidationError from labelbox.data.annotation_types import ( - Text, - Point, - Line, ClassificationAnnotation, + Line, ObjectAnnotation, + Point, + Text, TextEntity, ) -from labelbox.data.annotation_types.video import VideoObjectAnnotation from labelbox.data.annotation_types.geometry.rectangle import Rectangle -from labelbox.data.annotation_types.video import VideoClassificationAnnotation -from labelbox.exceptions import ConfidenceNotSupportedException -from pydantic import ValidationError +from labelbox.data.annotation_types.video import ( + VideoClassificationAnnotation, + VideoObjectAnnotation, +) def test_annotation(): diff --git a/libs/labelbox/tests/data/annotation_types/test_collection.py b/libs/labelbox/tests/data/annotation_types/test_collection.py index 9deddc3c8..e0fa7bd53 100644 --- a/libs/labelbox/tests/data/annotation_types/test_collection.py +++ b/libs/labelbox/tests/data/annotation_types/test_collection.py @@ -7,19 +7,21 @@ from labelbox.data.annotation_types import ( LabelGenerator, ObjectAnnotation, - ImageData, - MaskData, Line, Mask, Point, Label, + GenericDataRowData, + MaskData, ) from labelbox import OntologyBuilder, Tool @pytest.fixture def list_of_labels(): - return [Label(data=ImageData(url="http://someurl")) for _ in range(5)] + return [ + Label(data=GenericDataRowData(uid="http://someurl")) for _ in range(5) + ] @pytest.fixture @@ -70,58 +72,9 @@ def test_conversion(list_of_labels): assert [x for x in label_collection] == list_of_labels -def test_adding_schema_ids(): - name = "line_feature" - label = Label( - data=ImageData(arr=np.ones((32, 32, 3), dtype=np.uint8)), - annotations=[ - ObjectAnnotation( - value=Line(points=[Point(x=1, y=2), Point(x=2, y=2)]), - name=name, - ) - ], - ) - feature_schema_id = "expected_id" - ontology = OntologyBuilder( - tools=[ - Tool(Tool.Type.LINE, name=name, feature_schema_id=feature_schema_id) - ] - ) - generator = LabelGenerator([label]).assign_feature_schema_ids(ontology) - assert next(generator).annotations[0].feature_schema_id == feature_schema_id - - -def test_adding_urls(signer): - label = Label( - data=ImageData(arr=np.random.random((32, 32, 3)).astype(np.uint8)), - annotations=[], - ) - uuid = str(uuid4()) - generator = LabelGenerator([label]).add_url_to_data(signer(uuid)) - assert label.data.url != uuid - assert next(generator).data.url == uuid - assert label.data.url == uuid - - -def test_adding_to_dataset(signer): - dataset = FakeDataset() - label = Label( - data=ImageData(arr=np.random.random((32, 32, 3)).astype(np.uint8)), - annotations=[], - ) - uuid = str(uuid4()) - generator = LabelGenerator([label]).add_to_dataset(dataset, signer(uuid)) - assert label.data.url != uuid - generated_label = next(generator) - assert generated_label.data.url == uuid - assert generated_label.data.external_id != None - assert generated_label.data.uid == dataset.uid - assert label.data.url == uuid - - def test_adding_to_masks(signer): label = Label( - data=ImageData(arr=np.random.random((32, 32, 3)).astype(np.uint8)), + data=GenericDataRowData(uid="12345"), annotations=[ ObjectAnnotation( name="1234", diff --git a/libs/labelbox/tests/data/annotation_types/test_label.py b/libs/labelbox/tests/data/annotation_types/test_label.py index 5bdfb6bde..9cd992b0c 100644 --- a/libs/labelbox/tests/data/annotation_types/test_label.py +++ b/libs/labelbox/tests/data/annotation_types/test_label.py @@ -17,220 +17,16 @@ ObjectAnnotation, Point, Line, - ImageData, + MaskData, Label, ) import pytest -def test_schema_assignment_geometry(): - name = "line_feature" - label = Label( - data=ImageData(arr=np.ones((32, 32, 3), dtype=np.uint8)), - annotations=[ - ObjectAnnotation( - value=Line(points=[Point(x=1, y=2), Point(x=2, y=2)]), - name=name, - ) - ], - ) - feature_schema_id = "expected_id" - ontology = OntologyBuilder( - tools=[ - Tool(Tool.Type.LINE, name=name, feature_schema_id=feature_schema_id) - ] - ) - label.assign_feature_schema_ids(ontology) - - assert label.annotations[0].feature_schema_id == feature_schema_id - - -def test_schema_assignment_classification(): - radio_name = "radio_name" - text_name = "text_name" - option_name = "my_option" - - label = Label( - data=ImageData(arr=np.ones((32, 32, 3), dtype=np.uint8)), - annotations=[ - ClassificationAnnotation( - value=Radio(answer=ClassificationAnswer(name=option_name)), - name=radio_name, - ), - ClassificationAnnotation( - value=Text(answer="some text"), name=text_name - ), - ], - ) - radio_schema_id = "radio_schema_id" - text_schema_id = "text_schema_id" - option_schema_id = "option_schema_id" - ontology = OntologyBuilder( - tools=[], - classifications=[ - OClassification( - class_type=OClassification.Type.RADIO, - name=radio_name, - feature_schema_id=radio_schema_id, - options=[ - Option( - value=option_name, feature_schema_id=option_schema_id - ) - ], - ), - OClassification( - class_type=OClassification.Type.TEXT, - name=text_name, - feature_schema_id=text_schema_id, - ), - ], - ) - label.assign_feature_schema_ids(ontology) - assert label.annotations[0].feature_schema_id == radio_schema_id - assert label.annotations[1].feature_schema_id == text_schema_id - assert ( - label.annotations[0].value.answer.feature_schema_id == option_schema_id - ) - - -def test_schema_assignment_subclass(): - name = "line_feature" - radio_name = "radio_name" - option_name = "my_option" - classification = ClassificationAnnotation( - name=radio_name, - value=Radio(answer=ClassificationAnswer(name=option_name)), - ) - label = Label( - data=ImageData(arr=np.ones((32, 32, 3), dtype=np.uint8)), - annotations=[ - ObjectAnnotation( - value=Line(points=[Point(x=1, y=2), Point(x=2, y=2)]), - name=name, - classifications=[classification], - ) - ], - ) - feature_schema_id = "expected_id" - classification_schema_id = "classification_id" - option_schema_id = "option_schema_id" - ontology = OntologyBuilder( - tools=[ - Tool( - Tool.Type.LINE, - name=name, - feature_schema_id=feature_schema_id, - classifications=[ - OClassification( - class_type=OClassification.Type.RADIO, - name=radio_name, - feature_schema_id=classification_schema_id, - options=[ - Option( - value=option_name, - feature_schema_id=option_schema_id, - ) - ], - ) - ], - ) - ] - ) - label.assign_feature_schema_ids(ontology) - assert label.annotations[0].feature_schema_id == feature_schema_id - assert ( - label.annotations[0].classifications[0].feature_schema_id - == classification_schema_id - ) - assert ( - label.annotations[0].classifications[0].value.answer.feature_schema_id - == option_schema_id - ) - - -def test_highly_nested(): - name = "line_feature" - radio_name = "radio_name" - nested_name = "nested_name" - option_name = "my_option" - nested_option_name = "nested_option_name" - classification = ClassificationAnnotation( - name=radio_name, - value=Radio(answer=ClassificationAnswer(name=option_name)), - classifications=[ - ClassificationAnnotation( - value=Radio( - answer=ClassificationAnswer(name=nested_option_name) - ), - name=nested_name, - ) - ], - ) - label = Label( - data=ImageData(arr=np.ones((32, 32, 3), dtype=np.uint8)), - annotations=[ - ObjectAnnotation( - value=Line(points=[Point(x=1, y=2), Point(x=2, y=2)]), - name=name, - classifications=[classification], - ) - ], - ) - feature_schema_id = "expected_id" - classification_schema_id = "classification_id" - nested_classification_schema_id = "nested_classification_schema_id" - option_schema_id = "option_schema_id" - ontology = OntologyBuilder( - tools=[ - Tool( - Tool.Type.LINE, - name=name, - feature_schema_id=feature_schema_id, - classifications=[ - OClassification( - class_type=OClassification.Type.RADIO, - name=radio_name, - feature_schema_id=classification_schema_id, - options=[ - Option( - value=option_name, - feature_schema_id=option_schema_id, - options=[ - OClassification( - class_type=OClassification.Type.RADIO, - name=nested_name, - feature_schema_id=nested_classification_schema_id, - options=[ - Option( - value=nested_option_name, - feature_schema_id=nested_classification_schema_id, - ) - ], - ) - ], - ) - ], - ) - ], - ) - ] - ) - label.assign_feature_schema_ids(ontology) - assert label.annotations[0].feature_schema_id == feature_schema_id - assert ( - label.annotations[0].classifications[0].feature_schema_id - == classification_schema_id - ) - assert ( - label.annotations[0].classifications[0].value.answer.feature_schema_id - == option_schema_id - ) - - def test_schema_assignment_confidence(): name = "line_feature" label = Label( - data=ImageData(arr=np.ones((32, 32, 3), dtype=np.uint8)), + data=MaskData(arr=np.ones((32, 32, 3), dtype=np.uint8), uid="test"), annotations=[ ObjectAnnotation( value=Line( @@ -252,10 +48,10 @@ def test_initialize_label_no_coercion(): value=lb_types.ConversationEntity(start=0, end=8, message_id="4"), ) label = Label( - data=lb_types.ConversationData(global_key=global_key), + data=lb_types.GenericDataRowData(global_key=global_key), annotations=[ner_annotation], ) - assert isinstance(label.data, lb_types.ConversationData) + assert isinstance(label.data, lb_types.GenericDataRowData) assert label.data.global_key == global_key diff --git a/libs/labelbox/tests/data/annotation_types/test_metrics.py b/libs/labelbox/tests/data/annotation_types/test_metrics.py index 94c9521a5..4e9355573 100644 --- a/libs/labelbox/tests/data/annotation_types/test_metrics.py +++ b/libs/labelbox/tests/data/annotation_types/test_metrics.py @@ -8,7 +8,11 @@ ConfusionMatrixMetric, ScalarMetric, ) -from labelbox.data.annotation_types import ScalarMetric, Label, ImageData +from labelbox.data.annotation_types import ( + ScalarMetric, + Label, + GenericDataRowData, +) from labelbox.data.annotation_types.metrics.scalar import RESERVED_METRIC_NAMES from pydantic import ValidationError @@ -19,7 +23,8 @@ def test_legacy_scalar_metric(): assert metric.value == value label = Label( - data=ImageData(uid="ckrmd9q8g000009mg6vej7hzg"), annotations=[metric] + data=GenericDataRowData(uid="ckrmd9q8g000009mg6vej7hzg"), + annotations=[metric], ) expected = { "data": { @@ -72,7 +77,8 @@ def test_custom_scalar_metric(feature_name, subclass_name, aggregation, value): assert metric.value == value label = Label( - data=ImageData(uid="ckrmd9q8g000009mg6vej7hzg"), annotations=[metric] + data=GenericDataRowData(uid="ckrmd9q8g000009mg6vej7hzg"), + annotations=[metric], ) expected = { "data": { @@ -134,7 +140,8 @@ def test_custom_confusison_matrix_metric( assert metric.value == value label = Label( - data=ImageData(uid="ckrmd9q8g000009mg6vej7hzg"), annotations=[metric] + data=GenericDataRowData(uid="ckrmd9q8g000009mg6vej7hzg"), + annotations=[metric], ) expected = { "data": { diff --git a/libs/labelbox/tests/data/assets/ndjson/classification_import_global_key.json b/libs/labelbox/tests/data/assets/ndjson/classification_import_global_key.json deleted file mode 100644 index 4de15e217..000000000 --- a/libs/labelbox/tests/data/assets/ndjson/classification_import_global_key.json +++ /dev/null @@ -1,54 +0,0 @@ -[ - { - "answer": { - "schemaId": "ckrb1sfl8099g0y91cxbd5ftb", - "confidence": 0.8, - "customMetrics": [ - { - "name": "customMetric1", - "value": 0.5 - }, - { - "name": "customMetric2", - "value": 0.3 - } - ] - }, - "schemaId": "ckrb1sfl8099g0y91cxbd5ftb", - "dataRow": { - "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d" - }, - "uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673" - }, - { - "answer": [ - { - "schemaId": "ckrb1sfl8099e0y919v260awv", - "confidence": 0.82, - "customMetrics": [ - { - "name": "customMetric1", - "value": 0.5 - }, - { - "name": "customMetric2", - "value": 0.3 - } - ] - } - ], - "schemaId": "ckrb1sfkn099c0y910wbo0p1a", - "dataRow": { - "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d" - }, - "uuid": "d009925d-91a3-4f67-abd9-753453f5a584" - }, - { - "answer": "a value", - "schemaId": "ckrb1sfkn099c0y910wbo0p1a", - "dataRow": { - "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d" - }, - "uuid": "ee70fd88-9f88-48dd-b760-7469ff479b71" - } -] \ No newline at end of file diff --git a/libs/labelbox/tests/data/assets/ndjson/conversation_entity_import_global_key.json b/libs/labelbox/tests/data/assets/ndjson/conversation_entity_import_global_key.json deleted file mode 100644 index 83a95e5bf..000000000 --- a/libs/labelbox/tests/data/assets/ndjson/conversation_entity_import_global_key.json +++ /dev/null @@ -1,25 +0,0 @@ -[{ - "location": { - "start": 67, - "end": 128 - }, - "messageId": "some-message-id", - "uuid": "5ad9c52f-058d-49c8-a749-3f20b84f8cd4", - "dataRow": { - "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d" - }, - "name": "some-text-entity", - "schemaId": "cl6xnuwt95lqq07330tbb3mfd", - "classifications": [], - "confidence": 0.53, - "customMetrics": [ - { - "name": "customMetric1", - "value": 0.5 - }, - { - "name": "customMetric2", - "value": 0.3 - } - ] -}] diff --git a/libs/labelbox/tests/data/assets/ndjson/image_import.json b/libs/labelbox/tests/data/assets/ndjson/image_import.json index 91563b8ae..75fe36e44 100644 --- a/libs/labelbox/tests/data/assets/ndjson/image_import.json +++ b/libs/labelbox/tests/data/assets/ndjson/image_import.json @@ -8,16 +8,17 @@ "confidence": 0.851, "customMetrics": [ { - "name": "customMetric1", - "value": 0.4 + "name": "customMetric1", + "value": 0.4 } ], "bbox": { - "top": 1352, - "left": 2275, - "height": 350, - "width": 139 - } + "top": 1352.0, + "left": 2275.0, + "height": 350.0, + "width": 139.0 + }, + "classifications": [] }, { "uuid": "751fc725-f7b6-48ed-89b0-dd7d94d08af6", @@ -28,20 +29,17 @@ "confidence": 0.834, "customMetrics": [ { - "name": "customMetric1", - "value": 0.3 + "name": "customMetric1", + "value": 0.3 } ], "mask": { - "instanceURI": "https://storage.labelbox.com/ckqcx1czn06830y61gh9v02cs%2F3e729327-f038-f66c-186e-45e921ef9717-1?Expires=1626806874672&KeyName=labelbox-assets-key-3&Signature=YsUOGKrsqmAZ68vT9BlPJOaRyLY", - "colorRGB": [ - 255, - 0, - 0 - ] - } + "instanceURI": "https://storage.labelbox.com/ckqcx1czn06830y61gh9v02cs%2F3e729327-f038-f66c-186e-45e921ef9717-1?Expires=1626806874672&KeyName=labelbox-assets-key-3&Signature=YsUOGKrsqmAZ68vT9BlPJOaRyLY" + }, + "classifications": [] }, { + "classifications": [], "uuid": "43d719ac-5d7f-4aea-be00-2ebfca0900fd", "schemaId": "ckrazcuec16oi0z66dzrd8pfl", "dataRow": { @@ -50,762 +48,39 @@ "confidence": 0.986, "customMetrics": [ { - "name": "customMetric1", - "value": 0.9 + "name": "customMetric1", + "value": 0.9 } ], "polygon": [ { - "x": 1118, - "y": 935 - }, - { - "x": 1117, - "y": 935 - }, - { - "x": 1116, - "y": 935 - }, - { - "x": 1115, - "y": 935 - }, - { - "x": 1114, - "y": 935 - }, - { - "x": 1113, - "y": 935 - }, - { - "x": 1112, - "y": 935 - }, - { - "x": 1111, - "y": 935 - }, - { - "x": 1110, - "y": 935 - }, - { - "x": 1109, - "y": 935 - }, - { - "x": 1108, - "y": 935 - }, - { - "x": 1108, - "y": 934 - }, - { - "x": 1107, - "y": 934 - }, - { - "x": 1106, - "y": 934 - }, - { - "x": 1105, - "y": 934 - }, - { - "x": 1105, - "y": 933 - }, - { - "x": 1104, - "y": 933 - }, - { - "x": 1103, - "y": 933 - }, - { - "x": 1103, - "y": 932 - }, - { - "x": 1102, - "y": 932 - }, - { - "x": 1101, - "y": 932 - }, - { - "x": 1100, - "y": 932 - }, - { - "x": 1099, - "y": 932 - }, - { - "x": 1098, - "y": 932 - }, - { - "x": 1097, - "y": 932 - }, - { - "x": 1097, - "y": 931 - }, - { - "x": 1096, - "y": 931 - }, - { - "x": 1095, - "y": 931 - }, - { - "x": 1094, - "y": 931 - }, - { - "x": 1093, - "y": 931 - }, - { - "x": 1092, - "y": 931 - }, - { - "x": 1091, - "y": 931 - }, - { - "x": 1090, - "y": 931 - }, - { - "x": 1090, - "y": 930 - }, - { - "x": 1089, - "y": 930 - }, - { - "x": 1088, - "y": 930 - }, - { - "x": 1087, - "y": 930 - }, - { - "x": 1087, - "y": 929 - }, - { - "x": 1086, - "y": 929 - }, - { - "x": 1085, - "y": 929 - }, - { - "x": 1084, - "y": 929 - }, - { - "x": 1084, - "y": 928 - }, - { - "x": 1083, - "y": 928 - }, - { - "x": 1083, - "y": 927 - }, - { - "x": 1082, - "y": 927 - }, - { - "x": 1081, - "y": 927 - }, - { - "x": 1081, - "y": 926 - }, - { - "x": 1080, - "y": 926 - }, - { - "x": 1080, - "y": 925 - }, - { - "x": 1079, - "y": 925 - }, - { - "x": 1078, - "y": 925 - }, - { - "x": 1078, - "y": 924 - }, - { - "x": 1077, - "y": 924 - }, - { - "x": 1076, - "y": 924 - }, - { - "x": 1076, - "y": 923 - }, - { - "x": 1075, - "y": 923 - }, - { - "x": 1074, - "y": 923 - }, - { - "x": 1073, - "y": 923 - }, - { - "x": 1073, - "y": 922 - }, - { - "x": 1072, - "y": 922 - }, - { - "x": 1071, - "y": 922 - }, - { - "x": 1070, - "y": 922 - }, - { - "x": 1070, - "y": 921 - }, - { - "x": 1069, - "y": 921 - }, - { - "x": 1068, - "y": 921 - }, - { - "x": 1067, - "y": 921 - }, - { - "x": 1066, - "y": 921 - }, - { - "x": 1065, - "y": 921 - }, - { - "x": 1064, - "y": 921 - }, - { - "x": 1063, - "y": 921 - }, - { - "x": 1062, - "y": 921 - }, - { - "x": 1061, - "y": 921 - }, - { - "x": 1060, - "y": 921 - }, - { - "x": 1059, - "y": 921 - }, - { - "x": 1058, - "y": 921 - }, - { - "x": 1058, - "y": 920 - }, - { - "x": 1057, - "y": 920 - }, - { - "x": 1057, - "y": 919 - }, - { - "x": 1056, - "y": 919 - }, - { - "x": 1057, - "y": 918 - }, - { - "x": 1057, - "y": 918 - }, - { - "x": 1057, - "y": 917 - }, - { - "x": 1058, - "y": 916 - }, - { - "x": 1058, - "y": 916 - }, - { - "x": 1059, - "y": 915 - }, - { - "x": 1059, - "y": 915 - }, - { - "x": 1060, - "y": 914 - }, - { - "x": 1060, - "y": 914 - }, - { - "x": 1061, - "y": 913 - }, - { - "x": 1061, - "y": 913 - }, - { - "x": 1062, - "y": 912 - }, - { - "x": 1063, - "y": 912 - }, - { - "x": 1063, - "y": 912 - }, - { - "x": 1064, - "y": 911 - }, - { - "x": 1064, - "y": 911 - }, - { - "x": 1065, - "y": 910 - }, - { - "x": 1066, - "y": 910 - }, - { - "x": 1066, - "y": 910 - }, - { - "x": 1067, - "y": 909 - }, - { - "x": 1068, - "y": 909 - }, - { - "x": 1068, - "y": 909 - }, - { - "x": 1069, - "y": 908 - }, - { - "x": 1070, - "y": 908 - }, - { - "x": 1071, - "y": 908 - }, - { - "x": 1072, - "y": 908 - }, - { - "x": 1072, - "y": 908 - }, - { - "x": 1073, - "y": 907 - }, - { - "x": 1074, - "y": 907 - }, - { - "x": 1075, - "y": 907 - }, - { - "x": 1076, - "y": 907 - }, - { - "x": 1077, - "y": 907 - }, - { - "x": 1078, - "y": 907 - }, - { - "x": 1079, - "y": 907 - }, - { - "x": 1080, - "y": 907 - }, - { - "x": 1081, - "y": 907 - }, - { - "x": 1082, - "y": 907 - }, - { - "x": 1083, - "y": 907 - }, - { - "x": 1084, - "y": 907 - }, - { - "x": 1085, - "y": 907 - }, - { - "x": 1086, - "y": 907 - }, - { - "x": 1087, - "y": 907 - }, - { - "x": 1088, - "y": 907 - }, - { - "x": 1089, - "y": 907 - }, - { - "x": 1090, - "y": 907 - }, - { - "x": 1091, - "y": 907 - }, - { - "x": 1091, - "y": 908 - }, - { - "x": 1092, - "y": 908 - }, - { - "x": 1093, - "y": 908 - }, - { - "x": 1094, - "y": 908 - }, - { - "x": 1095, - "y": 908 - }, - { - "x": 1095, - "y": 909 - }, - { - "x": 1096, - "y": 909 - }, - { - "x": 1097, - "y": 909 - }, - { - "x": 1097, - "y": 910 - }, - { - "x": 1098, - "y": 910 - }, - { - "x": 1099, - "y": 910 - }, - { - "x": 1099, - "y": 911 - }, - { - "x": 1100, - "y": 911 - }, - { - "x": 1101, - "y": 911 - }, - { - "x": 1101, - "y": 912 - }, - { - "x": 1102, - "y": 912 - }, - { - "x": 1103, - "y": 912 - }, - { - "x": 1103, - "y": 913 - }, - { - "x": 1104, - "y": 913 - }, - { - "x": 1104, - "y": 914 - }, - { - "x": 1105, - "y": 914 - }, - { - "x": 1105, - "y": 915 - }, - { - "x": 1106, - "y": 915 - }, - { - "x": 1107, - "y": 915 - }, - { - "x": 1107, - "y": 916 - }, - { - "x": 1108, - "y": 916 - }, - { - "x": 1108, - "y": 917 - }, - { - "x": 1109, - "y": 917 - }, - { - "x": 1109, - "y": 918 - }, - { - "x": 1110, - "y": 918 - }, - { - "x": 1110, - "y": 919 - }, - { - "x": 1111, - "y": 919 - }, - { - "x": 1111, - "y": 920 - }, - { - "x": 1112, - "y": 920 - }, - { - "x": 1112, - "y": 921 - }, - { - "x": 1113, - "y": 921 - }, - { - "x": 1113, - "y": 922 - }, - { - "x": 1114, - "y": 922 - }, - { - "x": 1114, - "y": 923 - }, - { - "x": 1115, - "y": 923 - }, - { - "x": 1115, - "y": 924 - }, - { - "x": 1115, - "y": 925 - }, - { - "x": 1116, - "y": 925 - }, - { - "x": 1116, - "y": 926 - }, - { - "x": 1117, - "y": 926 - }, - { - "x": 1117, - "y": 927 - }, - { - "x": 1117, - "y": 928 - }, - { - "x": 1118, - "y": 928 - }, - { - "x": 1118, - "y": 929 - }, - { - "x": 1119, - "y": 929 - }, - { - "x": 1119, - "y": 930 - }, - { - "x": 1120, - "y": 930 - }, - { - "x": 1120, - "y": 931 - }, - { - "x": 1120, - "y": 932 - }, - { - "x": 1120, - "y": 932 - }, - { - "x": 1119, - "y": 933 - }, - { - "x": 1119, - "y": 934 + "x": 10.0, + "y": 20.0 }, { - "x": 1119, - "y": 934 + "x": 15.0, + "y": 20.0 }, { - "x": 1118, - "y": 935 + "x": 20.0, + "y": 25.0 }, { - "x": 1118, - "y": 935 + "x": 10.0, + "y": 20.0 } ] }, { + "classifications": [], "uuid": "b98f3a45-3328-41a0-9077-373a8177ebf2", "schemaId": "ckrazcuec16om0z66bhhh4tp7", "dataRow": { "id": "ckrazctum0z8a0ybc0b0o0g0v" }, "point": { - "x": 2122, - "y": 1457 + "x": 2122.0, + "y": 1457.0 } } ] \ No newline at end of file diff --git a/libs/labelbox/tests/data/assets/ndjson/image_import_global_key.json b/libs/labelbox/tests/data/assets/ndjson/image_import_global_key.json deleted file mode 100644 index 591e40cf6..000000000 --- a/libs/labelbox/tests/data/assets/ndjson/image_import_global_key.json +++ /dev/null @@ -1,823 +0,0 @@ -[ - { - "uuid": "b862c586-8614-483c-b5e6-82810f70cac0", - "schemaId": "ckrazcueb16og0z6609jj7y3y", - "dataRow": { - "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d" - }, - "confidence": 0.851, - "bbox": { - "top": 1352, - "left": 2275, - "height": 350, - "width": 139 - }, - "customMetrics": [ - { - "name": "customMetric1", - "value": 0.5 - }, - { - "name": "customMetric2", - "value": 0.3 - } - ] - }, - { - "uuid": "751fc725-f7b6-48ed-89b0-dd7d94d08af6", - "schemaId": "ckrazcuec16ok0z66f956apb7", - "dataRow": { - "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d" - }, - "confidence": 0.834, - "customMetrics": [ - { - "name": "customMetric1", - "value": 0.5 - }, - { - "name": "customMetric2", - "value": 0.3 - } - ], - "mask": { - "instanceURI": "https://storage.labelbox.com/ckqcx1czn06830y61gh9v02cs%2F3e729327-f038-f66c-186e-45e921ef9717-1?Expires=1626806874672&KeyName=labelbox-assets-key-3&Signature=YsUOGKrsqmAZ68vT9BlPJOaRyLY", - "colorRGB": [ - 255, - 0, - 0 - ] - } - }, - { - "uuid": "43d719ac-5d7f-4aea-be00-2ebfca0900fd", - "schemaId": "ckrazcuec16oi0z66dzrd8pfl", - "dataRow": { - "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d" - }, - "confidence": 0.986, - "customMetrics": [ - { - "name": "customMetric1", - "value": 0.5 - }, - { - "name": "customMetric2", - "value": 0.3 - } - ], - "polygon": [ - { - "x": 1118, - "y": 935 - }, - { - "x": 1117, - "y": 935 - }, - { - "x": 1116, - "y": 935 - }, - { - "x": 1115, - "y": 935 - }, - { - "x": 1114, - "y": 935 - }, - { - "x": 1113, - "y": 935 - }, - { - "x": 1112, - "y": 935 - }, - { - "x": 1111, - "y": 935 - }, - { - "x": 1110, - "y": 935 - }, - { - "x": 1109, - "y": 935 - }, - { - "x": 1108, - "y": 935 - }, - { - "x": 1108, - "y": 934 - }, - { - "x": 1107, - "y": 934 - }, - { - "x": 1106, - "y": 934 - }, - { - "x": 1105, - "y": 934 - }, - { - "x": 1105, - "y": 933 - }, - { - "x": 1104, - "y": 933 - }, - { - "x": 1103, - "y": 933 - }, - { - "x": 1103, - "y": 932 - }, - { - "x": 1102, - "y": 932 - }, - { - "x": 1101, - "y": 932 - }, - { - "x": 1100, - "y": 932 - }, - { - "x": 1099, - "y": 932 - }, - { - "x": 1098, - "y": 932 - }, - { - "x": 1097, - "y": 932 - }, - { - "x": 1097, - "y": 931 - }, - { - "x": 1096, - "y": 931 - }, - { - "x": 1095, - "y": 931 - }, - { - "x": 1094, - "y": 931 - }, - { - "x": 1093, - "y": 931 - }, - { - "x": 1092, - "y": 931 - }, - { - "x": 1091, - "y": 931 - }, - { - "x": 1090, - "y": 931 - }, - { - "x": 1090, - "y": 930 - }, - { - "x": 1089, - "y": 930 - }, - { - "x": 1088, - "y": 930 - }, - { - "x": 1087, - "y": 930 - }, - { - "x": 1087, - "y": 929 - }, - { - "x": 1086, - "y": 929 - }, - { - "x": 1085, - "y": 929 - }, - { - "x": 1084, - "y": 929 - }, - { - "x": 1084, - "y": 928 - }, - { - "x": 1083, - "y": 928 - }, - { - "x": 1083, - "y": 927 - }, - { - "x": 1082, - "y": 927 - }, - { - "x": 1081, - "y": 927 - }, - { - "x": 1081, - "y": 926 - }, - { - "x": 1080, - "y": 926 - }, - { - "x": 1080, - "y": 925 - }, - { - "x": 1079, - "y": 925 - }, - { - "x": 1078, - "y": 925 - }, - { - "x": 1078, - "y": 924 - }, - { - "x": 1077, - "y": 924 - }, - { - "x": 1076, - "y": 924 - }, - { - "x": 1076, - "y": 923 - }, - { - "x": 1075, - "y": 923 - }, - { - "x": 1074, - "y": 923 - }, - { - "x": 1073, - "y": 923 - }, - { - "x": 1073, - "y": 922 - }, - { - "x": 1072, - "y": 922 - }, - { - "x": 1071, - "y": 922 - }, - { - "x": 1070, - "y": 922 - }, - { - "x": 1070, - "y": 921 - }, - { - "x": 1069, - "y": 921 - }, - { - "x": 1068, - "y": 921 - }, - { - "x": 1067, - "y": 921 - }, - { - "x": 1066, - "y": 921 - }, - { - "x": 1065, - "y": 921 - }, - { - "x": 1064, - "y": 921 - }, - { - "x": 1063, - "y": 921 - }, - { - "x": 1062, - "y": 921 - }, - { - "x": 1061, - "y": 921 - }, - { - "x": 1060, - "y": 921 - }, - { - "x": 1059, - "y": 921 - }, - { - "x": 1058, - "y": 921 - }, - { - "x": 1058, - "y": 920 - }, - { - "x": 1057, - "y": 920 - }, - { - "x": 1057, - "y": 919 - }, - { - "x": 1056, - "y": 919 - }, - { - "x": 1057, - "y": 918 - }, - { - "x": 1057, - "y": 918 - }, - { - "x": 1057, - "y": 917 - }, - { - "x": 1058, - "y": 916 - }, - { - "x": 1058, - "y": 916 - }, - { - "x": 1059, - "y": 915 - }, - { - "x": 1059, - "y": 915 - }, - { - "x": 1060, - "y": 914 - }, - { - "x": 1060, - "y": 914 - }, - { - "x": 1061, - "y": 913 - }, - { - "x": 1061, - "y": 913 - }, - { - "x": 1062, - "y": 912 - }, - { - "x": 1063, - "y": 912 - }, - { - "x": 1063, - "y": 912 - }, - { - "x": 1064, - "y": 911 - }, - { - "x": 1064, - "y": 911 - }, - { - "x": 1065, - "y": 910 - }, - { - "x": 1066, - "y": 910 - }, - { - "x": 1066, - "y": 910 - }, - { - "x": 1067, - "y": 909 - }, - { - "x": 1068, - "y": 909 - }, - { - "x": 1068, - "y": 909 - }, - { - "x": 1069, - "y": 908 - }, - { - "x": 1070, - "y": 908 - }, - { - "x": 1071, - "y": 908 - }, - { - "x": 1072, - "y": 908 - }, - { - "x": 1072, - "y": 908 - }, - { - "x": 1073, - "y": 907 - }, - { - "x": 1074, - "y": 907 - }, - { - "x": 1075, - "y": 907 - }, - { - "x": 1076, - "y": 907 - }, - { - "x": 1077, - "y": 907 - }, - { - "x": 1078, - "y": 907 - }, - { - "x": 1079, - "y": 907 - }, - { - "x": 1080, - "y": 907 - }, - { - "x": 1081, - "y": 907 - }, - { - "x": 1082, - "y": 907 - }, - { - "x": 1083, - "y": 907 - }, - { - "x": 1084, - "y": 907 - }, - { - "x": 1085, - "y": 907 - }, - { - "x": 1086, - "y": 907 - }, - { - "x": 1087, - "y": 907 - }, - { - "x": 1088, - "y": 907 - }, - { - "x": 1089, - "y": 907 - }, - { - "x": 1090, - "y": 907 - }, - { - "x": 1091, - "y": 907 - }, - { - "x": 1091, - "y": 908 - }, - { - "x": 1092, - "y": 908 - }, - { - "x": 1093, - "y": 908 - }, - { - "x": 1094, - "y": 908 - }, - { - "x": 1095, - "y": 908 - }, - { - "x": 1095, - "y": 909 - }, - { - "x": 1096, - "y": 909 - }, - { - "x": 1097, - "y": 909 - }, - { - "x": 1097, - "y": 910 - }, - { - "x": 1098, - "y": 910 - }, - { - "x": 1099, - "y": 910 - }, - { - "x": 1099, - "y": 911 - }, - { - "x": 1100, - "y": 911 - }, - { - "x": 1101, - "y": 911 - }, - { - "x": 1101, - "y": 912 - }, - { - "x": 1102, - "y": 912 - }, - { - "x": 1103, - "y": 912 - }, - { - "x": 1103, - "y": 913 - }, - { - "x": 1104, - "y": 913 - }, - { - "x": 1104, - "y": 914 - }, - { - "x": 1105, - "y": 914 - }, - { - "x": 1105, - "y": 915 - }, - { - "x": 1106, - "y": 915 - }, - { - "x": 1107, - "y": 915 - }, - { - "x": 1107, - "y": 916 - }, - { - "x": 1108, - "y": 916 - }, - { - "x": 1108, - "y": 917 - }, - { - "x": 1109, - "y": 917 - }, - { - "x": 1109, - "y": 918 - }, - { - "x": 1110, - "y": 918 - }, - { - "x": 1110, - "y": 919 - }, - { - "x": 1111, - "y": 919 - }, - { - "x": 1111, - "y": 920 - }, - { - "x": 1112, - "y": 920 - }, - { - "x": 1112, - "y": 921 - }, - { - "x": 1113, - "y": 921 - }, - { - "x": 1113, - "y": 922 - }, - { - "x": 1114, - "y": 922 - }, - { - "x": 1114, - "y": 923 - }, - { - "x": 1115, - "y": 923 - }, - { - "x": 1115, - "y": 924 - }, - { - "x": 1115, - "y": 925 - }, - { - "x": 1116, - "y": 925 - }, - { - "x": 1116, - "y": 926 - }, - { - "x": 1117, - "y": 926 - }, - { - "x": 1117, - "y": 927 - }, - { - "x": 1117, - "y": 928 - }, - { - "x": 1118, - "y": 928 - }, - { - "x": 1118, - "y": 929 - }, - { - "x": 1119, - "y": 929 - }, - { - "x": 1119, - "y": 930 - }, - { - "x": 1120, - "y": 930 - }, - { - "x": 1120, - "y": 931 - }, - { - "x": 1120, - "y": 932 - }, - { - "x": 1120, - "y": 932 - }, - { - "x": 1119, - "y": 933 - }, - { - "x": 1119, - "y": 934 - }, - { - "x": 1119, - "y": 934 - }, - { - "x": 1118, - "y": 935 - }, - { - "x": 1118, - "y": 935 - } - ] - }, - { - "uuid": "b98f3a45-3328-41a0-9077-373a8177ebf2", - "schemaId": "ckrazcuec16om0z66bhhh4tp7", - "dataRow": { - "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d" - }, - "point": { - "x": 2122, - "y": 1457 - } - } -] \ No newline at end of file diff --git a/libs/labelbox/tests/data/assets/ndjson/image_import_name_only.json b/libs/labelbox/tests/data/assets/ndjson/image_import_name_only.json index 82be4cdab..466a03594 100644 --- a/libs/labelbox/tests/data/assets/ndjson/image_import_name_only.json +++ b/libs/labelbox/tests/data/assets/ndjson/image_import_name_only.json @@ -1,826 +1,86 @@ [ { "uuid": "b862c586-8614-483c-b5e6-82810f70cac0", - "name": "box a", + "name": "ckrazcueb16og0z6609jj7y3y", "dataRow": { "id": "ckrazctum0z8a0ybc0b0o0g0v" }, - "bbox": { - "top": 1352, - "left": 2275, - "height": 350, - "width": 139 - }, - "confidence": 0.854, + "classifications": [], + "confidence": 0.851, "customMetrics": [ { "name": "customMetric1", - "value": 0.5 - }, - { - "name": "customMetric2", - "value": 0.7 + "value": 0.4 } - ] + ], + "bbox": { + "top": 1352.0, + "left": 2275.0, + "height": 350.0, + "width": 139.0 + } }, { "uuid": "751fc725-f7b6-48ed-89b0-dd7d94d08af6", - "name": "mask a", + "name": "ckrazcuec16ok0z66f956apb7", "dataRow": { "id": "ckrazctum0z8a0ybc0b0o0g0v" }, - "mask": { - "instanceURI": "https://storage.labelbox.com/ckqcx1czn06830y61gh9v02cs%2F3e729327-f038-f66c-186e-45e921ef9717-1?Expires=1626806874672&KeyName=labelbox-assets-key-3&Signature=YsUOGKrsqmAZ68vT9BlPJOaRyLY", - "colorRGB": [ - 255, - 0, - 0 - ] - }, - "confidence": 0.685, + "classifications": [], + "confidence": 0.834, "customMetrics": [ { "name": "customMetric1", - "value": 0.4 - }, - { - "name": "customMetric2", - "value": 0.9 + "value": 0.3 } - ] + ], + "mask": { + "instanceURI": "https://storage.labelbox.com/ckqcx1czn06830y61gh9v02cs%2F3e729327-f038-f66c-186e-45e921ef9717-1?Expires=1626806874672&KeyName=labelbox-assets-key-3&Signature=YsUOGKrsqmAZ68vT9BlPJOaRyLY" + } }, { + "classifications": [], "uuid": "43d719ac-5d7f-4aea-be00-2ebfca0900fd", - "name": "polygon a", + "name": "ckrazcuec16oi0z66dzrd8pfl", "dataRow": { "id": "ckrazctum0z8a0ybc0b0o0g0v" }, - "confidence": 0.71, + "confidence": 0.986, "customMetrics": [ { "name": "customMetric1", - "value": 0.1 + "value": 0.9 } ], "polygon": [ { - "x": 1118, - "y": 935 - }, - { - "x": 1117, - "y": 935 - }, - { - "x": 1116, - "y": 935 - }, - { - "x": 1115, - "y": 935 - }, - { - "x": 1114, - "y": 935 - }, - { - "x": 1113, - "y": 935 - }, - { - "x": 1112, - "y": 935 - }, - { - "x": 1111, - "y": 935 - }, - { - "x": 1110, - "y": 935 - }, - { - "x": 1109, - "y": 935 - }, - { - "x": 1108, - "y": 935 - }, - { - "x": 1108, - "y": 934 - }, - { - "x": 1107, - "y": 934 - }, - { - "x": 1106, - "y": 934 - }, - { - "x": 1105, - "y": 934 - }, - { - "x": 1105, - "y": 933 - }, - { - "x": 1104, - "y": 933 - }, - { - "x": 1103, - "y": 933 - }, - { - "x": 1103, - "y": 932 - }, - { - "x": 1102, - "y": 932 - }, - { - "x": 1101, - "y": 932 - }, - { - "x": 1100, - "y": 932 - }, - { - "x": 1099, - "y": 932 - }, - { - "x": 1098, - "y": 932 - }, - { - "x": 1097, - "y": 932 - }, - { - "x": 1097, - "y": 931 - }, - { - "x": 1096, - "y": 931 - }, - { - "x": 1095, - "y": 931 - }, - { - "x": 1094, - "y": 931 - }, - { - "x": 1093, - "y": 931 - }, - { - "x": 1092, - "y": 931 - }, - { - "x": 1091, - "y": 931 - }, - { - "x": 1090, - "y": 931 - }, - { - "x": 1090, - "y": 930 - }, - { - "x": 1089, - "y": 930 - }, - { - "x": 1088, - "y": 930 - }, - { - "x": 1087, - "y": 930 - }, - { - "x": 1087, - "y": 929 - }, - { - "x": 1086, - "y": 929 - }, - { - "x": 1085, - "y": 929 - }, - { - "x": 1084, - "y": 929 - }, - { - "x": 1084, - "y": 928 - }, - { - "x": 1083, - "y": 928 - }, - { - "x": 1083, - "y": 927 - }, - { - "x": 1082, - "y": 927 - }, - { - "x": 1081, - "y": 927 - }, - { - "x": 1081, - "y": 926 - }, - { - "x": 1080, - "y": 926 - }, - { - "x": 1080, - "y": 925 - }, - { - "x": 1079, - "y": 925 - }, - { - "x": 1078, - "y": 925 - }, - { - "x": 1078, - "y": 924 - }, - { - "x": 1077, - "y": 924 - }, - { - "x": 1076, - "y": 924 - }, - { - "x": 1076, - "y": 923 - }, - { - "x": 1075, - "y": 923 - }, - { - "x": 1074, - "y": 923 - }, - { - "x": 1073, - "y": 923 - }, - { - "x": 1073, - "y": 922 - }, - { - "x": 1072, - "y": 922 - }, - { - "x": 1071, - "y": 922 - }, - { - "x": 1070, - "y": 922 - }, - { - "x": 1070, - "y": 921 - }, - { - "x": 1069, - "y": 921 - }, - { - "x": 1068, - "y": 921 - }, - { - "x": 1067, - "y": 921 - }, - { - "x": 1066, - "y": 921 - }, - { - "x": 1065, - "y": 921 - }, - { - "x": 1064, - "y": 921 - }, - { - "x": 1063, - "y": 921 - }, - { - "x": 1062, - "y": 921 - }, - { - "x": 1061, - "y": 921 - }, - { - "x": 1060, - "y": 921 - }, - { - "x": 1059, - "y": 921 - }, - { - "x": 1058, - "y": 921 - }, - { - "x": 1058, - "y": 920 - }, - { - "x": 1057, - "y": 920 - }, - { - "x": 1057, - "y": 919 - }, - { - "x": 1056, - "y": 919 - }, - { - "x": 1057, - "y": 918 - }, - { - "x": 1057, - "y": 918 - }, - { - "x": 1057, - "y": 917 - }, - { - "x": 1058, - "y": 916 - }, - { - "x": 1058, - "y": 916 - }, - { - "x": 1059, - "y": 915 - }, - { - "x": 1059, - "y": 915 - }, - { - "x": 1060, - "y": 914 - }, - { - "x": 1060, - "y": 914 - }, - { - "x": 1061, - "y": 913 - }, - { - "x": 1061, - "y": 913 - }, - { - "x": 1062, - "y": 912 - }, - { - "x": 1063, - "y": 912 - }, - { - "x": 1063, - "y": 912 - }, - { - "x": 1064, - "y": 911 - }, - { - "x": 1064, - "y": 911 - }, - { - "x": 1065, - "y": 910 - }, - { - "x": 1066, - "y": 910 - }, - { - "x": 1066, - "y": 910 - }, - { - "x": 1067, - "y": 909 - }, - { - "x": 1068, - "y": 909 - }, - { - "x": 1068, - "y": 909 - }, - { - "x": 1069, - "y": 908 - }, - { - "x": 1070, - "y": 908 - }, - { - "x": 1071, - "y": 908 - }, - { - "x": 1072, - "y": 908 - }, - { - "x": 1072, - "y": 908 - }, - { - "x": 1073, - "y": 907 - }, - { - "x": 1074, - "y": 907 - }, - { - "x": 1075, - "y": 907 - }, - { - "x": 1076, - "y": 907 - }, - { - "x": 1077, - "y": 907 - }, - { - "x": 1078, - "y": 907 - }, - { - "x": 1079, - "y": 907 - }, - { - "x": 1080, - "y": 907 - }, - { - "x": 1081, - "y": 907 - }, - { - "x": 1082, - "y": 907 - }, - { - "x": 1083, - "y": 907 - }, - { - "x": 1084, - "y": 907 - }, - { - "x": 1085, - "y": 907 - }, - { - "x": 1086, - "y": 907 - }, - { - "x": 1087, - "y": 907 - }, - { - "x": 1088, - "y": 907 - }, - { - "x": 1089, - "y": 907 - }, - { - "x": 1090, - "y": 907 - }, - { - "x": 1091, - "y": 907 - }, - { - "x": 1091, - "y": 908 - }, - { - "x": 1092, - "y": 908 - }, - { - "x": 1093, - "y": 908 - }, - { - "x": 1094, - "y": 908 - }, - { - "x": 1095, - "y": 908 - }, - { - "x": 1095, - "y": 909 - }, - { - "x": 1096, - "y": 909 - }, - { - "x": 1097, - "y": 909 - }, - { - "x": 1097, - "y": 910 - }, - { - "x": 1098, - "y": 910 - }, - { - "x": 1099, - "y": 910 + "x": 10.0, + "y": 20.0 }, { - "x": 1099, - "y": 911 + "x": 15.0, + "y": 20.0 }, { - "x": 1100, - "y": 911 + "x": 20.0, + "y": 25.0 }, { - "x": 1101, - "y": 911 - }, - { - "x": 1101, - "y": 912 - }, - { - "x": 1102, - "y": 912 - }, - { - "x": 1103, - "y": 912 - }, - { - "x": 1103, - "y": 913 - }, - { - "x": 1104, - "y": 913 - }, - { - "x": 1104, - "y": 914 - }, - { - "x": 1105, - "y": 914 - }, - { - "x": 1105, - "y": 915 - }, - { - "x": 1106, - "y": 915 - }, - { - "x": 1107, - "y": 915 - }, - { - "x": 1107, - "y": 916 - }, - { - "x": 1108, - "y": 916 - }, - { - "x": 1108, - "y": 917 - }, - { - "x": 1109, - "y": 917 - }, - { - "x": 1109, - "y": 918 - }, - { - "x": 1110, - "y": 918 - }, - { - "x": 1110, - "y": 919 - }, - { - "x": 1111, - "y": 919 - }, - { - "x": 1111, - "y": 920 - }, - { - "x": 1112, - "y": 920 - }, - { - "x": 1112, - "y": 921 - }, - { - "x": 1113, - "y": 921 - }, - { - "x": 1113, - "y": 922 - }, - { - "x": 1114, - "y": 922 - }, - { - "x": 1114, - "y": 923 - }, - { - "x": 1115, - "y": 923 - }, - { - "x": 1115, - "y": 924 - }, - { - "x": 1115, - "y": 925 - }, - { - "x": 1116, - "y": 925 - }, - { - "x": 1116, - "y": 926 - }, - { - "x": 1117, - "y": 926 - }, - { - "x": 1117, - "y": 927 - }, - { - "x": 1117, - "y": 928 - }, - { - "x": 1118, - "y": 928 - }, - { - "x": 1118, - "y": 929 - }, - { - "x": 1119, - "y": 929 - }, - { - "x": 1119, - "y": 930 - }, - { - "x": 1120, - "y": 930 - }, - { - "x": 1120, - "y": 931 - }, - { - "x": 1120, - "y": 932 - }, - { - "x": 1120, - "y": 932 - }, - { - "x": 1119, - "y": 933 - }, - { - "x": 1119, - "y": 934 - }, - { - "x": 1119, - "y": 934 - }, - { - "x": 1118, - "y": 935 - }, - { - "x": 1118, - "y": 935 + "x": 10.0, + "y": 20.0 } ] }, { + "classifications": [], "uuid": "b98f3a45-3328-41a0-9077-373a8177ebf2", - "name": "point a", + "name": "ckrazcuec16om0z66bhhh4tp7", "dataRow": { "id": "ckrazctum0z8a0ybc0b0o0g0v" }, - "confidence": 0.77, - "customMetrics": [ - { - "name": "customMetric2", - "value": 1.2 - } - ], "point": { - "x": 2122, - "y": 1457 + "x": 2122.0, + "y": 1457.0 } } ] \ No newline at end of file diff --git a/libs/labelbox/tests/data/assets/ndjson/metric_import_global_key.json b/libs/labelbox/tests/data/assets/ndjson/metric_import_global_key.json deleted file mode 100644 index 31be5a4c7..000000000 --- a/libs/labelbox/tests/data/assets/ndjson/metric_import_global_key.json +++ /dev/null @@ -1,10 +0,0 @@ -[ - { - "uuid": "a22bbf6e-b2da-4abe-9a11-df84759f7672", - "aggregation": "ARITHMETIC_MEAN", - "dataRow": { - "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d" - }, - "metricValue": 0.1 - } -] \ No newline at end of file diff --git a/libs/labelbox/tests/data/assets/ndjson/pdf_import_global_key.json b/libs/labelbox/tests/data/assets/ndjson/pdf_import_global_key.json deleted file mode 100644 index f4b4894f6..000000000 --- a/libs/labelbox/tests/data/assets/ndjson/pdf_import_global_key.json +++ /dev/null @@ -1,155 +0,0 @@ -[{ - "uuid": "5ad9c52f-058d-49c8-a749-3f20b84f8cd4", - "dataRow": { - "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d" - }, - "name": "boxy", - "schemaId": "cl6xnuwt95lqq07330tbb3mfd", - "classifications": [], - "page": 4, - "unit": "POINTS", - "confidence": 0.53, - "customMetrics": [ - { - "name": "customMetric1", - "value": 0.5 - }, - { - "name": "customMetric2", - "value": 0.3 - } - ], - "bbox": { - "top": 162.73, - "left": 32.45, - "height": 388.16999999999996, - "width": 101.66000000000001 - } -}, { - "uuid": "20eeef88-0294-49b4-a815-86588476bc6f", - "dataRow": { - "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d" - }, - "name": "boxy", - "schemaId": "cl6xnuwt95lqq07330tbb3mfd", - "classifications": [], - "page": 7, - "unit": "POINTS", - "bbox": { - "top": 223.26, - "left": 251.42, - "height": 457.03999999999996, - "width": 186.78 - } -}, { - "uuid": "641a8944-3938-409c-b4eb-dea354ed06e5", - "dataRow": { - "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d" - }, - "name": "boxy", - "schemaId": "cl6xnuwt95lqq07330tbb3mfd", - "classifications": [], - "page": 6, - "unit": "POINTS", - "confidence": 0.99, - "customMetrics": [ - { - "name": "customMetric1", - "value": 0.5 - }, - { - "name": "customMetric2", - "value": 0.3 - } - ], - "bbox": { - "top": 32.52, - "left": 218.17, - "height": 231.73, - "width": 110.56000000000003 - } -}, { - "uuid": "ebe4da7d-08b3-480a-8d15-26552b7f011c", - "dataRow": { - "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d" - }, - "name": "boxy", - "schemaId": "cl6xnuwt95lqq07330tbb3mfd", - "classifications": [], - "page": 7, - "unit": "POINTS", - "confidence": 0.89, - "customMetrics": [ - { - "name": "customMetric1", - "value": 0.5 - }, - { - "name": "customMetric2", - "value": 0.3 - } - ], - "bbox": { - "top": 117.39, - "left": 4.25, - "height": 456.9200000000001, - "width": 164.83 - } -}, { - "uuid": "35c41855-575f-42cc-a2f9-1f06237e9b63", - "dataRow": { - "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d" - }, - "name": "boxy", - "schemaId": "cl6xnuwt95lqq07330tbb3mfd", - "classifications": [], - "page": 8, - "unit": "POINTS", - "bbox": { - "top": 82.13, - "left": 217.28, - "height": 279.76, - "width": 82.43000000000004 - } -}, { - "uuid": "1b009654-bc17-42a2-8a71-160e7808c403", - "dataRow": { - "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d" - }, - "name": "boxy", - "schemaId": "cl6xnuwt95lqq07330tbb3mfd", - "classifications": [], - "page": 3, - "unit": "POINTS", - "bbox": { - "top": 298.12, - "left": 83.34, - "height": 203.83000000000004, - "width": 0.37999999999999545 - } -}, -{ - "uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673", - "dataRow": { - "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d" - }, - "name": "named_entity", - "classifications": [], - "textSelections": [ - { - "groupId": "2f4336f4-a07e-4e0a-a9e1-5629b03b719b", - "tokenIds": [ - "3f984bf3-1d61-44f5-b59a-9658a2e3440f", - "3bf00b56-ff12-4e52-8cc1-08dbddb3c3b8", - "6e1c3420-d4b7-4c5a-8fd6-ead43bf73d80", - "87a43d32-af76-4a1d-b262-5c5f4d5ace3a", - "e8606e8a-dfd9-4c49-a635-ad5c879c75d0", - "67c7c19e-4654-425d-bf17-2adb8cf02c30", - "149c5e80-3e07-49a7-ab2d-29ddfe6a38fa", - "b0e94071-2187-461e-8e76-96c58738a52c" - ], - "page": 1 - } - ] -} -] \ No newline at end of file diff --git a/libs/labelbox/tests/data/assets/ndjson/polyline_import_global_key.json b/libs/labelbox/tests/data/assets/ndjson/polyline_import_global_key.json deleted file mode 100644 index d6a9eecbd..000000000 --- a/libs/labelbox/tests/data/assets/ndjson/polyline_import_global_key.json +++ /dev/null @@ -1,36 +0,0 @@ -[ - { - "line": [ - { - "x": 2534.353, - "y": 249.471 - }, - { - "x": 2429.492, - "y": 182.092 - }, - { - "x": 2294.322, - "y": 221.962 - } - ], - "uuid": "5ad9c52f-058d-49c8-a749-3f20b84f8cd4", - "dataRow": { - "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d" - }, - "name": "some-line", - "schemaId": "cl6xnuwt95lqq07330tbb3mfd", - "classifications": [], - "confidence": 0.58, - "customMetrics": [ - { - "name": "customMetric1", - "value": 0.5 - }, - { - "name": "customMetric2", - "value": 0.3 - } - ] - } -] \ No newline at end of file diff --git a/libs/labelbox/tests/data/assets/ndjson/text_entity_import_global_key.json b/libs/labelbox/tests/data/assets/ndjson/text_entity_import_global_key.json deleted file mode 100644 index 1f26d8dc8..000000000 --- a/libs/labelbox/tests/data/assets/ndjson/text_entity_import_global_key.json +++ /dev/null @@ -1,26 +0,0 @@ -[ - { - "location": { - "start": 67, - "end": 128 - }, - "uuid": "5ad9c52f-058d-49c8-a749-3f20b84f8cd4", - "dataRow": { - "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d" - }, - "name": "some-text-entity", - "schemaId": "cl6xnuwt95lqq07330tbb3mfd", - "classifications": [], - "confidence": 0.53, - "customMetrics": [ - { - "name": "customMetric1", - "value": 0.5 - }, - { - "name": "customMetric2", - "value": 0.3 - } - ] - } -] \ No newline at end of file diff --git a/libs/labelbox/tests/data/assets/ndjson/video_import_global_key.json b/libs/labelbox/tests/data/assets/ndjson/video_import_global_key.json deleted file mode 100644 index 11e0753d9..000000000 --- a/libs/labelbox/tests/data/assets/ndjson/video_import_global_key.json +++ /dev/null @@ -1,166 +0,0 @@ -[{ - "answer": { - "schemaId": "ckrb1sfl8099g0y91cxbd5ftb" - }, - "schemaId": "ckrb1sfjx099a0y914hl319ie", - "dataRow": { - "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d" - }, - "uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673", - "frames": [{ - "start": 30, - "end": 35 - }, { - "start": 50, - "end": 51 - }] -}, { - "answer": [{ - "schemaId": "ckrb1sfl8099e0y919v260awv" - }], - "schemaId": "ckrb1sfkn099c0y910wbo0p1a", - "dataRow": { - "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d" - }, - "uuid": "d009925d-91a3-4f67-abd9-753453f5a584", - "frames": [{ - "start": 0, - "end": 5 - }] -}, { - "answer": "a value", - "schemaId": "ckrb1sfkn099c0y910wbo0p1a", - "dataRow": { - "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d" - }, - "uuid": "3b302706-37ec-4f72-ab2e-757d8bd302b9" -}, { - "classifications": [], - "schemaId": - "cl5islwg200gfci6g0oitaypu", - "dataRow": { - "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d" - }, - "uuid": - "6f7c835a-0139-4896-b73f-66a6baa89e94", - "segments": [{ - "keyframes": [{ - "frame": 1, - "line": [{ - "x": 10.0, - "y": 10.0 - }, { - "x": 100.0, - "y": 100.0 - }, { - "x": 50.0, - "y": 30.0 - }], - "classifications": [] - }, { - "frame": 5, - "line": [{ - "x": 15.0, - "y": 10.0 - }, { - "x": 50.0, - "y": 100.0 - }, { - "x": 50.0, - "y": 30.0 - }], - "classifications": [] - }] - }, { - "keyframes": [{ - "frame": 8, - "line": [{ - "x": 100.0, - "y": 10.0 - }, { - "x": 50.0, - "y": 100.0 - }, { - "x": 50.0, - "y": 30.0 - }], - "classifications": [] - }] - }] -}, { - "classifications": [], - "schemaId": - "cl5it7ktp00i5ci6gf80b1ysd", - "dataRow": { - "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d" - }, - "uuid": - "f963be22-227b-4efe-9be4-2738ed822216", - "segments": [{ - "keyframes": [{ - "frame": 1, - "point": { - "x": 10.0, - "y": 10.0 - }, - "classifications": [] - }] - }, { - "keyframes": [{ - "frame": 5, - "point": { - "x": 50.0, - "y": 50.0 - }, - "classifications": [] - }, { - "frame": 10, - "point": { - "x": 10.0, - "y": 50.0 - }, - "classifications": [] - }] - }] -}, { - "classifications": [], - "schemaId": - "cl5iw0roz00lwci6g5jni62vs", - "dataRow": { - "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d" - }, - "uuid": - "13b2ee0e-2355-4336-8b83-d74d09e3b1e7", - "segments": [{ - "keyframes": [{ - "frame": 1, - "bbox": { - "top": 10.0, - "left": 5.0, - "height": 100.0, - "width": 150.0 - }, - "classifications": [] - }, { - "frame": 5, - "bbox": { - "top": 30.0, - "left": 5.0, - "height": 50.0, - "width": 150.0 - }, - "classifications": [] - }] - }, { - "keyframes": [{ - "frame": 10, - "bbox": { - "top": 300.0, - "left": 200.0, - "height": 400.0, - "width": 150.0 - }, - "classifications": [] - }] - }] -}] diff --git a/libs/labelbox/tests/data/export/conftest.py b/libs/labelbox/tests/data/export/conftest.py index 0a62f39c8..4a59b6966 100644 --- a/libs/labelbox/tests/data/export/conftest.py +++ b/libs/labelbox/tests/data/export/conftest.py @@ -1,13 +1,16 @@ -import uuid import time +import uuid + import pytest -from labelbox.schema.queue_mode import QueueMode + +from labelbox import Client, MediaType +from labelbox.schema.annotation_import import AnnotationImportState, LabelImport from labelbox.schema.labeling_frontend import LabelingFrontend -from labelbox.schema.annotation_import import LabelImport, AnnotationImportState +from labelbox.schema.media_type import MediaType @pytest.fixture -def ontology(): +def ontology(client: Client): bbox_tool_with_nested_text = { "required": False, "name": "bbox_tool_with_nested_text", @@ -116,18 +119,174 @@ def ontology(): "color": "#008941", "classifications": [], } - entity_tool = { + raster_segmentation_tool = { "required": False, - "name": "entity--", - "tool": "named-entity", - "color": "#006FA6", + "name": "segmentation_mask", + "tool": "raster-segmentation", + "color": "#ff0000", "classifications": [], } - segmentation_tool = { + checklist = { + "required": False, + "instructions": "checklist", + "name": "checklist", + "type": "checklist", + "options": [ + {"label": "option1", "value": "option1"}, + {"label": "option2", "value": "option2"}, + {"label": "optionN", "value": "optionn"}, + ], + } + + free_form_text = { + "required": False, + "instructions": "text", + "name": "text", + "type": "text", + "options": [], + } + + radio = { "required": False, - "name": "segmentation--", - "tool": "superpixel", - "color": "#A30059", + "instructions": "radio", + "name": "radio", + "type": "radio", + "options": [ + { + "label": "first_radio_answer", + "value": "first_radio_answer", + "options": [], + }, + { + "label": "second_radio_answer", + "value": "second_radio_answer", + "options": [], + }, + ], + } + + tools = [ + bbox_tool, + bbox_tool_with_nested_text, + polygon_tool, + polyline_tool, + point_tool, + raster_segmentation_tool, + ] + classifications = [ + checklist, + free_form_text, + radio, + ] + ontology = client.create_ontology( + "image ontology", + {"tools": tools, "classifications": classifications}, + MediaType.Image, + ) + return ontology + + +@pytest.fixture +def video_ontology(client: Client): + bbox_tool_with_nested_text = { + "required": False, + "name": "bbox_tool_with_nested_text", + "tool": "rectangle", + "color": "#a23030", + "classifications": [ + { + "required": False, + "instructions": "nested", + "name": "nested", + "type": "radio", + "options": [ + { + "label": "radio_option_1", + "value": "radio_value_1", + "options": [ + { + "required": False, + "instructions": "nested_checkbox", + "name": "nested_checkbox", + "type": "checklist", + "options": [ + { + "label": "nested_checkbox_option_1", + "value": "nested_checkbox_value_1", + "options": [], + }, + { + "label": "nested_checkbox_option_2", + "value": "nested_checkbox_value_2", + }, + ], + }, + { + "required": False, + "instructions": "nested_text", + "name": "nested_text", + "type": "text", + "options": [], + }, + ], + }, + ], + } + ], + } + + bbox_tool = { + "required": False, + "name": "bbox", + "tool": "rectangle", + "color": "#a23030", + "classifications": [ + { + "required": False, + "instructions": "nested", + "name": "nested", + "type": "radio", + "options": [ + { + "label": "radio_option_1", + "value": "radio_value_1", + "options": [ + { + "required": False, + "instructions": "nested_checkbox", + "name": "nested_checkbox", + "type": "checklist", + "options": [ + { + "label": "nested_checkbox_option_1", + "value": "nested_checkbox_value_1", + "options": [], + }, + { + "label": "nested_checkbox_option_2", + "value": "nested_checkbox_value_2", + }, + ], + } + ], + }, + ], + } + ], + } + + polyline_tool = { + "required": False, + "name": "polyline", + "tool": "line", + "color": "#FF4A46", + "classifications": [], + } + point_tool = { + "required": False, + "name": "point--", + "tool": "point", + "color": "#008941", "classifications": [], } raster_segmentation_tool = { @@ -160,6 +319,7 @@ def ontology(): {"label": "optionN_index", "value": "optionn_index"}, ], } + free_form_text = { "required": False, "instructions": "text", @@ -167,14 +327,7 @@ def ontology(): "type": "text", "options": [], } - free_form_text_index = { - "required": False, - "instructions": "text_index", - "name": "text_index", - "type": "text", - "scope": "index", - "options": [], - } + radio = { "required": False, "instructions": "radio", @@ -193,33 +346,26 @@ def ontology(): }, ], } - named_entity = { - "tool": "named-entity", - "name": "named-entity", - "required": False, - "color": "#A30059", - "classifications": [], - } tools = [ bbox_tool, bbox_tool_with_nested_text, - polygon_tool, polyline_tool, point_tool, - entity_tool, - segmentation_tool, raster_segmentation_tool, - named_entity, ] classifications = [ - checklist, checklist_index, + checklist, free_form_text, - free_form_text_index, radio, ] - return {"tools": tools, "classifications": classifications} + ontology = client.create_ontology( + "image ontology", + {"tools": tools, "classifications": classifications}, + MediaType.Video, + ) + return ontology @pytest.fixture @@ -246,15 +392,17 @@ def configured_project_with_ontology( dataset = initial_dataset project = client.create_project( name=rand_gen(str), - queue_mode=QueueMode.Batch, + media_type=MediaType.Image, ) - editor = list( - client.get_labeling_frontends(where=LabelingFrontend.name == "editor") - )[0] - project.setup(editor, ontology) + project.connect_ontology(ontology) data_row_ids = [] - for _ in range(len(ontology["tools"]) + len(ontology["classifications"])): + normalized_ontology = ontology.normalized + + for _ in range( + len(normalized_ontology["tools"]) + + len(normalized_ontology["classifications"]) + ): data_row_ids.append(dataset.create_data_row(row_data=image_url).uid) project.create_batch( rand_gen(str), @@ -273,12 +421,25 @@ def configured_project_without_data_rows( project = client.create_project( name=rand_gen(str), description=rand_gen(str), - queue_mode=QueueMode.Batch, + media_type=MediaType.Image, ) - editor = list( - client.get_labeling_frontends(where=LabelingFrontend.name == "editor") - )[0] - project.setup(editor, ontology) + + project.connect_ontology(ontology) + yield project + teardown_helpers.teardown_project_labels_ontology_feature_schemas(project) + + +@pytest.fixture +def configured_video_project_without_data_rows( + client, video_ontology, rand_gen, teardown_helpers +): + project = client.create_project( + name=rand_gen(str), + description=rand_gen(str), + media_type=MediaType.Video, + ) + + project.connect_ontology(video_ontology) yield project teardown_helpers.teardown_project_labels_ontology_feature_schemas(project) diff --git a/libs/labelbox/tests/data/export/streamable/test_export_data_rows_streamable.py b/libs/labelbox/tests/data/export/streamable/test_export_data_rows_streamable.py index 3e4efbc46..8dfa1df72 100644 --- a/libs/labelbox/tests/data/export/streamable/test_export_data_rows_streamable.py +++ b/libs/labelbox/tests/data/export/streamable/test_export_data_rows_streamable.py @@ -1,7 +1,5 @@ -import json import time -import pytest from labelbox import DataRow, ExportTask, StreamType @@ -27,9 +25,7 @@ def test_with_data_row_object( ) assert export_task.get_total_lines(stream_type=StreamType.RESULT) == 1 assert ( - json.loads(list(export_task.get_stream())[0].json_str)["data_row"][ - "id" - ] + list(export_task.get_buffered_stream())[0].json["data_row"]["id"] == data_row.uid ) @@ -75,9 +71,7 @@ def test_with_id(self, client, data_row, wait_for_data_row_processing): ) assert export_task.get_total_lines(stream_type=StreamType.RESULT) == 1 assert ( - json.loads(list(export_task.get_stream())[0].json_str)["data_row"][ - "id" - ] + list(export_task.get_buffered_stream())[0].json["data_row"]["id"] == data_row.uid ) @@ -101,9 +95,7 @@ def test_with_global_key( ) assert export_task.get_total_lines(stream_type=StreamType.RESULT) == 1 assert ( - json.loads(list(export_task.get_stream())[0].json_str)["data_row"][ - "id" - ] + list(export_task.get_buffered_stream())[0].json["data_row"]["id"] == data_row.uid ) diff --git a/libs/labelbox/tests/data/export/streamable/test_export_dataset_streamable.py b/libs/labelbox/tests/data/export/streamable/test_export_dataset_streamable.py index 57f617a00..0d34a40b1 100644 --- a/libs/labelbox/tests/data/export/streamable/test_export_dataset_streamable.py +++ b/libs/labelbox/tests/data/export/streamable/test_export_dataset_streamable.py @@ -25,8 +25,8 @@ def test_export(self, dataset, data_rows): ) == len(expected_data_row_ids) data_row_ids = list( map( - lambda x: json.loads(x.json_str)["data_row"]["id"], - export_task.get_stream(), + lambda x: x.json["data_row"]["id"], + export_task.get_buffered_stream(), ) ) assert data_row_ids.sort() == expected_data_row_ids.sort() @@ -58,8 +58,8 @@ def test_with_data_row_filter(self, dataset, data_rows): ) data_row_ids = list( map( - lambda x: json.loads(x.json_str)["data_row"]["id"], - export_task.get_stream(), + lambda x: x.json["data_row"]["id"], + export_task.get_buffered_stream(), ) ) assert data_row_ids.sort() == expected_data_row_ids.sort() @@ -91,8 +91,8 @@ def test_with_global_key_filter(self, dataset, data_rows): ) global_keys = list( map( - lambda x: json.loads(x.json_str)["data_row"]["global_key"], - export_task.get_stream(), + lambda x: x.json["data_row"]["global_key"], + export_task.get_buffered_stream(), ) ) assert global_keys.sort() == expected_global_keys.sort() diff --git a/libs/labelbox/tests/data/export/streamable/test_export_embeddings_streamable.py b/libs/labelbox/tests/data/export/streamable/test_export_embeddings_streamable.py index 071acbb5b..803b5994a 100644 --- a/libs/labelbox/tests/data/export/streamable/test_export_embeddings_streamable.py +++ b/libs/labelbox/tests/data/export/streamable/test_export_embeddings_streamable.py @@ -1,7 +1,7 @@ import json import random -from labelbox import StreamType, JsonConverter +from labelbox import StreamType class TestExportEmbeddings: @@ -23,12 +23,8 @@ def test_export_embeddings_precomputed( assert export_task.has_errors() is False results = [] - export_task.get_stream( - converter=JsonConverter(), stream_type=StreamType.RESULT - ).start( - stream_handler=lambda output: results.append( - json.loads(output.json_str) - ) + export_task.get_buffered_stream(stream_type=StreamType.RESULT).start( + stream_handler=lambda output: results.append(output.json) ) assert len(results) == len(data_row_specs) @@ -69,12 +65,8 @@ def test_export_embeddings_custom( assert export_task.has_errors() is False results = [] - export_task.get_stream( - converter=JsonConverter(), stream_type=StreamType.RESULT - ).start( - stream_handler=lambda output: results.append( - json.loads(output.json_str) - ) + export_task.get_buffered_stream(stream_type=StreamType.RESULT).start( + stream_handler=lambda output: results.append(output.json) ) assert len(results) == 1 @@ -86,6 +78,6 @@ def test_export_embeddings_custom( if emb["id"] == embedding.id: assert emb["name"] == embedding.name assert emb["dimensions"] == embedding.dims - assert emb["is_custom"] == True + assert emb["is_custom"] is True assert len(emb["values"]) == 1 assert emb["values"][0]["value"] == vector diff --git a/libs/labelbox/tests/data/export/streamable/test_export_model_run_streamable.py b/libs/labelbox/tests/data/export/streamable/test_export_model_run_streamable.py index ada493fc3..7a583198b 100644 --- a/libs/labelbox/tests/data/export/streamable/test_export_model_run_streamable.py +++ b/libs/labelbox/tests/data/export/streamable/test_export_model_run_streamable.py @@ -27,8 +27,8 @@ def test_export(self, model_run_with_data_rows): stream_type=StreamType.RESULT ) == len(expected_data_rows) - for data in export_task.get_stream(): - obj = json.loads(data.json_str) + for data in export_task.get_buffered_stream(): + obj = data.json assert ( "media_attributes" in obj and obj["media_attributes"] is not None diff --git a/libs/labelbox/tests/data/export/streamable/test_export_project_streamable.py b/libs/labelbox/tests/data/export/streamable/test_export_project_streamable.py index 818a0178c..597c529aa 100644 --- a/libs/labelbox/tests/data/export/streamable/test_export_project_streamable.py +++ b/libs/labelbox/tests/data/export/streamable/test_export_project_streamable.py @@ -9,6 +9,7 @@ from labelbox import Project, Dataset from labelbox.schema.data_row import DataRow from labelbox.schema.label import Label +from labelbox import UniqueIds IMAGE_URL = "https://storage.googleapis.com/lb-artifacts-testing-public/sdk_integration_test/potato.jpeg" @@ -56,8 +57,8 @@ def test_export( ) assert export_task.get_total_lines(stream_type=StreamType.RESULT) > 0 - for data in export_task.get_stream(): - obj = json.loads(data.json_str) + for data in export_task.get_buffered_stream(): + obj = data.json task_media_attributes = obj["media_attributes"] task_project = obj["projects"][project.uid] task_project_label_ids_set = set( @@ -128,7 +129,9 @@ def test_with_date_filters( review_queue = next( tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE" ) - project.move_data_rows_to_task_queue([data_row.uid], review_queue.uid) + project.move_data_rows_to_task_queue( + UniqueIds([data_row.uid]), review_queue.uid + ) export_task = project_export( project, task_name, filters=filters, params=params ) @@ -139,8 +142,8 @@ def test_with_date_filters( ) assert export_task.get_total_lines(stream_type=StreamType.RESULT) > 0 - for data in export_task.get_stream(): - obj = json.loads(data.json_str) + for data in export_task.get_buffered_stream(): + obj = data.json task_project = obj["projects"][project.uid] task_project_label_ids_set = set( map(lambda prediction: prediction["id"], task_project["labels"]) @@ -181,9 +184,9 @@ def test_with_iso_date_filters( assert export_task.get_total_lines(stream_type=StreamType.RESULT) > 0 assert ( label_id - == json.loads(list(export_task.get_stream())[0].json_str)[ - "projects" - ][project.uid]["labels"][0]["id"] + == list(export_task.get_buffered_stream())[0].json["projects"][ + project.uid + ]["labels"][0]["id"] ) def test_with_iso_date_filters_no_start_date( @@ -207,9 +210,9 @@ def test_with_iso_date_filters_no_start_date( assert export_task.get_total_lines(stream_type=StreamType.RESULT) > 0 assert ( label_id - == json.loads(list(export_task.get_stream())[0].json_str)[ - "projects" - ][project.uid]["labels"][0]["id"] + == list(export_task.get_buffered_stream())[0].json["projects"][ + project.uid + ]["labels"][0]["id"] ) def test_with_iso_date_filters_and_future_start_date( @@ -270,8 +273,8 @@ def test_with_data_row_filter( ) data_row_ids = list( map( - lambda x: json.loads(x.json_str)["data_row"]["id"], - export_task.get_stream(), + lambda x: x.json["data_row"]["id"], + export_task.get_buffered_stream(), ) ) assert data_row_ids.sort() == expected_data_row_ids.sort() @@ -310,8 +313,8 @@ def test_with_global_key_filter( ) global_keys = list( map( - lambda x: json.loads(x.json_str)["data_row"]["global_key"], - export_task.get_stream(), + lambda x: x.json["data_row"]["global_key"], + export_task.get_buffered_stream(), ) ) assert global_keys.sort() == expected_global_keys.sort() diff --git a/libs/labelbox/tests/data/export/streamable/test_export_video_streamable.py b/libs/labelbox/tests/data/export/streamable/test_export_video_streamable.py index 115194a58..da99414ee 100644 --- a/libs/labelbox/tests/data/export/streamable/test_export_video_streamable.py +++ b/libs/labelbox/tests/data/export/streamable/test_export_video_streamable.py @@ -4,7 +4,7 @@ import labelbox as lb import labelbox.types as lb_types -from labelbox.data.annotation_types.data.video import VideoData +from labelbox.data.annotation_types.data import GenericDataRowData from labelbox.schema.annotation_import import AnnotationImportState from labelbox.schema.export_task import ExportTask, StreamType @@ -21,13 +21,13 @@ def org_id(self, client): def test_export( self, client, - configured_project_without_data_rows, + configured_video_project_without_data_rows, video_data, video_data_row, bbox_video_annotation_objects, rand_gen, ): - project = configured_project_without_data_rows + project = configured_video_project_without_data_rows project_id = project.uid labels = [] @@ -41,7 +41,7 @@ def test_export( for data_row_uid in data_row_uids: labels = [ lb_types.Label( - data=VideoData(uid=data_row_uid), + data=GenericDataRowData(uid=data_row_uid), annotations=bbox_video_annotation_objects, ) ] @@ -71,7 +71,8 @@ def test_export( export_task.get_total_file_size(stream_type=StreamType.RESULT) > 0 ) - export_data = json.loads(list(export_task.get_stream())[0].json_str) + export_data = list(export_task.get_buffered_stream())[0].json + data_row_export = export_data["data_row"] assert data_row_export["global_key"] == video_data_row["global_key"] assert data_row_export["row_data"] == video_data_row["row_data"] diff --git a/libs/labelbox/tests/data/serialization/coco/test_coco.py b/libs/labelbox/tests/data/serialization/coco/test_coco.py deleted file mode 100644 index a7c733ce5..000000000 --- a/libs/labelbox/tests/data/serialization/coco/test_coco.py +++ /dev/null @@ -1,38 +0,0 @@ -import json -from pathlib import Path - -from labelbox.data.serialization.coco import COCOConverter - -COCO_ASSETS_DIR = "tests/data/assets/coco" - - -def run_instances(tmpdir): - instance_json = json.load(open(Path(COCO_ASSETS_DIR, "instances.json"))) - res = COCOConverter.deserialize_instances( - instance_json, Path(COCO_ASSETS_DIR, "images") - ) - back = COCOConverter.serialize_instances( - res, - Path(tmpdir), - ) - - -def test_rle_objects(tmpdir): - rle_json = json.load(open(Path(COCO_ASSETS_DIR, "rle.json"))) - res = COCOConverter.deserialize_instances( - rle_json, Path(COCO_ASSETS_DIR, "images") - ) - back = COCOConverter.serialize_instances(res, tmpdir) - - -def test_panoptic(tmpdir): - panoptic_json = json.load(open(Path(COCO_ASSETS_DIR, "panoptic.json"))) - image_dir, mask_dir = [ - Path(COCO_ASSETS_DIR, dir_name) for dir_name in ["images", "masks"] - ] - res = COCOConverter.deserialize_panoptic(panoptic_json, image_dir, mask_dir) - back = COCOConverter.serialize_panoptic( - res, - Path(f"/{tmpdir}/images_panoptic"), - Path(f"/{tmpdir}/masks_panoptic"), - ) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_checklist.py b/libs/labelbox/tests/data/serialization/ndjson/test_checklist.py index 0bc3c8924..fb78916f4 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_checklist.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_checklist.py @@ -4,7 +4,7 @@ ClassificationAnswer, Radio, ) -from labelbox.data.annotation_types.data.text import TextData +from labelbox.data.annotation_types.data import GenericDataRowData from labelbox.data.annotation_types.label import Label from labelbox.data.serialization.ndjson.converter import NDJsonConverter @@ -13,9 +13,8 @@ def test_serialization_min(): label = Label( uid="ckj7z2q0b0000jx6x0q2q7q0d", - data=TextData( + data=GenericDataRowData( uid="bkj7z2q0b0000jx6x0q2q7q0d", - text="This is a test", ), annotations=[ ClassificationAnnotation( @@ -37,20 +36,12 @@ def test_serialization_min(): res.pop("uuid") assert res == expected - deserialized = NDJsonConverter.deserialize([res]) - res = next(deserialized) - for i, annotation in enumerate(res.annotations): - annotation.extra.pop("uuid") - assert annotation.value == label.annotations[i].value - assert annotation.name == label.annotations[i].name - def test_serialization_with_classification(): label = Label( uid="ckj7z2q0b0000jx6x0q2q7q0d", - data=TextData( + data=GenericDataRowData( uid="bkj7z2q0b0000jx6x0q2q7q0d", - text="This is a test", ), annotations=[ ClassificationAnnotation( @@ -134,19 +125,12 @@ def test_serialization_with_classification(): res.pop("uuid") assert res == expected - deserialized = NDJsonConverter.deserialize([res]) - res = next(deserialized) - assert label.model_dump(exclude_none=True) == label.model_dump( - exclude_none=True - ) - def test_serialization_with_classification_double_nested(): label = Label( uid="ckj7z2q0b0000jx6x0q2q7q0d", - data=TextData( + data=GenericDataRowData( uid="bkj7z2q0b0000jx6x0q2q7q0d", - text="This is a test", ), annotations=[ ClassificationAnnotation( @@ -233,20 +217,12 @@ def test_serialization_with_classification_double_nested(): res.pop("uuid") assert res == expected - deserialized = NDJsonConverter.deserialize([res]) - res = next(deserialized) - res.annotations[0].extra.pop("uuid") - assert label.model_dump(exclude_none=True) == label.model_dump( - exclude_none=True - ) - def test_serialization_with_classification_double_nested_2(): label = Label( uid="ckj7z2q0b0000jx6x0q2q7q0d", - data=TextData( + data=GenericDataRowData( uid="bkj7z2q0b0000jx6x0q2q7q0d", - text="This is a test", ), annotations=[ ClassificationAnnotation( @@ -330,9 +306,3 @@ def test_serialization_with_classification_double_nested_2(): res = next(serialized) res.pop("uuid") assert res == expected - - deserialized = NDJsonConverter.deserialize([res]) - res = next(deserialized) - assert label.model_dump(exclude_none=True) == label.model_dump( - exclude_none=True - ) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_classification.py b/libs/labelbox/tests/data/serialization/ndjson/test_classification.py index 8dcb17f0b..82adce99c 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_classification.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_classification.py @@ -1,15 +1,73 @@ import json +from labelbox.data.annotation_types.classification.classification import ( + Checklist, + Radio, + Text, +) +from labelbox.data.annotation_types.data.generic_data_row_data import ( + GenericDataRowData, +) from labelbox.data.serialization.ndjson.converter import NDJsonConverter +from labelbox.types import ( + Label, + ClassificationAnnotation, + ClassificationAnswer, +) +from labelbox.data.mixins import CustomMetric + def test_classification(): with open( "tests/data/assets/ndjson/classification_import.json", "r" ) as file: data = json.load(file) - res = list(NDJsonConverter.deserialize(data)) - res = list(NDJsonConverter.serialize(res)) + + label = Label( + data=GenericDataRowData( + uid="ckrb1sf1i1g7i0ybcdc6oc8ct", + ), + annotations=[ + ClassificationAnnotation( + feature_schema_id="ckrb1sfjx099a0y914hl319ie", + extra={"uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673"}, + value=Radio( + answer=ClassificationAnswer( + custom_metrics=[ + CustomMetric(name="customMetric1", value=0.5), + CustomMetric(name="customMetric2", value=0.3), + ], + confidence=0.8, + feature_schema_id="ckrb1sfl8099g0y91cxbd5ftb", + ), + ), + ), + ClassificationAnnotation( + feature_schema_id="ckrb1sfkn099c0y910wbo0p1a", + extra={"uuid": "d009925d-91a3-4f67-abd9-753453f5a584"}, + value=Checklist( + answer=[ + ClassificationAnswer( + custom_metrics=[ + CustomMetric(name="customMetric1", value=0.5), + CustomMetric(name="customMetric2", value=0.3), + ], + confidence=0.82, + feature_schema_id="ckrb1sfl8099e0y919v260awv", + ) + ], + ), + ), + ClassificationAnnotation( + feature_schema_id="ckrb1sfkn099c0y910wbo0p1a", + extra={"uuid": "78ff6a23-bebe-475c-8f67-4c456909648f"}, + value=Text(answer="a value"), + ), + ], + ) + + res = list(NDJsonConverter.serialize([label])) assert res == data @@ -18,6 +76,48 @@ def test_classification_with_name(): "tests/data/assets/ndjson/classification_import_name_only.json", "r" ) as file: data = json.load(file) - res = list(NDJsonConverter.deserialize(data)) - res = list(NDJsonConverter.serialize(res)) + label = Label( + data=GenericDataRowData( + uid="ckrb1sf1i1g7i0ybcdc6oc8ct", + ), + annotations=[ + ClassificationAnnotation( + name="classification a", + extra={"uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673"}, + value=Radio( + answer=ClassificationAnswer( + custom_metrics=[ + CustomMetric(name="customMetric1", value=0.5), + CustomMetric(name="customMetric2", value=0.3), + ], + confidence=0.99, + name="choice 1", + ), + ), + ), + ClassificationAnnotation( + name="classification b", + extra={"uuid": "d009925d-91a3-4f67-abd9-753453f5a584"}, + value=Checklist( + answer=[ + ClassificationAnswer( + custom_metrics=[ + CustomMetric(name="customMetric1", value=0.5), + CustomMetric(name="customMetric2", value=0.3), + ], + confidence=0.945, + name="choice 2", + ) + ], + ), + ), + ClassificationAnnotation( + name="classification c", + extra={"uuid": "150d60de-30af-44e4-be20-55201c533312"}, + value=Text(answer="a value"), + ), + ], + ) + + res = list(NDJsonConverter.serialize([label])) assert res == data diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_conversation.py b/libs/labelbox/tests/data/serialization/ndjson/test_conversation.py index f7da9181b..5aa7285e2 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_conversation.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_conversation.py @@ -1,8 +1,12 @@ import json +from labelbox.data.annotation_types.data.generic_data_row_data import ( + GenericDataRowData, +) import pytest import labelbox.types as lb_types from labelbox.data.serialization.ndjson.converter import NDJsonConverter +from labelbox.data.mixins import CustomMetric radio_ndjson = [ { @@ -15,7 +19,7 @@ radio_label = [ lb_types.Label( - data=lb_types.ConversationData(global_key="my_global_key"), + data=lb_types.GenericDataRowData(global_key="my_global_key"), annotations=[ lb_types.ClassificationAnnotation( name="radio", @@ -44,7 +48,7 @@ checklist_label = [ lb_types.Label( - data=lb_types.ConversationData(global_key="my_global_key"), + data=lb_types.GenericDataRowData(global_key="my_global_key"), annotations=[ lb_types.ClassificationAnnotation( name="checklist", @@ -74,7 +78,7 @@ ] free_text_label = [ lb_types.Label( - data=lb_types.ConversationData(global_key="my_global_key"), + data=lb_types.GenericDataRowData(global_key="my_global_key"), annotations=[ lb_types.ClassificationAnnotation( name="free_text", @@ -99,31 +103,68 @@ def test_message_based_radio_classification(label, ndjson): serialized_label[0].pop("uuid") assert serialized_label == ndjson - deserialized_label = list(NDJsonConverter().deserialize(ndjson)) - deserialized_label[0].annotations[0].extra.pop("uuid") - assert deserialized_label[0].model_dump(exclude_none=True) == label[ - 0 - ].model_dump(exclude_none=True) +def test_conversation_entity_import(): + with open( + "tests/data/assets/ndjson/conversation_entity_import.json", "r" + ) as file: + data = json.load(file) -@pytest.mark.parametrize( - "filename", - [ - "tests/data/assets/ndjson/conversation_entity_import.json", + label = lb_types.Label( + data=GenericDataRowData( + uid="cl6xnv9h61fv0085yhtoq06ht", + ), + annotations=[ + lb_types.ObjectAnnotation( + custom_metrics=[ + CustomMetric(name="customMetric1", value=0.5), + CustomMetric(name="customMetric2", value=0.3), + ], + confidence=0.53, + name="some-text-entity", + feature_schema_id="cl6xnuwt95lqq07330tbb3mfd", + extra={"uuid": "5ad9c52f-058d-49c8-a749-3f20b84f8cd4"}, + value=lb_types.ConversationEntity( + start=67, end=128, message_id="some-message-id" + ), + ) + ], + ) + + res = list(NDJsonConverter.serialize([label])) + assert res == data + + +def test_conversation_entity_import_without_confidence(): + with open( "tests/data/assets/ndjson/conversation_entity_without_confidence_import.json", - ], -) -def test_conversation_entity_import(filename: str): - with open(filename, "r") as file: + "r", + ) as file: data = json.load(file) - res = list(NDJsonConverter.deserialize(data)) - res = list(NDJsonConverter.serialize(res)) + label = lb_types.Label( + uid=None, + data=GenericDataRowData( + uid="cl6xnv9h61fv0085yhtoq06ht", + ), + annotations=[ + lb_types.ObjectAnnotation( + name="some-text-entity", + feature_schema_id="cl6xnuwt95lqq07330tbb3mfd", + extra={"uuid": "5ad9c52f-058d-49c8-a749-3f20b84f8cd4"}, + value=lb_types.ConversationEntity( + start=67, end=128, extra={}, message_id="some-message-id" + ), + ) + ], + ) + + res = list(NDJsonConverter.serialize([label])) assert res == data def test_benchmark_reference_label_flag_enabled(): label = lb_types.Label( - data=lb_types.ConversationData(global_key="my_global_key"), + data=lb_types.GenericDataRowData(global_key="my_global_key"), annotations=[ lb_types.ClassificationAnnotation( name="free_text", @@ -140,7 +181,7 @@ def test_benchmark_reference_label_flag_enabled(): def test_benchmark_reference_label_flag_disabled(): label = lb_types.Label( - data=lb_types.ConversationData(global_key="my_global_key"), + data=lb_types.GenericDataRowData(global_key="my_global_key"), annotations=[ lb_types.ClassificationAnnotation( name="free_text", diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_data_gen.py b/libs/labelbox/tests/data/serialization/ndjson/test_data_gen.py index 333c00250..999e1bda5 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_data_gen.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_data_gen.py @@ -1,67 +1,29 @@ -from copy import copy -import pytest import labelbox.types as lb_types from labelbox.data.serialization import NDJsonConverter -from labelbox.data.serialization.ndjson.objects import ( - NDDicomSegments, - NDDicomSegment, - NDDicomLine, -) - -""" -Data gen prompt test data -""" - -prompt_text_annotation = lb_types.PromptClassificationAnnotation( - feature_schema_id="ckrb1sfkn099c0y910wbo0p1a", - name="test", - value=lb_types.PromptText( - answer="the answer to the text questions right here" - ), -) - -prompt_text_ndjson = { - "answer": "the answer to the text questions right here", - "name": "test", - "schemaId": "ckrb1sfkn099c0y910wbo0p1a", - "dataRow": {"id": "ckrb1sf1i1g7i0ybcdc6oc8ct"}, -} - -data_gen_label = lb_types.Label( - data={"uid": "ckrb1sf1i1g7i0ybcdc6oc8ct"}, - annotations=[prompt_text_annotation], -) - -""" -Prompt annotation test -""" def test_serialize_label(): - serialized_label = next(NDJsonConverter().serialize([data_gen_label])) - # Remove uuid field since this is a random value that can not be specified also meant for relationships - del serialized_label["uuid"] - assert serialized_label == prompt_text_ndjson - - -def test_deserialize_label(): - deserialized_label = next( - NDJsonConverter().deserialize([prompt_text_ndjson]) + prompt_text_annotation = lb_types.PromptClassificationAnnotation( + feature_schema_id="ckrb1sfkn099c0y910wbo0p1a", + name="test", + extra={"uuid": "test"}, + value=lb_types.PromptText( + answer="the answer to the text questions right here" + ), ) - if hasattr(deserialized_label.annotations[0], "extra"): - # Extra fields are added to deserialized label by default need removed to match - deserialized_label.annotations[0].extra = {} - assert deserialized_label.model_dump( - exclude_none=True - ) == data_gen_label.model_dump(exclude_none=True) + prompt_text_ndjson = { + "answer": "the answer to the text questions right here", + "name": "test", + "schemaId": "ckrb1sfkn099c0y910wbo0p1a", + "dataRow": {"id": "ckrb1sf1i1g7i0ybcdc6oc8ct"}, + "uuid": "test", + } + + data_gen_label = lb_types.Label( + data={"uid": "ckrb1sf1i1g7i0ybcdc6oc8ct"}, + annotations=[prompt_text_annotation], + ) + serialized_label = next(NDJsonConverter().serialize([data_gen_label])) -def test_serialize_deserialize_label(): - serialized = list(NDJsonConverter.serialize([data_gen_label])) - deserialized = next(NDJsonConverter.deserialize(serialized)) - if hasattr(deserialized.annotations[0], "extra"): - # Extra fields are added to deserialized label by default need removed to match - deserialized.annotations[0].extra = {} - assert deserialized.model_dump( - exclude_none=True - ) == data_gen_label.model_dump(exclude_none=True) + assert serialized_label == prompt_text_ndjson diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_dicom.py b/libs/labelbox/tests/data/serialization/ndjson/test_dicom.py index 633214367..6a00fa871 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_dicom.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_dicom.py @@ -1,6 +1,5 @@ from copy import copy import pytest -import base64 import labelbox.types as lb_types from labelbox.data.serialization import NDJsonConverter from labelbox.data.serialization.ndjson.objects import ( @@ -32,7 +31,7 @@ ] polyline_label = lb_types.Label( - data=lb_types.DicomData(uid="test-uid"), + data=lb_types.GenericDataRowData(uid="test-uid"), annotations=dicom_polyline_annotations, ) @@ -59,7 +58,7 @@ } polyline_with_global_key = lb_types.Label( - data=lb_types.DicomData(global_key="test-global-key"), + data=lb_types.GenericDataRowData(global_key="test-global-key"), annotations=dicom_polyline_annotations, ) @@ -110,11 +109,12 @@ } video_mask_label = lb_types.Label( - data=lb_types.VideoData(uid="test-uid"), annotations=[video_mask_annotation] + data=lb_types.GenericDataRowData(uid="test-uid"), + annotations=[video_mask_annotation], ) video_mask_label_with_global_key = lb_types.Label( - data=lb_types.VideoData(global_key="test-global-key"), + data=lb_types.GenericDataRowData(global_key="test-global-key"), annotations=[video_mask_annotation], ) """ @@ -129,11 +129,12 @@ ) dicom_mask_label = lb_types.Label( - data=lb_types.DicomData(uid="test-uid"), annotations=[dicom_mask_annotation] + data=lb_types.GenericDataRowData(uid="test-uid"), + annotations=[dicom_mask_annotation], ) dicom_mask_label_with_global_key = lb_types.Label( - data=lb_types.DicomData(global_key="test-global-key"), + data=lb_types.GenericDataRowData(global_key="test-global-key"), annotations=[dicom_mask_annotation], ) @@ -181,28 +182,3 @@ def test_serialize_label(label, ndjson): if "uuid" in serialized_label: serialized_label.pop("uuid") assert serialized_label == ndjson - - -@pytest.mark.parametrize("label, ndjson", labels_ndjsons) -def test_deserialize_label(label, ndjson): - deserialized_label = next(NDJsonConverter().deserialize([ndjson])) - if hasattr(deserialized_label.annotations[0], "extra"): - deserialized_label.annotations[0].extra = {} - for i, annotation in enumerate(deserialized_label.annotations): - if hasattr(annotation, "frames"): - assert annotation.frames == label.annotations[i].frames - if hasattr(annotation, "value"): - assert annotation.value == label.annotations[i].value - - -@pytest.mark.parametrize("label", labels) -def test_serialize_deserialize_label(label): - serialized = list(NDJsonConverter.serialize([label])) - deserialized = list(NDJsonConverter.deserialize(serialized)) - if hasattr(deserialized[0].annotations[0], "extra"): - deserialized[0].annotations[0].extra = {} - for i, annotation in enumerate(deserialized[0].annotations): - if hasattr(annotation, "frames"): - assert annotation.frames == label.annotations[i].frames - if hasattr(annotation, "value"): - assert annotation.value == label.annotations[i].value diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_document.py b/libs/labelbox/tests/data/serialization/ndjson/test_document.py index 5fe6a9789..fcdf4368b 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_document.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_document.py @@ -1,6 +1,19 @@ import json +from labelbox.data.annotation_types.data.generic_data_row_data import ( + GenericDataRowData, +) +from labelbox.data.mixins import CustomMetric import labelbox.types as lb_types from labelbox.data.serialization.ndjson.converter import NDJsonConverter +from labelbox.types import ( + Label, + ObjectAnnotation, + RectangleUnit, + Point, + DocumentRectangle, + DocumentEntity, + DocumentTextSelection, +) bbox_annotation = lb_types.ObjectAnnotation( name="bounding_box", # must match your ontology feature's name @@ -13,7 +26,7 @@ ) bbox_labels = [ lb_types.Label( - data=lb_types.DocumentData(global_key="test-global-key"), + data=lb_types.GenericDataRowData(global_key="test-global-key"), annotations=[bbox_annotation], ) ] @@ -53,10 +66,144 @@ def test_pdf(): """ with open("tests/data/assets/ndjson/pdf_import.json", "r") as f: data = json.load(f) - res = list(NDJsonConverter.deserialize(data)) - res = list(NDJsonConverter.serialize(res)) + labels = [ + Label( + uid=None, + data=GenericDataRowData( + uid="cl6xnv9h61fv0085yhtoq06ht", + ), + annotations=[ + ObjectAnnotation( + custom_metrics=[ + CustomMetric(name="customMetric1", value=0.5), + CustomMetric(name="customMetric2", value=0.3), + ], + confidence=0.53, + name="boxy", + feature_schema_id="cl6xnuwt95lqq07330tbb3mfd", + extra={ + "uuid": "5ad9c52f-058d-49c8-a749-3f20b84f8cd4", + }, + value=DocumentRectangle( + start=Point(x=32.45, y=162.73), + end=Point(x=134.11, y=550.9), + page=4, + unit=RectangleUnit.POINTS, + ), + ), + ObjectAnnotation( + name="boxy", + feature_schema_id="cl6xnuwt95lqq07330tbb3mfd", + extra={ + "uuid": "20eeef88-0294-49b4-a815-86588476bc6f", + }, + value=DocumentRectangle( + start=Point(x=251.42, y=223.26), + end=Point(x=438.2, y=680.3), + page=7, + unit=RectangleUnit.POINTS, + ), + ), + ObjectAnnotation( + custom_metrics=[ + CustomMetric(name="customMetric1", value=0.5), + CustomMetric(name="customMetric2", value=0.3), + ], + confidence=0.99, + name="boxy", + feature_schema_id="cl6xnuwt95lqq07330tbb3mfd", + extra={ + "uuid": "641a8944-3938-409c-b4eb-dea354ed06e5", + }, + value=DocumentRectangle( + start=Point(x=218.17, y=32.52), + end=Point(x=328.73, y=264.25), + page=6, + unit=RectangleUnit.POINTS, + ), + ), + ObjectAnnotation( + custom_metrics=[ + CustomMetric(name="customMetric1", value=0.5), + CustomMetric(name="customMetric2", value=0.3), + ], + confidence=0.89, + name="boxy", + feature_schema_id="cl6xnuwt95lqq07330tbb3mfd", + extra={ + "uuid": "ebe4da7d-08b3-480a-8d15-26552b7f011c", + }, + value=DocumentRectangle( + start=Point(x=4.25, y=117.39), + end=Point(x=169.08, y=574.3100000000001), + page=7, + unit=RectangleUnit.POINTS, + ), + ), + ObjectAnnotation( + name="boxy", + feature_schema_id="cl6xnuwt95lqq07330tbb3mfd", + extra={ + "uuid": "35c41855-575f-42cc-a2f9-1f06237e9b63", + }, + value=DocumentRectangle( + start=Point(x=217.28, y=82.13), + end=Point(x=299.71000000000004, y=361.89), + page=8, + unit=RectangleUnit.POINTS, + ), + ), + ObjectAnnotation( + name="boxy", + feature_schema_id="cl6xnuwt95lqq07330tbb3mfd", + extra={ + "uuid": "1b009654-bc17-42a2-8a71-160e7808c403", + }, + value=DocumentRectangle( + start=Point(x=83.34, y=298.12), + end=Point(x=83.72, y=501.95000000000005), + page=3, + unit=RectangleUnit.POINTS, + ), + ), + ], + ), + Label( + data=GenericDataRowData( + uid="ckrb1sf1i1g7i0ybcdc6oc8ct", + ), + annotations=[ + ObjectAnnotation( + name="named_entity", + feature_schema_id="cl6xnuwt95lqq07330tbb3mfd", + extra={ + "uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673", + }, + value=DocumentEntity( + text_selections=[ + DocumentTextSelection( + token_ids=[ + "3f984bf3-1d61-44f5-b59a-9658a2e3440f", + "3bf00b56-ff12-4e52-8cc1-08dbddb3c3b8", + "6e1c3420-d4b7-4c5a-8fd6-ead43bf73d80", + "87a43d32-af76-4a1d-b262-5c5f4d5ace3a", + "e8606e8a-dfd9-4c49-a635-ad5c879c75d0", + "67c7c19e-4654-425d-bf17-2adb8cf02c30", + "149c5e80-3e07-49a7-ab2d-29ddfe6a38fa", + "b0e94071-2187-461e-8e76-96c58738a52c", + ], + group_id="2f4336f4-a07e-4e0a-a9e1-5629b03b719b", + page=1, + ) + ] + ), + ) + ], + ), + ] + + res = list(NDJsonConverter.serialize(labels)) assert [round_dict(x) for x in res] == [round_dict(x) for x in data] - f.close() def test_pdf_with_name_only(): @@ -65,26 +212,135 @@ def test_pdf_with_name_only(): """ with open("tests/data/assets/ndjson/pdf_import_name_only.json", "r") as f: data = json.load(f) - res = list(NDJsonConverter.deserialize(data)) - res = list(NDJsonConverter.serialize(res)) + + labels = [ + Label( + data=GenericDataRowData( + uid="cl6xnv9h61fv0085yhtoq06ht", + ), + annotations=[ + ObjectAnnotation( + custom_metrics=[ + CustomMetric(name="customMetric1", value=0.5), + CustomMetric(name="customMetric2", value=0.3), + ], + confidence=0.99, + name="boxy", + feature_schema_id=None, + extra={ + "uuid": "5ad9c52f-058d-49c8-a749-3f20b84f8cd4", + }, + value=DocumentRectangle( + start=Point(x=32.45, y=162.73), + end=Point(x=134.11, y=550.9), + page=4, + unit=RectangleUnit.POINTS, + ), + ), + ObjectAnnotation( + name="boxy", + extra={ + "uuid": "20eeef88-0294-49b4-a815-86588476bc6f", + }, + value=DocumentRectangle( + start=Point(x=251.42, y=223.26), + end=Point(x=438.2, y=680.3), + page=7, + unit=RectangleUnit.POINTS, + ), + ), + ObjectAnnotation( + name="boxy", + extra={ + "uuid": "641a8944-3938-409c-b4eb-dea354ed06e5", + }, + value=DocumentRectangle( + start=Point(x=218.17, y=32.52), + end=Point(x=328.73, y=264.25), + page=6, + unit=RectangleUnit.POINTS, + ), + ), + ObjectAnnotation( + custom_metrics=[ + CustomMetric(name="customMetric1", value=0.5), + CustomMetric(name="customMetric2", value=0.3), + ], + confidence=0.74, + name="boxy", + extra={ + "uuid": "ebe4da7d-08b3-480a-8d15-26552b7f011c", + }, + value=DocumentRectangle( + start=Point(x=4.25, y=117.39), + end=Point(x=169.08, y=574.3100000000001), + page=7, + unit=RectangleUnit.POINTS, + ), + ), + ObjectAnnotation( + name="boxy", + extra={ + "uuid": "35c41855-575f-42cc-a2f9-1f06237e9b63", + }, + value=DocumentRectangle( + start=Point(x=217.28, y=82.13), + end=Point(x=299.71000000000004, y=361.89), + page=8, + unit=RectangleUnit.POINTS, + ), + ), + ObjectAnnotation( + name="boxy", + extra={ + "uuid": "1b009654-bc17-42a2-8a71-160e7808c403", + }, + value=DocumentRectangle( + start=Point(x=83.34, y=298.12), + end=Point(x=83.72, y=501.95000000000005), + page=3, + unit=RectangleUnit.POINTS, + ), + ), + ], + ), + Label( + data=GenericDataRowData( + uid="ckrb1sf1i1g7i0ybcdc6oc8ct", + ), + annotations=[ + ObjectAnnotation( + name="named_entity", + extra={ + "uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673", + }, + value=DocumentEntity( + text_selections=[ + DocumentTextSelection( + token_ids=[ + "3f984bf3-1d61-44f5-b59a-9658a2e3440f", + "3bf00b56-ff12-4e52-8cc1-08dbddb3c3b8", + "6e1c3420-d4b7-4c5a-8fd6-ead43bf73d80", + "87a43d32-af76-4a1d-b262-5c5f4d5ace3a", + "e8606e8a-dfd9-4c49-a635-ad5c879c75d0", + "67c7c19e-4654-425d-bf17-2adb8cf02c30", + "149c5e80-3e07-49a7-ab2d-29ddfe6a38fa", + "b0e94071-2187-461e-8e76-96c58738a52c", + ], + group_id="2f4336f4-a07e-4e0a-a9e1-5629b03b719b", + page=1, + ) + ] + ), + ) + ], + ), + ] + res = list(NDJsonConverter.serialize(labels)) assert [round_dict(x) for x in res] == [round_dict(x) for x in data] - f.close() def test_pdf_bbox_serialize(): serialized = list(NDJsonConverter.serialize(bbox_labels)) serialized[0].pop("uuid") assert serialized == bbox_ndjson - - -def test_pdf_bbox_deserialize(): - deserialized = list(NDJsonConverter.deserialize(bbox_ndjson)) - deserialized[0].annotations[0].extra = {} - assert ( - deserialized[0].annotations[0].value - == bbox_labels[0].annotations[0].value - ) - assert ( - deserialized[0].annotations[0].name - == bbox_labels[0].annotations[0].name - ) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_export_video_objects.py b/libs/labelbox/tests/data/serialization/ndjson/test_export_video_objects.py index 4adcd9935..a0cd13e81 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_export_video_objects.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_export_video_objects.py @@ -1,16 +1,14 @@ from labelbox.data.annotation_types import Label, VideoObjectAnnotation from labelbox.data.serialization.ndjson.converter import NDJsonConverter from labelbox.data.annotation_types.geometry import Rectangle, Point -from labelbox.data.annotation_types import VideoData +from labelbox.data.annotation_types.data import GenericDataRowData def video_bbox_label(): return Label( uid="cl1z52xwh00050fhcmfgczqvn", - data=VideoData( + data=GenericDataRowData( uid="cklr9mr4m5iao0rb6cvxu4qbn", - file_path=None, - frames=None, url="https://storage.labelbox.com/ckcz6bubudyfi0855o1dt1g9s%2F26403a22-604a-a38c-eeff-c2ed481fb40a-cat.mp4?Expires=1651677421050&KeyName=labelbox-assets-key-3&Signature=vF7gMyfHzgZdfbB8BHgd88Ws-Ms", ), annotations=[ @@ -22,6 +20,7 @@ def video_bbox_label(): "instanceURI": None, "color": "#1CE6FF", "feature_id": "cl1z52xw700000fhcayaqy0ev", + "uuid": "b24e672b-8f79-4d96-bf5e-b552ca0820d5", }, value=Rectangle( extra={}, @@ -588,31 +587,4 @@ def test_serialize_video_objects(): serialized_labels = NDJsonConverter.serialize([label]) label = next(serialized_labels) - manual_label = video_serialized_bbox_label() - - for key in label.keys(): - # ignore uuid because we randomize if there was none - if key != "uuid": - assert label[key] == manual_label[key] - - assert len(label["segments"]) == 2 - assert len(label["segments"][0]["keyframes"]) == 2 - assert len(label["segments"][1]["keyframes"]) == 4 - - # #converts back only the keyframes. should be the sum of all prev segments - deserialized_labels = NDJsonConverter.deserialize([label]) - label = next(deserialized_labels) - assert len(label.annotations) == 6 - - -def test_confidence_is_ignored(): - label = video_bbox_label() - serialized_labels = NDJsonConverter.serialize([label]) - label = next(serialized_labels) - label["confidence"] = 0.453 - label["segments"][0]["confidence"] = 0.453 - - deserialized_labels = NDJsonConverter.deserialize([label]) - label = next(deserialized_labels) - for annotation in label.annotations: - assert annotation.confidence is None + assert label == video_serialized_bbox_label() diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_free_text.py b/libs/labelbox/tests/data/serialization/ndjson/test_free_text.py index 84c017497..7b03a8447 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_free_text.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_free_text.py @@ -5,7 +5,7 @@ Radio, Text, ) -from labelbox.data.annotation_types.data.text import TextData +from labelbox.data.annotation_types.data import GenericDataRowData from labelbox.data.annotation_types.label import Label from labelbox.data.serialization.ndjson.converter import NDJsonConverter @@ -14,7 +14,7 @@ def test_serialization(): label = Label( uid="ckj7z2q0b0000jx6x0q2q7q0d", - data=TextData( + data=GenericDataRowData( uid="bkj7z2q0b0000jx6x0q2q7q0d", text="This is a test", ), @@ -34,21 +34,11 @@ def test_serialization(): assert res["answer"] == "text_answer" assert res["dataRow"]["id"] == "bkj7z2q0b0000jx6x0q2q7q0d" - deserialized = NDJsonConverter.deserialize([res]) - res = next(deserialized) - - annotation = res.annotations[0] - - annotation_value = annotation.value - assert type(annotation_value) is Text - assert annotation_value.answer == "text_answer" - assert annotation_value.confidence == 0.5 - def test_nested_serialization(): label = Label( uid="ckj7z2q0b0000jx6x0q2q7q0d", - data=TextData( + data=GenericDataRowData( uid="bkj7z2q0b0000jx6x0q2q7q0d", text="This is a test", ), @@ -102,19 +92,3 @@ def test_nested_serialization(): assert sub_classification["name"] == "nested answer" assert sub_classification["answer"] == "nested answer" assert sub_classification["confidence"] == 0.7 - - deserialized = NDJsonConverter.deserialize([res]) - res = next(deserialized) - annotation = res.annotations[0] - answer = annotation.value.answer[0] - assert answer.confidence == 0.9 - assert answer.name == "first_answer" - - classification_answer = answer.classifications[0].value.answer - assert classification_answer.confidence == 0.8 - assert classification_answer.name == "first_sub_radio_answer" - - sub_classification_answer = classification_answer.classifications[0].value - assert type(sub_classification_answer) is Text - assert sub_classification_answer.answer == "nested answer" - assert sub_classification_answer.confidence == 0.7 diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_generic_data_row_data.py b/libs/labelbox/tests/data/serialization/ndjson/test_generic_data_row_data.py new file mode 100644 index 000000000..0dc4c21c0 --- /dev/null +++ b/libs/labelbox/tests/data/serialization/ndjson/test_generic_data_row_data.py @@ -0,0 +1,79 @@ +from labelbox.data.annotation_types.data.generic_data_row_data import ( + GenericDataRowData, +) +from labelbox.data.serialization.ndjson.converter import NDJsonConverter +from labelbox.types import Label, ClassificationAnnotation, Text + + +def test_generic_data_row_global_key(): + label_1 = Label( + data=GenericDataRowData(global_key="test"), + annotations=[ + ClassificationAnnotation( + name="free_text", + value=Text(answer="sample text"), + extra={"uuid": "141c3592-e5f0-4866-9943-d4a21fd47eb0"}, + ) + ], + ) + label_2 = Label( + data={"global_key": "test"}, + annotations=[ + ClassificationAnnotation( + name="free_text", + value=Text(answer="sample text"), + extra={"uuid": "141c3592-e5f0-4866-9943-d4a21fd47eb0"}, + ) + ], + ) + + expected_result = [ + { + "answer": "sample text", + "dataRow": {"globalKey": "test"}, + "name": "free_text", + "uuid": "141c3592-e5f0-4866-9943-d4a21fd47eb0", + } + ] + assert ( + list(NDJsonConverter.serialize([label_1])) + == list(NDJsonConverter.serialize([label_2])) + == expected_result + ) + + +def test_generic_data_row_id(): + label_1 = Label( + data=GenericDataRowData(uid="test"), + annotations=[ + ClassificationAnnotation( + name="free_text", + value=Text(answer="sample text"), + extra={"uuid": "141c3592-e5f0-4866-9943-d4a21fd47eb0"}, + ) + ], + ) + label_2 = Label( + data={"uid": "test"}, + annotations=[ + ClassificationAnnotation( + name="free_text", + value=Text(answer="sample text"), + extra={"uuid": "141c3592-e5f0-4866-9943-d4a21fd47eb0"}, + ) + ], + ) + + expected_result = [ + { + "answer": "sample text", + "dataRow": {"id": "test"}, + "name": "free_text", + "uuid": "141c3592-e5f0-4866-9943-d4a21fd47eb0", + } + ] + assert ( + list(NDJsonConverter.serialize([label_1])) + == list(NDJsonConverter.serialize([label_2])) + == expected_result + ) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_global_key.py b/libs/labelbox/tests/data/serialization/ndjson/test_global_key.py index 2b3fa7f8c..d104a691e 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_global_key.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_global_key.py @@ -1,73 +1,74 @@ -import json -import pytest - -from labelbox.data.serialization.ndjson.classification import NDRadio - +from labelbox.data.annotation_types.data.generic_data_row_data import ( + GenericDataRowData, +) from labelbox.data.serialization.ndjson.converter import NDJsonConverter -from labelbox.data.serialization.ndjson.objects import NDLine - - -def round_dict(data): - if isinstance(data, dict): - for key in data: - if isinstance(data[key], float): - data[key] = int(data[key]) - elif isinstance(data[key], dict): - data[key] = round_dict(data[key]) - elif isinstance(data[key], (list, tuple)): - data[key] = [round_dict(r) for r in data[key]] +from labelbox.types import ( + Label, + ClassificationAnnotation, + Radio, + ClassificationAnswer, +) - return data +def test_generic_data_row_global_key_included(): + expected = [ + { + "answer": {"schemaId": "ckrb1sfl8099g0y91cxbd5ftb"}, + "dataRow": {"globalKey": "ckrb1sf1i1g7i0ybcdc6oc8ct"}, + "schemaId": "ckrb1sfjx099a0y914hl319ie", + "uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673", + } + ] -@pytest.mark.parametrize( - "filename", - [ - "tests/data/assets/ndjson/classification_import_global_key.json", - "tests/data/assets/ndjson/metric_import_global_key.json", - "tests/data/assets/ndjson/polyline_import_global_key.json", - "tests/data/assets/ndjson/text_entity_import_global_key.json", - "tests/data/assets/ndjson/conversation_entity_import_global_key.json", - ], -) -def test_many_types(filename: str): - with open(filename, "r") as f: - data = json.load(f) - res = list(NDJsonConverter.deserialize(data)) - res = list(NDJsonConverter.serialize(res)) - assert res == data - f.close() + label = Label( + data=GenericDataRowData( + global_key="ckrb1sf1i1g7i0ybcdc6oc8ct", + ), + annotations=[ + ClassificationAnnotation( + feature_schema_id="ckrb1sfjx099a0y914hl319ie", + extra={"uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673"}, + value=Radio( + answer=ClassificationAnswer( + feature_schema_id="ckrb1sfl8099g0y91cxbd5ftb", + ), + ), + ) + ], + ) + res = list(NDJsonConverter.serialize([label])) -def test_image(): - with open( - "tests/data/assets/ndjson/image_import_global_key.json", "r" - ) as f: - data = json.load(f) - res = list(NDJsonConverter.deserialize(data)) - res = list(NDJsonConverter.serialize(res)) - for r in res: - r.pop("classifications", None) - assert [round_dict(x) for x in res] == [round_dict(x) for x in data] - f.close() + assert res == expected -def test_pdf(): - with open("tests/data/assets/ndjson/pdf_import_global_key.json", "r") as f: - data = json.load(f) - res = list(NDJsonConverter.deserialize(data)) - res = list(NDJsonConverter.serialize(res)) - assert [round_dict(x) for x in res] == [round_dict(x) for x in data] - f.close() +def test_dict_data_row_global_key_included(): + expected = [ + { + "answer": {"schemaId": "ckrb1sfl8099g0y91cxbd5ftb"}, + "dataRow": {"globalKey": "ckrb1sf1i1g7i0ybcdc6oc8ct"}, + "schemaId": "ckrb1sfjx099a0y914hl319ie", + "uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673", + } + ] + label = Label( + data={ + "global_key": "ckrb1sf1i1g7i0ybcdc6oc8ct", + }, + annotations=[ + ClassificationAnnotation( + feature_schema_id="ckrb1sfjx099a0y914hl319ie", + extra={"uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673"}, + value=Radio( + answer=ClassificationAnswer( + feature_schema_id="ckrb1sfl8099g0y91cxbd5ftb", + ), + ), + ) + ], + ) -def test_video(): - with open( - "tests/data/assets/ndjson/video_import_global_key.json", "r" - ) as f: - data = json.load(f) + res = list(NDJsonConverter.serialize([label])) - res = list(NDJsonConverter.deserialize(data)) - res = list(NDJsonConverter.serialize(res)) - assert res == [data[2], data[0], data[1], data[3], data[4], data[5]] - f.close() + assert res == expected diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_image.py b/libs/labelbox/tests/data/serialization/ndjson/test_image.py index 1729e1f46..4d615658c 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_image.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_image.py @@ -1,4 +1,8 @@ import json +from labelbox.data.annotation_types.data.generic_data_row_data import ( + GenericDataRowData, +) +from labelbox.data.mixins import CustomMetric import numpy as np import cv2 @@ -7,9 +11,9 @@ Mask, Label, ObjectAnnotation, - ImageData, MaskData, ) +from labelbox.types import Rectangle, Polygon, Point def round_dict(data): @@ -29,12 +33,74 @@ def test_image(): with open("tests/data/assets/ndjson/image_import.json", "r") as file: data = json.load(file) - res = list(NDJsonConverter.deserialize(data)) - res = list(NDJsonConverter.serialize(res)) + labels = [ + Label( + data=GenericDataRowData( + uid="ckrazctum0z8a0ybc0b0o0g0v", + ), + annotations=[ + ObjectAnnotation( + custom_metrics=[ + CustomMetric(name="customMetric1", value=0.4) + ], + confidence=0.851, + feature_schema_id="ckrazcueb16og0z6609jj7y3y", + extra={ + "uuid": "b862c586-8614-483c-b5e6-82810f70cac0", + }, + value=Rectangle( + start=Point(extra={}, x=2275.0, y=1352.0), + end=Point(extra={}, x=2414.0, y=1702.0), + ), + ), + ObjectAnnotation( + custom_metrics=[ + CustomMetric(name="customMetric1", value=0.3) + ], + confidence=0.834, + feature_schema_id="ckrazcuec16ok0z66f956apb7", + extra={ + "uuid": "751fc725-f7b6-48ed-89b0-dd7d94d08af6", + }, + value=Mask( + mask=MaskData( + url="https://storage.labelbox.com/ckqcx1czn06830y61gh9v02cs%2F3e729327-f038-f66c-186e-45e921ef9717-1?Expires=1626806874672&KeyName=labelbox-assets-key-3&Signature=YsUOGKrsqmAZ68vT9BlPJOaRyLY", + ), + color=[255, 0, 0], + ), + ), + ObjectAnnotation( + custom_metrics=[ + CustomMetric(name="customMetric1", value=0.9) + ], + confidence=0.986, + feature_schema_id="ckrazcuec16oi0z66dzrd8pfl", + extra={ + "uuid": "43d719ac-5d7f-4aea-be00-2ebfca0900fd", + }, + value=Polygon( + points=[ + Point(x=10.0, y=20.0), + Point(x=15.0, y=20.0), + Point(x=20.0, y=25.0), + Point(x=10.0, y=20.0), + ], + ), + ), + ObjectAnnotation( + feature_schema_id="ckrazcuec16om0z66bhhh4tp7", + extra={ + "uuid": "b98f3a45-3328-41a0-9077-373a8177ebf2", + }, + value=Point(x=2122.0, y=1457.0), + ), + ], + ) + ] - for r in res: - r.pop("classifications", None) - assert [round_dict(x) for x in res] == [round_dict(x) for x in data] + res = list(NDJsonConverter.serialize(labels)) + del res[1]["mask"]["colorRGB"] # JSON does not support tuples + assert res == data def test_image_with_name_only(): @@ -43,11 +109,74 @@ def test_image_with_name_only(): ) as file: data = json.load(file) - res = list(NDJsonConverter.deserialize(data)) - res = list(NDJsonConverter.serialize(res)) - for r in res: - r.pop("classifications", None) - assert [round_dict(x) for x in res] == [round_dict(x) for x in data] + labels = [ + Label( + data=GenericDataRowData( + uid="ckrazctum0z8a0ybc0b0o0g0v", + ), + annotations=[ + ObjectAnnotation( + custom_metrics=[ + CustomMetric(name="customMetric1", value=0.4) + ], + confidence=0.851, + name="ckrazcueb16og0z6609jj7y3y", + extra={ + "uuid": "b862c586-8614-483c-b5e6-82810f70cac0", + }, + value=Rectangle( + start=Point(extra={}, x=2275.0, y=1352.0), + end=Point(extra={}, x=2414.0, y=1702.0), + ), + ), + ObjectAnnotation( + custom_metrics=[ + CustomMetric(name="customMetric1", value=0.3) + ], + confidence=0.834, + name="ckrazcuec16ok0z66f956apb7", + extra={ + "uuid": "751fc725-f7b6-48ed-89b0-dd7d94d08af6", + }, + value=Mask( + mask=MaskData( + url="https://storage.labelbox.com/ckqcx1czn06830y61gh9v02cs%2F3e729327-f038-f66c-186e-45e921ef9717-1?Expires=1626806874672&KeyName=labelbox-assets-key-3&Signature=YsUOGKrsqmAZ68vT9BlPJOaRyLY", + ), + color=[255, 0, 0], + ), + ), + ObjectAnnotation( + custom_metrics=[ + CustomMetric(name="customMetric1", value=0.9) + ], + confidence=0.986, + name="ckrazcuec16oi0z66dzrd8pfl", + extra={ + "uuid": "43d719ac-5d7f-4aea-be00-2ebfca0900fd", + }, + value=Polygon( + points=[ + Point(x=10.0, y=20.0), + Point(x=15.0, y=20.0), + Point(x=20.0, y=25.0), + Point(x=10.0, y=20.0), + ], + ), + ), + ObjectAnnotation( + name="ckrazcuec16om0z66bhhh4tp7", + extra={ + "uuid": "b98f3a45-3328-41a0-9077-373a8177ebf2", + }, + value=Point(x=2122.0, y=1457.0), + ), + ], + ) + ] + + res = list(NDJsonConverter.serialize(labels)) + del res[1]["mask"]["colorRGB"] # JSON does not support tuples + assert res == data def test_mask(): @@ -57,10 +186,11 @@ def test_mask(): "schemaId": "ckrazcueb16og0z6609jj7y3y", "dataRow": {"id": "ckrazctum0z8a0ybc0b0o0g0v"}, "mask": { - "png": "iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAAAAACoWZBhAAAAMklEQVR4nD3MuQ3AQADDMOqQ/Vd2ijytaSiZLAcYuyLEYYYl9cvrlGftTHvsYl+u/3EDv0QLI8Z7FlwAAAAASUVORK5CYII=" + "png": "iVBORw0KGgoAAAANSUhEUgAAAAIAAAACCAAAAABX3VL4AAAADklEQVR4nGNgYGBkZAAAAAsAA+RRQXwAAAAASUVORK5CYII=" }, "confidence": 0.8, "customMetrics": [{"name": "customMetric1", "value": 0.4}], + "classifications": [], }, { "uuid": "751fc725-f7b6-48ed-89b0-dd7d94d08af6", @@ -68,16 +198,54 @@ def test_mask(): "dataRow": {"id": "ckrazctum0z8a0ybc0b0o0g0v"}, "mask": { "instanceURI": "https://storage.labelbox.com/ckqcx1czn06830y61gh9v02cs%2F3e729327-f038-f66c-186e-45e921ef9717-1?Expires=1626806874672&KeyName=labelbox-assets-key-3&Signature=YsUOGKrsqmAZ68vT9BlPJOaRyLY", - "colorRGB": [255, 0, 0], + "colorRGB": (255, 0, 0), }, + "classifications": [], }, ] - res = list(NDJsonConverter.deserialize(data)) - res = list(NDJsonConverter.serialize(res)) - for r in res: - r.pop("classifications", None) - assert [round_dict(x) for x in res] == [round_dict(x) for x in data] + mask_numpy = np.array([[[1, 1, 0], [1, 0, 1]], [[1, 1, 1], [1, 1, 1]]]) + mask_numpy = mask_numpy.astype(np.uint8) + + labels = [ + Label( + data=GenericDataRowData( + uid="ckrazctum0z8a0ybc0b0o0g0v", + ), + annotations=[ + ObjectAnnotation( + custom_metrics=[ + CustomMetric(name="customMetric1", value=0.4) + ], + confidence=0.8, + feature_schema_id="ckrazcueb16og0z6609jj7y3y", + extra={ + "uuid": "b862c586-8614-483c-b5e6-82810f70cac0", + }, + value=Mask( + mask=MaskData(arr=mask_numpy), + color=(1, 1, 1), + ), + ), + ObjectAnnotation( + feature_schema_id="ckrazcuec16ok0z66f956apb7", + extra={ + "uuid": "751fc725-f7b6-48ed-89b0-dd7d94d08af6", + }, + value=Mask( + extra={}, + mask=MaskData( + url="https://storage.labelbox.com/ckqcx1czn06830y61gh9v02cs%2F3e729327-f038-f66c-186e-45e921ef9717-1?Expires=1626806874672&KeyName=labelbox-assets-key-3&Signature=YsUOGKrsqmAZ68vT9BlPJOaRyLY", + ), + color=(255, 0, 0), + ), + ), + ], + ) + ] + res = list(NDJsonConverter.serialize(labels)) + + assert res == data def test_mask_from_arr(): @@ -93,7 +261,7 @@ def test_mask_from_arr(): ), ) ], - data=ImageData(uid="0" * 25), + data=GenericDataRowData(uid="0" * 25), ) res = next(NDJsonConverter.serialize([label])) res.pop("uuid") diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_metric.py b/libs/labelbox/tests/data/serialization/ndjson/test_metric.py index 45c5c67bf..40e098405 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_metric.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_metric.py @@ -1,38 +1,166 @@ import json +from labelbox.data.annotation_types.data.generic_data_row_data import ( + GenericDataRowData, +) +from labelbox.data.annotation_types.metrics.confusion_matrix import ( + ConfusionMatrixMetric, +) from labelbox.data.serialization.ndjson.converter import NDJsonConverter +from labelbox.types import ( + Label, + ScalarMetric, + ScalarMetricAggregation, + ConfusionMatrixAggregation, +) def test_metric(): with open("tests/data/assets/ndjson/metric_import.json", "r") as file: data = json.load(file) - label_list = list(NDJsonConverter.deserialize(data)) - reserialized = list(NDJsonConverter.serialize(label_list)) - assert reserialized == data + labels = [ + Label( + data=GenericDataRowData( + uid="ckrmdnqj4000007msh9p2a27r", + ), + annotations=[ + ScalarMetric( + value=0.1, + extra={"uuid": "a22bbf6e-b2da-4abe-9a11-df84759f7672"}, + aggregation=ScalarMetricAggregation.ARITHMETIC_MEAN, + ) + ], + ) + ] + + res = list(NDJsonConverter.serialize(labels)) + assert res == data def test_custom_scalar_metric(): - with open( - "tests/data/assets/ndjson/custom_scalar_import.json", "r" - ) as file: - data = json.load(file) + data = [ + { + "uuid": "a22bbf6e-b2da-4abe-9a11-df84759f7672", + "dataRow": {"id": "ckrmdnqj4000007msh9p2a27r"}, + "metricValue": 0.1, + "metricName": "custom_iou", + "featureName": "sample_class", + "subclassName": "sample_subclass", + "aggregation": "SUM", + }, + { + "uuid": "a22bbf6e-b2da-4abe-9a11-df84759f7673", + "dataRow": {"id": "ckrmdnqj4000007msh9p2a27r"}, + "metricValue": 0.1, + "metricName": "custom_iou", + "featureName": "sample_class", + "aggregation": "SUM", + }, + { + "uuid": "a22bbf6e-b2da-4abe-9a11-df84759f7674", + "dataRow": {"id": "ckrmdnqj4000007msh9p2a27r"}, + "metricValue": {0.1: 0.1, 0.2: 0.5}, + "metricName": "custom_iou", + "aggregation": "SUM", + }, + ] + + labels = [ + Label( + data=GenericDataRowData( + uid="ckrmdnqj4000007msh9p2a27r", + ), + annotations=[ + ScalarMetric( + value=0.1, + feature_name="sample_class", + subclass_name="sample_subclass", + extra={"uuid": "a22bbf6e-b2da-4abe-9a11-df84759f7672"}, + metric_name="custom_iou", + aggregation=ScalarMetricAggregation.SUM, + ), + ScalarMetric( + value=0.1, + feature_name="sample_class", + extra={"uuid": "a22bbf6e-b2da-4abe-9a11-df84759f7673"}, + metric_name="custom_iou", + aggregation=ScalarMetricAggregation.SUM, + ), + ScalarMetric( + value={"0.1": 0.1, "0.2": 0.5}, + extra={"uuid": "a22bbf6e-b2da-4abe-9a11-df84759f7674"}, + metric_name="custom_iou", + aggregation=ScalarMetricAggregation.SUM, + ), + ], + ) + ] + + res = list(NDJsonConverter.serialize(labels)) - label_list = list(NDJsonConverter.deserialize(data)) - reserialized = list(NDJsonConverter.serialize(label_list)) - assert json.dumps(reserialized, sort_keys=True) == json.dumps( - data, sort_keys=True - ) + assert res == data def test_custom_confusion_matrix_metric(): - with open( - "tests/data/assets/ndjson/custom_confusion_matrix_import.json", "r" - ) as file: - data = json.load(file) + data = [ + { + "uuid": "a22bbf6e-b2da-4abe-9a11-df84759f7672", + "dataRow": {"id": "ckrmdnqj4000007msh9p2a27r"}, + "metricValue": (1, 1, 2, 3), + "metricName": "50%_iou", + "featureName": "sample_class", + "subclassName": "sample_subclass", + "aggregation": "CONFUSION_MATRIX", + }, + { + "uuid": "a22bbf6e-b2da-4abe-9a11-df84759f7673", + "dataRow": {"id": "ckrmdnqj4000007msh9p2a27r"}, + "metricValue": (0, 1, 2, 5), + "metricName": "50%_iou", + "featureName": "sample_class", + "aggregation": "CONFUSION_MATRIX", + }, + { + "uuid": "a22bbf6e-b2da-4abe-9a11-df84759f7674", + "dataRow": {"id": "ckrmdnqj4000007msh9p2a27r"}, + "metricValue": {0.1: (0, 1, 2, 3), 0.2: (5, 3, 4, 3)}, + "metricName": "50%_iou", + "aggregation": "CONFUSION_MATRIX", + }, + ] + + labels = [ + Label( + data=GenericDataRowData( + uid="ckrmdnqj4000007msh9p2a27r", + ), + annotations=[ + ConfusionMatrixMetric( + value=(1, 1, 2, 3), + feature_name="sample_class", + subclass_name="sample_subclass", + extra={"uuid": "a22bbf6e-b2da-4abe-9a11-df84759f7672"}, + metric_name="50%_iou", + aggregation=ConfusionMatrixAggregation.CONFUSION_MATRIX, + ), + ConfusionMatrixMetric( + value=(0, 1, 2, 5), + feature_name="sample_class", + extra={"uuid": "a22bbf6e-b2da-4abe-9a11-df84759f7673"}, + metric_name="50%_iou", + aggregation=ConfusionMatrixAggregation.CONFUSION_MATRIX, + ), + ConfusionMatrixMetric( + value={0.1: (0, 1, 2, 3), 0.2: (5, 3, 4, 3)}, + extra={"uuid": "a22bbf6e-b2da-4abe-9a11-df84759f7674"}, + metric_name="50%_iou", + aggregation=ConfusionMatrixAggregation.CONFUSION_MATRIX, + ), + ], + ) + ] + + res = list(NDJsonConverter.serialize(labels)) - label_list = list(NDJsonConverter.deserialize(data)) - reserialized = list(NDJsonConverter.serialize(label_list)) - assert json.dumps(reserialized, sort_keys=True) == json.dumps( - data, sort_keys=True - ) + assert data == res diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_mmc.py b/libs/labelbox/tests/data/serialization/ndjson/test_mmc.py index 69594ff73..202f793fe 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_mmc.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_mmc.py @@ -1,32 +1,125 @@ import json +from labelbox.data.annotation_types.data.generic_data_row_data import ( + GenericDataRowData, +) import pytest from labelbox.data.serialization import NDJsonConverter +from labelbox.types import ( + Label, + MessageEvaluationTaskAnnotation, + MessageSingleSelectionTask, + MessageMultiSelectionTask, + MessageInfo, + OrderedMessageInfo, + MessageRankingTask, +) def test_message_task_annotation_serialization(): with open("tests/data/assets/ndjson/mmc_import.json", "r") as file: data = json.load(file) - deserialized = list(NDJsonConverter.deserialize(data)) - reserialized = list(NDJsonConverter.serialize(deserialized)) + labels = [ + Label( + data=GenericDataRowData( + uid="cnjencjencjfencvj", + ), + annotations=[ + MessageEvaluationTaskAnnotation( + name="single-selection", + extra={"uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72"}, + value=MessageSingleSelectionTask( + message_id="clxfzocbm00083b6v8vczsept", + model_config_name="GPT 5", + parent_message_id="clxfznjb800073b6v43ppx9ca", + ), + ) + ], + ), + Label( + data=GenericDataRowData( + uid="cfcerfvergerfefj", + ), + annotations=[ + MessageEvaluationTaskAnnotation( + name="multi-selection", + extra={"uuid": "gferf3a57-597e-48cb-8d8d-a8526fefe72"}, + value=MessageMultiSelectionTask( + parent_message_id="clxfznjb800073b6v43ppx9ca", + selected_messages=[ + MessageInfo( + message_id="clxfzocbm00083b6v8vczsept", + model_config_name="GPT 5", + ) + ], + ), + ) + ], + ), + Label( + data=GenericDataRowData( + uid="cwefgtrgrthveferfferffr", + ), + annotations=[ + MessageEvaluationTaskAnnotation( + name="ranking", + extra={"uuid": "hybe3a57-5gt7e-48tgrb-8d8d-a852dswqde72"}, + value=MessageRankingTask( + parent_message_id="clxfznjb800073b6v43ppx9ca", + ranked_messages=[ + OrderedMessageInfo( + message_id="clxfzocbm00083b6v8vczsept", + model_config_name="GPT 4 with temperature 0.7", + order=1, + ), + OrderedMessageInfo( + message_id="clxfzocbm00093b6vx4ndisub", + model_config_name="GPT 5", + order=2, + ), + ], + ), + ) + ], + ), + ] - assert data == reserialized + res = list(NDJsonConverter.serialize(labels)) + assert res == data -def test_mesage_ranking_task_wrong_order_serialization(): - with open("tests/data/assets/ndjson/mmc_import.json", "r") as file: - data = json.load(file) - - some_ranking_task = next( - task - for task in data - if task["messageEvaluationTask"]["format"] == "message-ranking" - ) - some_ranking_task["messageEvaluationTask"]["data"]["rankedMessages"][0][ - "order" - ] = 3 +def test_mesage_ranking_task_wrong_order_serialization(): with pytest.raises(ValueError): - list(NDJsonConverter.deserialize([some_ranking_task])) + ( + Label( + data=GenericDataRowData( + uid="cwefgtrgrthveferfferffr", + ), + annotations=[ + MessageEvaluationTaskAnnotation( + name="ranking", + extra={ + "uuid": "hybe3a57-5gt7e-48tgrb-8d8d-a852dswqde72" + }, + value=MessageRankingTask( + parent_message_id="clxfznjb800073b6v43ppx9ca", + ranked_messages=[ + OrderedMessageInfo( + message_id="clxfzocbm00093b6vx4ndisub", + model_config_name="GPT 5", + order=1, + ), + OrderedMessageInfo( + message_id="clxfzocbm00083b6v8vczsept", + model_config_name="GPT 4 with temperature 0.7", + order=1, + ), + ], + ), + ) + ], + ), + ) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_ndlabel_subclass_matching.py b/libs/labelbox/tests/data/serialization/ndjson/test_ndlabel_subclass_matching.py deleted file mode 100644 index 790bd87b3..000000000 --- a/libs/labelbox/tests/data/serialization/ndjson/test_ndlabel_subclass_matching.py +++ /dev/null @@ -1,19 +0,0 @@ -import json -from labelbox.data.serialization.ndjson.label import NDLabel -from labelbox.data.serialization.ndjson.objects import NDDocumentRectangle -import pytest - - -def test_bad_annotation_input(): - data = [{"test": 3}] - with pytest.raises(ValueError): - NDLabel(**{"annotations": data}) - - -def test_correct_annotation_input(): - with open("tests/data/assets/ndjson/pdf_import_name_only.json", "r") as f: - data = json.load(f) - assert isinstance( - NDLabel(**{"annotations": [data[0]]}).annotations[0], - NDDocumentRectangle, - ) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_nested.py b/libs/labelbox/tests/data/serialization/ndjson/test_nested.py index e0f0df0e6..3633c9cbe 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_nested.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_nested.py @@ -1,13 +1,135 @@ import json +from labelbox.data.annotation_types.data.generic_data_row_data import ( + GenericDataRowData, +) +from labelbox.data.mixins import CustomMetric from labelbox.data.serialization.ndjson.converter import NDJsonConverter +from labelbox.types import ( + Label, + ObjectAnnotation, + Rectangle, + Point, + ClassificationAnnotation, + Radio, + ClassificationAnswer, + Text, + Checklist, +) def test_nested(): with open("tests/data/assets/ndjson/nested_import.json", "r") as file: data = json.load(file) - res = list(NDJsonConverter.deserialize(data)) - res = list(NDJsonConverter.serialize(res)) + labels = [ + Label( + data=GenericDataRowData( + uid="ckrb1sf1i1g7i0ybcdc6oc8ct", + ), + annotations=[ + ObjectAnnotation( + feature_schema_id="ckrb1sfjx099a0y914hl319ie", + extra={ + "uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673", + }, + value=Rectangle( + start=Point(x=2275.0, y=1352.0), + end=Point(x=2414.0, y=1702.0), + ), + classifications=[ + ClassificationAnnotation( + feature_schema_id="ckrb1sfkn099c0y910wbo0p1a", + value=Radio( + answer=ClassificationAnswer( + custom_metrics=[ + CustomMetric( + name="customMetric1", value=0.5 + ), + CustomMetric( + name="customMetric2", value=0.3 + ), + ], + confidence=0.34, + feature_schema_id="ckrb1sfl8099g0y91cxbd5ftb", + ), + ), + ) + ], + ), + ObjectAnnotation( + feature_schema_id="ckrb1sfjx099a0y914hl319ie", + extra={ + "uuid": "d009925d-91a3-4f67-abd9-753453f5a584", + }, + value=Rectangle( + start=Point(x=2089.0, y=1251.0), + end=Point(x=2247.0, y=1679.0), + ), + classifications=[ + ClassificationAnnotation( + feature_schema_id="ckrb1sfkn099c0y910wbo0p1a", + value=Radio( + answer=ClassificationAnswer( + feature_schema_id="ckrb1sfl8099e0y919v260awv", + ), + ), + ) + ], + ), + ObjectAnnotation( + feature_schema_id="ckrb1sfjx099a0y914hl319ie", + extra={ + "uuid": "5d03213e-4408-456c-9eca-cf0723202961", + }, + value=Rectangle( + start=Point(x=2089.0, y=1251.0), + end=Point(x=2247.0, y=1679.0), + ), + classifications=[ + ClassificationAnnotation( + feature_schema_id="ckrb1sfkn099c0y910wbo0p1a", + value=Checklist( + answer=[ + ClassificationAnswer( + custom_metrics=[ + CustomMetric( + name="customMetric1", value=0.5 + ), + CustomMetric( + name="customMetric2", value=0.3 + ), + ], + confidence=0.894, + feature_schema_id="ckrb1sfl8099e0y919v260awv", + ) + ], + ), + ) + ], + ), + ObjectAnnotation( + feature_schema_id="ckrb1sfjx099a0y914hl319ie", + extra={ + "uuid": "d50812f6-34eb-4f12-b3cb-bbde51a31d83", + }, + value=Rectangle( + start=Point(x=2089.0, y=1251.0), + end=Point(x=2247.0, y=1679.0), + ), + classifications=[ + ClassificationAnnotation( + feature_schema_id="ckrb1sfkn099c0y910wbo0p1a", + extra={}, + value=Text( + answer="a string", + ), + ) + ], + ), + ], + ) + ] + res = list(NDJsonConverter.serialize(labels)) assert res == data @@ -16,6 +138,112 @@ def test_nested_name_only(): "tests/data/assets/ndjson/nested_import_name_only.json", "r" ) as file: data = json.load(file) - res = list(NDJsonConverter.deserialize(data)) - res = list(NDJsonConverter.serialize(res)) + labels = [ + Label( + data=GenericDataRowData( + uid="ckrb1sf1i1g7i0ybcdc6oc8ct", + ), + annotations=[ + ObjectAnnotation( + name="box a", + extra={ + "uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673", + }, + value=Rectangle( + start=Point(x=2275.0, y=1352.0), + end=Point(x=2414.0, y=1702.0), + ), + classifications=[ + ClassificationAnnotation( + name="classification a", + value=Radio( + answer=ClassificationAnswer( + custom_metrics=[ + CustomMetric( + name="customMetric1", value=0.5 + ), + CustomMetric( + name="customMetric2", value=0.3 + ), + ], + confidence=0.811, + name="first answer", + ), + ), + ) + ], + ), + ObjectAnnotation( + name="box b", + extra={ + "uuid": "d009925d-91a3-4f67-abd9-753453f5a584", + }, + value=Rectangle( + start=Point(x=2089.0, y=1251.0), + end=Point(x=2247.0, y=1679.0), + ), + classifications=[ + ClassificationAnnotation( + name="classification b", + value=Radio( + answer=ClassificationAnswer( + custom_metrics=[ + CustomMetric( + name="customMetric1", value=0.5 + ), + CustomMetric( + name="customMetric2", value=0.3 + ), + ], + confidence=0.815, + name="second answer", + ), + ), + ) + ], + ), + ObjectAnnotation( + name="box c", + extra={ + "uuid": "8a2b2c43-f0a1-4763-ba96-e322d986ced6", + }, + value=Rectangle( + start=Point(x=2089.0, y=1251.0), + end=Point(x=2247.0, y=1679.0), + ), + classifications=[ + ClassificationAnnotation( + name="classification c", + value=Checklist( + answer=[ + ClassificationAnswer( + name="third answer", + ) + ], + ), + ) + ], + ), + ObjectAnnotation( + name="box c", + extra={ + "uuid": "456dd2c6-9fa0-42f9-9809-acc27b9886a7", + }, + value=Rectangle( + start=Point(x=2089.0, y=1251.0), + end=Point(x=2247.0, y=1679.0), + ), + classifications=[ + ClassificationAnnotation( + name="a string", + value=Text( + answer="a string", + ), + ) + ], + ), + ], + ) + ] + res = list(NDJsonConverter.serialize(labels)) assert res == data diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_polyline.py b/libs/labelbox/tests/data/serialization/ndjson/test_polyline.py index 97d48a14e..cd11d97fe 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_polyline.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_polyline.py @@ -1,18 +1,76 @@ import json -import pytest +from labelbox.data.annotation_types.data.generic_data_row_data import ( + GenericDataRowData, +) +from labelbox.data.mixins import CustomMetric from labelbox.data.serialization.ndjson.converter import NDJsonConverter +from labelbox.types import ObjectAnnotation, Point, Line, Label -@pytest.mark.parametrize( - "filename", - [ - "tests/data/assets/ndjson/polyline_without_confidence_import.json", - "tests/data/assets/ndjson/polyline_import.json", - ], -) -def test_polyline_import(filename: str): - with open(filename, "r") as file: +def test_polyline_import_with_confidence(): + with open( + "tests/data/assets/ndjson/polyline_without_confidence_import.json", "r" + ) as file: + data = json.load(file) + labels = [ + Label( + data=GenericDataRowData( + uid="cl6xnv9h61fv0085yhtoq06ht", + ), + annotations=[ + ObjectAnnotation( + name="some-line", + feature_schema_id="cl6xnuwt95lqq07330tbb3mfd", + extra={ + "uuid": "5ad9c52f-058d-49c8-a749-3f20b84f8cd4", + }, + value=Line( + points=[ + Point(x=2534.353, y=249.471), + Point(x=2429.492, y=182.092), + Point(x=2294.322, y=221.962), + ], + ), + ) + ], + ) + ] + res = list(NDJsonConverter.serialize(labels)) + assert res == data + + +def test_polyline_import_without_confidence(): + with open("tests/data/assets/ndjson/polyline_import.json", "r") as file: data = json.load(file) - res = list(NDJsonConverter.deserialize(data)) - res = list(NDJsonConverter.serialize(res)) + + labels = [ + Label( + data=GenericDataRowData( + uid="cl6xnv9h61fv0085yhtoq06ht", + ), + annotations=[ + ObjectAnnotation( + custom_metrics=[ + CustomMetric(name="customMetric1", value=0.5), + CustomMetric(name="customMetric2", value=0.3), + ], + confidence=0.58, + name="some-line", + feature_schema_id="cl6xnuwt95lqq07330tbb3mfd", + extra={ + "uuid": "5ad9c52f-058d-49c8-a749-3f20b84f8cd4", + }, + value=Line( + points=[ + Point(x=2534.353, y=249.471), + Point(x=2429.492, y=182.092), + Point(x=2294.322, y=221.962), + ], + ), + ) + ], + ) + ] + + res = list(NDJsonConverter.serialize(labels)) assert res == data diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_radio.py b/libs/labelbox/tests/data/serialization/ndjson/test_radio.py index bd80f9267..ec57f0528 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_radio.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_radio.py @@ -1,10 +1,9 @@ -import json from labelbox.data.annotation_types.annotation import ClassificationAnnotation from labelbox.data.annotation_types.classification.classification import ( ClassificationAnswer, ) from labelbox.data.annotation_types.classification.classification import Radio -from labelbox.data.annotation_types.data.text import TextData +from labelbox.data.annotation_types.data import GenericDataRowData from labelbox.data.annotation_types.label import Label from labelbox.data.serialization.ndjson.converter import NDJsonConverter @@ -13,9 +12,8 @@ def test_serialization_with_radio_min(): label = Label( uid="ckj7z2q0b0000jx6x0q2q7q0d", - data=TextData( + data=GenericDataRowData( uid="bkj7z2q0b0000jx6x0q2q7q0d", - text="This is a test", ), annotations=[ ClassificationAnnotation( @@ -40,21 +38,12 @@ def test_serialization_with_radio_min(): res.pop("uuid") assert res == expected - deserialized = NDJsonConverter.deserialize([res]) - res = next(deserialized) - - for i, annotation in enumerate(res.annotations): - annotation.extra.pop("uuid") - assert annotation.value == label.annotations[i].value - assert annotation.name == label.annotations[i].name - def test_serialization_with_radio_classification(): label = Label( uid="ckj7z2q0b0000jx6x0q2q7q0d", - data=TextData( + data=GenericDataRowData( uid="bkj7z2q0b0000jx6x0q2q7q0d", - text="This is a test", ), annotations=[ ClassificationAnnotation( @@ -101,10 +90,3 @@ def test_serialization_with_radio_classification(): res = next(serialized) res.pop("uuid") assert res == expected - - deserialized = NDJsonConverter.deserialize([res]) - res = next(deserialized) - res.annotations[0].extra.pop("uuid") - assert res.annotations[0].model_dump( - exclude_none=True - ) == label.annotations[0].model_dump(exclude_none=True) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_rectangle.py b/libs/labelbox/tests/data/serialization/ndjson/test_rectangle.py index 66630dbb5..0e42ab152 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_rectangle.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_rectangle.py @@ -1,6 +1,10 @@ import json +from labelbox.data.annotation_types.data.generic_data_row_data import ( + GenericDataRowData, +) import labelbox.types as lb_types from labelbox.data.serialization.ndjson.converter import NDJsonConverter +from labelbox.types import Label, ObjectAnnotation, Rectangle, Point DATAROW_ID = "ckrb1sf1i1g7i0ybcdc6oc8ct" @@ -8,8 +12,26 @@ def test_rectangle(): with open("tests/data/assets/ndjson/rectangle_import.json", "r") as file: data = json.load(file) - res = list(NDJsonConverter.deserialize(data)) - res = list(NDJsonConverter.serialize(res)) + labels = [ + Label( + data=GenericDataRowData( + uid="ckrb1sf1i1g7i0ybcdc6oc8ct", + ), + annotations=[ + ObjectAnnotation( + name="bbox", + extra={ + "uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72", + }, + value=Rectangle( + start=Point(x=38.0, y=28.0), + end=Point(x=81.0, y=69.0), + ), + ) + ], + ) + ] + res = list(NDJsonConverter.serialize(labels)) assert res == data @@ -39,8 +61,6 @@ def test_rectangle_inverted_start_end_points(): ), extra={ "uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72", - "page": None, - "unit": None, }, ) @@ -48,8 +68,9 @@ def test_rectangle_inverted_start_end_points(): data={"uid": DATAROW_ID}, annotations=[expected_bbox] ) - res = list(NDJsonConverter.deserialize(res)) - assert res == [label] + data = list(NDJsonConverter.serialize([label])) + + assert res == data def test_rectangle_mixed_start_end_points(): @@ -76,17 +97,13 @@ def test_rectangle_mixed_start_end_points(): start=lb_types.Point(x=38, y=28), end=lb_types.Point(x=81, y=69), ), - extra={ - "uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72", - "page": None, - "unit": None, - }, + extra={"uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72"}, ) label = lb_types.Label(data={"uid": DATAROW_ID}, annotations=[bbox]) - res = list(NDJsonConverter.deserialize(res)) - assert res == [label] + data = list(NDJsonConverter.serialize([label])) + assert res == data def test_benchmark_reference_label_flag_enabled(): diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_relationship.py b/libs/labelbox/tests/data/serialization/ndjson/test_relationship.py index f33719035..235b66957 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_relationship.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_relationship.py @@ -1,16 +1,135 @@ import json -from uuid import uuid4 -import pytest +from labelbox.data.annotation_types.data.generic_data_row_data import ( + GenericDataRowData, +) from labelbox.data.serialization.ndjson.converter import NDJsonConverter +from labelbox.types import ( + Label, + ObjectAnnotation, + Point, + Rectangle, + RelationshipAnnotation, + Relationship, +) def test_relationship(): with open("tests/data/assets/ndjson/relationship_import.json", "r") as file: data = json.load(file) - res = list(NDJsonConverter.deserialize(data)) + res = [ + Label( + data=GenericDataRowData( + uid="clf98gj90000qp38ka34yhptl", + ), + annotations=[ + ObjectAnnotation( + name="cat", + extra={ + "uuid": "d8813907-b15d-4374-bbe6-b9877fb42ccd", + }, + value=Rectangle( + start=Point(x=100.0, y=200.0), + end=Point(x=200.0, y=300.0), + ), + ), + ObjectAnnotation( + name="dog", + extra={ + "uuid": "9b1e1249-36b4-4665-b60a-9060e0d18660", + }, + value=Rectangle( + start=Point(x=400.0, y=500.0), + end=Point(x=600.0, y=700.0), + ), + ), + RelationshipAnnotation( + name="is chasing", + extra={"uuid": "0e6354eb-9adb-47e5-8e52-217ed016d948"}, + value=Relationship( + source=ObjectAnnotation( + name="dog", + extra={ + "uuid": "9b1e1249-36b4-4665-b60a-9060e0d18660", + }, + value=Rectangle( + start=Point(x=400.0, y=500.0), + end=Point(x=600.0, y=700.0), + ), + ), + target=ObjectAnnotation( + name="cat", + extra={ + "uuid": "d8813907-b15d-4374-bbe6-b9877fb42ccd", + }, + value=Rectangle( + extra={}, + start=Point(x=100.0, y=200.0), + end=Point(x=200.0, y=300.0), + ), + ), + type=Relationship.Type.UNIDIRECTIONAL, + ), + ), + ], + ), + Label( + data=GenericDataRowData( + uid="clf98gj90000qp38ka34yhptl-DIFFERENT", + ), + annotations=[ + ObjectAnnotation( + name="cat", + extra={ + "uuid": "d8813907-b15d-4374-bbe6-b9877fb42ccd", + }, + value=Rectangle( + start=Point(x=100.0, y=200.0), + end=Point(x=200.0, y=300.0), + ), + ), + ObjectAnnotation( + name="dog", + extra={ + "uuid": "9b1e1249-36b4-4665-b60a-9060e0d18660", + }, + value=Rectangle( + start=Point(x=400.0, y=500.0), + end=Point(x=600.0, y=700.0), + ), + ), + RelationshipAnnotation( + name="is chasing", + extra={"uuid": "0e6354eb-9adb-47e5-8e52-217ed016d948"}, + value=Relationship( + source=ObjectAnnotation( + name="dog", + extra={ + "uuid": "9b1e1249-36b4-4665-b60a-9060e0d18660", + }, + value=Rectangle( + start=Point(x=400.0, y=500.0), + end=Point(x=600.0, y=700.0), + ), + ), + target=ObjectAnnotation( + name="cat", + extra={ + "uuid": "d8813907-b15d-4374-bbe6-b9877fb42ccd", + }, + value=Rectangle( + start=Point(x=100.0, y=200.0), + end=Point(x=200.0, y=300.0), + ), + ), + type=Relationship.Type.UNIDIRECTIONAL, + ), + ), + ], + ), + ] res = list(NDJsonConverter.serialize(res)) assert len(res) == len(data) @@ -44,29 +163,3 @@ def test_relationship(): assert res_relationship_second_annotation["relationship"]["target"] in [ annot["uuid"] for annot in res_source_and_target ] - - -def test_relationship_nonexistent_object(): - with open("tests/data/assets/ndjson/relationship_import.json", "r") as file: - data = json.load(file) - - relationship_annotation = data[2] - source_uuid = relationship_annotation["relationship"]["source"] - target_uuid = str(uuid4()) - relationship_annotation["relationship"]["target"] = target_uuid - error_msg = f"Relationship object refers to nonexistent object with UUID '{source_uuid}' and/or '{target_uuid}'" - - with pytest.raises(ValueError, match=error_msg): - list(NDJsonConverter.deserialize(data)) - - -def test_relationship_duplicate_uuids(): - with open("tests/data/assets/ndjson/relationship_import.json", "r") as file: - data = json.load(file) - - source, target = data[0], data[1] - target["uuid"] = source["uuid"] - error_msg = f"UUID '{source['uuid']}' is not unique" - - with pytest.raises(AssertionError, match=error_msg): - list(NDJsonConverter.deserialize(data)) diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_text.py b/libs/labelbox/tests/data/serialization/ndjson/test_text.py index d5e81c51a..28eba07bd 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_text.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_text.py @@ -1,10 +1,8 @@ from labelbox.data.annotation_types.annotation import ClassificationAnnotation from labelbox.data.annotation_types.classification.classification import ( - ClassificationAnswer, - Radio, Text, ) -from labelbox.data.annotation_types.data.text import TextData +from labelbox.data.annotation_types.data import GenericDataRowData from labelbox.data.annotation_types.label import Label from labelbox.data.serialization.ndjson.converter import NDJsonConverter @@ -13,9 +11,8 @@ def test_serialization(): label = Label( uid="ckj7z2q0b0000jx6x0q2q7q0d", - data=TextData( + data=GenericDataRowData( uid="bkj7z2q0b0000jx6x0q2q7q0d", - text="This is a test", ), annotations=[ ClassificationAnnotation( @@ -34,11 +31,3 @@ def test_serialization(): assert res["name"] == "radio_question_geo" assert res["answer"] == "first_radio_answer" assert res["dataRow"]["id"] == "bkj7z2q0b0000jx6x0q2q7q0d" - - deserialized = NDJsonConverter.deserialize([res]) - res = next(deserialized) - annotation = res.annotations[0] - - annotation_value = annotation.value - assert type(annotation_value) is Text - assert annotation_value.answer == "first_radio_answer" diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_text_entity.py b/libs/labelbox/tests/data/serialization/ndjson/test_text_entity.py index 3e856f001..fb93f15d4 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_text_entity.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_text_entity.py @@ -1,21 +1,68 @@ import json -import pytest +from labelbox.data.annotation_types.data.generic_data_row_data import ( + GenericDataRowData, +) +from labelbox.data.mixins import CustomMetric from labelbox.data.serialization.ndjson.converter import NDJsonConverter +from labelbox.types import Label, ObjectAnnotation, TextEntity + + +def test_text_entity_import(): + with open("tests/data/assets/ndjson/text_entity_import.json", "r") as file: + data = json.load(file) + + labels = [ + Label( + data=GenericDataRowData( + uid="cl6xnv9h61fv0085yhtoq06ht", + ), + annotations=[ + ObjectAnnotation( + custom_metrics=[ + CustomMetric(name="customMetric1", value=0.5), + CustomMetric(name="customMetric2", value=0.3), + ], + confidence=0.53, + name="some-text-entity", + feature_schema_id="cl6xnuwt95lqq07330tbb3mfd", + extra={ + "uuid": "5ad9c52f-058d-49c8-a749-3f20b84f8cd4", + }, + value=TextEntity(start=67, end=128, extra={}), + ) + ], + ) + ] + res = list(NDJsonConverter.serialize(labels)) + assert res == data -@pytest.mark.parametrize( - "filename", - [ - "tests/data/assets/ndjson/text_entity_import.json", +def test_text_entity_import_without_confidence(): + with open( "tests/data/assets/ndjson/text_entity_without_confidence_import.json", - ], -) -def test_text_entity_import(filename: str): - with open(filename, "r") as file: + "r", + ) as file: data = json.load(file) - res = list(NDJsonConverter.deserialize(data)) - res = list(NDJsonConverter.serialize(res)) + labels = [ + Label( + data=GenericDataRowData( + uid="cl6xnv9h61fv0085yhtoq06ht", + ), + annotations=[ + ObjectAnnotation( + name="some-text-entity", + feature_schema_id="cl6xnuwt95lqq07330tbb3mfd", + extra={ + "uuid": "5ad9c52f-058d-49c8-a749-3f20b84f8cd4", + }, + value=TextEntity(start=67, end=128, extra={}), + ) + ], + ) + ] + + res = list(NDJsonConverter.serialize(labels)) assert res == data diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_video.py b/libs/labelbox/tests/data/serialization/ndjson/test_video.py index c7a6535c4..6c14343a4 100644 --- a/libs/labelbox/tests/data/serialization/ndjson/test_video.py +++ b/libs/labelbox/tests/data/serialization/ndjson/test_video.py @@ -1,20 +1,21 @@ import json -from labelbox.client import Client from labelbox.data.annotation_types.classification.classification import ( Checklist, ClassificationAnnotation, ClassificationAnswer, Radio, + Text, ) -from labelbox.data.annotation_types.data.video import VideoData +from labelbox.data.annotation_types.data import GenericDataRowData from labelbox.data.annotation_types.geometry.line import Line from labelbox.data.annotation_types.geometry.point import Point from labelbox.data.annotation_types.geometry.rectangle import Rectangle -from labelbox.data.annotation_types.geometry.point import Point from labelbox.data.annotation_types.label import Label -from labelbox.data.annotation_types.video import VideoObjectAnnotation -from labelbox import parser +from labelbox.data.annotation_types.video import ( + VideoClassificationAnnotation, + VideoObjectAnnotation, +) from labelbox.data.serialization.ndjson.converter import NDJsonConverter from operator import itemgetter @@ -24,15 +25,275 @@ def test_video(): with open("tests/data/assets/ndjson/video_import.json", "r") as file: data = json.load(file) - res = list(NDJsonConverter.deserialize(data)) - res = list(NDJsonConverter.serialize(res)) + labels = [ + Label( + data=GenericDataRowData(uid="ckrb1sf1i1g7i0ybcdc6oc8ct"), + annotations=[ + VideoClassificationAnnotation( + feature_schema_id="ckrb1sfjx099a0y914hl319ie", + extra={"uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673"}, + value=Radio( + answer=ClassificationAnswer( + feature_schema_id="ckrb1sfl8099g0y91cxbd5ftb", + ), + ), + frame=30, + ), + VideoClassificationAnnotation( + feature_schema_id="ckrb1sfjx099a0y914hl319ie", + extra={"uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673"}, + value=Radio( + answer=ClassificationAnswer( + feature_schema_id="ckrb1sfl8099g0y91cxbd5ftb", + ), + ), + frame=31, + ), + VideoClassificationAnnotation( + feature_schema_id="ckrb1sfjx099a0y914hl319ie", + extra={"uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673"}, + value=Radio( + answer=ClassificationAnswer( + feature_schema_id="ckrb1sfl8099g0y91cxbd5ftb", + ), + ), + frame=32, + ), + VideoClassificationAnnotation( + feature_schema_id="ckrb1sfjx099a0y914hl319ie", + extra={"uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673"}, + value=Radio( + answer=ClassificationAnswer( + feature_schema_id="ckrb1sfl8099g0y91cxbd5ftb", + ), + ), + frame=33, + ), + VideoClassificationAnnotation( + feature_schema_id="ckrb1sfjx099a0y914hl319ie", + extra={"uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673"}, + value=Radio( + answer=ClassificationAnswer( + feature_schema_id="ckrb1sfl8099g0y91cxbd5ftb", + ), + ), + frame=34, + ), + VideoClassificationAnnotation( + feature_schema_id="ckrb1sfjx099a0y914hl319ie", + extra={"uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673"}, + value=Radio( + answer=ClassificationAnswer( + feature_schema_id="ckrb1sfl8099g0y91cxbd5ftb", + ), + ), + frame=35, + ), + VideoClassificationAnnotation( + feature_schema_id="ckrb1sfjx099a0y914hl319ie", + extra={"uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673"}, + value=Radio( + answer=ClassificationAnswer( + feature_schema_id="ckrb1sfl8099g0y91cxbd5ftb", + ), + ), + frame=50, + ), + VideoClassificationAnnotation( + feature_schema_id="ckrb1sfjx099a0y914hl319ie", + extra={"uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673"}, + value=Radio( + answer=ClassificationAnswer( + feature_schema_id="ckrb1sfl8099g0y91cxbd5ftb", + ), + ), + frame=51, + ), + VideoClassificationAnnotation( + feature_schema_id="ckrb1sfkn099c0y910wbo0p1a", + extra={"uuid": "d009925d-91a3-4f67-abd9-753453f5a584"}, + value=Checklist( + answer=[ + ClassificationAnswer( + feature_schema_id="ckrb1sfl8099e0y919v260awv", + ) + ], + ), + frame=0, + ), + VideoClassificationAnnotation( + feature_schema_id="ckrb1sfkn099c0y910wbo0p1a", + extra={"uuid": "d009925d-91a3-4f67-abd9-753453f5a584"}, + value=Checklist( + answer=[ + ClassificationAnswer( + feature_schema_id="ckrb1sfl8099e0y919v260awv", + ) + ], + ), + frame=1, + ), + VideoClassificationAnnotation( + feature_schema_id="ckrb1sfkn099c0y910wbo0p1a", + extra={"uuid": "d009925d-91a3-4f67-abd9-753453f5a584"}, + value=Checklist( + answer=[ + ClassificationAnswer( + feature_schema_id="ckrb1sfl8099e0y919v260awv", + ) + ], + ), + frame=2, + ), + VideoClassificationAnnotation( + feature_schema_id="ckrb1sfkn099c0y910wbo0p1a", + extra={"uuid": "d009925d-91a3-4f67-abd9-753453f5a584"}, + value=Checklist( + answer=[ + ClassificationAnswer( + feature_schema_id="ckrb1sfl8099e0y919v260awv", + ) + ], + ), + frame=3, + ), + VideoClassificationAnnotation( + feature_schema_id="ckrb1sfkn099c0y910wbo0p1a", + extra={"uuid": "d009925d-91a3-4f67-abd9-753453f5a584"}, + value=Checklist( + answer=[ + ClassificationAnswer( + feature_schema_id="ckrb1sfl8099e0y919v260awv", + ) + ], + ), + frame=4, + ), + VideoClassificationAnnotation( + feature_schema_id="ckrb1sfkn099c0y910wbo0p1a", + extra={"uuid": "d009925d-91a3-4f67-abd9-753453f5a584"}, + value=Checklist( + answer=[ + ClassificationAnswer( + feature_schema_id="ckrb1sfl8099e0y919v260awv", + ) + ], + ), + frame=5, + ), + ClassificationAnnotation( + feature_schema_id="ckrb1sfkn099c0y910wbo0p1a", + extra={"uuid": "90e2ecf7-c19c-47e6-8cdb-8867e1b9d88c"}, + value=Text(answer="a value"), + ), + VideoObjectAnnotation( + feature_schema_id="cl5islwg200gfci6g0oitaypu", + extra={"uuid": "6f7c835a-0139-4896-b73f-66a6baa89e94"}, + value=Line( + points=[ + Point(x=10.0, y=10.0), + Point(x=100.0, y=100.0), + Point(x=50.0, y=30.0), + ], + ), + frame=1, + keyframe=True, + segment_index=0, + ), + VideoObjectAnnotation( + feature_schema_id="cl5islwg200gfci6g0oitaypu", + extra={"uuid": "6f7c835a-0139-4896-b73f-66a6baa89e94"}, + value=Line( + points=[ + Point(x=15.0, y=10.0), + Point(x=50.0, y=100.0), + Point(x=50.0, y=30.0), + ], + ), + frame=5, + keyframe=True, + segment_index=0, + ), + VideoObjectAnnotation( + feature_schema_id="cl5islwg200gfci6g0oitaypu", + extra={"uuid": "6f7c835a-0139-4896-b73f-66a6baa89e94"}, + value=Line( + points=[ + Point(x=100.0, y=10.0), + Point(x=50.0, y=100.0), + Point(x=50.0, y=30.0), + ], + ), + frame=8, + keyframe=True, + segment_index=1, + ), + VideoObjectAnnotation( + feature_schema_id="cl5it7ktp00i5ci6gf80b1ysd", + extra={"uuid": "f963be22-227b-4efe-9be4-2738ed822216"}, + value=Point(x=10.0, y=10.0), + frame=1, + keyframe=True, + segment_index=0, + ), + VideoObjectAnnotation( + feature_schema_id="cl5it7ktp00i5ci6gf80b1ysd", + extra={"uuid": "f963be22-227b-4efe-9be4-2738ed822216"}, + value=Point(x=50.0, y=50.0), + frame=5, + keyframe=True, + segment_index=1, + ), + VideoObjectAnnotation( + feature_schema_id="cl5it7ktp00i5ci6gf80b1ysd", + extra={"uuid": "f963be22-227b-4efe-9be4-2738ed822216"}, + value=Point(x=10.0, y=50.0), + frame=10, + keyframe=True, + segment_index=1, + ), + VideoObjectAnnotation( + feature_schema_id="cl5iw0roz00lwci6g5jni62vs", + extra={"uuid": "13b2ee0e-2355-4336-8b83-d74d09e3b1e7"}, + value=Rectangle( + start=Point(x=5.0, y=10.0), + end=Point(x=155.0, y=110.0), + ), + frame=1, + keyframe=True, + segment_index=0, + ), + VideoObjectAnnotation( + feature_schema_id="cl5iw0roz00lwci6g5jni62vs", + extra={"uuid": "13b2ee0e-2355-4336-8b83-d74d09e3b1e7"}, + value=Rectangle( + start=Point(x=5.0, y=30.0), + end=Point(x=155.0, y=80.0), + ), + frame=5, + keyframe=True, + segment_index=0, + ), + VideoObjectAnnotation( + feature_schema_id="cl5iw0roz00lwci6g5jni62vs", + extra={"uuid": "13b2ee0e-2355-4336-8b83-d74d09e3b1e7"}, + value=Rectangle( + start=Point(x=200.0, y=300.0), + end=Point(x=350.0, y=700.0), + ), + frame=10, + keyframe=True, + segment_index=1, + ), + ], + ) + ] + + res = list(NDJsonConverter.serialize(labels)) data = sorted(data, key=itemgetter("uuid")) res = sorted(res, key=itemgetter("uuid")) - - pairs = zip(data, res) - for data, res in pairs: - assert data == res + assert data == res def test_video_name_only(): @@ -40,21 +301,279 @@ def test_video_name_only(): "tests/data/assets/ndjson/video_import_name_only.json", "r" ) as file: data = json.load(file) - - res = list(NDJsonConverter.deserialize(data)) - res = list(NDJsonConverter.serialize(res)) - + labels = [ + Label( + data=GenericDataRowData(uid="ckrb1sf1i1g7i0ybcdc6oc8ct"), + annotations=[ + VideoClassificationAnnotation( + name="question 1", + extra={"uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673"}, + value=Radio( + answer=ClassificationAnswer( + name="answer 1", + ), + ), + frame=30, + ), + VideoClassificationAnnotation( + name="question 1", + extra={"uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673"}, + value=Radio( + answer=ClassificationAnswer( + name="answer 1", + ), + ), + frame=31, + ), + VideoClassificationAnnotation( + name="question 1", + extra={"uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673"}, + value=Radio( + answer=ClassificationAnswer( + name="answer 1", + ), + ), + frame=32, + ), + VideoClassificationAnnotation( + name="question 1", + extra={"uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673"}, + value=Radio( + answer=ClassificationAnswer( + name="answer 1", + ), + ), + frame=33, + ), + VideoClassificationAnnotation( + name="question 1", + extra={"uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673"}, + value=Radio( + answer=ClassificationAnswer( + name="answer 1", + ), + ), + frame=34, + ), + VideoClassificationAnnotation( + name="question 1", + extra={"uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673"}, + value=Radio( + answer=ClassificationAnswer( + name="answer 1", + ), + ), + frame=35, + ), + VideoClassificationAnnotation( + name="question 1", + extra={"uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673"}, + value=Radio( + answer=ClassificationAnswer( + name="answer 1", + ), + ), + frame=50, + ), + VideoClassificationAnnotation( + name="question 1", + extra={"uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673"}, + value=Radio( + answer=ClassificationAnswer( + name="answer 1", + ), + ), + frame=51, + ), + VideoClassificationAnnotation( + name="question 2", + extra={"uuid": "d009925d-91a3-4f67-abd9-753453f5a584"}, + value=Checklist( + answer=[ + ClassificationAnswer( + name="answer 2", + ) + ], + ), + frame=0, + ), + VideoClassificationAnnotation( + name="question 2", + extra={"uuid": "d009925d-91a3-4f67-abd9-753453f5a584"}, + value=Checklist( + answer=[ + ClassificationAnswer( + name="answer 2", + ) + ], + ), + frame=1, + ), + VideoClassificationAnnotation( + name="question 2", + extra={"uuid": "d009925d-91a3-4f67-abd9-753453f5a584"}, + value=Checklist( + answer=[ + ClassificationAnswer( + name="answer 2", + ) + ], + ), + frame=2, + ), + VideoClassificationAnnotation( + name="question 2", + extra={"uuid": "d009925d-91a3-4f67-abd9-753453f5a584"}, + value=Checklist( + answer=[ + ClassificationAnswer( + name="answer 2", + ) + ], + ), + frame=3, + ), + VideoClassificationAnnotation( + name="question 2", + extra={"uuid": "d009925d-91a3-4f67-abd9-753453f5a584"}, + value=Checklist( + answer=[ + ClassificationAnswer( + name="answer 2", + ) + ], + ), + frame=4, + ), + VideoClassificationAnnotation( + name="question 2", + extra={"uuid": "d009925d-91a3-4f67-abd9-753453f5a584"}, + value=Checklist( + answer=[ + ClassificationAnswer( + name="answer 2", + ) + ], + ), + frame=5, + ), + ClassificationAnnotation( + name="question 3", + extra={"uuid": "e5f32456-bd67-4520-8d3b-cbeb2204bad3"}, + value=Text(answer="a value"), + ), + VideoObjectAnnotation( + name="segment 1", + extra={"uuid": "6f7c835a-0139-4896-b73f-66a6baa89e94"}, + value=Line( + points=[ + Point(x=10.0, y=10.0), + Point(x=100.0, y=100.0), + Point(x=50.0, y=30.0), + ], + ), + frame=1, + keyframe=True, + segment_index=0, + ), + VideoObjectAnnotation( + name="segment 1", + extra={"uuid": "6f7c835a-0139-4896-b73f-66a6baa89e94"}, + value=Line( + points=[ + Point(x=15.0, y=10.0), + Point(x=50.0, y=100.0), + Point(x=50.0, y=30.0), + ], + ), + frame=5, + keyframe=True, + segment_index=0, + ), + VideoObjectAnnotation( + name="segment 1", + extra={"uuid": "6f7c835a-0139-4896-b73f-66a6baa89e94"}, + value=Line( + points=[ + Point(x=100.0, y=10.0), + Point(x=50.0, y=100.0), + Point(x=50.0, y=30.0), + ], + ), + frame=8, + keyframe=True, + segment_index=1, + ), + VideoObjectAnnotation( + name="segment 2", + extra={"uuid": "f963be22-227b-4efe-9be4-2738ed822216"}, + value=Point(x=10.0, y=10.0), + frame=1, + keyframe=True, + segment_index=0, + ), + VideoObjectAnnotation( + name="segment 2", + extra={"uuid": "f963be22-227b-4efe-9be4-2738ed822216"}, + value=Point(x=50.0, y=50.0), + frame=5, + keyframe=True, + segment_index=1, + ), + VideoObjectAnnotation( + name="segment 2", + extra={"uuid": "f963be22-227b-4efe-9be4-2738ed822216"}, + value=Point(x=10.0, y=50.0), + frame=10, + keyframe=True, + segment_index=1, + ), + VideoObjectAnnotation( + name="segment 3", + extra={"uuid": "13b2ee0e-2355-4336-8b83-d74d09e3b1e7"}, + value=Rectangle( + start=Point(x=5.0, y=10.0), + end=Point(x=155.0, y=110.0), + ), + frame=1, + keyframe=True, + segment_index=0, + ), + VideoObjectAnnotation( + name="segment 3", + extra={"uuid": "13b2ee0e-2355-4336-8b83-d74d09e3b1e7"}, + value=Rectangle( + start=Point(x=5.0, y=30.0), + end=Point(x=155.0, y=80.0), + ), + frame=5, + keyframe=True, + segment_index=0, + ), + VideoObjectAnnotation( + name="segment 3", + extra={"uuid": "13b2ee0e-2355-4336-8b83-d74d09e3b1e7"}, + value=Rectangle( + start=Point(x=200.0, y=300.0), + end=Point(x=350.0, y=700.0), + ), + frame=10, + keyframe=True, + segment_index=1, + ), + ], + ) + ] + res = list(NDJsonConverter.serialize(labels)) data = sorted(data, key=itemgetter("uuid")) res = sorted(res, key=itemgetter("uuid")) - pairs = zip(data, res) - for data, res in pairs: - assert data == res + assert data == res def test_video_classification_global_subclassifications(): label = Label( - data=VideoData( + data=GenericDataRowData( global_key="sample-video-4.mp4", ), annotations=[ @@ -67,7 +586,6 @@ def test_video_classification_global_subclassifications(): ClassificationAnnotation( name="nested_checklist_question", value=Checklist( - name="checklist", answer=[ ClassificationAnswer( name="first_checklist_answer", @@ -94,7 +612,7 @@ def test_video_classification_global_subclassifications(): "dataRow": {"globalKey": "sample-video-4.mp4"}, } - expected_second_annotation = nested_checklist_annotation_ndjson = { + expected_second_annotation = { "name": "nested_checklist_question", "answer": [ { @@ -116,12 +634,6 @@ def test_video_classification_global_subclassifications(): annotations.pop("uuid") assert res == [expected_first_annotation, expected_second_annotation] - deserialized = NDJsonConverter.deserialize(res) - res = next(deserialized) - annotations = res.annotations - for i, annotation in enumerate(annotations): - assert annotation.name == label.annotations[i].name - def test_video_classification_nesting_bbox(): bbox_annotation = [ @@ -277,7 +789,7 @@ def test_video_classification_nesting_bbox(): ] label = Label( - data=VideoData( + data=GenericDataRowData( global_key="sample-video-4.mp4", ), annotations=bbox_annotation, @@ -287,14 +799,6 @@ def test_video_classification_nesting_bbox(): res = [x for x in serialized] assert res == expected - deserialized = NDJsonConverter.deserialize(res) - res = next(deserialized) - annotations = res.annotations - for i, annotation in enumerate(annotations): - annotation.extra.pop("uuid") - assert annotation.value == label.annotations[i].value - assert annotation.name == label.annotations[i].name - def test_video_classification_point(): bbox_annotation = [ @@ -435,7 +939,7 @@ def test_video_classification_point(): ] label = Label( - data=VideoData( + data=GenericDataRowData( global_key="sample-video-4.mp4", ), annotations=bbox_annotation, @@ -445,13 +949,6 @@ def test_video_classification_point(): res = [x for x in serialized] assert res == expected - deserialized = NDJsonConverter.deserialize(res) - res = next(deserialized) - annotations = res.annotations - for i, annotation in enumerate(annotations): - annotation.extra.pop("uuid") - assert annotation.value == label.annotations[i].value - def test_video_classification_frameline(): bbox_annotation = [ @@ -610,7 +1107,7 @@ def test_video_classification_frameline(): ] label = Label( - data=VideoData( + data=GenericDataRowData( global_key="sample-video-4.mp4", ), annotations=bbox_annotation, @@ -619,9 +1116,289 @@ def test_video_classification_frameline(): res = [x for x in serialized] assert res == expected - deserialized = NDJsonConverter.deserialize(res) - res = next(deserialized) - annotations = res.annotations - for i, annotation in enumerate(annotations): - annotation.extra.pop("uuid") - assert annotation.value == label.annotations[i].value + +[ + { + "answer": "a value", + "dataRow": {"id": "ckrb1sf1i1g7i0ybcdc6oc8ct"}, + "schemaId": "ckrb1sfkn099c0y910wbo0p1a", + "uuid": "90e2ecf7-c19c-47e6-8cdb-8867e1b9d88c", + }, + { + "answer": {"schemaId": "ckrb1sfl8099g0y91cxbd5ftb"}, + "dataRow": {"id": "ckrb1sf1i1g7i0ybcdc6oc8ct"}, + "frames": [{"end": 35, "start": 30}, {"end": 51, "start": 50}], + "schemaId": "ckrb1sfjx099a0y914hl319ie", + "uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673", + }, + { + "answer": [{"schemaId": "ckrb1sfl8099e0y919v260awv"}], + "dataRow": {"id": "ckrb1sf1i1g7i0ybcdc6oc8ct"}, + "frames": [{"end": 5, "start": 0}], + "schemaId": "ckrb1sfkn099c0y910wbo0p1a", + "uuid": "d009925d-91a3-4f67-abd9-753453f5a584", + }, + { + "classifications": [], + "dataRow": {"id": "ckrb1sf1i1g7i0ybcdc6oc8ct"}, + "schemaId": "cl5islwg200gfci6g0oitaypu", + "segments": [ + { + "keyframes": [ + { + "classifications": [], + "frame": 1, + "line": [ + {"x": 10.0, "y": 10.0}, + {"x": 100.0, "y": 100.0}, + {"x": 50.0, "y": 30.0}, + ], + }, + { + "classifications": [], + "frame": 5, + "line": [ + {"x": 15.0, "y": 10.0}, + {"x": 50.0, "y": 100.0}, + {"x": 50.0, "y": 30.0}, + ], + }, + ] + }, + { + "keyframes": [ + { + "classifications": [], + "frame": 8, + "line": [ + {"x": 100.0, "y": 10.0}, + {"x": 50.0, "y": 100.0}, + {"x": 50.0, "y": 30.0}, + ], + } + ] + }, + ], + "uuid": "6f7c835a-0139-4896-b73f-66a6baa89e94", + }, + { + "classifications": [], + "dataRow": {"id": "ckrb1sf1i1g7i0ybcdc6oc8ct"}, + "schemaId": "cl5it7ktp00i5ci6gf80b1ysd", + "segments": [ + { + "keyframes": [ + { + "classifications": [], + "frame": 1, + "point": {"x": 10.0, "y": 10.0}, + } + ] + }, + { + "keyframes": [ + { + "classifications": [], + "frame": 5, + "point": {"x": 50.0, "y": 50.0}, + }, + { + "classifications": [], + "frame": 10, + "point": {"x": 10.0, "y": 50.0}, + }, + ] + }, + ], + "uuid": "f963be22-227b-4efe-9be4-2738ed822216", + }, + { + "classifications": [], + "dataRow": {"id": "ckrb1sf1i1g7i0ybcdc6oc8ct"}, + "schemaId": "cl5iw0roz00lwci6g5jni62vs", + "segments": [ + { + "keyframes": [ + { + "bbox": { + "height": 100.0, + "left": 5.0, + "top": 10.0, + "width": 150.0, + }, + "classifications": [], + "frame": 1, + }, + { + "bbox": { + "height": 50.0, + "left": 5.0, + "top": 30.0, + "width": 150.0, + }, + "classifications": [], + "frame": 5, + }, + ] + }, + { + "keyframes": [ + { + "bbox": { + "height": 400.0, + "left": 200.0, + "top": 300.0, + "width": 150.0, + }, + "classifications": [], + "frame": 10, + } + ] + }, + ], + "uuid": "13b2ee0e-2355-4336-8b83-d74d09e3b1e7", + }, +] + +[ + { + "answer": {"schemaId": "ckrb1sfl8099g0y91cxbd5ftb"}, + "schemaId": "ckrb1sfjx099a0y914hl319ie", + "dataRow": {"id": "ckrb1sf1i1g7i0ybcdc6oc8ct"}, + "uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673", + "frames": [{"start": 30, "end": 35}, {"start": 50, "end": 51}], + }, + { + "answer": [{"schemaId": "ckrb1sfl8099e0y919v260awv"}], + "schemaId": "ckrb1sfkn099c0y910wbo0p1a", + "dataRow": {"id": "ckrb1sf1i1g7i0ybcdc6oc8ct"}, + "uuid": "d009925d-91a3-4f67-abd9-753453f5a584", + "frames": [{"start": 0, "end": 5}], + }, + { + "answer": "a value", + "schemaId": "ckrb1sfkn099c0y910wbo0p1a", + "dataRow": {"id": "ckrb1sf1i1g7i0ybcdc6oc8ct"}, + "uuid": "90e2ecf7-c19c-47e6-8cdb-8867e1b9d88c", + }, + { + "classifications": [], + "schemaId": "cl5islwg200gfci6g0oitaypu", + "dataRow": {"id": "ckrb1sf1i1g7i0ybcdc6oc8ct"}, + "uuid": "6f7c835a-0139-4896-b73f-66a6baa89e94", + "segments": [ + { + "keyframes": [ + { + "frame": 1, + "line": [ + {"x": 10.0, "y": 10.0}, + {"x": 100.0, "y": 100.0}, + {"x": 50.0, "y": 30.0}, + ], + "classifications": [], + }, + { + "frame": 5, + "line": [ + {"x": 15.0, "y": 10.0}, + {"x": 50.0, "y": 100.0}, + {"x": 50.0, "y": 30.0}, + ], + "classifications": [], + }, + ] + }, + { + "keyframes": [ + { + "frame": 8, + "line": [ + {"x": 100.0, "y": 10.0}, + {"x": 50.0, "y": 100.0}, + {"x": 50.0, "y": 30.0}, + ], + "classifications": [], + } + ] + }, + ], + }, + { + "classifications": [], + "schemaId": "cl5it7ktp00i5ci6gf80b1ysd", + "dataRow": {"id": "ckrb1sf1i1g7i0ybcdc6oc8ct"}, + "uuid": "f963be22-227b-4efe-9be4-2738ed822216", + "segments": [ + { + "keyframes": [ + { + "frame": 1, + "point": {"x": 10.0, "y": 10.0}, + "classifications": [], + } + ] + }, + { + "keyframes": [ + { + "frame": 5, + "point": {"x": 50.0, "y": 50.0}, + "classifications": [], + }, + { + "frame": 10, + "point": {"x": 10.0, "y": 50.0}, + "classifications": [], + }, + ] + }, + ], + }, + { + "classifications": [], + "schemaId": "cl5iw0roz00lwci6g5jni62vs", + "dataRow": {"id": "ckrb1sf1i1g7i0ybcdc6oc8ct"}, + "uuid": "13b2ee0e-2355-4336-8b83-d74d09e3b1e7", + "segments": [ + { + "keyframes": [ + { + "frame": 1, + "bbox": { + "top": 10.0, + "left": 5.0, + "height": 100.0, + "width": 150.0, + }, + "classifications": [], + }, + { + "frame": 5, + "bbox": { + "top": 30.0, + "left": 5.0, + "height": 50.0, + "width": 150.0, + }, + "classifications": [], + }, + ] + }, + { + "keyframes": [ + { + "frame": 10, + "bbox": { + "top": 300.0, + "left": 200.0, + "height": 400.0, + "width": 150.0, + }, + "classifications": [], + } + ] + }, + ], + }, +] diff --git a/libs/labelbox/tests/data/test_data_row_metadata.py b/libs/labelbox/tests/data/test_data_row_metadata.py index 891cab9be..2a455efce 100644 --- a/libs/labelbox/tests/data/test_data_row_metadata.py +++ b/libs/labelbox/tests/data/test_data_row_metadata.py @@ -1,18 +1,18 @@ -from datetime import datetime +import uuid +from datetime import datetime, timezone import pytest -import uuid +from lbox.exceptions import MalformedQueryException from labelbox import Dataset -from labelbox.exceptions import MalformedQueryException -from labelbox.schema.identifiables import GlobalKeys, UniqueIds from labelbox.schema.data_row_metadata import ( - DataRowMetadataField, DataRowMetadata, + DataRowMetadataField, DataRowMetadataKind, DataRowMetadataOntology, _parse_metadata_schema, ) +from labelbox.schema.identifiables import GlobalKeys, UniqueIds INVALID_SCHEMA_ID = "1" * 25 FAKE_SCHEMA_ID = "0" * 25 @@ -61,7 +61,7 @@ def big_dataset(dataset: Dataset, image_url): def make_metadata(dr_id: str = None, gk: str = None) -> DataRowMetadata: msg = "A message" - time = datetime.utcnow() + time = datetime.now(timezone.utc) metadata = DataRowMetadata( global_key=gk, @@ -79,7 +79,7 @@ def make_metadata(dr_id: str = None, gk: str = None) -> DataRowMetadata: def make_named_metadata(dr_id) -> DataRowMetadata: msg = "A message" - time = datetime.utcnow() + time = datetime.now(timezone.utc) metadata = DataRowMetadata( data_row_id=dr_id, @@ -124,7 +124,7 @@ def test_get_datarow_metadata_ontology(mdo): fields=[ DataRowMetadataField( schema_id=mdo.reserved_by_name["captureDateTime"].uid, - value=datetime.utcnow(), + value=datetime.now(timezone.utc), ), DataRowMetadataField(schema_id=split.parent, value=split.uid), DataRowMetadataField( diff --git a/libs/labelbox/tests/integration/conftest.py b/libs/labelbox/tests/integration/conftest.py index e73fef920..a0c2499a5 100644 --- a/libs/labelbox/tests/integration/conftest.py +++ b/libs/labelbox/tests/integration/conftest.py @@ -1,16 +1,22 @@ +import json import os +import re import sys import time +import uuid from collections import defaultdict from datetime import datetime, timezone +from enum import Enum from itertools import islice -from typing import Type +from types import SimpleNamespace +from typing import List, Tuple, Type import pytest from labelbox import ( Classification, Client, + DataRow, Dataset, LabelingFrontend, MediaType, @@ -20,9 +26,14 @@ ResponseOption, Tool, ) +from labelbox.orm import query +from labelbox.pagination import PaginatedCollection +from labelbox.schema.annotation_import import LabelImport +from labelbox.schema.catalog import Catalog from labelbox.schema.data_row import DataRowMetadataField +from labelbox.schema.enums import AnnotationImportState +from labelbox.schema.invite import Invite from labelbox.schema.ontology_kind import OntologyKind -from labelbox.schema.queue_mode import QueueMode from labelbox.schema.user import User @@ -80,7 +91,6 @@ def project_pack(client): projects = [ client.create_project( name=f"user-proj-{idx}", - queue_mode=QueueMode.Batch, media_type=MediaType.Image, ) for idx in range(2) @@ -91,24 +101,27 @@ def project_pack(client): @pytest.fixture -def project_with_empty_ontology(project): - editor = list( - project.client.get_labeling_frontends( - where=LabelingFrontend.name == "editor" - ) - )[0] - empty_ontology = {"tools": [], "classifications": []} - project.setup(editor, empty_ontology) +def project_with_one_feature_ontology(project, client: Client): + tools = [ + Tool(tool=Tool.Type.BBOX, name="test-bbox-class").asdict(), + ] + empty_ontology = {"tools": tools, "classifications": []} + ontology = client.create_ontology( + "empty ontology", + empty_ontology, + MediaType.Image, + ) + project.connect_ontology(ontology) yield project @pytest.fixture def configured_project( - project_with_empty_ontology, initial_dataset, rand_gen, image_url + project_with_one_feature_ontology, initial_dataset, rand_gen, image_url ): dataset = initial_dataset data_row_id = dataset.create_data_row(row_data=image_url).uid - project = project_with_empty_ontology + project = project_with_one_feature_ontology batch = project.create_batch( rand_gen(str), @@ -124,11 +137,10 @@ def configured_project( @pytest.fixture def configured_project_with_complex_ontology( - client, initial_dataset, rand_gen, image_url, teardown_helpers + client: Client, initial_dataset, rand_gen, image_url, teardown_helpers ): project = client.create_project( name=rand_gen(str), - queue_mode=QueueMode.Batch, media_type=MediaType.Image, ) dataset = initial_dataset @@ -142,19 +154,12 @@ def configured_project_with_complex_ontology( ) project.data_row_ids = data_row_ids - editor = list( - project.client.get_labeling_frontends( - where=LabelingFrontend.name == "editor" - ) - )[0] - ontology = OntologyBuilder() tools = [ Tool(tool=Tool.Type.BBOX, name="test-bbox-class"), Tool(tool=Tool.Type.LINE, name="test-line-class"), Tool(tool=Tool.Type.POINT, name="test-point-class"), Tool(tool=Tool.Type.POLYGON, name="test-polygon-class"), - Tool(tool=Tool.Type.NER, name="test-ner-class"), ] options = [ @@ -186,7 +191,12 @@ def configured_project_with_complex_ontology( for c in classifications: ontology.add_classification(c) - project.setup(editor, ontology.asdict()) + ontology = client.create_ontology( + "image ontology", + ontology.asdict(), + MediaType.Image, + ) + project.connect_ontology(ontology) yield [project, data_row] teardown_helpers.teardown_project_labels_ontology_feature_schemas(project) diff --git a/libs/labelbox/tests/integration/schema/test_user_group.py b/libs/labelbox/tests/integration/schema/test_user_group.py index 50aaad4a7..b1443b7e7 100644 --- a/libs/labelbox/tests/integration/schema/test_user_group.py +++ b/libs/labelbox/tests/integration/schema/test_user_group.py @@ -2,11 +2,11 @@ import faker import pytest - -from labelbox.exceptions import ( +from lbox.exceptions import ( ResourceCreationError, ResourceNotFoundError, ) + from labelbox.schema.user_group import UserGroup, UserGroupColor data = faker.Faker() diff --git a/libs/labelbox/tests/integration/test_batch.py b/libs/labelbox/tests/integration/test_batch.py index 8a8707175..f63b4c3d9 100644 --- a/libs/labelbox/tests/integration/test_batch.py +++ b/libs/labelbox/tests/integration/test_batch.py @@ -1,16 +1,16 @@ -import time from typing import List from uuid import uuid4 -import pytest -from labelbox import Dataset, Project -from labelbox.exceptions import ( - ProcessingWaitTimeout, +import pytest +from lbox.exceptions import ( + LabelboxError, MalformedQueryException, + ProcessingWaitTimeout, ResourceConflict, - LabelboxError, ) +from labelbox import Dataset, Project + def get_data_row_ids(ds: Dataset): return [dr.uid for dr in list(ds.data_rows())] diff --git a/libs/labelbox/tests/integration/test_benchmark.py b/libs/labelbox/tests/integration/test_benchmark.py index c10542bda..661f83bd7 100644 --- a/libs/labelbox/tests/integration/test_benchmark.py +++ b/libs/labelbox/tests/integration/test_benchmark.py @@ -1,17 +1,17 @@ def test_benchmark(configured_project_with_label): project, _, data_row, label = configured_project_with_label assert set(project.benchmarks()) == set() - assert label.is_benchmark_reference == False + assert label.is_benchmark_reference is False benchmark = label.create_benchmark() assert set(project.benchmarks()) == {benchmark} assert benchmark.reference_label() == label # Refresh label data to check it's benchmark reference label = list(data_row.labels())[0] - assert label.is_benchmark_reference == True + assert label.is_benchmark_reference is True benchmark.delete() assert set(project.benchmarks()) == set() # Refresh label data to check it's benchmark reference label = list(data_row.labels())[0] - assert label.is_benchmark_reference == False + assert label.is_benchmark_reference is False diff --git a/libs/labelbox/tests/integration/test_chat_evaluation_ontology_project.py b/libs/labelbox/tests/integration/test_chat_evaluation_ontology_project.py index 2c02b77ac..796cd9859 100644 --- a/libs/labelbox/tests/integration/test_chat_evaluation_ontology_project.py +++ b/libs/labelbox/tests/integration/test_chat_evaluation_ontology_project.py @@ -1,3 +1,5 @@ +from unittest.mock import patch + import pytest from labelbox import MediaType diff --git a/libs/labelbox/tests/integration/test_client_errors.py b/libs/labelbox/tests/integration/test_client_errors.py index 64b8fb626..38022d1d2 100644 --- a/libs/labelbox/tests/integration/test_client_errors.py +++ b/libs/labelbox/tests/integration/test_client_errors.py @@ -1,44 +1,46 @@ -from multiprocessing.dummy import Pool import os import time +from multiprocessing.dummy import Pool + +import lbox.exceptions import pytest from google.api_core.exceptions import RetryError -from labelbox import Project, Dataset, User import labelbox.client -import labelbox.exceptions +from labelbox import Project, User +from labelbox.schema.media_type import MediaType def test_missing_api_key(): - key = os.environ.get(labelbox.client._LABELBOX_API_KEY, None) + key = os.environ.get(lbox.request_client._LABELBOX_API_KEY, None) if key is not None: - del os.environ[labelbox.client._LABELBOX_API_KEY] + del os.environ[lbox.request_client._LABELBOX_API_KEY] - with pytest.raises(labelbox.exceptions.AuthenticationError) as excinfo: + with pytest.raises(lbox.exceptions.AuthenticationError) as excinfo: labelbox.client.Client() assert excinfo.value.message == "Labelbox API key not provided" if key is not None: - os.environ[labelbox.client._LABELBOX_API_KEY] = key + os.environ[lbox.request_client._LABELBOX_API_KEY] = key def test_bad_key(rand_gen): bad_key = "BAD_KEY_" + rand_gen(str) client = labelbox.client.Client(api_key=bad_key) - with pytest.raises(labelbox.exceptions.AuthenticationError) as excinfo: - client.create_project(name=rand_gen(str)) + with pytest.raises(lbox.exceptions.AuthenticationError) as excinfo: + client.create_project(name=rand_gen(str), media_type=MediaType.Image) def test_syntax_error(client): - with pytest.raises(labelbox.exceptions.InvalidQueryError) as excinfo: + with pytest.raises(lbox.exceptions.InvalidQueryError) as excinfo: client.execute("asda", check_naming=False) assert excinfo.value.message.startswith("Syntax Error:") def test_semantic_error(client): - with pytest.raises(labelbox.exceptions.InvalidQueryError) as excinfo: + with pytest.raises(lbox.exceptions.InvalidQueryError) as excinfo: client.execute("query {bbb {id}}", check_naming=False) assert excinfo.value.message.startswith('Cannot query field "bbb"') @@ -58,7 +60,7 @@ def test_timeout_error(client, project): def test_query_complexity_error(client): - with pytest.raises(labelbox.exceptions.ValidationFailedError) as excinfo: + with pytest.raises(lbox.exceptions.ValidationFailedError) as excinfo: client.execute( "{projects {datasets {dataRows {labels {id}}}}}", check_naming=False ) @@ -66,41 +68,17 @@ def test_query_complexity_error(client): def test_resource_not_found_error(client): - with pytest.raises(labelbox.exceptions.ResourceNotFoundError): + with pytest.raises(lbox.exceptions.ResourceNotFoundError): client.get_project("invalid project ID") def test_network_error(client): client = labelbox.client.Client( - api_key=client.api_key, endpoint="not_a_valid_URL" + api_key=client._request_client.api_key, endpoint="not_a_valid_URL" ) - with pytest.raises(labelbox.exceptions.NetworkError) as excinfo: - client.create_project(name="Project name") - - -def test_invalid_attribute_error( - client, - rand_gen, -): - # Creation - with pytest.raises(labelbox.exceptions.InvalidAttributeError) as excinfo: - client.create_project(name="Name", invalid_field="Whatever") - assert excinfo.value.db_object_type == Project - assert excinfo.value.field == "invalid_field" - - # Update - project = client.create_project(name=rand_gen(str)) - with pytest.raises(labelbox.exceptions.InvalidAttributeError) as excinfo: - project.update(invalid_field="Whatever") - assert excinfo.value.db_object_type == Project - assert excinfo.value.field == "invalid_field" - - # Top-level-fetch - with pytest.raises(labelbox.exceptions.InvalidAttributeError) as excinfo: - client.get_projects(where=User.email == "email") - assert excinfo.value.db_object_type == Project - assert excinfo.value.field == {User.email} + with pytest.raises(lbox.exceptions.NetworkError) as excinfo: + client.create_project(name="Project name", media_type=MediaType.Image) @pytest.mark.skip("timeouts cause failure before rate limit") @@ -108,7 +86,7 @@ def test_api_limit_error(client): def get(arg): try: return client.get_user() - except labelbox.exceptions.ApiLimitError as e: + except lbox.exceptions.ApiLimitError as e: return e # Rate limited at 1500 + buffer @@ -120,7 +98,7 @@ def get(arg): elapsed = time.time() - start assert elapsed < 60, "Didn't finish fast enough" - assert labelbox.exceptions.ApiLimitError in {type(r) for r in results} + assert lbox.exceptions.ApiLimitError in {type(r) for r in results} # Sleep at the end of this test to allow other tests to execute. time.sleep(60) diff --git a/libs/labelbox/tests/integration/test_data_row_delete_metadata.py b/libs/labelbox/tests/integration/test_data_row_delete_metadata.py index 2df860181..a2ffd31ba 100644 --- a/libs/labelbox/tests/integration/test_data_row_delete_metadata.py +++ b/libs/labelbox/tests/integration/test_data_row_delete_metadata.py @@ -1,13 +1,13 @@ -from datetime import datetime, timezone import uuid +from datetime import datetime, timezone import pytest +from lbox.exceptions import MalformedQueryException -from labelbox import DataRow, Dataset, Client, DataRowMetadataOntology -from labelbox.exceptions import MalformedQueryException +from labelbox import Client, DataRow, DataRowMetadataOntology, Dataset from labelbox.schema.data_row_metadata import ( - DataRowMetadataField, DataRowMetadata, + DataRowMetadataField, DataRowMetadataKind, DeleteDataRowMetadata, ) diff --git a/libs/labelbox/tests/integration/test_data_rows.py b/libs/labelbox/tests/integration/test_data_rows.py index baa65db69..485719575 100644 --- a/libs/labelbox/tests/integration/test_data_rows.py +++ b/libs/labelbox/tests/integration/test_data_rows.py @@ -7,13 +7,13 @@ import pytest import requests - -from labelbox import AssetAttachment, DataRow -from labelbox.exceptions import ( +from lbox.exceptions import ( InvalidQueryError, MalformedQueryException, ResourceCreationError, ) + +from labelbox import AssetAttachment, DataRow from labelbox.schema.data_row_metadata import ( DataRowMetadataField, DataRowMetadataKind, @@ -105,7 +105,7 @@ def make_metadata_fields_dict(constants): def test_get_data_row_by_global_key(data_row_and_global_key, client, rand_gen): _, global_key = data_row_and_global_key data_row = client.get_data_row_by_global_key(global_key) - assert type(data_row) == DataRow + assert type(data_row) is DataRow assert data_row.global_key == global_key @@ -505,8 +505,6 @@ def test_create_data_rows_with_metadata( [ ("create_data_rows", "class"), ("create_data_rows", "dict"), - ("create_data_rows_sync", "class"), - ("create_data_rows_sync", "dict"), ("create_data_row", "class"), ("create_data_row", "dict"), ], @@ -546,7 +544,6 @@ def create_data_row(data_rows): CREATION_FUNCTION = { "create_data_rows": dataset.create_data_rows, - "create_data_rows_sync": dataset.create_data_rows_sync, "create_data_row": create_data_row, } data_rows = [METADATA_FIELDS[metadata_obj_type]] @@ -695,9 +692,10 @@ def test_data_row_update( pdf_url = "https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483.pdf" tileLayerUrl = "https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483-lb-textlayer.json" data_row.update(row_data={"pdfUrl": pdf_url, "tileLayerUrl": tileLayerUrl}) - custom_check = ( - lambda data_row: data_row.row_data and "pdfUrl" not in data_row.row_data - ) + + def custom_check(data_row): + return data_row.row_data and "pdfUrl" not in data_row.row_data + data_row = wait_for_data_row_processing( client, data_row, custom_check=custom_check ) @@ -821,49 +819,6 @@ def test_data_row_attachments(dataset, image_url): ) -def test_create_data_rows_sync_attachments(dataset, image_url): - attachments = [ - ("IMAGE", image_url, "image URL"), - ("RAW_TEXT", "test-text", None), - ("IMAGE_OVERLAY", image_url, "Overlay"), - ("HTML", image_url, None), - ] - attachments_per_data_row = 3 - dataset.create_data_rows_sync( - [ - { - "row_data": image_url, - "external_id": "test-id", - "attachments": [ - { - "type": attachment_type, - "value": attachment_value, - "name": attachment_name, - } - for _ in range(attachments_per_data_row) - ], - } - for attachment_type, attachment_value, attachment_name in attachments - ] - ) - data_rows = list(dataset.data_rows()) - assert len(data_rows) == len(attachments) - for data_row in data_rows: - assert len(list(data_row.attachments())) == attachments_per_data_row - - -def test_create_data_rows_sync_mixed_upload(dataset, image_url): - n_local = 100 - n_urls = 100 - with NamedTemporaryFile() as fp: - fp.write("Test data".encode()) - fp.flush() - dataset.create_data_rows_sync( - [{DataRow.row_data: image_url}] * n_urls + [fp.name] * n_local - ) - assert len(list(dataset.data_rows())) == n_local + n_urls - - def test_create_data_row_attachment(data_row): att = data_row.create_attachment( "IMAGE", "https://example.com/image.jpg", "name" @@ -1047,9 +1002,9 @@ def test_data_row_bulk_creation_with_same_global_keys( task.wait_till_done() assert task.status == "COMPLETE" - assert type(task.failed_data_rows) is list + assert isinstance(task.failed_data_rows, list) assert len(task.failed_data_rows) == 1 - assert type(task.created_data_rows) is list + assert isinstance(task.created_data_rows, list) assert len(task.created_data_rows) == 1 assert ( task.failed_data_rows[0]["message"] @@ -1109,53 +1064,6 @@ def test_data_row_delete_and_create_with_same_global_key( assert task.result[0]["global_key"] == global_key_1 -def test_data_row_bulk_creation_sync_with_unique_global_keys( - dataset, sample_image -): - global_key_1 = str(uuid.uuid4()) - global_key_2 = str(uuid.uuid4()) - global_key_3 = str(uuid.uuid4()) - - dataset.create_data_rows_sync( - [ - {DataRow.row_data: sample_image, DataRow.global_key: global_key_1}, - {DataRow.row_data: sample_image, DataRow.global_key: global_key_2}, - {DataRow.row_data: sample_image, DataRow.global_key: global_key_3}, - ] - ) - - assert {row.global_key for row in dataset.data_rows()} == { - global_key_1, - global_key_2, - global_key_3, - } - - -def test_data_row_bulk_creation_sync_with_same_global_keys( - dataset, sample_image -): - global_key_1 = str(uuid.uuid4()) - - with pytest.raises(ResourceCreationError) as exc_info: - dataset.create_data_rows_sync( - [ - { - DataRow.row_data: sample_image, - DataRow.global_key: global_key_1, - }, - { - DataRow.row_data: sample_image, - DataRow.global_key: global_key_1, - }, - ] - ) - - assert len(list(dataset.data_rows())) == 1 - assert list(dataset.data_rows())[0].global_key == global_key_1 - assert "Duplicate global key" in str(exc_info.value) - assert exc_info.value.args[1] # task id - - @pytest.fixture def conversational_data_rows(dataset, conversational_content): examples = [ @@ -1197,7 +1105,7 @@ def test_invalid_media_type(dataset, conversational_content): # TODO: What error kind should this be? It looks like for global key we are # using malformed query. But for invalid contents in FileUploads we use InvalidQueryError with pytest.raises(ResourceCreationError): - dataset.create_data_rows_sync( + dataset._create_data_rows_sync( [{**conversational_content, "media_type": "IMAGE"}] ) @@ -1207,7 +1115,8 @@ def test_create_tiled_layer(dataset, tile_content): {**tile_content, "media_type": "TMS_GEO"}, tile_content, ] - dataset.create_data_rows_sync(examples) + task = dataset.create_data_rows(examples) + task.wait_until_done() data_rows = list(dataset.data_rows()) assert len(data_rows) == len(examples) for data_row in data_rows: diff --git a/libs/labelbox/tests/integration/test_dataset.py b/libs/labelbox/tests/integration/test_dataset.py index 89210d6c9..a32c5541d 100644 --- a/libs/labelbox/tests/integration/test_dataset.py +++ b/libs/labelbox/tests/integration/test_dataset.py @@ -1,9 +1,10 @@ +from unittest.mock import MagicMock + import pytest import requests -from unittest.mock import MagicMock -from labelbox import Dataset -from labelbox.exceptions import ResourceNotFoundError, ResourceCreationError +from lbox.exceptions import ResourceCreationError, ResourceNotFoundError +from labelbox import Dataset from labelbox.schema.internal.descriptor_file_creator import ( DescriptorFileCreator, ) diff --git a/libs/labelbox/tests/integration/test_dates.py b/libs/labelbox/tests/integration/test_dates.py index 7bde0d666..f4e2364c7 100644 --- a/libs/labelbox/tests/integration/test_dates.py +++ b/libs/labelbox/tests/integration/test_dates.py @@ -22,6 +22,8 @@ def test_utc_conversion(project): # Update with a datetime with TZ info tz = timezone(timedelta(hours=6)) # +6 timezone - project.update(setup_complete=datetime.utcnow().replace(tzinfo=tz)) - diff = datetime.utcnow() - project.setup_complete.replace(tzinfo=None) + project.update(setup_complete=datetime.now(timezone.utc).replace(tzinfo=tz)) + diff = datetime.now(timezone.utc) - project.setup_complete.replace( + tzinfo=timezone.utc + ) assert diff > timedelta(hours=5, minutes=58) diff --git a/libs/labelbox/tests/integration/test_embedding.py b/libs/labelbox/tests/integration/test_embedding.py index 1b54ab81c..41f7ed3de 100644 --- a/libs/labelbox/tests/integration/test_embedding.py +++ b/libs/labelbox/tests/integration/test_embedding.py @@ -2,12 +2,12 @@ import random import threading from tempfile import NamedTemporaryFile -from typing import List, Dict, Any +from typing import Any, Dict, List +import lbox.exceptions import pytest -import labelbox.exceptions -from labelbox import Client, Dataset, DataRow +from labelbox import Client, DataRow, Dataset from labelbox.schema.embedding import Embedding @@ -23,7 +23,7 @@ def test_get_embedding_by_id(client: Client, embedding: Embedding): def test_get_embedding_by_name_not_found(client: Client): - with pytest.raises(labelbox.exceptions.ResourceNotFoundError): + with pytest.raises(lbox.exceptions.ResourceNotFoundError): client.get_embedding_by_name("does-not-exist") diff --git a/libs/labelbox/tests/integration/test_ephemeral.py b/libs/labelbox/tests/integration/test_ephemeral.py index a23572fdf..3c4fc62e4 100644 --- a/libs/labelbox/tests/integration/test_ephemeral.py +++ b/libs/labelbox/tests/integration/test_ephemeral.py @@ -7,7 +7,7 @@ reason="This test only runs in EPHEMERAL environment", ) def test_org_and_user_setup(client, ephmeral_client): - assert type(client) == ephmeral_client + assert type(client) is ephmeral_client assert client.admin_client assert client.api_key != client.admin_client.api_key @@ -22,4 +22,4 @@ def test_org_and_user_setup(client, ephmeral_client): reason="This test does not run in EPHEMERAL environment", ) def test_integration_client(client, integration_client): - assert type(client) == integration_client + assert type(client) is integration_client diff --git a/libs/labelbox/tests/integration/test_filtering.py b/libs/labelbox/tests/integration/test_filtering.py index 1c37227b2..bba483b19 100644 --- a/libs/labelbox/tests/integration/test_filtering.py +++ b/libs/labelbox/tests/integration/test_filtering.py @@ -1,8 +1,8 @@ import pytest +from lbox.exceptions import InvalidQueryError from labelbox import Project -from labelbox.exceptions import InvalidQueryError -from labelbox.schema.queue_mode import QueueMode +from labelbox.schema.media_type import MediaType @pytest.fixture @@ -11,9 +11,9 @@ def project_to_test_where(client, rand_gen): p_b_name = f"b-{rand_gen(str)}" p_c_name = f"c-{rand_gen(str)}" - p_a = client.create_project(name=p_a_name, queue_mode=QueueMode.Batch) - p_b = client.create_project(name=p_b_name, queue_mode=QueueMode.Batch) - p_c = client.create_project(name=p_c_name, queue_mode=QueueMode.Batch) + p_a = client.create_project(name=p_a_name, media_type=MediaType.Image) + p_b = client.create_project(name=p_b_name, media_type=MediaType.Image) + p_c = client.create_project(name=p_c_name, media_type=MediaType.Image) yield p_a, p_b, p_c diff --git a/libs/labelbox/tests/integration/test_foundry.py b/libs/labelbox/tests/integration/test_foundry.py index 83c4effc5..b9fd1b6f3 100644 --- a/libs/labelbox/tests/integration/test_foundry.py +++ b/libs/labelbox/tests/integration/test_foundry.py @@ -1,7 +1,8 @@ -import labelbox as lb import pytest -from labelbox.schema.foundry.app import App +from lbox.exceptions import LabelboxError, ResourceNotFoundError +import labelbox as lb +from labelbox.schema.foundry.app import App from labelbox.schema.foundry.foundry_client import FoundryClient # Yolo object detection model id @@ -97,7 +98,7 @@ def test_get_app(foundry_client, app): def test_get_app_with_invalid_id(foundry_client): - with pytest.raises(lb.exceptions.ResourceNotFoundError): + with pytest.raises(ResourceNotFoundError): foundry_client._get_app("invalid-id") @@ -144,7 +145,7 @@ def test_run_foundry_app_returns_model_run_id( def test_run_foundry_with_invalid_data_row_id(foundry_client, app, random_str): invalid_datarow_id = "invalid-global-key" data_rows = lb.GlobalKeys([invalid_datarow_id]) - with pytest.raises(lb.exceptions.LabelboxError) as exception: + with pytest.raises(LabelboxError) as exception: foundry_client.run_app( model_run_name=f"test-app-with-invalid-datarow-id-{random_str}", data_rows=data_rows, @@ -156,7 +157,7 @@ def test_run_foundry_with_invalid_data_row_id(foundry_client, app, random_str): def test_run_foundry_with_invalid_global_key(foundry_client, app, random_str): invalid_global_key = "invalid-global-key" data_rows = lb.GlobalKeys([invalid_global_key]) - with pytest.raises(lb.exceptions.LabelboxError) as exception: + with pytest.raises(LabelboxError) as exception: foundry_client.run_app( model_run_name=f"test-app-with-invalid-global-key-{random_str}", data_rows=data_rows, diff --git a/libs/labelbox/tests/integration/test_labeling_frontend.py b/libs/labelbox/tests/integration/test_labeling_frontend.py index d6ea1aac9..9a72fed47 100644 --- a/libs/labelbox/tests/integration/test_labeling_frontend.py +++ b/libs/labelbox/tests/integration/test_labeling_frontend.py @@ -1,7 +1,7 @@ import pytest +from lbox.exceptions import OperationNotSupportedException from labelbox import LabelingFrontend -from labelbox.exceptions import OperationNotSupportedException def test_get_labeling_frontends(client): diff --git a/libs/labelbox/tests/integration/test_labeling_parameter_overrides.py b/libs/labelbox/tests/integration/test_labeling_parameter_overrides.py index bd14040de..afa038482 100644 --- a/libs/labelbox/tests/integration/test_labeling_parameter_overrides.py +++ b/libs/labelbox/tests/integration/test_labeling_parameter_overrides.py @@ -23,7 +23,7 @@ def test_labeling_parameter_overrides(consensus_project_with_batch): data_rows[2].uid, } - data = [(data_rows[0], 4, 2), (data_rows[1], 3)] + data = [(UniqueId(data_rows[0].uid), 4, 2), (UniqueId(data_rows[1].uid), 3)] success = project.set_labeling_parameter_overrides(data) assert success @@ -60,7 +60,7 @@ def test_labeling_parameter_overrides(consensus_project_with_batch): assert {o.priority for o in updated_overrides} == {2, 3, 4} with pytest.raises(TypeError) as exc_info: - data = [(data_rows[2], "a_string", 3)] + data = [(UniqueId(data_rows[2].uid), "a_string", 3)] project.set_labeling_parameter_overrides(data) assert ( str(exc_info.value) @@ -72,7 +72,7 @@ def test_labeling_parameter_overrides(consensus_project_with_batch): project.set_labeling_parameter_overrides(data) assert ( str(exc_info.value) - == f"Data row identifier should be be of type DataRow, UniqueId or GlobalKey. Found for data_row_identifier {data_rows[2].uid}" + == "Data row identifier should be of type DataRowIdentifier. Found ." ) @@ -85,13 +85,6 @@ def test_set_labeling_priority(consensus_project_with_batch): assert len(init_labeling_parameter_overrides) == 3 assert {o.priority for o in init_labeling_parameter_overrides} == {5, 5, 5} - data = [data_row.uid for data_row in data_rows] - success = project.update_data_row_labeling_priority(data, 1) - lo = list(project.labeling_parameter_overrides()) - assert success - assert len(lo) == 3 - assert {o.priority for o in lo} == {1, 1, 1} - data = [data_row.uid for data_row in data_rows] success = project.update_data_row_labeling_priority(UniqueIds(data), 2) lo = list(project.labeling_parameter_overrides()) diff --git a/libs/labelbox/tests/integration/test_labeling_service.py b/libs/labelbox/tests/integration/test_labeling_service.py index 03a5694a7..e8b3d4cdc 100644 --- a/libs/labelbox/tests/integration/test_labeling_service.py +++ b/libs/labelbox/tests/integration/test_labeling_service.py @@ -1,9 +1,9 @@ import pytest - -from labelbox.exceptions import ( - MalformedQueryException, +from lbox.exceptions import ( + LabelboxError, ResourceNotFoundError, ) + from labelbox.schema.labeling_service import LabelingServiceStatus @@ -53,7 +53,7 @@ def test_request_labeling_service_moe_project( labeling_service = project.get_labeling_service() with pytest.raises( - MalformedQueryException, + LabelboxError, match='[{"errorType":"PROJECT_MODEL_CONFIG","errorMessage":"Project model config is not completed"}]', ): labeling_service.request() @@ -75,5 +75,5 @@ def test_request_labeling_service_incomplete_requirements(ontology, project): ): # No labeling service by default labeling_service.request() project.connect_ontology(ontology) - with pytest.raises(MalformedQueryException): + with pytest.raises(LabelboxError): labeling_service.request() diff --git a/libs/labelbox/tests/integration/test_legacy_project.py b/libs/labelbox/tests/integration/test_legacy_project.py index 320a2191d..3e652f333 100644 --- a/libs/labelbox/tests/integration/test_legacy_project.py +++ b/libs/labelbox/tests/integration/test_legacy_project.py @@ -1,40 +1,13 @@ -import pytest - -from labelbox.schema.queue_mode import QueueMode - - -def test_project_dataset(client, rand_gen): - with pytest.raises( - ValueError, - match="Dataset queue mode is deprecated. Please prefer Batch queue mode.", - ): - client.create_project( - name=rand_gen(str), - queue_mode=QueueMode.Dataset, - ) +from os import name +import pytest +from pydantic import ValidationError -def test_project_auto_audit_parameters(client, rand_gen): - with pytest.raises( - ValueError, - match="quality_modes must be set instead of auto_audit_percentage or auto_audit_number_of_labels.", - ): - client.create_project(name=rand_gen(str), auto_audit_percentage=0.5) - - with pytest.raises( - ValueError, - match="quality_modes must be set instead of auto_audit_percentage or auto_audit_number_of_labels.", - ): - client.create_project(name=rand_gen(str), auto_audit_number_of_labels=2) +from labelbox.schema.media_type import MediaType def test_project_name_parameter(client, rand_gen): with pytest.raises( - ValueError, match="project name must be a valid string." - ): - client.create_project() - - with pytest.raises( - ValueError, match="project name must be a valid string." + ValidationError, match="project name must be a valid string" ): - client.create_project(name=" ") + client.create_project(name=" ", media_type=MediaType.Image) diff --git a/libs/labelbox/tests/integration/test_model_config.py b/libs/labelbox/tests/integration/test_model_config.py index 7a060b917..66912e8d9 100644 --- a/libs/labelbox/tests/integration/test_model_config.py +++ b/libs/labelbox/tests/integration/test_model_config.py @@ -1,5 +1,5 @@ import pytest -from labelbox.exceptions import ResourceNotFoundError +from lbox.exceptions import ResourceNotFoundError def test_create_model_config(client, valid_model_id): diff --git a/libs/labelbox/tests/integration/test_ontology.py b/libs/labelbox/tests/integration/test_ontology.py index 91ef74a39..c7c7c270c 100644 --- a/libs/labelbox/tests/integration/test_ontology.py +++ b/libs/labelbox/tests/integration/test_ontology.py @@ -1,11 +1,10 @@ -import pytest - -from labelbox import OntologyBuilder, MediaType, Tool -from labelbox.orm.model import Entity import json import time -from labelbox.schema.queue_mode import QueueMode +import pytest + +from labelbox import MediaType, OntologyBuilder, Tool +from labelbox.orm.model import Entity def test_feature_schema_is_not_archived(client, ontology): @@ -13,7 +12,7 @@ def test_feature_schema_is_not_archived(client, ontology): result = client.is_feature_schema_archived( ontology.uid, feature_schema_to_check["featureSchemaId"] ) - assert result == False + assert result is False def test_feature_schema_is_archived(client, configured_project_with_label): @@ -23,10 +22,10 @@ def test_feature_schema_is_archived(client, configured_project_with_label): result = client.delete_feature_schema_from_ontology( ontology.uid, feature_schema_id ) - assert result.archived == True and result.deleted == False + assert result.archived is True and result.deleted is False assert ( client.is_feature_schema_archived(ontology.uid, feature_schema_id) - == True + is True ) @@ -58,8 +57,8 @@ def test_delete_tool_feature_from_ontology(client, ontology): result = client.delete_feature_schema_from_ontology( ontology.uid, feature_schema_to_delete["featureSchemaId"] ) - assert result.deleted == True - assert result.archived == False + assert result.deleted is True + assert result.archived is False updatedOntology = client.get_ontology(ontology.uid) assert len(updatedOntology.normalized["tools"]) == 1 @@ -99,7 +98,6 @@ def test_deletes_an_ontology(client): def test_cant_delete_an_ontology_with_project(client): project = client.create_project( name="test project", - queue_mode=QueueMode.Batch, media_type=MediaType.Image, ) tool = client.upsert_feature_schema(point.asdict()) @@ -187,7 +185,6 @@ def test_does_not_include_used_ontologies(client): ) project = client.create_project( name="test project", - queue_mode=QueueMode.Batch, media_type=MediaType.Image, ) project.connect_ontology(ontology_with_project) @@ -300,7 +297,7 @@ def test_unarchive_feature_schema_node(client, ontology): result = client.unarchive_feature_schema_node( ontology.uid, feature_schema_to_unarchive["featureSchemaId"] ) - assert result == None + assert result is None def test_unarchive_feature_schema_node_for_non_existing_feature_schema( diff --git a/libs/labelbox/tests/integration/test_project.py b/libs/labelbox/tests/integration/test_project.py index a38fa2b5d..ea995c6f6 100644 --- a/libs/labelbox/tests/integration/test_project.py +++ b/libs/labelbox/tests/integration/test_project.py @@ -1,21 +1,22 @@ -import time import os +import time import uuid + +from labelbox.schema.ontology import OntologyBuilder, Tool import pytest import requests +from lbox.exceptions import InvalidQueryError -from labelbox import Project, LabelingFrontend, Dataset -from labelbox.exceptions import InvalidQueryError +from labelbox import Dataset, LabelingFrontend, Project +from labelbox.schema import media_type from labelbox.schema.media_type import MediaType from labelbox.schema.quality_mode import QualityMode -from labelbox.schema.queue_mode import QueueMode def test_project(client, rand_gen): data = { "name": rand_gen(str), "description": rand_gen(str), - "queue_mode": QueueMode.Batch.Batch, "media_type": MediaType.Image, } project = client.create_project(**data) @@ -50,7 +51,7 @@ def data_for_project_test(client, rand_gen): def _create_project(name: str = None): if name is None: name = rand_gen(str) - project = client.create_project(name=name) + project = client.create_project(name=name, media_type=MediaType.Image) projects.append(project) return project @@ -139,10 +140,6 @@ def test_extend_reservations(project): project.extend_reservations("InvalidQueueType") -@pytest.mark.skipif( - condition=os.environ["LABELBOX_TEST_ENVIRON"] == "onprem", - reason="new mutation does not work for onprem", -) def test_attach_instructions(client, project): with pytest.raises(ValueError) as execinfo: project.upsert_instructions("tests/integration/media/sample_pdf.pdf") @@ -150,11 +147,17 @@ def test_attach_instructions(client, project): str(execinfo.value) == "Cannot attach instructions to a project that has not been set up." ) - editor = list( - client.get_labeling_frontends(where=LabelingFrontend.name == "editor") - )[0] - empty_ontology = {"tools": [], "classifications": []} - project.setup(editor, empty_ontology) + ontology_builder = OntologyBuilder( + tools=[ + Tool(tool=Tool.Type.BBOX, name="test-bbox-class"), + ] + ) + ontology = client.create_ontology( + name="ontology with features", + media_type=MediaType.Image, + normalized=ontology_builder.asdict(), + ) + project.connect_ontology(ontology) project.upsert_instructions("tests/integration/media/sample_pdf.pdf") time.sleep(3) @@ -171,24 +174,20 @@ def test_attach_instructions(client, project): condition=os.environ["LABELBOX_TEST_ENVIRON"] == "onprem", reason="new mutation does not work for onprem", ) -def test_html_instructions(project_with_empty_ontology): +def test_html_instructions(project_with_one_feature_ontology): html_file_path = "/tmp/instructions.html" sample_html_str = "" with open(html_file_path, "w") as file: file.write(sample_html_str) - project_with_empty_ontology.upsert_instructions(html_file_path) - updated_ontology = project_with_empty_ontology.ontology().normalized + project_with_one_feature_ontology.upsert_instructions(html_file_path) + updated_ontology = project_with_one_feature_ontology.ontology().normalized instructions = updated_ontology.pop("projectInstructions") assert requests.get(instructions).text == sample_html_str -@pytest.mark.skipif( - condition=os.environ["LABELBOX_TEST_ENVIRON"] == "onprem", - reason="new mutation does not work for onprem", -) def test_same_ontology_after_instructions( configured_project_with_complex_ontology, ): @@ -247,9 +246,11 @@ def test_media_type(client, project: Project, rand_gen): assert isinstance(project.media_type, MediaType) # Update test - project = client.create_project(name=rand_gen(str)) - project.update(media_type=MediaType.Image) - assert project.media_type == MediaType.Image + project = client.create_project( + name=rand_gen(str), media_type=MediaType.Image + ) + project.update(media_type=MediaType.Text) + assert project.media_type == MediaType.Text project.delete() for media_type in MediaType.get_supported_members(): @@ -270,13 +271,16 @@ def test_media_type(client, project: Project, rand_gen): def test_queue_mode(client, rand_gen): project = client.create_project( - name=rand_gen(str) + name=rand_gen(str), + media_type=MediaType.Image, ) # defaults to benchmark and consensus assert project.auto_audit_number_of_labels == 3 assert project.auto_audit_percentage == 0 project = client.create_project( - name=rand_gen(str), quality_modes=[QualityMode.Benchmark] + name=rand_gen(str), + quality_modes=[QualityMode.Benchmark], + media_type=MediaType.Image, ) assert project.auto_audit_number_of_labels == 1 assert project.auto_audit_percentage == 1 @@ -284,13 +288,16 @@ def test_queue_mode(client, rand_gen): project = client.create_project( name=rand_gen(str), quality_modes=[QualityMode.Benchmark, QualityMode.Consensus], + media_type=MediaType.Image, ) assert project.auto_audit_number_of_labels == 3 assert project.auto_audit_percentage == 0 def test_label_count(client, configured_batch_project_with_label): - project = client.create_project(name="test label count") + project = client.create_project( + name="test label count", media_type=MediaType.Image + ) assert project.get_label_count() == 0 project.delete() @@ -308,7 +315,6 @@ def test_clone(client, project, rand_gen): assert cloned_project.description == project.description assert cloned_project.media_type == project.media_type - assert cloned_project.queue_mode == project.queue_mode assert ( cloned_project.auto_audit_number_of_labels == project.auto_audit_number_of_labels diff --git a/libs/labelbox/tests/integration/test_project_model_config.py b/libs/labelbox/tests/integration/test_project_model_config.py index 975a39afe..f1646dfc0 100644 --- a/libs/labelbox/tests/integration/test_project_model_config.py +++ b/libs/labelbox/tests/integration/test_project_model_config.py @@ -1,6 +1,5 @@ import pytest - -from labelbox.exceptions import ResourceNotFoundError +from lbox.exceptions import ResourceNotFoundError def test_add_single_model_config(live_chat_evaluation_project, model_config): diff --git a/libs/labelbox/tests/integration/test_project_set_model_setup_complete.py b/libs/labelbox/tests/integration/test_project_set_model_setup_complete.py index 16a124945..30e179028 100644 --- a/libs/labelbox/tests/integration/test_project_set_model_setup_complete.py +++ b/libs/labelbox/tests/integration/test_project_set_model_setup_complete.py @@ -1,6 +1,5 @@ import pytest - -from labelbox.exceptions import LabelboxError, OperationNotAllowedException +from lbox.exceptions import LabelboxError, OperationNotAllowedException def test_live_chat_evaluation_project( diff --git a/libs/labelbox/tests/integration/test_project_setup.py b/libs/labelbox/tests/integration/test_project_setup.py index faadea228..a6fba03e0 100644 --- a/libs/labelbox/tests/integration/test_project_setup.py +++ b/libs/labelbox/tests/integration/test_project_setup.py @@ -1,11 +1,11 @@ -from datetime import datetime, timedelta, timezone import json import time +from datetime import datetime, timedelta, timezone import pytest +from lbox.exceptions import InvalidQueryError from labelbox import LabelingFrontend -from labelbox.exceptions import InvalidQueryError, ResourceConflict def simple_ontology(): @@ -24,35 +24,6 @@ def simple_ontology(): return {"tools": [], "classifications": classifications} -def test_project_setup(project) -> None: - client = project.client - labeling_frontends = list( - client.get_labeling_frontends(where=LabelingFrontend.name == "Editor") - ) - assert len(labeling_frontends) - labeling_frontend = labeling_frontends[0] - - time.sleep(3) - now = datetime.now().astimezone(timezone.utc) - - project.setup(labeling_frontend, simple_ontology()) - assert now - project.setup_complete <= timedelta(seconds=3) - assert now - project.last_activity_time <= timedelta(seconds=3) - - assert project.labeling_frontend() == labeling_frontend - options = list(project.labeling_frontend_options()) - assert len(options) == 1 - options = options[0] - # TODO ensure that LabelingFrontendOptions can be obtaind by ID - with pytest.raises(InvalidQueryError): - assert options.labeling_frontend() == labeling_frontend - assert options.project() == project - assert options.organization() == client.get_organization() - assert options.customization_options == json.dumps(simple_ontology()) - assert project.organization() == client.get_organization() - assert project.created_by() == client.get_user() - - def test_project_editor_setup(client, project, rand_gen): ontology_name = f"test_project_editor_setup_ontology_name-{rand_gen(str)}" ontology = client.create_ontology(ontology_name, simple_ontology()) diff --git a/libs/labelbox/tests/integration/test_prompt_response_generation_project.py b/libs/labelbox/tests/integration/test_prompt_response_generation_project.py index 1373ee470..f5003f061 100644 --- a/libs/labelbox/tests/integration/test_prompt_response_generation_project.py +++ b/libs/labelbox/tests/integration/test_prompt_response_generation_project.py @@ -1,9 +1,8 @@ -import pytest from unittest.mock import patch +import pytest + from labelbox import MediaType -from labelbox.schema.ontology_kind import OntologyKind -from labelbox.exceptions import MalformedQueryException @pytest.mark.parametrize( diff --git a/libs/labelbox/tests/integration/test_task_queue.py b/libs/labelbox/tests/integration/test_task_queue.py index 835f67219..0cd66cb62 100644 --- a/libs/labelbox/tests/integration/test_task_queue.py +++ b/libs/labelbox/tests/integration/test_task_queue.py @@ -68,7 +68,9 @@ def test_move_to_task(configured_batch_project_with_label): review_queue = next( tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE" ) - project.move_data_rows_to_task_queue([data_row.uid], review_queue.uid) + project.move_data_rows_to_task_queue( + UniqueIds([data_row.uid]), review_queue.uid + ) _validate_moved(project, "MANUAL_REVIEW_QUEUE", 1) review_queue = next( diff --git a/libs/labelbox/tests/integration/test_toggle_mal.py b/libs/labelbox/tests/integration/test_toggle_mal.py index 41dfbe395..566c4210c 100644 --- a/libs/labelbox/tests/integration/test_toggle_mal.py +++ b/libs/labelbox/tests/integration/test_toggle_mal.py @@ -1,9 +1,9 @@ def test_enable_model_assisted_labeling(project): response = project.enable_model_assisted_labeling() - assert response == True + assert response is True response = project.enable_model_assisted_labeling(True) - assert response == True + assert response is True response = project.enable_model_assisted_labeling(False) - assert response == False + assert response is False diff --git a/libs/labelbox/tests/unit/export_task/test_unit_file_converter.py b/libs/labelbox/tests/unit/export_task/test_unit_file_converter.py deleted file mode 100644 index 81e9eb60f..000000000 --- a/libs/labelbox/tests/unit/export_task/test_unit_file_converter.py +++ /dev/null @@ -1,77 +0,0 @@ -from unittest.mock import MagicMock - -from labelbox.schema.export_task import ( - Converter, - FileConverter, - Range, - StreamType, - _MetadataFileInfo, - _MetadataHeader, - _TaskContext, -) - - -class TestFileConverter: - def test_with_correct_ndjson(self, tmp_path, generate_random_ndjson): - directory = tmp_path / "file-converter" - directory.mkdir() - line_count = 10 - ndjson = generate_random_ndjson(line_count) - file_content = "\n".join(ndjson) + "\n" - input_args = Converter.ConverterInputArgs( - ctx=_TaskContext( - client=MagicMock(), - task_id="task-id", - stream_type=StreamType.RESULT, - metadata_header=_MetadataHeader( - total_size=len(file_content), total_lines=line_count - ), - ), - file_info=_MetadataFileInfo( - offsets=Range(start=0, end=len(file_content) - 1), - lines=Range(start=0, end=line_count - 1), - file="file.ndjson", - ), - raw_data=file_content, - ) - path = directory / "output.ndjson" - with FileConverter(file_path=path) as converter: - for output in converter.convert(input_args): - assert output.current_line == 0 - assert output.current_offset == 0 - assert output.file_path == path - assert output.total_lines == line_count - assert output.total_size == len(file_content) - assert output.bytes_written == len(file_content) - - def test_with_no_newline_at_end(self, tmp_path, generate_random_ndjson): - directory = tmp_path / "file-converter" - directory.mkdir() - line_count = 10 - ndjson = generate_random_ndjson(line_count) - file_content = "\n".join(ndjson) - input_args = Converter.ConverterInputArgs( - ctx=_TaskContext( - client=MagicMock(), - task_id="task-id", - stream_type=StreamType.RESULT, - metadata_header=_MetadataHeader( - total_size=len(file_content), total_lines=line_count - ), - ), - file_info=_MetadataFileInfo( - offsets=Range(start=0, end=len(file_content) - 1), - lines=Range(start=0, end=line_count - 1), - file="file.ndjson", - ), - raw_data=file_content, - ) - path = directory / "output.ndjson" - with FileConverter(file_path=path) as converter: - for output in converter.convert(input_args): - assert output.current_line == 0 - assert output.current_offset == 0 - assert output.file_path == path - assert output.total_lines == line_count - assert output.total_size == len(file_content) - assert output.bytes_written == len(file_content) diff --git a/libs/labelbox/tests/unit/export_task/test_unit_file_retriever_by_line.py b/libs/labelbox/tests/unit/export_task/test_unit_file_retriever_by_line.py deleted file mode 100644 index 37c93647e..000000000 --- a/libs/labelbox/tests/unit/export_task/test_unit_file_retriever_by_line.py +++ /dev/null @@ -1,126 +0,0 @@ -from unittest.mock import MagicMock, patch -from labelbox.schema.export_task import ( - FileRetrieverByLine, - _TaskContext, - _MetadataHeader, - StreamType, -) - - -class TestFileRetrieverByLine: - def test_by_line_from_start(self, generate_random_ndjson, mock_response): - line_count = 10 - ndjson = generate_random_ndjson(line_count) - file_content = "\n".join(ndjson) + "\n" - - mock_client = MagicMock() - mock_client.execute = MagicMock( - return_value={ - "task": { - "exportFileFromLine": { - "offsets": {"start": "0", "end": len(file_content) - 1}, - "lines": {"start": "0", "end": str(line_count - 1)}, - "file": "http://some-url.com/file.ndjson", - } - } - } - ) - - mock_ctx = _TaskContext( - client=mock_client, - task_id="task-id", - stream_type=StreamType.RESULT, - metadata_header=_MetadataHeader( - total_size=len(file_content), total_lines=line_count - ), - ) - - with patch("requests.get", return_value=mock_response(file_content)): - retriever = FileRetrieverByLine(mock_ctx, 0) - info, content = retriever.get_next_chunk() - assert info.offsets.start == 0 - assert info.offsets.end == len(file_content) - 1 - assert info.lines.start == 0 - assert info.lines.end == line_count - 1 - assert info.file == "http://some-url.com/file.ndjson" - assert content == file_content - - def test_by_line_from_middle(self, generate_random_ndjson, mock_response): - line_count = 10 - ndjson = generate_random_ndjson(line_count) - file_content = "\n".join(ndjson) + "\n" - - mock_client = MagicMock() - mock_client.execute = MagicMock( - return_value={ - "task": { - "exportFileFromLine": { - "offsets": {"start": "0", "end": len(file_content) - 1}, - "lines": {"start": "0", "end": str(line_count - 1)}, - "file": "http://some-url.com/file.ndjson", - } - } - } - ) - - mock_ctx = _TaskContext( - client=mock_client, - task_id="task-id", - stream_type=StreamType.RESULT, - metadata_header=_MetadataHeader( - total_size=len(file_content), total_lines=line_count - ), - ) - - line_start = 5 - current_offset = file_content.find(ndjson[line_start]) - - with patch("requests.get", return_value=mock_response(file_content)): - retriever = FileRetrieverByLine(mock_ctx, line_start) - info, content = retriever.get_next_chunk() - assert info.offsets.start == current_offset - assert info.offsets.end == len(file_content) - 1 - assert info.lines.start == line_start - assert info.lines.end == line_count - 1 - assert info.file == "http://some-url.com/file.ndjson" - assert content == file_content[current_offset:] - - def test_by_line_from_last(self, generate_random_ndjson, mock_response): - line_count = 10 - ndjson = generate_random_ndjson(line_count) - file_content = "\n".join(ndjson) + "\n" - - mock_client = MagicMock() - mock_client.execute = MagicMock( - return_value={ - "task": { - "exportFileFromLine": { - "offsets": {"start": "0", "end": len(file_content) - 1}, - "lines": {"start": "0", "end": str(line_count - 1)}, - "file": "http://some-url.com/file.ndjson", - } - } - } - ) - - mock_ctx = _TaskContext( - client=mock_client, - task_id="task-id", - stream_type=StreamType.RESULT, - metadata_header=_MetadataHeader( - total_size=len(file_content), total_lines=line_count - ), - ) - - line_start = 9 - current_offset = file_content.find(ndjson[line_start]) - - with patch("requests.get", return_value=mock_response(file_content)): - retriever = FileRetrieverByLine(mock_ctx, line_start) - info, content = retriever.get_next_chunk() - assert info.offsets.start == current_offset - assert info.offsets.end == len(file_content) - 1 - assert info.lines.start == line_start - assert info.lines.end == line_count - 1 - assert info.file == "http://some-url.com/file.ndjson" - assert content == file_content[current_offset:] diff --git a/libs/labelbox/tests/unit/export_task/test_unit_file_retriever_by_offset.py b/libs/labelbox/tests/unit/export_task/test_unit_file_retriever_by_offset.py deleted file mode 100644 index 870e03307..000000000 --- a/libs/labelbox/tests/unit/export_task/test_unit_file_retriever_by_offset.py +++ /dev/null @@ -1,87 +0,0 @@ -from unittest.mock import MagicMock, patch -from labelbox.schema.export_task import ( - FileRetrieverByOffset, - _TaskContext, - _MetadataHeader, - StreamType, -) - - -class TestFileRetrieverByOffset: - def test_by_offset_from_start(self, generate_random_ndjson, mock_response): - line_count = 10 - ndjson = generate_random_ndjson(line_count) - file_content = "\n".join(ndjson) + "\n" - - mock_client = MagicMock() - mock_client.execute = MagicMock( - return_value={ - "task": { - "exportFileFromOffset": { - "offsets": {"start": "0", "end": len(file_content) - 1}, - "lines": {"start": "0", "end": str(line_count - 1)}, - "file": "http://some-url.com/file.ndjson", - } - } - } - ) - - mock_ctx = _TaskContext( - client=mock_client, - task_id="task-id", - stream_type=StreamType.RESULT, - metadata_header=_MetadataHeader( - total_size=len(file_content), total_lines=line_count - ), - ) - - with patch("requests.get", return_value=mock_response(file_content)): - retriever = FileRetrieverByOffset(mock_ctx, 0) - info, content = retriever.get_next_chunk() - assert info.offsets.start == 0 - assert info.offsets.end == len(file_content) - 1 - assert info.lines.start == 0 - assert info.lines.end == line_count - 1 - assert info.file == "http://some-url.com/file.ndjson" - assert content == file_content - - def test_by_offset_from_middle(self, generate_random_ndjson, mock_response): - line_count = 10 - ndjson = generate_random_ndjson(line_count) - file_content = "\n".join(ndjson) + "\n" - - mock_client = MagicMock() - mock_client.execute = MagicMock( - return_value={ - "task": { - "exportFileFromOffset": { - "offsets": {"start": "0", "end": len(file_content) - 1}, - "lines": {"start": "0", "end": str(line_count - 1)}, - "file": "http://some-url.com/file.ndjson", - } - } - } - ) - - mock_ctx = _TaskContext( - client=mock_client, - task_id="task-id", - stream_type=StreamType.RESULT, - metadata_header=_MetadataHeader( - total_size=len(file_content), total_lines=line_count - ), - ) - - line_start = 5 - skipped_bytes = 15 - current_offset = file_content.find(ndjson[line_start]) + skipped_bytes - - with patch("requests.get", return_value=mock_response(file_content)): - retriever = FileRetrieverByOffset(mock_ctx, current_offset) - info, content = retriever.get_next_chunk() - assert info.offsets.start == current_offset - assert info.offsets.end == len(file_content) - 1 - assert info.lines.start == 5 - assert info.lines.end == line_count - 1 - assert info.file == "http://some-url.com/file.ndjson" - assert content == file_content[current_offset:] diff --git a/libs/labelbox/tests/unit/export_task/test_unit_json_converter.py b/libs/labelbox/tests/unit/export_task/test_unit_json_converter.py deleted file mode 100644 index f5ccf26fb..000000000 --- a/libs/labelbox/tests/unit/export_task/test_unit_json_converter.py +++ /dev/null @@ -1,112 +0,0 @@ -from unittest.mock import MagicMock - -from labelbox.schema.export_task import ( - Converter, - JsonConverter, - Range, - _MetadataFileInfo, -) - - -class TestJsonConverter: - def test_with_correct_ndjson(self, generate_random_ndjson): - line_count = 10 - ndjson = generate_random_ndjson(line_count) - file_content = "\n".join(ndjson) + "\n" - input_args = Converter.ConverterInputArgs( - ctx=MagicMock(), - file_info=_MetadataFileInfo( - offsets=Range(start=0, end=len(file_content) - 1), - lines=Range(start=0, end=line_count - 1), - file="file.ndjson", - ), - raw_data=file_content, - ) - with JsonConverter() as converter: - current_offset = 0 - for idx, output in enumerate(converter.convert(input_args)): - assert output.current_line == idx - assert output.current_offset == current_offset - assert output.json_str == ndjson[idx] - current_offset += len(output.json_str) + 1 - - def test_with_no_newline_at_end(self, generate_random_ndjson): - line_count = 10 - ndjson = generate_random_ndjson(line_count) - file_content = "\n".join(ndjson) - input_args = Converter.ConverterInputArgs( - ctx=MagicMock(), - file_info=_MetadataFileInfo( - offsets=Range(start=0, end=len(file_content) - 1), - lines=Range(start=0, end=line_count - 1), - file="file.ndjson", - ), - raw_data=file_content, - ) - with JsonConverter() as converter: - current_offset = 0 - for idx, output in enumerate(converter.convert(input_args)): - assert output.current_line == idx - assert output.current_offset == current_offset - assert output.json_str == ndjson[idx] - current_offset += len(output.json_str) + 1 - - def test_from_offset(self, generate_random_ndjson): - # testing middle of a JSON string, but not the last line - line_count = 10 - line_start = 5 - ndjson = generate_random_ndjson(line_count) - file_content = "\n".join(ndjson) + "\n" - offset_end = len(file_content) - skipped_bytes = 15 - current_offset = file_content.find(ndjson[line_start]) + skipped_bytes - file_content = file_content[current_offset:] - - input_args = Converter.ConverterInputArgs( - ctx=MagicMock(), - file_info=_MetadataFileInfo( - offsets=Range(start=current_offset, end=offset_end), - lines=Range(start=line_start, end=line_count - 1), - file="file.ndjson", - ), - raw_data=file_content, - ) - with JsonConverter() as converter: - for idx, output in enumerate(converter.convert(input_args)): - assert output.current_line == line_start + idx - assert output.current_offset == current_offset - assert ( - output.json_str == ndjson[line_start + idx][skipped_bytes:] - ) - current_offset += len(output.json_str) + 1 - skipped_bytes = 0 - - def test_from_offset_last_line(self, generate_random_ndjson): - # testing middle of a JSON string, but not the last line - line_count = 10 - line_start = 9 - ndjson = generate_random_ndjson(line_count) - file_content = "\n".join(ndjson) + "\n" - offset_end = len(file_content) - skipped_bytes = 15 - current_offset = file_content.find(ndjson[line_start]) + skipped_bytes - file_content = file_content[current_offset:] - - input_args = Converter.ConverterInputArgs( - ctx=MagicMock(), - file_info=_MetadataFileInfo( - offsets=Range(start=current_offset, end=offset_end), - lines=Range(start=line_start, end=line_count - 1), - file="file.ndjson", - ), - raw_data=file_content, - ) - with JsonConverter() as converter: - for idx, output in enumerate(converter.convert(input_args)): - assert output.current_line == line_start + idx - assert output.current_offset == current_offset - assert ( - output.json_str == ndjson[line_start + idx][skipped_bytes:] - ) - current_offset += len(output.json_str) + 1 - skipped_bytes = 0 diff --git a/libs/labelbox/tests/unit/schema/test_user_group.py b/libs/labelbox/tests/unit/schema/test_user_group.py index 65584f8ef..6bc29048d 100644 --- a/libs/labelbox/tests/unit/schema/test_user_group.py +++ b/libs/labelbox/tests/unit/schema/test_user_group.py @@ -1,20 +1,21 @@ -import pytest from collections import defaultdict from unittest.mock import MagicMock -from labelbox import Client -from labelbox.exceptions import ( + +import pytest +from lbox.exceptions import ( + MalformedQueryException, ResourceConflict, ResourceCreationError, ResourceNotFoundError, - MalformedQueryException, UnprocessableEntityError, ) + +from labelbox import Client +from labelbox.schema.media_type import MediaType +from labelbox.schema.ontology_kind import EditorTaskType from labelbox.schema.project import Project from labelbox.schema.user import User from labelbox.schema.user_group import UserGroup, UserGroupColor -from labelbox.schema.queue_mode import QueueMode -from labelbox.schema.ontology_kind import EditorTaskType -from labelbox.schema.media_type import MediaType @pytest.fixture @@ -30,7 +31,6 @@ def group_project(): project_values = defaultdict(lambda: None) project_values["id"] = "project_id" project_values["name"] = "Test Project" - project_values["queueMode"] = QueueMode.Batch.value project_values["editorTaskType"] = EditorTaskType.Missing.value project_values["mediaType"] = MediaType.Image.value return Project(MagicMock(Client), project_values) @@ -55,12 +55,6 @@ def setup_method(self): self.client.enable_experimental = True self.group = UserGroup(client=self.client) - def test_constructor_experimental_needed(self): - client = MagicMock(Client) - client.enable_experimental = False - with pytest.raises(RuntimeError): - group = UserGroup(client) - def test_constructor(self): group = UserGroup(self.client) diff --git a/libs/labelbox/tests/unit/test_exceptions.py b/libs/labelbox/tests/unit/test_exceptions.py index 4602fb984..074a735f2 100644 --- a/libs/labelbox/tests/unit/test_exceptions.py +++ b/libs/labelbox/tests/unit/test_exceptions.py @@ -1,6 +1,5 @@ import pytest - -from labelbox.exceptions import error_message_for_unparsed_graphql_error +from lbox.exceptions import error_message_for_unparsed_graphql_error @pytest.mark.parametrize( diff --git a/libs/labelbox/tests/unit/test_label_data_type.py b/libs/labelbox/tests/unit/test_label_data_type.py index 7bc32e37c..611324f78 100644 --- a/libs/labelbox/tests/unit/test_label_data_type.py +++ b/libs/labelbox/tests/unit/test_label_data_type.py @@ -1,11 +1,7 @@ -from email import message import pytest -from pydantic import ValidationError - from labelbox.data.annotation_types.data.generic_data_row_data import ( GenericDataRowData, ) -from labelbox.data.annotation_types.data.video import VideoData from labelbox.data.annotation_types.label import Label @@ -37,20 +33,6 @@ def test_generic_data_type_validations(): Label(data=data) -def test_video_data_type(): - data = { - "global_key": "https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-BEidMVWRmyXjVCnr", - } - with pytest.warns(UserWarning, match="Use a dict"): - label = Label(data=VideoData(**data)) - data = label.data - assert isinstance(data, VideoData) - assert ( - data.global_key - == "https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-BEidMVWRmyXjVCnr" - ) - - def test_generic_data_row(): data = { "global_key": "https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-BEidMVWRmyXjVCnr", diff --git a/libs/labelbox/tests/unit/test_project.py b/libs/labelbox/tests/unit/test_project.py index 5e5f99c57..1bc6fa840 100644 --- a/libs/labelbox/tests/unit/test_project.py +++ b/libs/labelbox/tests/unit/test_project.py @@ -21,7 +21,6 @@ def project_entity(): "editorTaskType": "MODEL_CHAT_EVALUATION", "lastActivityTime": "2021-06-01T00:00:00.000Z", "allowedMediaType": "IMAGE", - "queueMode": "BATCH", "setupComplete": "2021-06-01T00:00:00.000Z", "modelSetupComplete": None, "uploadType": "Auto", @@ -62,7 +61,6 @@ def test_project_editor_task_type( "editorTaskType": api_editor_task_type, "lastActivityTime": "2021-06-01T00:00:00.000Z", "allowedMediaType": "IMAGE", - "queueMode": "BATCH", "setupComplete": "2021-06-01T00:00:00.000Z", "modelSetupComplete": None, "uploadType": "Auto", @@ -72,13 +70,3 @@ def test_project_editor_task_type( ) assert project.editor_task_type == expected_editor_task_type - - -def test_setup_editor_using_connect_ontology(project_entity): - project = project_entity - ontology = MagicMock() - project.connect_ontology = MagicMock() - with patch("warnings.warn") as warn: - project.setup_editor(ontology) - warn.assert_called_once() - project.connect_ontology.assert_called_once_with(ontology) diff --git a/libs/labelbox/tests/unit/test_queue_mode.py b/libs/labelbox/tests/unit/test_queue_mode.py deleted file mode 100644 index a07b14a54..000000000 --- a/libs/labelbox/tests/unit/test_queue_mode.py +++ /dev/null @@ -1,20 +0,0 @@ -import pytest - -from labelbox.schema.queue_mode import QueueMode - - -def test_parse_deprecated_catalog(): - assert QueueMode("CATALOG") == QueueMode.Batch - - -def test_parse_batch(): - assert QueueMode("BATCH") == QueueMode.Batch - - -def test_parse_data_set(): - assert QueueMode("DATA_SET") == QueueMode.Dataset - - -def test_fails_for_unknown(): - with pytest.raises(ValueError): - QueueMode("foo") diff --git a/libs/labelbox/tests/unit/test_unit_ontology.py b/libs/labelbox/tests/unit/test_unit_ontology.py index 0566ad623..61c9f523a 100644 --- a/libs/labelbox/tests/unit/test_unit_ontology.py +++ b/libs/labelbox/tests/unit/test_unit_ontology.py @@ -1,8 +1,9 @@ +from itertools import product + import pytest +from lbox.exceptions import InconsistentOntologyException -from labelbox.exceptions import InconsistentOntologyException -from labelbox import Tool, Classification, Option, OntologyBuilder -from itertools import product +from labelbox import Classification, OntologyBuilder, Option, Tool _SAMPLE_ONTOLOGY = { "tools": [ @@ -197,7 +198,7 @@ def test_add_ontology_tool() -> None: assert len(o.tools) == 2 for tool in o.tools: - assert type(tool) == Tool + assert type(tool) is Tool with pytest.raises(InconsistentOntologyException) as exc: o.add_tool(Tool(tool=Tool.Type.BBOX, name="bounding box")) @@ -217,7 +218,7 @@ def test_add_ontology_classification() -> None: assert len(o.classifications) == 2 for classification in o.classifications: - assert type(classification) == Classification + assert type(classification) is Classification with pytest.raises(InconsistentOntologyException) as exc: o.add_classification( diff --git a/libs/labelbox/tests/unit/test_unit_project_validate_labeling_parameter_overrides.py b/libs/labelbox/tests/unit/test_unit_project_validate_labeling_parameter_overrides.py index 7f6d29d5a..8e6c38559 100644 --- a/libs/labelbox/tests/unit/test_unit_project_validate_labeling_parameter_overrides.py +++ b/libs/labelbox/tests/unit/test_unit_project_validate_labeling_parameter_overrides.py @@ -6,19 +6,6 @@ from labelbox.schema.project import validate_labeling_parameter_overrides -def test_validate_labeling_parameter_overrides_valid_data(): - mock_data_row = MagicMock(spec=DataRow) - mock_data_row.uid = "abc" - data = [(mock_data_row, 1), (UniqueId("efg"), 2), (GlobalKey("hij"), 3)] - validate_labeling_parameter_overrides(data) - - -def test_validate_labeling_parameter_overrides_invalid_data(): - data = [("abc", 1), (UniqueId("efg"), 2), (GlobalKey("hij"), 3)] - with pytest.raises(TypeError): - validate_labeling_parameter_overrides(data) - - def test_validate_labeling_parameter_overrides_invalid_priority(): mock_data_row = MagicMock(spec=DataRow) mock_data_row.uid = "abc" diff --git a/libs/lbox-clients/.gitignore b/libs/lbox-clients/.gitignore new file mode 100644 index 000000000..ae8554dec --- /dev/null +++ b/libs/lbox-clients/.gitignore @@ -0,0 +1,10 @@ +# python generated files +__pycache__/ +*.py[oc] +build/ +dist/ +wheels/ +*.egg-info + +# venv +.venv diff --git a/libs/lbox-clients/.python-version b/libs/lbox-clients/.python-version new file mode 100644 index 000000000..43077b246 --- /dev/null +++ b/libs/lbox-clients/.python-version @@ -0,0 +1 @@ +3.9.18 diff --git a/libs/lbox-clients/Dockerfile b/libs/lbox-clients/Dockerfile new file mode 100644 index 000000000..dd3b5147d --- /dev/null +++ b/libs/lbox-clients/Dockerfile @@ -0,0 +1,44 @@ +# https://github.com/ucyo/python-package-template/blob/master/Dockerfile +FROM python:3.9-slim as rye + +ENV LANG="C.UTF-8" \ + LC_ALL="C.UTF-8" \ + PATH="/home/python/.local/bin:/home/python/.rye/shims:$PATH" \ + PIP_NO_CACHE_DIR="false" \ + RYE_VERSION="0.34.0" \ + RYE_INSTALL_OPTION="--yes" \ + LABELBOX_TEST_ENVIRON="prod" + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + ca-certificates \ + curl \ + inotify-tools \ + make \ + # cv2 + libsm6 \ + libxext6 \ + ffmpeg \ + libfontconfig1 \ + libxrender1 \ + libgl1-mesa-glx \ + libgeos-dev \ + gcc \ + && rm -rf /var/lib/apt/lists/* + +RUN groupadd --gid 1000 python && \ + useradd --uid 1000 --gid python --shell /bin/bash --create-home python + +USER 1000 +WORKDIR /home/python/ + +RUN curl -sSf https://rye.astral.sh/get | bash - + +COPY --chown=python:python . /home/python/labelbox-python/ +WORKDIR /home/python/labelbox-python + +RUN rye config --set-bool behavior.global-python=true && \ + rye config --set-bool behavior.use-uv=true && \ + rye pin 3.9 && \ + rye sync + +CMD rye run unit && rye integration \ No newline at end of file diff --git a/libs/lbox-clients/README.md b/libs/lbox-clients/README.md new file mode 100644 index 000000000..a7bf3e94f --- /dev/null +++ b/libs/lbox-clients/README.md @@ -0,0 +1,9 @@ +# lbox-example + +This is an example module which can be cloned and reused to develop modules under the `lbox` namespace. + +## Module Status: Experimental + +**TLDR: This module may be removed or altered at any given time and there is no offical support.** + +Please see [here](https://docs.labelbox.com/docs/product-release-phases) for the formal definition of `Experimental`. \ No newline at end of file diff --git a/libs/lbox-clients/pyproject.toml b/libs/lbox-clients/pyproject.toml new file mode 100644 index 000000000..744cf7ee5 --- /dev/null +++ b/libs/lbox-clients/pyproject.toml @@ -0,0 +1,61 @@ +[project] +name = "lbox-clients" +version = "1.1.0" +description = "This module contains client sdk uses to conntect to the Labelbox API and backends" +authors = [ + { name = "Labelbox", email = "engineering@labelbox.com" } +] +dependencies = [ + "requests>=2.22.0", + "google-api-core>=1.22.1", +] +readme = "README.md" +requires-python = ">= 3.9" + +classifiers=[ + # How mature is this project? + "Development Status :: 5 - Production/Stable", + # Indicate who your project is intended for + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development :: Libraries", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "Intended Audience :: Education", + # Pick your license as you wish + "License :: OSI Approved :: Apache Software License", + # Specify the Python versions you support here. + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", +] +keywords = ["ml", "ai", "labelbox", "labeling", "llm", "machinelearning", "edu"] + +[project.urls] +Homepage = "https://labelbox.com/" +Documentation = "https://labelbox-python.readthedocs.io/en/latest/" +Repository = "https://github.com/Labelbox/labelbox-python" +Issues = "https://github.com/Labelbox/labelbox-python/issues" +Changelog = "https://github.com/Labelbox/labelbox-python/blob/develop/libs/labelbox/CHANGELOG.md" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.rye] +managed = true +dev-dependencies = [] + +[tool.rye.scripts] +unit = "pytest tests/unit" +integration = "python -c \"import sys; sys.exit(0)\"" + +[tool.hatch.metadata] +allow-direct-references = true + +[tool.hatch.build.targets.wheel] +packages = ["src/lbox"] + +[tool.pytest.ini_options] +addopts = "-rP -vvv --durations=20 --cov=lbox.example --import-mode=importlib" \ No newline at end of file diff --git a/libs/labelbox/src/labelbox/exceptions.py b/libs/lbox-clients/src/lbox/exceptions.py similarity index 96% rename from libs/labelbox/src/labelbox/exceptions.py rename to libs/lbox-clients/src/lbox/exceptions.py index 34cfeaf4d..493e9cb09 100644 --- a/libs/labelbox/src/labelbox/exceptions.py +++ b/libs/lbox-clients/src/lbox/exceptions.py @@ -16,7 +16,10 @@ def __init__(self, message, cause=None): self.cause = cause def __str__(self): - return self.message + str(self.args) + exception_message = self.message + if self.cause is not None: + exception_message += " (caused by: %s)" % self.cause + return exception_message class AuthenticationError(LabelboxError): diff --git a/libs/lbox-clients/src/lbox/request_client.py b/libs/lbox-clients/src/lbox/request_client.py new file mode 100644 index 000000000..c855413bd --- /dev/null +++ b/libs/lbox-clients/src/lbox/request_client.py @@ -0,0 +1,426 @@ +import inspect +import json +import logging +import os +import re +import sys +from datetime import datetime, timezone +from types import MappingProxyType +from typing import Callable, Dict, Optional, TypedDict + +import requests +import requests.exceptions +from google.api_core import retry +from lbox import exceptions # type: ignore + +logger = logging.getLogger(__name__) + +_LABELBOX_API_KEY = "LABELBOX_API_KEY" + + +def python_version_info(): + version_info = sys.version_info + + return f"{version_info.major}.{version_info.minor}.{version_info.micro}-{version_info.releaselevel}" + + +LABELBOX_CALL_PATTERN = re.compile(r"/labelbox/") +TEST_FILE_PATTERN = re.compile(r".*test.*\.py$") + + +class _RequestInfo(TypedDict): + prefix: str + class_name: str + method_name: str + + +def call_info(): + method_name = "Unknown" + prefix = "" + class_name = "" + skip_methods = ["wrapper", "__init__"] + skip_classes = ["PaginatedCollection", "_CursorPagination", "_OffsetPagination"] + + try: + call_info = None + for stack in reversed(inspect.stack()): + if LABELBOX_CALL_PATTERN.search(stack.filename): + call_info = stack + method_name = call_info.function + class_name = call_info.frame.f_locals.get( + "self", None + ).__class__.__name__ + + if method_name not in skip_methods and class_name not in skip_classes: + if TEST_FILE_PATTERN.search(call_info.filename): + prefix = "test:" + else: + if class_name == "NoneType": + class_name = "" + break + + except Exception: + pass + return _RequestInfo(prefix=prefix, class_name=class_name, method_name=method_name) + + +def call_info_as_str(): + info = call_info() + return f"{info['prefix']}{info['class_name']}:{info['method_name']}" + + +class RequestClient: + """A Labelbox request client. + + Contains info necessary for connecting to a Labelbox server (URL, + authentication key). + """ + + def __init__( + self, + sdk_version, + api_key=None, + endpoint="https://api.labelbox.com/graphql", + enable_experimental=False, + app_url="https://app.labelbox.com", + rest_endpoint="https://api.labelbox.com/api/v1", + ): + """Creates and initializes a RequestClient. + This class executes graphql and rest requests to the Labelbox server. + + Args: + api_key (str): API key. If None, the key is obtained from the "LABELBOX_API_KEY" environment variable. + endpoint (str): URL of the Labelbox server to connect to. + enable_experimental (bool): Indicates whether or not to use experimental features + app_url (str) : host url for all links to the web app + Raises: + exceptions.AuthenticationError: If no `api_key` + is provided as an argument or via the environment + variable. + """ + if api_key is None: + if _LABELBOX_API_KEY not in os.environ: + raise exceptions.AuthenticationError("Labelbox API key not provided") + api_key = os.environ[_LABELBOX_API_KEY] + self.api_key = api_key + + self.enable_experimental = enable_experimental + if enable_experimental: + logger.info("Experimental features have been enabled") + + logger.info("Initializing Labelbox client at '%s'", endpoint) + self.app_url = app_url + self.endpoint = endpoint + self.rest_endpoint = rest_endpoint + self.sdk_version = sdk_version + self._connection: requests.Session = self._init_connection() + + def _init_connection(self) -> requests.Session: + connection = requests.Session() # using default connection pool size of 10 + connection.headers.update(self._default_headers()) + + return connection + + @property + def headers(self) -> MappingProxyType: + return self._connection.headers + + def _default_headers(self): + return { + "Authorization": "Bearer %s" % self.api_key, + "Accept": "application/json", + "Content-Type": "application/json", + "X-User-Agent": f"python-sdk {self.sdk_version}", + "X-Python-Version": f"{python_version_info()}", + } + + @retry.Retry( + predicate=retry.if_exception_type( + exceptions.InternalServerError, + exceptions.TimeoutError, + ) + ) + def execute( + self, + query=None, + params=None, + data=None, + files=None, + timeout=60.0, + experimental=False, + error_log_key="message", + raise_return_resource_not_found=False, + error_handlers: Optional[ + Dict[str, Callable[[requests.models.Response], None]] + ] = None, + ): + """Sends a request to the server for the execution of the + given query. + + Checks the response for errors and wraps errors + in appropriate `exceptions.LabelboxError` subtypes. + + Args: + query (str): The query to execute. + params (dict): Query parameters referenced within the query. + data (str): json string containing the query to execute + files (dict): file arguments for request + timeout (float): Max allowed time for query execution, + in seconds. + raise_return_resource_not_found: By default the client relies on the caller to raise the correct exception when a resource is not found. + If this is set to True, the client will raise a ResourceNotFoundError exception automatically. + This simplifies processing. + We recommend to use it only of api returns a clear and well-formed error when a resource not found for a given query. + error_handlers (dict): A dictionary mapping graphql error code to handler functions. + Allows a caller to handle specific errors reporting in a custom way or produce more user-friendly readable messages. + + Example - custom error handler: + >>> def _raise_readable_errors(self, response): + >>> errors = response.json().get('errors', []) + >>> if errors: + >>> message = errors[0].get( + >>> 'message', json.dumps([{ + >>> "errorMessage": "Unknown error" + >>> }])) + >>> errors = json.loads(message) + >>> error_messages = [error['errorMessage'] for error in errors] + >>> else: + >>> error_messages = ["Uknown error"] + >>> raise LabelboxError(". ".join(error_messages)) + + Returns: + dict, parsed JSON response. + Raises: + exceptions.AuthenticationError: If authentication + failed. + exceptions.InvalidQueryError: If `query` is not + syntactically or semantically valid (checked server-side). + exceptions.ApiLimitError: If the server API limit was + exceeded. See "How to import data" in the online documentation + to see API limits. + exceptions.TimeoutError: If response was not received + in `timeout` seconds. + exceptions.NetworkError: If an unknown error occurred + most likely due to connection issues. + exceptions.LabelboxError: If an unknown error of any + kind occurred. + ValueError: If query and data are both None. + """ + logger.debug("Query: %s, params: %r, data %r", query, params, data) + + # Convert datetimes to UTC strings. + def convert_value(value): + if isinstance(value, datetime): + value = value.astimezone(timezone.utc) + value = value.strftime("%Y-%m-%dT%H:%M:%SZ") + return value + + if query is not None: + if params is not None: + params = {key: convert_value(value) for key, value in params.items()} + data = json.dumps({"query": query, "variables": params}).encode("utf-8") + elif data is None: + raise ValueError("query and data cannot both be none") + + endpoint = ( + self.endpoint + if not experimental + else self.endpoint.replace("/graphql", "/_gql") + ) + + try: + headers = self._connection.headers.copy() + if files: + del headers["Content-Type"] + del headers["Accept"] + headers["X-SDK-Method"] = call_info_as_str() + + request = requests.Request( + "POST", + endpoint, + headers=headers, + data=data, + files=files if files else None, + ) + + prepped: requests.PreparedRequest = request.prepare() + + settings = self._connection.merge_environment_settings( + prepped.url, {}, None, None, None + ) + response = self._connection.send(prepped, timeout=timeout, **settings) + logger.debug("Response: %s", response.text) + except requests.exceptions.Timeout as e: + raise exceptions.TimeoutError(str(e)) + except requests.exceptions.RequestException as e: + logger.error("Unknown error: %s", str(e)) + raise exceptions.NetworkError(e) + except Exception as e: + raise exceptions.LabelboxError( + "Unknown error during Client.query(): " + str(e), e + ) + + if ( + 200 <= response.status_code < 300 + or response.status_code < 500 + or response.status_code >= 600 + ): + try: + r_json = response.json() + except Exception: + raise exceptions.LabelboxError( + "Failed to parse response as JSON: %s" % response.text + ) + else: + if ( + "upstream connect error or disconnect/reset before headers" + in response.text + ): + raise exceptions.InternalServerError("Connection reset") + elif response.status_code == 502: + error_502 = "502 Bad Gateway" + raise exceptions.InternalServerError(error_502) + elif 500 <= response.status_code < 600: + error_500 = f"Internal server http error {response.status_code}" + raise exceptions.InternalServerError(error_500) + + errors = r_json.get("errors", []) + + def check_errors(keywords, *path): + """Helper that looks for any of the given `keywords` in any of + current errors on paths (like error[path][component][to][keyword]). + """ + for error in errors: + obj = error + for path_elem in path: + obj = obj.get(path_elem, {}) + if obj in keywords: + return error + return None + + def get_error_status_code(error: dict) -> int: + try: + return int(error["extensions"].get("exception").get("status")) + except Exception: + return 500 + + if check_errors(["AUTHENTICATION_ERROR"], "extensions", "code") is not None: + raise exceptions.AuthenticationError("Invalid API key") + + authorization_error = check_errors( + ["AUTHORIZATION_ERROR"], "extensions", "code" + ) + if authorization_error is not None: + raise exceptions.AuthorizationError(authorization_error["message"]) + + validation_error = check_errors( + ["GRAPHQL_VALIDATION_FAILED"], "extensions", "code" + ) + + if validation_error is not None: + message = validation_error["message"] + if message == "Query complexity limit exceeded": + raise exceptions.ValidationFailedError(message) + else: + raise exceptions.InvalidQueryError(message) + + graphql_error = check_errors(["GRAPHQL_PARSE_FAILED"], "extensions", "code") + if graphql_error is not None: + raise exceptions.InvalidQueryError(graphql_error["message"]) + + # Check if API limit was exceeded + response_msg = r_json.get("message", "") + + if response_msg.startswith("You have exceeded"): + raise exceptions.ApiLimitError(response_msg) + + resource_not_found_error = check_errors( + ["RESOURCE_NOT_FOUND"], "extensions", "code" + ) + if resource_not_found_error is not None: + if raise_return_resource_not_found: + raise exceptions.ResourceNotFoundError( + message=resource_not_found_error["message"] + ) + else: + # Return None and let the caller methods raise an exception + # as they already know which resource type and ID was requested + return None + + resource_conflict_error = check_errors( + ["RESOURCE_CONFLICT"], "extensions", "code" + ) + if resource_conflict_error is not None: + raise exceptions.ResourceConflict(resource_conflict_error["message"]) + + malformed_request_error = check_errors( + ["MALFORMED_REQUEST"], "extensions", "code" + ) + + error_code = "MALFORMED_REQUEST" + if malformed_request_error is not None: + if error_handlers and error_code in error_handlers: + handler = error_handlers[error_code] + handler(response) + return None + raise exceptions.MalformedQueryException( + malformed_request_error[error_log_key] + ) + + # A lot of different error situations are now labeled serverside + # as INTERNAL_SERVER_ERROR, when they are actually client errors. + # TODO: fix this in the server API + internal_server_error = check_errors( + ["INTERNAL_SERVER_ERROR"], "extensions", "code" + ) + error_code = "INTERNAL_SERVER_ERROR" + + if internal_server_error is not None: + if error_handlers and error_code in error_handlers: + handler = error_handlers[error_code] + handler(response) + return None + message = internal_server_error.get("message") + error_status_code = get_error_status_code(internal_server_error) + if error_status_code == 400: + raise exceptions.InvalidQueryError(message) + elif error_status_code == 422: + raise exceptions.UnprocessableEntityError(message) + elif error_status_code == 426: + raise exceptions.OperationNotAllowedException(message) + elif error_status_code == 500: + raise exceptions.LabelboxError(message) + else: + raise exceptions.InternalServerError(message) + + not_allowed_error = check_errors( + ["OPERATION_NOT_ALLOWED"], "extensions", "code" + ) + if not_allowed_error is not None: + message = not_allowed_error.get("message") + raise exceptions.OperationNotAllowedException(message) + + if len(errors) > 0: + logger.warning("Unparsed errors on query execution: %r", errors) + messages = list( + map( + lambda x: { + "message": x["message"], + "code": x["extensions"]["code"], + }, + errors, + ) + ) + raise exceptions.LabelboxError("Unknown error: %s" % str(messages)) + + # if we do return a proper error code, and didn't catch this above + # reraise + # this mainly catches a 401 for API access disabled for free tier + # TODO: need to unify API errors to handle things more uniformly + # in the SDK + if response.status_code != requests.codes.ok: + message = f"{response.status_code} {response.reason}" + cause = r_json.get("message") + raise exceptions.LabelboxError(message, cause) + + return r_json["data"] diff --git a/libs/lbox-clients/tests/unit/lbox/test_client.py b/libs/lbox-clients/tests/unit/lbox/test_client.py new file mode 100644 index 000000000..42b141d33 --- /dev/null +++ b/libs/lbox-clients/tests/unit/lbox/test_client.py @@ -0,0 +1,46 @@ +from unittest.mock import MagicMock + +from lbox.request_client import RequestClient + + +# @patch.dict(os.environ, {'LABELBOX_API_KEY': 'bar'}) +def test_headers(): + client = RequestClient( + sdk_version="foo", api_key="api_key", endpoint="http://localhost:8080/_gql" + ) + assert client.headers + assert client.headers["Authorization"] == "Bearer api_key" + assert client.headers["Content-Type"] == "application/json" + assert client.headers["User-Agent"] + assert client.headers["X-Python-Version"] + + +def test_custom_error_handling(): + mock_raise_error = MagicMock() + + response_dict = { + "errors": [ + { + "message": "Internal server error", + "extensions": {"code": "INTERNAL_SERVER_ERROR"}, + } + ], + } + response = MagicMock() + response.json.return_value = response_dict + response.status_code = 200 + + client = RequestClient( + sdk_version="foo", api_key="api_key", endpoint="http://localhost:8080/_gql" + ) + connection_mock = MagicMock() + connection_mock.send.return_value = response + client._connection = connection_mock + + client.execute( + "query_str", + {"projectId": "project_id"}, + raise_return_resource_not_found=True, + error_handlers={"INTERNAL_SERVER_ERROR": mock_raise_error}, + ) + mock_raise_error.assert_called_once_with(response) diff --git a/libs/lbox-example/Dockerfile b/libs/lbox-example/Dockerfile index 2ee61ab7e..dd3b5147d 100644 --- a/libs/lbox-example/Dockerfile +++ b/libs/lbox-example/Dockerfile @@ -1,5 +1,5 @@ # https://github.com/ucyo/python-package-template/blob/master/Dockerfile -FROM python:3.8-slim as rye +FROM python:3.9-slim as rye ENV LANG="C.UTF-8" \ LC_ALL="C.UTF-8" \ @@ -38,7 +38,7 @@ WORKDIR /home/python/labelbox-python RUN rye config --set-bool behavior.global-python=true && \ rye config --set-bool behavior.use-uv=true && \ - rye pin 3.8 && \ + rye pin 3.9 && \ rye sync CMD rye run unit && rye integration \ No newline at end of file diff --git a/libs/lbox-example/pyproject.toml b/libs/lbox-example/pyproject.toml index a14f5a34f..a3fe0a13e 100644 --- a/libs/lbox-example/pyproject.toml +++ b/libs/lbox-example/pyproject.toml @@ -9,7 +9,7 @@ dependencies = [ "art>=6.2", ] readme = "README.md" -requires-python = ">= 3.8" +requires-python = ">= 3.9" classifiers=[ # How mature is this project? @@ -24,7 +24,6 @@ classifiers=[ "License :: OSI Approved :: Apache Software License", # Specify the Python versions you support here. "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", diff --git a/pyproject.toml b/pyproject.toml index ebce059f5..7ab0a0b79 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ dependencies = [ "sphinx-rtd-theme>=2.0.0", ] readme = "README.md" -requires-python = ">= 3.8" +requires-python = ">= 3.9" [tool.rye] managed = true @@ -28,6 +28,7 @@ dev-dependencies = [ "pytest-timestamper>=0.0.10", "pytest-timeout>=2.3.1", "pytest-order>=1.2.1", + "pyjwt>=2.9.0", ] [tool.rye.workspace] @@ -35,7 +36,7 @@ members = ["libs/*", "examples"] [tool.pytest.ini_options] # https://github.com/pytest-dev/pytest-rerunfailures/issues/99 -addopts = "-rP -vvv --reruns 1 --reruns-delay 5 --durations=20 -n auto --cov=labelbox --import-mode=importlib --order-group-scope=module" +addopts = "-rP -vvv" markers = """ slow: marks tests as slow (deselect with '-m "not slow"') """ diff --git a/requirements-dev.lock b/requirements-dev.lock index 05ca1683a..7616ff075 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -6,10 +6,10 @@ # features: [] # all-features: true # with-sources: false -# generate-hashes: false -# universal: false -e file:libs/labelbox +-e file:libs/lbox-clients + # via labelbox -e file:libs/lbox-example alabaster==0.7.13 # via sphinx @@ -74,6 +74,7 @@ gitpython==3.1.43 # via databooks google-api-core==2.19.1 # via labelbox + # via lbox-clients google-auth==2.31.0 # via google-api-core googleapis-common-protos==1.63.2 @@ -89,9 +90,6 @@ importlib-metadata==8.0.0 # via sphinx # via typeguard # via yapf -importlib-resources==6.4.0 - # via jsonschema - # via jsonschema-specifications iniconfig==2.0.0 # via pytest ipython==8.12.3 @@ -125,6 +123,7 @@ matplotlib-inline==0.1.7 mistune==3.0.2 # via nbconvert mypy==1.10.1 + # via labelbox mypy-extensions==1.0.0 # via black # via mypy @@ -161,8 +160,6 @@ pickleshare==0.7.5 # via ipython pillow==10.4.0 # via labelbox -pkgutil-resolve-name==1.3.10 - # via jsonschema platformdirs==4.2.2 # via black # via jupyter-core @@ -198,6 +195,7 @@ pygments==2.18.0 # via nbconvert # via rich # via sphinx +pyjwt==2.9.0 pyproj==3.5.0 # via labelbox pytest==8.2.2 @@ -221,7 +219,6 @@ python-dateutil==2.9.0.post0 # via labelbox # via pandas pytz==2024.1 - # via babel # via pandas pyzmq==26.0.3 # via jupyter-client @@ -233,6 +230,7 @@ regex==2024.5.15 requests==2.32.3 # via google-api-core # via labelbox + # via lbox-clients # via sphinx rich==12.6.0 # via databooks @@ -315,7 +313,6 @@ types-python-dateutil==2.9.0.20240316 types-requests==2.32.0.20240622 types-tqdm==4.66.0.20240417 typing-extensions==4.12.2 - # via annotated-types # via black # via databooks # via ipython @@ -323,7 +320,6 @@ typing-extensions==4.12.2 # via mypy # via pydantic # via pydantic-core - # via rich # via typeguard # via typer tzdata==2024.1 @@ -339,4 +335,3 @@ webencodings==0.5.1 yapf==0.40.2 zipp==3.19.2 # via importlib-metadata - # via importlib-resources diff --git a/requirements.lock b/requirements.lock index 07e026d59..36871f22b 100644 --- a/requirements.lock +++ b/requirements.lock @@ -6,10 +6,10 @@ # features: [] # all-features: true # with-sources: false -# generate-hashes: false -# universal: false -e file:libs/labelbox +-e file:libs/lbox-clients + # via labelbox -e file:libs/lbox-example alabaster==0.7.13 # via sphinx @@ -33,6 +33,7 @@ geojson==3.1.0 # via labelbox google-api-core==2.19.1 # via labelbox + # via lbox-clients google-auth==2.31.0 # via google-api-core googleapis-common-protos==1.63.2 @@ -49,6 +50,10 @@ jinja2==3.1.4 # via sphinx markupsafe==2.1.5 # via jinja2 +mypy==1.10.1 + # via labelbox +mypy-extensions==1.0.0 + # via mypy numpy==1.24.4 # via labelbox # via opencv-python-headless @@ -82,11 +87,10 @@ pyproj==3.5.0 # via labelbox python-dateutil==2.9.0.post0 # via labelbox -pytz==2024.1 - # via babel requests==2.32.3 # via google-api-core # via labelbox + # via lbox-clients # via sphinx rsa==4.9 # via google-auth @@ -117,13 +121,15 @@ sphinxcontrib-serializinghtml==1.1.5 # via sphinx strenum==0.4.15 # via labelbox +tomli==2.0.1 + # via mypy tqdm==4.66.4 # via labelbox typeguard==4.3.0 # via labelbox typing-extensions==4.12.2 - # via annotated-types # via labelbox + # via mypy # via pydantic # via pydantic-core # via typeguard