diff --git a/.github/actions/lbox-matrix/index.js b/.github/actions/lbox-matrix/index.js
index ac3975488..574a56381 100644
--- a/.github/actions/lbox-matrix/index.js
+++ b/.github/actions/lbox-matrix/index.js
@@ -26811,11 +26811,6 @@ const core = __nccwpck_require__(8611);
try {
const files = JSON.parse(core.getInput('files-changed'));
const startingMatrix = [
- {
- "python-version": "3.8",
- "api-key": "STAGING_LABELBOX_API_KEY_2",
- "da-test-key": "DA_GCP_LABELBOX_API_KEY"
- },
{
"python-version": "3.9",
"api-key": "STAGING_LABELBOX_API_KEY_3",
diff --git a/.github/workflows/lbox-develop.yml b/.github/workflows/lbox-develop.yml
index ba1e4f34e..309ea2969 100644
--- a/.github/workflows/lbox-develop.yml
+++ b/.github/workflows/lbox-develop.yml
@@ -90,7 +90,7 @@ jobs:
- uses: ./.github/actions/python-package-shared-setup
with:
rye-version: ${{ vars.RYE_VERSION }}
- python-version: '3.8'
+ python-version: '3.9'
- name: Create build
id: create-build
working-directory: libs/${{ matrix.package }}
diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml
index d64123c60..4fcb10257 100644
--- a/.github/workflows/publish.yml
+++ b/.github/workflows/publish.yml
@@ -71,9 +71,6 @@ jobs:
fail-fast: false
matrix:
include:
- - python-version: 3.8
- prod-key: PROD_LABELBOX_API_KEY_2
- da-test-key: DA_GCP_LABELBOX_API_KEY
- python-version: 3.9
prod-key: PROD_LABELBOX_API_KEY_3
da-test-key: DA_GCP_LABELBOX_API_KEY
diff --git a/.github/workflows/python-package-develop.yml b/.github/workflows/python-package-develop.yml
index 05eff5dc4..65dd02872 100644
--- a/.github/workflows/python-package-develop.yml
+++ b/.github/workflows/python-package-develop.yml
@@ -44,7 +44,7 @@ jobs:
- name: Get Latest SDK versions
id: get_sdk_versions
run: |
- sdk_versions=$(git tag --list --sort=-version:refname "v.*" | head -n 4 | jq -R -s -c 'split("\n")[:-1]')
+ sdk_versions=$(git tag --list --sort=-version:refname "v.*" | head -n 3 | jq -R -s -c 'split("\n")[:-1]')
if [ -z "$sdk_versions" ]; then
echo "No tags found"
exit 1
@@ -58,10 +58,6 @@ jobs:
fail-fast: false
matrix:
include:
- - python-version: 3.8
- api-key: STAGING_LABELBOX_API_KEY_2
- da-test-key: DA_GCP_LABELBOX_API_KEY
- sdk-version: ${{ fromJson(needs.get_sdk_versions.outputs.sdk_versions)[3] }}
- python-version: 3.9
api-key: STAGING_LABELBOX_API_KEY_3
da-test-key: DA_GCP_LABELBOX_API_KEY
@@ -103,7 +99,7 @@ jobs:
- uses: ./.github/actions/python-package-shared-setup
with:
rye-version: ${{ vars.RYE_VERSION }}
- python-version: '3.8'
+ python-version: '3.9'
- name: Create build
id: create-build
working-directory: libs/labelbox
diff --git a/.github/workflows/python-package-prod.yml b/.github/workflows/python-package-prod.yml
index 1936af132..c0e24536f 100644
--- a/.github/workflows/python-package-prod.yml
+++ b/.github/workflows/python-package-prod.yml
@@ -13,9 +13,6 @@ jobs:
fail-fast: false
matrix:
include:
- - python-version: 3.8
- api-key: PROD_LABELBOX_API_KEY_2
- da-test-key: DA_GCP_LABELBOX_API_KEY
- python-version: 3.9
api-key: PROD_LABELBOX_API_KEY_3
da-test-key: DA_GCP_LABELBOX_API_KEY
diff --git a/.python-version b/.python-version
index 9ad6380c1..43077b246 100644
--- a/.python-version
+++ b/.python-version
@@ -1 +1 @@
-3.8.18
+3.9.18
diff --git a/docs/labelbox/exceptions.rst b/docs/labelbox/exceptions.rst
index 3082bc081..96ea0f6d5 100644
--- a/docs/labelbox/exceptions.rst
+++ b/docs/labelbox/exceptions.rst
@@ -1,6 +1,6 @@
Exceptions
===============================================================================================
-.. automodule:: labelbox.exceptions
+.. automodule:: lbox.exceptions
:members:
:show-inheritance:
\ No newline at end of file
diff --git a/docs/labelbox/index.rst b/docs/labelbox/index.rst
index 35118f56f..fcfa2f6b5 100644
--- a/docs/labelbox/index.rst
+++ b/docs/labelbox/index.rst
@@ -41,6 +41,7 @@ Labelbox Python SDK Documentation
project
project-model-config
quality-mode
+ request-client
resource-tag
review
search-filters
diff --git a/docs/labelbox/request-client.rst b/docs/labelbox/request-client.rst
new file mode 100644
index 000000000..fcfea7f97
--- /dev/null
+++ b/docs/labelbox/request-client.rst
@@ -0,0 +1,6 @@
+Request Client
+===============================================================================================
+
+.. automodule:: lbox.request_client
+ :members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/examples/README.md b/examples/README.md
index 1bc102947..e02b7c64f 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -16,16 +16,6 @@
-
- Basics |
-  |
-  |
-
-
- Projects |
-  |
-  |
-
Ontologies |
 |
@@ -42,25 +32,35 @@
 |
- Data Row Metadata |
-  |
-  |
-
-
- User Management |
-  |
-  |
+ Basics |
+  |
+  |
Batches |
 |
 |
+
+ Projects |
+  |
+  |
+
+
+ Data Row Metadata |
+  |
+  |
+
Custom Embeddings |
 |
 |
+
+ User Management |
+  |
+  |
+
@@ -75,26 +75,26 @@
-
- Exporting to CSV |
-  |
-  |
-
-
- Export Data |
-  |
-  |
-
Export V1 to V2 Migration Support |
 |
 |
+
+ Exporting to CSV |
+  |
+  |
+
Composite Mask Export |
 |
 |
+
+ Export Data |
+  |
+  |
+
@@ -109,16 +109,16 @@
-
- Project Setup |
-  |
-  |
-
Queue Management |
 |
 |
+
+ Project Setup |
+  |
+  |
+
Webhooks |
 |
@@ -143,30 +143,20 @@
+
+ DICOM |
+  |
+  |
+
Tiled |
 |
 |
- Conversational LLM |
-  |
-  |
-
-
- HTML |
-  |
-  |
-
-
- Conversational LLM Data Generation |
-  |
-  |
-
-
- Image |
-  |
-  |
+ Text |
+  |
+  |
PDF |
@@ -174,14 +164,9 @@
 |
- DICOM |
-  |
-  |
-
-
- Text |
-  |
-  |
+ Video |
+  |
+  |
Audio |
@@ -194,9 +179,24 @@
 |
- Video |
-  |
-  |
+ HTML |
+  |
+  |
+
+
+ Conversational LLM Data Generation |
+  |
+  |
+
+
+ Image |
+  |
+  |
+
+
+ Conversational LLM |
+  |
+  |
@@ -213,9 +213,9 @@
- Import YOLOv8 Annotations |
-  |
-  |
+ Langchain |
+  |
+  |
Meta SAM Video |
@@ -228,9 +228,9 @@
 |
- Langchain |
-  |
-  |
+ Import YOLOv8 Annotations |
+  |
+  |
Huggingface Custom Embeddings |
@@ -251,25 +251,25 @@
+
+ Model Predictions to Project |
+  |
+  |
+
Custom Metrics Demo |
 |
 |
-
- Model Slices |
-  |
-  |
-
Custom Metrics Basics |
 |
 |
- Model Predictions to Project |
-  |
-  |
+ Model Slices |
+  |
+  |
@@ -285,20 +285,20 @@
-
- Video Predictions |
-  |
-  |
-
HTML Predictions |
 |
 |
- Geospatial Predictions |
-  |
-  |
+ Text Predictions |
+  |
+  |
+
+
+ Video Predictions |
+  |
+  |
Conversational Predictions |
@@ -306,14 +306,9 @@
 |
- Text Predictions |
-  |
-  |
-
-
- Conversational LLM Predictions |
-  |
-  |
+ Geospatial Predictions |
+  |
+  |
PDF Predictions |
@@ -325,6 +320,11 @@
 |
 |
+
+ Conversational LLM Predictions |
+  |
+  |
+
diff --git a/examples/pyproject.toml b/examples/pyproject.toml
index 969454ecc..b620680e0 100644
--- a/examples/pyproject.toml
+++ b/examples/pyproject.toml
@@ -5,7 +5,7 @@ description = "Labelbox Python Example Notebooks"
authors = [{ name = "Labelbox", email = "docs@labelbox.com" }]
readme = "README.md"
# Python version matches labelbox SDK
-requires-python = ">=3.8"
+requires-python = ">=3.9"
dependencies = []
[project.urls]
@@ -22,8 +22,8 @@ dev-dependencies = [
"yapf>=0.40.2",
"black[jupyter]>=24.4.2",
"databooks>=1.3.10",
- # higher versions dont support python 3.8
- "pandas>=2.0.3",
+ # higher versions dont support python 3.9
+ "pandas>=2.2.3",
]
[tool.rye.scripts]
diff --git a/libs/labelbox/Dockerfile b/libs/labelbox/Dockerfile
index ddf451b14..386bdbc63 100644
--- a/libs/labelbox/Dockerfile
+++ b/libs/labelbox/Dockerfile
@@ -1,5 +1,5 @@
# https://github.com/ucyo/python-package-template/blob/master/Dockerfile
-FROM python:3.8-slim as rye
+FROM python:3.9-slim as rye
ENV LANG="C.UTF-8" \
LC_ALL="C.UTF-8" \
@@ -38,7 +38,7 @@ WORKDIR /home/python/labelbox-python
RUN rye config --set-bool behavior.global-python=true && \
rye config --set-bool behavior.use-uv=true && \
- rye pin 3.8 && \
+ rye pin 3.9 && \
rye sync
CMD cd libs/labelbox && rye run integration && rye sync -f --features labelbox/data && rye run unit && rye run data
diff --git a/libs/labelbox/mypy.ini b/libs/labelbox/mypy.ini
index 48135600e..5129fbe0f 100644
--- a/libs/labelbox/mypy.ini
+++ b/libs/labelbox/mypy.ini
@@ -7,4 +7,7 @@ ignore_missing_imports = True
ignore_errors = True
[mypy-labelbox]
-ignore_errors = True
\ No newline at end of file
+ignore_errors = True
+
+[mypy-lbox.exceptions]
+ignore_missing_imports = True
diff --git a/libs/labelbox/pyproject.toml b/libs/labelbox/pyproject.toml
index ee2f9b859..f58dba890 100644
--- a/libs/labelbox/pyproject.toml
+++ b/libs/labelbox/pyproject.toml
@@ -12,9 +12,10 @@ dependencies = [
"tqdm>=4.66.2",
"geojson>=3.1.0",
"mypy==1.10.1",
+ "lbox-clients==1.1.0",
]
readme = "README.md"
-requires-python = ">=3.8"
+requires-python = ">=3.9,<3.13"
classifiers = [
# How mature is this project?
"Development Status :: 5 - Production/Stable",
@@ -28,7 +29,6 @@ classifiers = [
"License :: OSI Approved :: Apache Software License",
# Specify the Python versions you support here.
"Programming Language :: Python :: 3",
- "Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
@@ -46,8 +46,7 @@ Changelog = "https://github.com/Labelbox/labelbox-python/blob/develop/libs/label
[project.optional-dependencies]
data = [
"shapely>=2.0.3",
- # numpy v2 breaks package since it only supports python >3.9
- "numpy>=1.24.4, <2.0.0",
+ "numpy>=1.25.0",
"pillow>=10.2.0",
"typeguard>=4.1.5",
"imagesize>=1.4.1",
@@ -74,6 +73,10 @@ dev-dependencies = [
[tool.ruff]
line-length = 80
+[tool.ruff.lint]
+ignore = ["F", "E722"]
+exclude = ["**/__init__.py"]
+
[tool.rye.scripts]
unit = "pytest tests/unit"
# https://github.com/Labelbox/labelbox-python/blob/7c84fdffbc14fd1f69d2a6abdcc0087dc557fa4e/Makefile
@@ -89,9 +92,11 @@ unit = "pytest tests/unit"
# LABELBOX_TEST_BASE_URL="http://host.docker.internal:8080" \
integration = { cmd = "pytest tests/integration" }
data = { cmd = "pytest tests/data" }
+rye-lint = "rye lint"
rye-fmt-check = "rye fmt --check"
+MYPYPATH = "../lbox-clients/src/"
mypy-lint = "mypy src --pretty --show-error-codes --non-interactive --install-types"
-lint = { chain = ["mypy-lint", "rye-fmt-check"] }
+lint = { chain = ["rye-fmt-check", "mypy-lint", "rye-lint"] }
test = { chain = ["lint", "unit", "integration"] }
[tool.hatch.metadata]
diff --git a/libs/labelbox/src/labelbox/__init__.py b/libs/labelbox/src/labelbox/__init__.py
index b0f4ebe11..9b5d6f715 100644
--- a/libs/labelbox/src/labelbox/__init__.py
+++ b/libs/labelbox/src/labelbox/__init__.py
@@ -12,7 +12,6 @@
from labelbox.schema.asset_attachment import AssetAttachment
from labelbox.schema.batch import Batch
from labelbox.schema.benchmark import Benchmark
-from labelbox.schema.bulk_import_request import BulkImportRequest
from labelbox.schema.catalog import Catalog
from labelbox.schema.data_row import DataRow
from labelbox.schema.data_row_metadata import (
@@ -26,10 +25,7 @@
from labelbox.schema.export_task import (
BufferedJsonConverterOutput,
ExportTask,
- FileConverter,
- FileConverterOutput,
- JsonConverter,
- JsonConverterOutput,
+ BufferedJsonConverterOutput,
StreamType,
)
from labelbox.schema.iam_integration import IAMIntegration
@@ -59,6 +55,28 @@
ResponseOption,
Tool,
)
+from labelbox.schema.ontology import PromptResponseClassification
+from labelbox.schema.ontology import ResponseOption
+from labelbox.schema.role import Role, ProjectRole
+from labelbox.schema.invite import Invite, InviteLimit
+from labelbox.schema.data_row_metadata import (
+ DataRowMetadataOntology,
+ DataRowMetadataField,
+ DataRowMetadata,
+ DeleteDataRowMetadata,
+)
+from labelbox.schema.model_run import ModelRun, DataSplit
+from labelbox.schema.benchmark import Benchmark
+from labelbox.schema.iam_integration import IAMIntegration
+from labelbox.schema.resource_tag import ResourceTag
+from labelbox.schema.project_model_config import ProjectModelConfig
+from labelbox.schema.project_resource_tag import ProjectResourceTag
+from labelbox.schema.media_type import MediaType
+from labelbox.schema.slice import Slice, CatalogSlice, ModelSlice
+from labelbox.schema.task_queue import TaskQueue
+from labelbox.schema.label_score import LabelScore
+from labelbox.schema.identifiables import UniqueIds, GlobalKeys, DataRowIds
+from labelbox.schema.identifiable import UniqueId, GlobalKey
from labelbox.schema.ontology_kind import OntologyKind
from labelbox.schema.organization import Organization
from labelbox.schema.project import Project
@@ -68,7 +86,6 @@
ProjectOverviewDetailed,
)
from labelbox.schema.project_resource_tag import ProjectResourceTag
-from labelbox.schema.queue_mode import QueueMode
from labelbox.schema.resource_tag import ResourceTag
from labelbox.schema.review import Review
from labelbox.schema.role import ProjectRole, Role
diff --git a/libs/labelbox/src/labelbox/adv_client.py b/libs/labelbox/src/labelbox/adv_client.py
index 626ac0279..20766a9a4 100644
--- a/libs/labelbox/src/labelbox/adv_client.py
+++ b/libs/labelbox/src/labelbox/adv_client.py
@@ -1,12 +1,12 @@
import io
import json
import logging
-from typing import Dict, Any, Optional, List, Callable
+from typing import Any, Callable, Dict, List, Optional
from urllib.parse import urlparse
-from labelbox.exceptions import LabelboxError
import requests
-from requests import Session, Response
+from lbox.exceptions import LabelboxError
+from requests import Response, Session
logger = logging.getLogger(__name__)
diff --git a/libs/labelbox/src/labelbox/annotated_types.py b/libs/labelbox/src/labelbox/annotated_types.py
index d0fe93ef8..de2e1e01f 100644
--- a/libs/labelbox/src/labelbox/annotated_types.py
+++ b/libs/labelbox/src/labelbox/annotated_types.py
@@ -1,6 +1,5 @@
-from typing_extensions import Annotated
+from typing import Annotated
from pydantic import Field
-
Cuid = Annotated[str, Field(min_length=25, max_length=25)]
diff --git a/libs/labelbox/src/labelbox/client.py b/libs/labelbox/src/labelbox/client.py
index 671f8b8cc..bcf29665e 100644
--- a/libs/labelbox/src/labelbox/client.py
+++ b/libs/labelbox/src/labelbox/client.py
@@ -4,20 +4,25 @@
import mimetypes
import os
import random
-import sys
import time
import urllib.parse
import warnings
from collections import defaultdict
from datetime import datetime, timezone
from types import MappingProxyType
-from typing import Any, Dict, List, Optional, Union, overload
+from typing import Any, Callable, Dict, List, Optional, Set, Union, overload
import requests
import requests.exceptions
from google.api_core import retry
+from lbox.exceptions import (
+ InternalServerError,
+ LabelboxError,
+ ResourceNotFoundError,
+ TimeoutError,
+)
+from lbox.request_client import RequestClient
-import labelbox.exceptions
from labelbox import __version__ as SDK_VERSION
from labelbox import utils
from labelbox.adv_client import AdvClient
@@ -25,6 +30,7 @@
from labelbox.orm.db_object import DbObject
from labelbox.orm.model import Entity, Field
from labelbox.pagination import PaginatedCollection
+from labelbox.project_validation import _CoreProjectInput
from labelbox.schema import role
from labelbox.schema.catalog import Catalog
from labelbox.schema.data_row import DataRow
@@ -67,7 +73,6 @@
CONSENSUS_AUTO_AUDIT_PERCENTAGE,
QualityMode,
)
-from labelbox.schema.queue_mode import QueueMode
from labelbox.schema.role import Role
from labelbox.schema.search_filters import SearchFilter
from labelbox.schema.send_to_annotate_params import (
@@ -82,20 +87,11 @@
logger = logging.getLogger(__name__)
-_LABELBOX_API_KEY = "LABELBOX_API_KEY"
-
-
-def python_version_info():
- version_info = sys.version_info
-
- return f"{version_info.major}.{version_info.minor}.{version_info.micro}-{version_info.releaselevel}"
-
class Client:
"""A Labelbox client.
- Contains info necessary for connecting to a Labelbox server (URL,
- authentication key). Provides functions for querying and creating
+ Provides functions for querying and creating
top-level data objects (Projects, Datasets).
"""
@@ -121,57 +117,45 @@ def __init__(
enable_experimental (bool): Indicates whether or not to use experimental features
app_url (str) : host url for all links to the web app
Raises:
- labelbox.exceptions.AuthenticationError: If no `api_key`
+ AuthenticationError: If no `api_key`
is provided as an argument or via the environment
variable.
"""
- if api_key is None:
- if _LABELBOX_API_KEY not in os.environ:
- raise labelbox.exceptions.AuthenticationError(
- "Labelbox API key not provided"
- )
- api_key = os.environ[_LABELBOX_API_KEY]
- self.api_key = api_key
-
- self.enable_experimental = enable_experimental
- if enable_experimental:
- logger.info("Experimental features have been enabled")
-
- logger.info("Initializing Labelbox client at '%s'", endpoint)
- self.app_url = app_url
- self.endpoint = endpoint
- self.rest_endpoint = rest_endpoint
self._data_row_metadata_ontology = None
+ self._request_client = RequestClient(
+ sdk_version=SDK_VERSION,
+ api_key=api_key,
+ endpoint=endpoint,
+ enable_experimental=enable_experimental,
+ app_url=app_url,
+ rest_endpoint=rest_endpoint,
+ )
self._adv_client = AdvClient.factory(rest_endpoint, api_key)
- self._connection: requests.Session = self._init_connection()
- def _init_connection(self) -> requests.Session:
- connection = (
- requests.Session()
- ) # using default connection pool size of 10
- connection.headers.update(self._default_headers())
+ @property
+ def headers(self) -> MappingProxyType:
+ return self._request_client.headers
- return connection
+ @property
+ def connection(self) -> requests.Session:
+ return self._request_client._connection
@property
- def headers(self) -> MappingProxyType:
- return self._connection.headers
-
- def _default_headers(self):
- return {
- "Authorization": "Bearer %s" % self.api_key,
- "Accept": "application/json",
- "Content-Type": "application/json",
- "X-User-Agent": f"python-sdk {SDK_VERSION}",
- "X-Python-Version": f"{python_version_info()}",
- }
+ def endpoint(self) -> str:
+ return self._request_client.endpoint
+
+ @property
+ def rest_endpoint(self) -> str:
+ return self._request_client.rest_endpoint
+
+ @property
+ def enable_experimental(self) -> bool:
+ return self._request_client.enable_experimental
+
+ @property
+ def app_url(self) -> str:
+ return self._request_client.app_url
- @retry.Retry(
- predicate=retry.if_exception_type(
- labelbox.exceptions.InternalServerError,
- labelbox.exceptions.TimeoutError,
- )
- )
def execute(
self,
query=None,
@@ -182,264 +166,34 @@ def execute(
experimental=False,
error_log_key="message",
raise_return_resource_not_found=False,
- ):
- """Sends a request to the server for the execution of the
- given query.
-
- Checks the response for errors and wraps errors
- in appropriate `labelbox.exceptions.LabelboxError` subtypes.
+ error_handlers: Optional[
+ Dict[str, Callable[[requests.models.Response], None]]
+ ] = None,
+ ) -> Dict[str, Any]:
+ """Executes a GraphQL query.
Args:
query (str): The query to execute.
- params (dict): Query parameters referenced within the query.
- data (str): json string containing the query to execute
- files (dict): file arguments for request
- timeout (float): Max allowed time for query execution,
- in seconds.
- Returns:
- dict, parsed JSON response.
- Raises:
- labelbox.exceptions.AuthenticationError: If authentication
- failed.
- labelbox.exceptions.InvalidQueryError: If `query` is not
- syntactically or semantically valid (checked server-side).
- labelbox.exceptions.ApiLimitError: If the server API limit was
- exceeded. See "How to import data" in the online documentation
- to see API limits.
- labelbox.exceptions.TimeoutError: If response was not received
- in `timeout` seconds.
- labelbox.exceptions.NetworkError: If an unknown error occurred
- most likely due to connection issues.
- labelbox.exceptions.LabelboxError: If an unknown error of any
- kind occurred.
- ValueError: If query and data are both None.
- """
- logger.debug("Query: %s, params: %r, data %r", query, params, data)
-
- # Convert datetimes to UTC strings.
- def convert_value(value):
- if isinstance(value, datetime):
- value = value.astimezone(timezone.utc)
- value = value.strftime("%Y-%m-%dT%H:%M:%SZ")
- return value
-
- if query is not None:
- if params is not None:
- params = {
- key: convert_value(value) for key, value in params.items()
- }
- data = json.dumps({"query": query, "variables": params}).encode(
- "utf-8"
- )
- elif data is None:
- raise ValueError("query and data cannot both be none")
-
- endpoint = (
- self.endpoint
- if not experimental
- else self.endpoint.replace("/graphql", "/_gql")
- )
-
- try:
- headers = self._connection.headers.copy()
- if files:
- del headers["Content-Type"]
- del headers["Accept"]
- request = requests.Request(
- "POST",
- endpoint,
- headers=headers,
- data=data,
- files=files if files else None,
- )
-
- prepped: requests.PreparedRequest = request.prepare()
-
- settings = self._connection.merge_environment_settings(
- prepped.url, {}, None, None, None
- )
-
- response = self._connection.send(
- prepped, timeout=timeout, **settings
- )
- logger.debug("Response: %s", response.text)
- except requests.exceptions.Timeout as e:
- raise labelbox.exceptions.TimeoutError(str(e))
- except requests.exceptions.RequestException as e:
- logger.error("Unknown error: %s", str(e))
- raise labelbox.exceptions.NetworkError(e)
- except Exception as e:
- raise labelbox.exceptions.LabelboxError(
- "Unknown error during Client.query(): " + str(e), e
- )
-
- if (
- 200 <= response.status_code < 300
- or response.status_code < 500
- or response.status_code >= 600
- ):
- try:
- r_json = response.json()
- except Exception:
- raise labelbox.exceptions.LabelboxError(
- "Failed to parse response as JSON: %s" % response.text
- )
- else:
- if (
- "upstream connect error or disconnect/reset before headers"
- in response.text
- ):
- raise labelbox.exceptions.InternalServerError(
- "Connection reset"
- )
- elif response.status_code == 502:
- error_502 = "502 Bad Gateway"
- raise labelbox.exceptions.InternalServerError(error_502)
- elif 500 <= response.status_code < 600:
- error_500 = f"Internal server http error {response.status_code}"
- raise labelbox.exceptions.InternalServerError(error_500)
-
- errors = r_json.get("errors", [])
-
- def check_errors(keywords, *path):
- """Helper that looks for any of the given `keywords` in any of
- current errors on paths (like error[path][component][to][keyword]).
- """
- for error in errors:
- obj = error
- for path_elem in path:
- obj = obj.get(path_elem, {})
- if obj in keywords:
- return error
- return None
-
- def get_error_status_code(error: dict) -> int:
- try:
- return int(error["extensions"].get("exception").get("status"))
- except:
- return 500
-
- if (
- check_errors(["AUTHENTICATION_ERROR"], "extensions", "code")
- is not None
- ):
- raise labelbox.exceptions.AuthenticationError("Invalid API key")
-
- authorization_error = check_errors(
- ["AUTHORIZATION_ERROR"], "extensions", "code"
- )
- if authorization_error is not None:
- raise labelbox.exceptions.AuthorizationError(
- authorization_error["message"]
- )
-
- validation_error = check_errors(
- ["GRAPHQL_VALIDATION_FAILED"], "extensions", "code"
- )
-
- if validation_error is not None:
- message = validation_error["message"]
- if message == "Query complexity limit exceeded":
- raise labelbox.exceptions.ValidationFailedError(message)
- else:
- raise labelbox.exceptions.InvalidQueryError(message)
-
- graphql_error = check_errors(
- ["GRAPHQL_PARSE_FAILED"], "extensions", "code"
- )
- if graphql_error is not None:
- raise labelbox.exceptions.InvalidQueryError(
- graphql_error["message"]
- )
-
- # Check if API limit was exceeded
- response_msg = r_json.get("message", "")
-
- if response_msg.startswith("You have exceeded"):
- raise labelbox.exceptions.ApiLimitError(response_msg)
-
- resource_not_found_error = check_errors(
- ["RESOURCE_NOT_FOUND"], "extensions", "code"
- )
- if resource_not_found_error is not None:
- if raise_return_resource_not_found:
- raise labelbox.exceptions.ResourceNotFoundError(
- message=resource_not_found_error["message"]
- )
- else:
- # Return None and let the caller methods raise an exception
- # as they already know which resource type and ID was requested
- return None
-
- resource_conflict_error = check_errors(
- ["RESOURCE_CONFLICT"], "extensions", "code"
- )
- if resource_conflict_error is not None:
- raise labelbox.exceptions.ResourceConflict(
- resource_conflict_error["message"]
- )
+ variables (dict): Variables to pass to the query.
+ raise_return_resource_not_found (bool): If True, raise a
+ ResourceNotFoundError if the query returns None.
+ error_handlers (dict): A dictionary mapping graphql error code to handler functions.
+ Allows a caller to handle specific errors reporting in a custom way or produce more user-friendly readable messages
- malformed_request_error = check_errors(
- ["MALFORMED_REQUEST"], "extensions", "code"
- )
- if malformed_request_error is not None:
- raise labelbox.exceptions.MalformedQueryException(
- malformed_request_error[error_log_key]
- )
-
- # A lot of different error situations are now labeled serverside
- # as INTERNAL_SERVER_ERROR, when they are actually client errors.
- # TODO: fix this in the server API
- internal_server_error = check_errors(
- ["INTERNAL_SERVER_ERROR"], "extensions", "code"
- )
- if internal_server_error is not None:
- message = internal_server_error.get("message")
- error_status_code = get_error_status_code(internal_server_error)
- if error_status_code == 400:
- raise labelbox.exceptions.InvalidQueryError(message)
- elif error_status_code == 422:
- raise labelbox.exceptions.UnprocessableEntityError(message)
- elif error_status_code == 426:
- raise labelbox.exceptions.OperationNotAllowedException(message)
- elif error_status_code == 500:
- raise labelbox.exceptions.LabelboxError(message)
- else:
- raise labelbox.exceptions.InternalServerError(message)
-
- not_allowed_error = check_errors(
- ["OPERATION_NOT_ALLOWED"], "extensions", "code"
+ Returns:
+ dict: The response from the server.
+ """
+ return self._request_client.execute(
+ query,
+ params,
+ data=data,
+ files=files,
+ timeout=timeout,
+ experimental=experimental,
+ error_log_key=error_log_key,
+ raise_return_resource_not_found=raise_return_resource_not_found,
+ error_handlers=error_handlers,
)
- if not_allowed_error is not None:
- message = not_allowed_error.get("message")
- raise labelbox.exceptions.OperationNotAllowedException(message)
-
- if len(errors) > 0:
- logger.warning("Unparsed errors on query execution: %r", errors)
- messages = list(
- map(
- lambda x: {
- "message": x["message"],
- "code": x["extensions"]["code"],
- },
- errors,
- )
- )
- raise labelbox.exceptions.LabelboxError(
- "Unknown error: %s" % str(messages)
- )
-
- # if we do return a proper error code, and didn't catch this above
- # reraise
- # this mainly catches a 401 for API access disabled for free tier
- # TODO: need to unify API errors to handle things more uniformly
- # in the SDK
- if response.status_code != requests.codes.ok:
- message = f"{response.status_code} {response.reason}"
- cause = r_json.get("message")
- raise labelbox.exceptions.LabelboxError(message, cause)
-
- return r_json["data"]
def upload_file(self, path: str) -> str:
"""Uploads given path to local file.
@@ -451,7 +205,7 @@ def upload_file(self, path: str) -> str:
Returns:
str, the URL of uploaded data.
Raises:
- labelbox.exceptions.LabelboxError: If upload failed.
+ LabelboxError: If upload failed.
"""
content_type, _ = mimetypes.guess_type(path)
filename = os.path.basename(path)
@@ -460,11 +214,7 @@ def upload_file(self, path: str) -> str:
content=f.read(), filename=filename, content_type=content_type
)
- @retry.Retry(
- predicate=retry.if_exception_type(
- labelbox.exceptions.InternalServerError
- )
- )
+ @retry.Retry(predicate=retry.if_exception_type(InternalServerError))
def upload_data(
self,
content: bytes,
@@ -484,7 +234,7 @@ def upload_data(
str, the URL of uploaded data.
Raises:
- labelbox.exceptions.LabelboxError: If upload failed.
+ LabelboxError: If upload failed.
"""
request_data = {
@@ -509,7 +259,7 @@ def upload_data(
if (filename and content_type)
else content
}
- headers = self._connection.headers.copy()
+ headers = self.connection.headers.copy()
headers.pop("Content-Type", None)
request = requests.Request(
"POST",
@@ -521,22 +271,20 @@ def upload_data(
prepped: requests.PreparedRequest = request.prepare()
- response = self._connection.send(prepped)
+ response = self.connection.send(prepped)
if response.status_code == 502:
error_502 = "502 Bad Gateway"
- raise labelbox.exceptions.InternalServerError(error_502)
+ raise InternalServerError(error_502)
elif response.status_code == 503:
- raise labelbox.exceptions.InternalServerError(response.text)
+ raise InternalServerError(response.text)
elif response.status_code == 520:
- raise labelbox.exceptions.InternalServerError(response.text)
+ raise InternalServerError(response.text)
try:
file_data = response.json().get("data", None)
except ValueError as e: # response is not valid JSON
- raise labelbox.exceptions.LabelboxError(
- "Failed to upload, unknown cause", e
- )
+ raise LabelboxError("Failed to upload, unknown cause", e)
if not file_data or not file_data.get("uploadFile", None):
try:
@@ -546,9 +294,7 @@ def upload_data(
)
except Exception:
error_msg = "Unknown error"
- raise labelbox.exceptions.LabelboxError(
- "Failed to upload, message: %s" % error_msg
- )
+ raise LabelboxError("Failed to upload, message: %s" % error_msg)
return file_data["uploadFile"]["url"]
@@ -561,7 +307,7 @@ def _get_single(self, db_object_type, uid):
Returns:
Object of `db_object_type`.
Raises:
- labelbox.exceptions.ResourceNotFoundError: If there is no object
+ ResourceNotFoundError: If there is no object
of the given type for the given ID.
"""
query_str, params = query.get_single(db_object_type, uid)
@@ -569,9 +315,7 @@ def _get_single(self, db_object_type, uid):
res = self.execute(query_str, params)
res = res and res.get(utils.camel_case(db_object_type.type_name()))
if res is None:
- raise labelbox.exceptions.ResourceNotFoundError(
- db_object_type, params
- )
+ raise ResourceNotFoundError(db_object_type, params)
else:
return db_object_type(self, res)
@@ -585,7 +329,7 @@ def get_project(self, project_id) -> Project:
Returns:
The sought Project.
Raises:
- labelbox.exceptions.ResourceNotFoundError: If there is no
+ ResourceNotFoundError: If there is no
Project with the given ID.
"""
return self._get_single(Entity.Project, project_id)
@@ -600,7 +344,7 @@ def get_dataset(self, dataset_id) -> Dataset:
Returns:
The sought Dataset.
Raises:
- labelbox.exceptions.ResourceNotFoundError: If there is no
+ ResourceNotFoundError: If there is no
Dataset with the given ID.
"""
return self._get_single(Entity.Dataset, dataset_id)
@@ -630,7 +374,7 @@ def _get_all(self, db_object_type, where, filter_deleted=True):
An iterable of `db_object_type` instances.
"""
if filter_deleted:
- not_deleted = db_object_type.deleted == False
+ not_deleted = db_object_type.deleted == False # noqa: E712 Needed for bit operator to combine comparisons
where = not_deleted if where is None else where & not_deleted
query_str, params = query.get_all(db_object_type, where)
@@ -724,13 +468,11 @@ def _create(self, db_object_type, data, extra_params={}):
res = self.execute(
query_string, params, raise_return_resource_not_found=True
)
-
if not res:
- raise labelbox.exceptions.LabelboxError(
+ raise LabelboxError(
"Failed to create %s" % db_object_type.type_name()
)
res = res["create%s" % db_object_type.type_name()]
-
return db_object_type(self, res)
def create_model_config(
@@ -784,9 +526,7 @@ def delete_model_config(self, id: str) -> bool:
params = {"id": id}
result = self.execute(query, params)
if not result:
- raise labelbox.exceptions.ResourceNotFoundError(
- Entity.ModelConfig, params
- )
+ raise ResourceNotFoundError(Entity.ModelConfig, params)
return result["deleteModelConfig"]["success"]
def create_dataset(
@@ -853,7 +593,18 @@ def create_dataset(
raise e
return dataset
- def create_project(self, **kwargs) -> Project:
+ def create_project(
+ self,
+ name: str,
+ media_type: MediaType,
+ description: Optional[str] = None,
+ quality_modes: Optional[Set[QualityMode]] = {
+ QualityMode.Benchmark,
+ QualityMode.Consensus,
+ },
+ is_benchmark_enabled: Optional[bool] = None,
+ is_consensus_enabled: Optional[bool] = None,
+ ) -> Project:
"""Creates a Project object on the server.
Attribute values are passed as keyword arguments.
@@ -862,71 +613,44 @@ def create_project(self, **kwargs) -> Project:
name="",
description="",
media_type=MediaType.Image,
- queue_mode=QueueMode.Batch
)
Args:
name (str): A name for the project
description (str): A short summary for the project
media_type (MediaType): The type of assets that this project will accept
- queue_mode (Optional[QueueMode]): The queue mode to use
- quality_mode (Optional[QualityMode]): The quality mode to use (e.g. Benchmark, Consensus). Defaults to
- Benchmark
quality_modes (Optional[List[QualityMode]]): The quality modes to use (e.g. Benchmark, Consensus). Defaults to
Benchmark.
+ is_benchmark_enabled (Optional[bool]): Whether the project supports benchmark. Defaults to None.
+ is_consensus_enabled (Optional[bool]): Whether the project supports consensus. Defaults to None.
Returns:
A new Project object.
Raises:
- InvalidAttributeError: If the Project type does not contain
- any of the attribute names given in kwargs.
-
- NOTE: the following attributes are used only in chat model evaluation projects:
- dataset_name_or_id, append_to_existing_dataset, data_row_count, editor_task_type
- They are not used for general projects and not supported in this method
+ ValueError: If inputs are invalid.
"""
- # The following arguments are not supported for general projects, only for chat model evaluation projects
- kwargs.pop("dataset_name_or_id", None)
- kwargs.pop("append_to_existing_dataset", None)
- kwargs.pop("data_row_count", None)
- kwargs.pop("editor_task_type", None)
- return self._create_project(**kwargs)
-
- @overload
- def create_model_evaluation_project(
- self,
- dataset_name: str,
- dataset_id: str = None,
- data_row_count: int = 100,
- **kwargs,
- ) -> Project:
- pass
-
- @overload
- def create_model_evaluation_project(
- self,
- dataset_id: str,
- dataset_name: str = None,
- data_row_count: int = 100,
- **kwargs,
- ) -> Project:
- pass
-
- @overload
- def create_model_evaluation_project(
- self,
- dataset_id: Optional[str] = None,
- dataset_name: Optional[str] = None,
- data_row_count: Optional[int] = None,
- **kwargs,
- ) -> Project:
- pass
+ input = {
+ "name": name,
+ "description": description,
+ "media_type": media_type,
+ "quality_modes": quality_modes,
+ "is_benchmark_enabled": is_benchmark_enabled,
+ "is_consensus_enabled": is_consensus_enabled,
+ }
+ return self._create_project(_CoreProjectInput(**input))
def create_model_evaluation_project(
self,
+ name: str,
+ description: Optional[str] = None,
+ quality_modes: Optional[Set[QualityMode]] = {
+ QualityMode.Benchmark,
+ QualityMode.Consensus,
+ },
+ is_benchmark_enabled: Optional[bool] = None,
+ is_consensus_enabled: Optional[bool] = None,
dataset_id: Optional[str] = None,
dataset_name: Optional[str] = None,
data_row_count: Optional[int] = None,
- **kwargs,
) -> Project:
"""
Use this method exclusively to create a chat model evaluation project.
@@ -934,12 +658,12 @@ def create_model_evaluation_project(
dataset_name: When creating a new dataset, pass the name
dataset_id: When using an existing dataset, pass the id
data_row_count: The number of data row assets to use for the project
- **kwargs: Additional parameters to pass to the the create_project method
+ See create_project for additional parameters
Returns:
Project: The created project
Examples:
- >>> client.create_model_evaluation_project(name=project_name, dataset_name="new data set")
+ >>> client.create_model_evaluation_project(name=project_name, media_type=dataset_name="new data set")
>>> This creates a new dataset with a default number of rows (100), creates new project and assigns a batch of the newly created datarows to the project.
>>> client.create_model_evaluation_project(name=project_name, dataset_name="new data set", data_row_count=10)
@@ -959,51 +683,75 @@ def create_model_evaluation_project(
append_to_existing_dataset = bool(dataset_id)
if dataset_name_or_id:
- kwargs["dataset_name_or_id"] = dataset_name_or_id
- kwargs["append_to_existing_dataset"] = append_to_existing_dataset
if data_row_count is None:
data_row_count = 100
- if data_row_count < 0:
- raise ValueError("data_row_count must be a positive integer.")
- kwargs["data_row_count"] = data_row_count
warnings.warn(
"Automatic generation of data rows of live model evaluation projects is deprecated. dataset_name_or_id, append_to_existing_dataset, data_row_count will be removed in a future version.",
DeprecationWarning,
)
- kwargs["media_type"] = MediaType.Conversational
- kwargs["editor_task_type"] = EditorTaskType.ModelChatEvaluation.value
+ media_type = MediaType.Conversational
+ editor_task_type = EditorTaskType.ModelChatEvaluation
- return self._create_project(**kwargs)
+ input = {
+ "name": name,
+ "description": description,
+ "media_type": media_type,
+ "quality_modes": quality_modes,
+ "is_benchmark_enabled": is_benchmark_enabled,
+ "is_consensus_enabled": is_consensus_enabled,
+ "dataset_name_or_id": dataset_name_or_id,
+ "append_to_existing_dataset": append_to_existing_dataset,
+ "data_row_count": data_row_count,
+ "editor_task_type": editor_task_type,
+ }
+ return self._create_project(_CoreProjectInput(**input))
- def create_offline_model_evaluation_project(self, **kwargs) -> Project:
+ def create_offline_model_evaluation_project(
+ self,
+ name: str,
+ description: Optional[str] = None,
+ quality_modes: Optional[Set[QualityMode]] = {
+ QualityMode.Benchmark,
+ QualityMode.Consensus,
+ },
+ is_benchmark_enabled: Optional[bool] = None,
+ is_consensus_enabled: Optional[bool] = None,
+ ) -> Project:
"""
Creates a project for offline model evaluation.
Args:
- **kwargs: Additional parameters to pass see the create_project method
+ See create_project for parameters
Returns:
Project: The created project
"""
- kwargs["media_type"] = (
- MediaType.Conversational
- ) # Only Conversational is supported
- kwargs["editor_task_type"] = (
- EditorTaskType.OfflineModelChatEvaluation.value
- ) # Special editor task type for offline model evaluation
-
- # The following arguments are not supported for offline model evaluation
- kwargs.pop("dataset_name_or_id", None)
- kwargs.pop("append_to_existing_dataset", None)
- kwargs.pop("data_row_count", None)
-
- return self._create_project(**kwargs)
+ input = {
+ "name": name,
+ "description": description,
+ "media_type": MediaType.Conversational,
+ "quality_modes": quality_modes,
+ "is_benchmark_enabled": is_benchmark_enabled,
+ "is_consensus_enabled": is_consensus_enabled,
+ "editor_task_type": EditorTaskType.OfflineModelChatEvaluation,
+ }
+ return self._create_project(_CoreProjectInput(**input))
def create_prompt_response_generation_project(
self,
+ name: str,
+ media_type: MediaType,
+ description: Optional[str] = None,
+ auto_audit_percentage: Optional[float] = None,
+ auto_audit_number_of_labels: Optional[int] = None,
+ quality_modes: Optional[Set[QualityMode]] = {
+ QualityMode.Benchmark,
+ QualityMode.Consensus,
+ },
+ is_benchmark_enabled: Optional[bool] = None,
+ is_consensus_enabled: Optional[bool] = None,
dataset_id: Optional[str] = None,
dataset_name: Optional[str] = None,
data_row_count: int = 100,
- **kwargs,
) -> Project:
"""
Use this method exclusively to create a prompt and response generation project.
@@ -1012,7 +760,8 @@ def create_prompt_response_generation_project(
dataset_name: When creating a new dataset, pass the name
dataset_id: When using an existing dataset, pass the id
data_row_count: The number of data row assets to use for the project
- **kwargs: Additional parameters to pass see the create_project method
+ media_type: The type of assets that this project will accept. Limited to LLMPromptCreation and LLMPromptResponseCreation
+ See create_project for additional parameters
Returns:
Project: The created project
@@ -1042,9 +791,6 @@ def create_prompt_response_generation_project(
"Only provide a dataset_name or dataset_id, not both."
)
- if data_row_count <= 0:
- raise ValueError("data_row_count must be a positive integer.")
-
if dataset_id:
append_to_existing_dataset = True
dataset_name_or_id = dataset_id
@@ -1052,7 +798,7 @@ def create_prompt_response_generation_project(
append_to_existing_dataset = False
dataset_name_or_id = dataset_name
- if "media_type" in kwargs and kwargs.get("media_type") not in [
+ if media_type not in [
MediaType.LLMPromptCreation,
MediaType.LLMPromptResponseCreation,
]:
@@ -1060,133 +806,52 @@ def create_prompt_response_generation_project(
"media_type must be either LLMPromptCreation or LLMPromptResponseCreation"
)
- kwargs["dataset_name_or_id"] = dataset_name_or_id
- kwargs["append_to_existing_dataset"] = append_to_existing_dataset
- kwargs["data_row_count"] = data_row_count
-
- kwargs.pop("editor_task_type", None)
-
- return self._create_project(**kwargs)
+ input = {
+ "name": name,
+ "description": description,
+ "media_type": media_type,
+ "auto_audit_percentage": auto_audit_percentage,
+ "auto_audit_number_of_labels": auto_audit_number_of_labels,
+ "quality_modes": quality_modes,
+ "is_benchmark_enabled": is_benchmark_enabled,
+ "is_consensus_enabled": is_consensus_enabled,
+ "dataset_name_or_id": dataset_name_or_id,
+ "append_to_existing_dataset": append_to_existing_dataset,
+ "data_row_count": data_row_count,
+ }
+ return self._create_project(_CoreProjectInput(**input))
- def create_response_creation_project(self, **kwargs) -> Project:
+ def create_response_creation_project(
+ self,
+ name: str,
+ description: Optional[str] = None,
+ quality_modes: Optional[Set[QualityMode]] = {
+ QualityMode.Benchmark,
+ QualityMode.Consensus,
+ },
+ is_benchmark_enabled: Optional[bool] = None,
+ is_consensus_enabled: Optional[bool] = None,
+ ) -> Project:
"""
Creates a project for response creation.
Args:
- **kwargs: Additional parameters to pass see the create_project method
+ See create_project for parameters
Returns:
Project: The created project
"""
- kwargs["media_type"] = MediaType.Text # Only Text is supported
- kwargs["editor_task_type"] = (
- EditorTaskType.ResponseCreation.value
- ) # Special editor task type for response creation projects
-
- # The following arguments are not supported for response creation projects
- kwargs.pop("dataset_name_or_id", None)
- kwargs.pop("append_to_existing_dataset", None)
- kwargs.pop("data_row_count", None)
-
- return self._create_project(**kwargs)
-
- def _create_project(self, **kwargs) -> Project:
- auto_audit_percentage = kwargs.get("auto_audit_percentage")
- auto_audit_number_of_labels = kwargs.get("auto_audit_number_of_labels")
- if (
- auto_audit_percentage is not None
- or auto_audit_number_of_labels is not None
- ):
- raise ValueError(
- "quality_modes must be set instead of auto_audit_percentage or auto_audit_number_of_labels."
- )
-
- name = kwargs.get("name")
- if name is None or not name.strip():
- raise ValueError("project name must be a valid string.")
-
- queue_mode = kwargs.get("queue_mode")
- if queue_mode is QueueMode.Dataset:
- raise ValueError(
- "Dataset queue mode is deprecated. Please prefer Batch queue mode."
- )
- elif queue_mode is QueueMode.Batch:
- logger.warning(
- "Passing a queue mode of batch is redundant and will soon no longer be supported."
- )
-
- media_type = kwargs.get("media_type")
- if media_type and MediaType.is_supported(media_type):
- media_type_value = media_type.value
- elif media_type:
- raise TypeError(
- f"{media_type} is not a valid media type. Use"
- f" any of {MediaType.get_supported_members()}"
- " from MediaType. Example: MediaType.Image."
- )
- else:
- logger.warning(
- "Creating a project without specifying media_type"
- " through this method will soon no longer be supported."
- )
- media_type_value = None
-
- quality_modes = kwargs.get("quality_modes")
- quality_mode = kwargs.get("quality_mode")
- if quality_mode:
- logger.warning(
- "Passing quality_mode is deprecated and will soon no longer be supported. Use quality_modes instead."
- )
-
- if quality_modes and quality_mode:
- raise ValueError(
- "Cannot use both quality_modes and quality_mode at the same time. Use one or the other."
- )
-
- if not quality_modes and not quality_mode:
- logger.info("Defaulting quality modes to Benchmark and Consensus.")
-
- data = kwargs
- data.pop("quality_modes", None)
- data.pop("quality_mode", None)
-
- # check if quality_modes is a set, if not, convert to set
- quality_modes_set = quality_modes
- if quality_modes and not isinstance(quality_modes, set):
- quality_modes_set = set(quality_modes)
- if quality_mode:
- quality_modes_set = {quality_mode}
-
- if (
- quality_modes_set is None
- or len(quality_modes_set) == 0
- or quality_modes_set
- == {QualityMode.Benchmark, QualityMode.Consensus}
- ):
- data["auto_audit_number_of_labels"] = (
- CONSENSUS_AUTO_AUDIT_NUMBER_OF_LABELS
- )
- data["auto_audit_percentage"] = CONSENSUS_AUTO_AUDIT_PERCENTAGE
- data["is_benchmark_enabled"] = True
- data["is_consensus_enabled"] = True
- elif quality_modes_set == {QualityMode.Benchmark}:
- data["auto_audit_number_of_labels"] = (
- BENCHMARK_AUTO_AUDIT_NUMBER_OF_LABELS
- )
- data["auto_audit_percentage"] = BENCHMARK_AUTO_AUDIT_PERCENTAGE
- data["is_benchmark_enabled"] = True
- elif quality_modes_set == {QualityMode.Consensus}:
- data["auto_audit_number_of_labels"] = (
- CONSENSUS_AUTO_AUDIT_NUMBER_OF_LABELS
- )
- data["auto_audit_percentage"] = CONSENSUS_AUTO_AUDIT_PERCENTAGE
- data["is_consensus_enabled"] = True
- else:
- raise ValueError(
- f"{quality_modes_set} is not a valid quality modes set. Allowed values are [Benchmark, Consensus]"
- )
+ input = {
+ "name": name,
+ "description": description,
+ "media_type": MediaType.Text, # Only Text is supported
+ "quality_modes": quality_modes,
+ "is_benchmark_enabled": is_benchmark_enabled,
+ "is_consensus_enabled": is_consensus_enabled,
+ "editor_task_type": EditorTaskType.ResponseCreation.value, # Special editor task type for response creation projects
+ }
+ return self._create_project(_CoreProjectInput(**input))
- params = {**data}
- if media_type_value:
- params["media_type"] = media_type_value
+ def _create_project(self, input: _CoreProjectInput) -> Project:
+ params = input.model_dump(exclude_none=True)
extra_params = {
Field.String("dataset_name_or_id"): params.pop(
@@ -1222,7 +887,7 @@ def get_data_row_by_global_key(self, global_key: str) -> DataRow:
"""
res = self.get_data_row_ids_for_global_keys([global_key])
if res["status"] != "SUCCESS":
- raise labelbox.exceptions.ResourceNotFoundError(
+ raise ResourceNotFoundError(
Entity.DataRow, {global_key: global_key}
)
data_row_id = res["results"][0]
@@ -1250,7 +915,7 @@ def get_model(self, model_id) -> Model:
Returns:
The sought Model.
Raises:
- labelbox.exceptions.ResourceNotFoundError: If there is no
+ ResourceNotFoundError: If there is no
Model with the given ID.
"""
return self._get_single(Entity.Model, model_id)
@@ -1493,10 +1158,10 @@ def delete_unused_feature_schema(self, feature_schema_id: str) -> None:
+ "/feature-schemas/"
+ urllib.parse.quote(feature_schema_id)
)
- response = self._connection.delete(endpoint)
+ response = self.connection.delete(endpoint)
if response.status_code != requests.codes.no_content:
- raise labelbox.exceptions.LabelboxError(
+ raise LabelboxError(
"Failed to delete the feature schema, message: "
+ str(response.json()["message"])
)
@@ -1514,10 +1179,10 @@ def delete_unused_ontology(self, ontology_id: str) -> None:
+ "/ontologies/"
+ urllib.parse.quote(ontology_id)
)
- response = self._connection.delete(endpoint)
+ response = self.connection.delete(endpoint)
if response.status_code != requests.codes.no_content:
- raise labelbox.exceptions.LabelboxError(
+ raise LabelboxError(
"Failed to delete the ontology, message: "
+ str(response.json()["message"])
)
@@ -1542,12 +1207,12 @@ def update_feature_schema_title(
+ urllib.parse.quote(feature_schema_id)
+ "/definition"
)
- response = self._connection.patch(endpoint, json={"title": title})
+ response = self.connection.patch(endpoint, json={"title": title})
if response.status_code == requests.codes.ok:
return self.get_feature_schema(feature_schema_id)
else:
- raise labelbox.exceptions.LabelboxError(
+ raise LabelboxError(
"Failed to update the feature schema, message: "
+ str(response.json()["message"])
)
@@ -1576,14 +1241,14 @@ def upsert_feature_schema(self, feature_schema: Dict) -> FeatureSchema:
+ "/feature-schemas/"
+ urllib.parse.quote(feature_schema_id)
)
- response = self._connection.put(
+ response = self.connection.put(
endpoint, json={"normalized": json.dumps(feature_schema)}
)
if response.status_code == requests.codes.ok:
return self.get_feature_schema(response.json()["schemaId"])
else:
- raise labelbox.exceptions.LabelboxError(
+ raise LabelboxError(
"Failed to upsert the feature schema, message: "
+ str(response.json()["message"])
)
@@ -1609,9 +1274,9 @@ def insert_feature_schema_into_ontology(
+ "/feature-schemas/"
+ urllib.parse.quote(feature_schema_id)
)
- response = self._connection.post(endpoint, json={"position": position})
+ response = self.connection.post(endpoint, json={"position": position})
if response.status_code != requests.codes.created:
- raise labelbox.exceptions.LabelboxError(
+ raise LabelboxError(
"Failed to insert the feature schema into the ontology, message: "
+ str(response.json()["message"])
)
@@ -1631,12 +1296,12 @@ def get_unused_ontologies(self, after: str = None) -> List[str]:
"""
endpoint = self.rest_endpoint + "/ontologies/unused"
- response = self._connection.get(endpoint, json={"after": after})
+ response = self.connection.get(endpoint, json={"after": after})
if response.status_code == requests.codes.ok:
return response.json()
else:
- raise labelbox.exceptions.LabelboxError(
+ raise LabelboxError(
"Failed to get unused ontologies, message: "
+ str(response.json()["message"])
)
@@ -1656,12 +1321,12 @@ def get_unused_feature_schemas(self, after: str = None) -> List[str]:
"""
endpoint = self.rest_endpoint + "/feature-schemas/unused"
- response = self._connection.get(endpoint, json={"after": after})
+ response = self.connection.get(endpoint, json={"after": after})
if response.status_code == requests.codes.ok:
return response.json()
else:
- raise labelbox.exceptions.LabelboxError(
+ raise LabelboxError(
"Failed to get unused feature schemas, message: "
+ str(response.json()["message"])
)
@@ -1957,12 +1622,12 @@ def _format_failed_rows(
elif (
res["assignGlobalKeysToDataRowsResult"]["jobStatus"] == "FAILED"
):
- raise labelbox.exceptions.LabelboxError(
+ raise LabelboxError(
"Job assign_global_keys_to_data_rows failed."
)
current_time = time.time()
if current_time - start_time > timeout_seconds:
- raise labelbox.exceptions.TimeoutError(
+ raise TimeoutError(
"Timed out waiting for assign_global_keys_to_data_rows job to complete."
)
time.sleep(sleep_time)
@@ -1973,10 +1638,6 @@ def get_data_row_ids_for_global_keys(
"""
Gets data row ids for a list of global keys.
- Deprecation Notice: This function will soon no longer return 'Deleted Data Rows'
- as part of the 'results'. Global keys for deleted data rows will soon be placed
- under 'Data Row not found' portion.
-
Args:
A list of global keys
Returns:
@@ -2066,12 +1727,10 @@ def _format_failed_rows(
return {"status": status, "results": results, "errors": errors}
elif res["dataRowsForGlobalKeysResult"]["jobStatus"] == "FAILED":
- raise labelbox.exceptions.LabelboxError(
- "Job dataRowsForGlobalKeys failed."
- )
+ raise LabelboxError("Job dataRowsForGlobalKeys failed.")
current_time = time.time()
if current_time - start_time > timeout_seconds:
- raise labelbox.exceptions.TimeoutError(
+ raise TimeoutError(
"Timed out waiting for get_data_rows_for_global_keys job to complete."
)
time.sleep(sleep_time)
@@ -2170,12 +1829,10 @@ def _format_failed_rows(
return {"status": status, "results": results, "errors": errors}
elif res["clearGlobalKeysResult"]["jobStatus"] == "FAILED":
- raise labelbox.exceptions.LabelboxError(
- "Job clearGlobalKeys failed."
- )
+ raise LabelboxError("Job clearGlobalKeys failed.")
current_time = time.time()
if current_time - start_time > timeout_seconds:
- raise labelbox.exceptions.TimeoutError(
+ raise TimeoutError(
"Timed out waiting for clear_global_keys job to complete."
)
time.sleep(sleep_time)
@@ -2224,7 +1881,7 @@ def is_feature_schema_archived(
+ "/ontologies/"
+ urllib.parse.quote(ontology_id)
)
- response = self._connection.get(ontology_endpoint)
+ response = self.connection.get(ontology_endpoint)
if response.status_code == requests.codes.ok:
feature_schema_nodes = response.json()["featureSchemaNodes"]
@@ -2240,16 +1897,14 @@ def is_feature_schema_archived(
if filtered_feature_schema_nodes:
return bool(filtered_feature_schema_nodes[0]["archived"])
else:
- raise labelbox.exceptions.LabelboxError(
+ raise LabelboxError(
"The specified feature schema was not in the ontology."
)
elif response.status_code == 404:
- raise labelbox.exceptions.ResourceNotFoundError(
- Ontology, ontology_id
- )
+ raise ResourceNotFoundError(Ontology, ontology_id)
else:
- raise labelbox.exceptions.LabelboxError(
+ raise LabelboxError(
"Failed to get the feature schema archived status."
)
@@ -2276,9 +1931,7 @@ def get_model_slice(self, slice_id) -> ModelSlice:
"""
res = self.execute(query_str, {"id": slice_id})
if res is None or res["getSavedQuery"] is None:
- raise labelbox.exceptions.ResourceNotFoundError(
- ModelSlice, slice_id
- )
+ raise ResourceNotFoundError(ModelSlice, slice_id)
return Entity.ModelSlice(self, res["getSavedQuery"])
@@ -2308,15 +1961,15 @@ def delete_feature_schema_from_ontology(
+ "/feature-schemas/"
+ urllib.parse.quote(feature_schema_id)
)
- response = self._connection.delete(ontology_endpoint)
+ response = self.connection.delete(ontology_endpoint)
if response.status_code == requests.codes.ok:
response_json = response.json()
- if response_json["archived"] == True:
+ if response_json["archived"] is True:
logger.info(
"Feature schema was archived from the ontology because it had associated labels."
)
- elif response_json["deleted"] == True:
+ elif response_json["deleted"] is True:
logger.info(
"Feature schema was successfully removed from the ontology"
)
@@ -2325,7 +1978,7 @@ def delete_feature_schema_from_ontology(
result.deleted = bool(response_json["deleted"])
return result
else:
- raise labelbox.exceptions.LabelboxError(
+ raise LabelboxError(
"Failed to remove feature schema from ontology, message: "
+ str(response.json()["message"])
)
@@ -2350,14 +2003,12 @@ def unarchive_feature_schema_node(
+ urllib.parse.quote(root_feature_schema_id)
+ "/unarchive"
)
- response = self._connection.patch(ontology_endpoint)
+ response = self.connection.patch(ontology_endpoint)
if response.status_code == requests.codes.ok:
if not bool(response.json()["unarchived"]):
- raise labelbox.exceptions.LabelboxError(
- "Failed unarchive the feature schema."
- )
+ raise LabelboxError("Failed unarchive the feature schema.")
else:
- raise labelbox.exceptions.LabelboxError(
+ raise LabelboxError(
"Failed unarchive the feature schema node, message: ",
response.text,
)
@@ -2586,9 +2237,7 @@ def get_embedding_by_name(self, name: str) -> Embedding:
for e in embeddings:
if e.name == name:
return e
- raise labelbox.exceptions.ResourceNotFoundError(
- Embedding, dict(name=name)
- )
+ raise ResourceNotFoundError(Embedding, dict(name=name))
def upsert_label_feedback(
self, label_id: str, feedback: str, scores: Dict[str, float]
@@ -2635,8 +2284,7 @@ def upsert_label_feedback(
scores_raw = res["upsertAutoQaLabelFeedback"]["scores"]
return [
- labelbox.LabelScore(name=x["name"], score=x["score"])
- for x in scores_raw
+ LabelScore(name=x["name"], score=x["score"]) for x in scores_raw
]
def get_labeling_service_dashboards(
@@ -2712,7 +2360,7 @@ def get_task_by_id(self, task_id: str) -> Union[Task, DataUpsertTask]:
result = self.execute(query, {"userId": user.uid, "taskId": task_id})
data = result.get("user", {}).get("createdTasks", [])
if not data:
- raise labelbox.exceptions.ResourceNotFoundError(
+ raise ResourceNotFoundError(
message=f"The task {task_id} does not exist."
)
task_data = data[0]
diff --git a/libs/labelbox/src/labelbox/data/annotation_types/__init__.py b/libs/labelbox/src/labelbox/data/annotation_types/__init__.py
index 7908bc242..1a78127e1 100644
--- a/libs/labelbox/src/labelbox/data/annotation_types/__init__.py
+++ b/libs/labelbox/src/labelbox/data/annotation_types/__init__.py
@@ -32,18 +32,8 @@
from .classification import Radio
from .classification import Text
-from .data import AudioData
-from .data import ConversationData
-from .data import DicomData
-from .data import DocumentData
-from .data import HTMLData
-from .data import ImageData
+from .data import GenericDataRowData
from .data import MaskData
-from .data import TextData
-from .data import VideoData
-from .data import LlmPromptResponseCreationData
-from .data import LlmPromptCreationData
-from .data import LlmResponseCreationData
from .label import Label
from .collection import LabelGenerator
diff --git a/libs/labelbox/src/labelbox/data/annotation_types/collection.py b/libs/labelbox/src/labelbox/data/annotation_types/collection.py
index a636a3b3a..42d2a1184 100644
--- a/libs/labelbox/src/labelbox/data/annotation_types/collection.py
+++ b/libs/labelbox/src/labelbox/data/annotation_types/collection.py
@@ -21,62 +21,6 @@ def __init__(self, data: Generator[Label, None, None], *args, **kwargs):
self._fns = {}
super().__init__(data, *args, **kwargs)
- def assign_feature_schema_ids(
- self, ontology_builder: "ontology.OntologyBuilder"
- ) -> "LabelGenerator":
- def _assign_ids(label: Label):
- label.assign_feature_schema_ids(ontology_builder)
- return label
-
- warnings.warn(
- "This method is deprecated and will be "
- "removed in a future release. Feature schema ids"
- " are no longer required for importing."
- )
- self._fns["assign_feature_schema_ids"] = _assign_ids
- return self
-
- def add_url_to_data(
- self, signer: Callable[[bytes], str]
- ) -> "LabelGenerator":
- """
- Creates signed urls for the data
- Only uploads url if one doesn't already exist.
-
- Args:
- signer: A function that accepts bytes and returns a signed url.
- Returns:
- LabelGenerator that signs urls as data is accessed
- """
-
- def _add_url_to_data(label: Label):
- label.add_url_to_data(signer)
- return label
-
- self._fns["add_url_to_data"] = _add_url_to_data
- return self
-
- def add_to_dataset(
- self, dataset: "Entity.Dataset", signer: Callable[[bytes], str]
- ) -> "LabelGenerator":
- """
- Creates data rows from each labels data object and attaches the data to the given dataset.
- Updates the label's data object to have the same external_id and uid as the data row.
-
- Args:
- dataset: labelbox dataset object to add the new data row to
- signer: A function that accepts bytes and returns a signed url.
- Returns:
- LabelGenerator that updates references to the new data rows as data is accessed
- """
-
- def _add_to_dataset(label: Label):
- label.create_data_row(dataset, signer)
- return label
-
- self._fns["assign_datarow_ids"] = _add_to_dataset
- return self
-
def add_url_to_masks(
self, signer: Callable[[bytes], str]
) -> "LabelGenerator":
diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/__init__.py b/libs/labelbox/src/labelbox/data/annotation_types/data/__init__.py
index 2522b2741..8d5e7289b 100644
--- a/libs/labelbox/src/labelbox/data/annotation_types/data/__init__.py
+++ b/libs/labelbox/src/labelbox/data/annotation_types/data/__init__.py
@@ -1,12 +1,2 @@
-from .audio import AudioData
-from .conversation import ConversationData
-from .dicom import DicomData
-from .document import DocumentData
-from .html import HTMLData
-from .raster import ImageData
from .raster import MaskData
-from .text import TextData
-from .video import VideoData
-from .llm_prompt_response_creation import LlmPromptResponseCreationData
-from .llm_prompt_creation import LlmPromptCreationData
-from .llm_response_creation import LlmResponseCreationData
+from .generic_data_row_data import GenericDataRowData
diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/audio.py b/libs/labelbox/src/labelbox/data/annotation_types/data/audio.py
deleted file mode 100644
index 916fca99d..000000000
--- a/libs/labelbox/src/labelbox/data/annotation_types/data/audio.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from labelbox.typing_imports import Literal
-from labelbox.utils import _NoCoercionMixin
-from .base_data import BaseData
-
-
-class AudioData(BaseData, _NoCoercionMixin):
- class_name: Literal["AudioData"] = "AudioData"
diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/conversation.py b/libs/labelbox/src/labelbox/data/annotation_types/data/conversation.py
deleted file mode 100644
index ef6507dca..000000000
--- a/libs/labelbox/src/labelbox/data/annotation_types/data/conversation.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from labelbox.typing_imports import Literal
-from labelbox.utils import _NoCoercionMixin
-from .base_data import BaseData
-
-
-class ConversationData(BaseData, _NoCoercionMixin):
- class_name: Literal["ConversationData"] = "ConversationData"
diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/dicom.py b/libs/labelbox/src/labelbox/data/annotation_types/data/dicom.py
deleted file mode 100644
index ae4c377dc..000000000
--- a/libs/labelbox/src/labelbox/data/annotation_types/data/dicom.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from labelbox.typing_imports import Literal
-from labelbox.utils import _NoCoercionMixin
-from .base_data import BaseData
-
-
-class DicomData(BaseData, _NoCoercionMixin):
- class_name: Literal["DicomData"] = "DicomData"
diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/document.py b/libs/labelbox/src/labelbox/data/annotation_types/data/document.py
deleted file mode 100644
index 810a3ed3e..000000000
--- a/libs/labelbox/src/labelbox/data/annotation_types/data/document.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from labelbox.typing_imports import Literal
-from labelbox.utils import _NoCoercionMixin
-from .base_data import BaseData
-
-
-class DocumentData(BaseData, _NoCoercionMixin):
- class_name: Literal["DocumentData"] = "DocumentData"
diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/html.py b/libs/labelbox/src/labelbox/data/annotation_types/data/html.py
deleted file mode 100644
index 7a78fcb7b..000000000
--- a/libs/labelbox/src/labelbox/data/annotation_types/data/html.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from labelbox.typing_imports import Literal
-from labelbox.utils import _NoCoercionMixin
-from .base_data import BaseData
-
-
-class HTMLData(BaseData, _NoCoercionMixin):
- class_name: Literal["HTMLData"] = "HTMLData"
diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/llm_prompt_creation.py b/libs/labelbox/src/labelbox/data/annotation_types/data/llm_prompt_creation.py
deleted file mode 100644
index a1b0450bc..000000000
--- a/libs/labelbox/src/labelbox/data/annotation_types/data/llm_prompt_creation.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from labelbox.typing_imports import Literal
-from labelbox.utils import _NoCoercionMixin
-from .base_data import BaseData
-
-
-class LlmPromptCreationData(BaseData, _NoCoercionMixin):
- class_name: Literal["LlmPromptCreationData"] = "LlmPromptCreationData"
diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/llm_prompt_response_creation.py b/libs/labelbox/src/labelbox/data/annotation_types/data/llm_prompt_response_creation.py
deleted file mode 100644
index a8dfce894..000000000
--- a/libs/labelbox/src/labelbox/data/annotation_types/data/llm_prompt_response_creation.py
+++ /dev/null
@@ -1,9 +0,0 @@
-from labelbox.typing_imports import Literal
-from labelbox.utils import _NoCoercionMixin
-from .base_data import BaseData
-
-
-class LlmPromptResponseCreationData(BaseData, _NoCoercionMixin):
- class_name: Literal["LlmPromptResponseCreationData"] = (
- "LlmPromptResponseCreationData"
- )
diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/llm_response_creation.py b/libs/labelbox/src/labelbox/data/annotation_types/data/llm_response_creation.py
deleted file mode 100644
index a8963ed3f..000000000
--- a/libs/labelbox/src/labelbox/data/annotation_types/data/llm_response_creation.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from labelbox.typing_imports import Literal
-from labelbox.utils import _NoCoercionMixin
-from .base_data import BaseData
-
-
-class LlmResponseCreationData(BaseData, _NoCoercionMixin):
- class_name: Literal["LlmResponseCreationData"] = "LlmResponseCreationData"
diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/raster.py b/libs/labelbox/src/labelbox/data/annotation_types/data/raster.py
index ba4c6485f..8702b8aeb 100644
--- a/libs/labelbox/src/labelbox/data/annotation_types/data/raster.py
+++ b/libs/labelbox/src/labelbox/data/annotation_types/data/raster.py
@@ -1,17 +1,15 @@
from abc import ABC
from io import BytesIO
-from typing import Callable, Optional, Union
-from typing_extensions import Literal
+from typing import Callable, Literal, Optional, Union
-from PIL import Image
+import numpy as np
+import requests
from google.api_core import retry
+from lbox.exceptions import InternalServerError
+from PIL import Image
+from pydantic import BaseModel, ConfigDict, model_validator
from requests.exceptions import ConnectTimeout
-import requests
-import numpy as np
-from pydantic import BaseModel, model_validator, ConfigDict
-from labelbox.exceptions import InternalServerError
-from .base_data import BaseData
from ..types import TypedArray
@@ -172,7 +170,7 @@ def validate_args(self, values):
uid = self.uid
global_key = self.global_key
if (
- uid == file_path == im_bytes == url == global_key == None
+ uid == file_path == im_bytes == url == global_key is None
and arr is None
):
raise ValueError(
@@ -191,7 +189,9 @@ def validate_args(self, values):
return self
def __repr__(self) -> str:
- symbol_or_none = lambda data: "..." if data is not None else None
+ def symbol_or_none(data):
+ return "..." if data is not None else None
+
return (
f"{self.__class__.__name__}(im_bytes={symbol_or_none(self.im_bytes)},"
f"file_path={self.file_path},"
@@ -220,6 +220,3 @@ class MaskData(RasterData):
url: Optional[str] = None
arr: Optional[TypedArray[Literal['uint8']]] = None
"""
-
-
-class ImageData(RasterData, BaseData): ...
diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/text.py b/libs/labelbox/src/labelbox/data/annotation_types/data/text.py
deleted file mode 100644
index fe4c222d3..000000000
--- a/libs/labelbox/src/labelbox/data/annotation_types/data/text.py
+++ /dev/null
@@ -1,115 +0,0 @@
-from typing import Callable, Optional
-
-import requests
-from requests.exceptions import ConnectTimeout
-from google.api_core import retry
-
-from pydantic import ConfigDict, model_validator
-from labelbox.exceptions import InternalServerError
-from labelbox.typing_imports import Literal
-from labelbox.utils import _NoCoercionMixin
-from .base_data import BaseData
-
-
-class TextData(BaseData, _NoCoercionMixin):
- """
- Represents text data. Requires arg file_path, text, or url
-
- >>> TextData(text="")
-
- Args:
- file_path (str)
- text (str)
- url (str)
- """
-
- class_name: Literal["TextData"] = "TextData"
- file_path: Optional[str] = None
- text: Optional[str] = None
- url: Optional[str] = None
- model_config = ConfigDict(extra="forbid")
-
- @property
- def value(self) -> str:
- """
- Property that unifies the data access pattern for all references to the text.
-
- Returns:
- string representation of the text
- """
- if self.text:
- return self.text
- elif self.file_path:
- with open(self.file_path, "r") as file:
- text = file.read()
- self.text = text
- return text
- elif self.url:
- text = self.fetch_remote()
- self.text = text
- return text
- else:
- raise ValueError("Must set either url, file_path or im_bytes")
-
- def set_fetch_fn(self, fn):
- object.__setattr__(self, "fetch_remote", lambda: fn(self))
-
- @retry.Retry(
- deadline=15.0,
- predicate=retry.if_exception_type(ConnectTimeout, InternalServerError),
- )
- def fetch_remote(self) -> str:
- """
- Method for accessing url.
-
- If url is not publicly accessible or requires another access pattern
- simply override this function
- """
- response = requests.get(self.url)
- if response.status_code in [500, 502, 503, 504]:
- raise InternalServerError(response.text)
- response.raise_for_status()
- return response.text
-
- @retry.Retry(deadline=15.0)
- def create_url(self, signer: Callable[[bytes], str]) -> None:
- """
- Utility for creating a url from any of the other text references.
-
- Args:
- signer: A function that accepts bytes and returns a signed url.
- Returns:
- url for the text
- """
- if self.url is not None:
- return self.url
- elif self.file_path is not None:
- with open(self.file_path, "rb") as file:
- self.url = signer(file.read())
- elif self.text is not None:
- self.url = signer(self.text.encode())
- else:
- raise ValueError(
- "One of url, im_bytes, file_path, numpy must not be None."
- )
- return self.url
-
- @model_validator(mode="after")
- def validate_date(self, values):
- file_path = self.file_path
- text = self.text
- url = self.url
- uid = self.uid
- global_key = self.global_key
- if uid == file_path == text == url == global_key == None:
- raise ValueError(
- "One of `file_path`, `text`, `uid`, `global_key` or `url` required."
- )
- return self
-
- def __repr__(self) -> str:
- return (
- f"TextData(file_path={self.file_path},"
- f"text={self.text[:30] + '...' if self.text is not None else None},"
- f"url={self.url})"
- )
diff --git a/libs/labelbox/src/labelbox/data/annotation_types/data/video.py b/libs/labelbox/src/labelbox/data/annotation_types/data/video.py
deleted file mode 100644
index 581801036..000000000
--- a/libs/labelbox/src/labelbox/data/annotation_types/data/video.py
+++ /dev/null
@@ -1,173 +0,0 @@
-import logging
-import os
-import urllib.request
-from typing import Callable, Dict, Generator, Optional, Tuple
-from typing_extensions import Literal
-from uuid import uuid4
-
-import cv2
-import numpy as np
-from google.api_core import retry
-
-from .base_data import BaseData
-from ..types import TypedArray
-
-from pydantic import ConfigDict, model_validator
-
-logger = logging.getLogger(__name__)
-
-
-class VideoData(BaseData):
- """
- Represents video
- """
-
- file_path: Optional[str] = None
- url: Optional[str] = None
- frames: Optional[Dict[int, TypedArray[Literal["uint8"]]]] = None
- # Required for discriminating between data types
- model_config = ConfigDict(extra="forbid")
-
- def load_frames(self, overwrite: bool = False) -> None:
- """
- Loads all frames into memory at once in order to access in non-sequential order.
- This will use a lot of memory, especially for longer videos
-
- Args:
- overwrite: Replace existing frames
- """
- if self.frames and not overwrite:
- return
-
- for count, frame in self.frame_generator():
- if self.frames is None:
- self.frames = {}
- self.frames[count] = frame
-
- @property
- def value(self):
- return self.frame_generator()
-
- def frame_generator(
- self, cache_frames=False, download_dir="/tmp"
- ) -> Generator[Tuple[int, np.ndarray], None, None]:
- """
- A generator for accessing individual frames in a video.
-
- Args:
- cache_frames (bool): Whether or not to cache frames while iterating through the video.
- download_dir (str): Directory to save the video to. Defaults to `/tmp` dir
- """
- if self.frames is not None:
- for idx, frame in self.frames.items():
- yield idx, frame
- return
- elif self.url and not self.file_path:
- file_path = os.path.join(download_dir, f"{uuid4()}.mp4")
- logger.info("Downloading the video locally to %s", file_path)
- self.fetch_remote(file_path)
- self.file_path = file_path
-
- vidcap = cv2.VideoCapture(self.file_path)
-
- success, frame = vidcap.read()
- count = 0
- if cache_frames:
- self.frames = {}
- while success:
- frame = frame[:, :, ::-1]
- yield count, frame
- if cache_frames:
- self.frames[count] = frame
- success, frame = vidcap.read()
- count += 1
-
- def __getitem__(self, idx: int) -> np.ndarray:
- if self.frames is None:
- raise ValueError(
- "Cannot select by index without iterating over the entire video or loading all frames."
- )
- return self.frames[idx]
-
- def set_fetch_fn(self, fn):
- object.__setattr__(self, "fetch_remote", lambda: fn(self))
-
- @retry.Retry(deadline=15.0)
- def fetch_remote(self, local_path) -> None:
- """
- Method for downloading data from self.url
-
- If url is not publicly accessible or requires another access pattern
- simply override this function
-
- Args:
- local_path: Where to save the thing too.
- """
- urllib.request.urlretrieve(self.url, local_path)
-
- @retry.Retry(deadline=15.0)
- def create_url(self, signer: Callable[[bytes], str]) -> None:
- """
- Utility for creating a url from any of the other video references.
-
- Args:
- signer: A function that accepts bytes and returns a signed url.
- Returns:
- url for the video
- """
- if self.url is not None:
- return self.url
- elif self.file_path is not None:
- with open(self.file_path, "rb") as file:
- self.url = signer(file.read())
- elif self.frames is not None:
- self.file_path = self.frames_to_video(self.frames)
- self.url = self.create_url(signer)
- else:
- raise ValueError("One of url, file_path, frames must not be None.")
- return self.url
-
- def frames_to_video(
- self, frames: Dict[int, np.ndarray], fps=20, save_dir="/tmp"
- ) -> str:
- """
- Compresses the data by converting a set of individual frames to a single video.
-
- """
- file_path = os.path.join(save_dir, f"{uuid4()}.mp4")
- out = None
- for key in frames.keys():
- frame = frames[key]
- if out is None:
- out = cv2.VideoWriter(
- file_path,
- cv2.VideoWriter_fourcc(*"MP4V"),
- fps,
- frame.shape[:2],
- )
- out.write(frame)
- if out is None:
- return
- out.release()
- return file_path
-
- @model_validator(mode="after")
- def validate_data(self):
- file_path = self.file_path
- url = self.url
- frames = self.frames
- uid = self.uid
- global_key = self.global_key
-
- if uid == file_path == frames == url == global_key == None:
- raise ValueError(
- "One of `file_path`, `frames`, `uid`, `global_key` or `url` required."
- )
- return self
-
- def __repr__(self) -> str:
- return (
- f"VideoData(file_path={self.file_path},"
- f"frames={'...' if self.frames is not None else None},"
- f"url={self.url})"
- )
diff --git a/libs/labelbox/src/labelbox/data/annotation_types/geometry/mask.py b/libs/labelbox/src/labelbox/data/annotation_types/geometry/mask.py
index 0d870f24f..03e1dd62c 100644
--- a/libs/labelbox/src/labelbox/data/annotation_types/geometry/mask.py
+++ b/libs/labelbox/src/labelbox/data/annotation_types/geometry/mask.py
@@ -1,15 +1,13 @@
-from typing import Callable, Optional, Tuple, Union, Dict, List
+from typing import Callable, Dict, List, Optional, Tuple, Union
-import numpy as np
import cv2
-
+import numpy as np
+from pydantic import field_validator
from shapely.geometry import MultiPolygon, Polygon
from ..data import MaskData
from .geometry import Geometry
-from pydantic import field_validator
-
class Mask(Geometry):
"""Mask used to represent a single class in a larger segmentation mask
@@ -91,7 +89,7 @@ def draw(
as opposed to the mask that this object references which might have multiple objects determined by colors
"""
mask = self.mask.value
- mask = np.alltrue(mask == self.color, axis=2).astype(np.uint8)
+ mask = np.all(mask == self.color, axis=2).astype(np.uint8)
if height is not None or width is not None:
mask = cv2.resize(
diff --git a/libs/labelbox/src/labelbox/data/annotation_types/label.py b/libs/labelbox/src/labelbox/data/annotation_types/label.py
index 7eef43f31..2f835b23c 100644
--- a/libs/labelbox/src/labelbox/data/annotation_types/label.py
+++ b/libs/labelbox/src/labelbox/data/annotation_types/label.py
@@ -3,10 +3,7 @@
import warnings
import labelbox
-from labelbox.data.annotation_types.data.generic_data_row_data import (
- GenericDataRowData,
-)
-from labelbox.data.annotation_types.data.tiled_image import TiledImageData
+from labelbox.data.annotation_types.data import GenericDataRowData, MaskData
from labelbox.schema import ontology
from ...annotated_types import Cuid
@@ -14,42 +11,13 @@
from .relationship import RelationshipAnnotation
from .llm_prompt_response.prompt import PromptClassificationAnnotation
from .classification import ClassificationAnswer
-from .data import (
- AudioData,
- ConversationData,
- DicomData,
- DocumentData,
- HTMLData,
- ImageData,
- TextData,
- VideoData,
- LlmPromptCreationData,
- LlmPromptResponseCreationData,
- LlmResponseCreationData,
-)
from .geometry import Mask
from .metrics import ScalarMetric, ConfusionMatrixMetric
from .video import VideoClassificationAnnotation
from .video import VideoObjectAnnotation, VideoMaskAnnotation
from .mmc import MessageEvaluationTaskAnnotation
from ..ontology import get_feature_schema_lookup
-from pydantic import BaseModel, field_validator, model_serializer
-
-DataType = Union[
- VideoData,
- ImageData,
- TextData,
- TiledImageData,
- AudioData,
- ConversationData,
- DicomData,
- DocumentData,
- HTMLData,
- LlmPromptCreationData,
- LlmPromptResponseCreationData,
- LlmResponseCreationData,
- GenericDataRowData,
-]
+from pydantic import BaseModel, field_validator
class Label(BaseModel):
@@ -67,14 +35,13 @@ class Label(BaseModel):
Args:
uid: Optional Label Id in Labelbox
- data: Data of Label, Image, Video, Text or dict with a single key uid | global_key | external_id.
- Note use of classes as data is deprecated. Use GenericDataRowData or dict with a single key instead.
+ data: GenericDataRowData or dict with a single key uid | global_key | external_id.
annotations: List of Annotations in the label
extra: additional context
"""
uid: Optional[Cuid] = None
- data: DataType
+ data: Union[GenericDataRowData, MaskData]
annotations: List[
Union[
ClassificationAnnotation,
@@ -94,13 +61,6 @@ class Label(BaseModel):
def validate_data(cls, data):
if isinstance(data, Dict):
return GenericDataRowData(**data)
- elif isinstance(data, GenericDataRowData):
- return data
- else:
- warnings.warn(
- f"Using {type(data).__name__} class for label.data is deprecated. "
- "Use a dict or an instance of GenericDataRowData instead."
- )
return data
def object_annotations(self) -> List[ObjectAnnotation]:
@@ -128,19 +88,6 @@ def frame_annotations(
frame_dict[annotation.frame].append(annotation)
return frame_dict
- def add_url_to_data(self, signer) -> "Label":
- """
- Creates signed urls for the data
- Only uploads url if one doesn't already exist.
-
- Args:
- signer: A function that accepts bytes and returns a signed url.
- Returns:
- Label with updated references to new data url
- """
- self.data.create_url(signer)
- return self
-
def add_url_to_masks(self, signer) -> "Label":
"""
Creates signed urls for all masks in the Label.
@@ -189,42 +136,6 @@ def create_data_row(
self.data.external_id = data_row.external_id
return self
- def assign_feature_schema_ids(
- self, ontology_builder: ontology.OntologyBuilder
- ) -> "Label":
- """
- Adds schema ids to all FeatureSchema objects in the Labels.
-
- Args:
- ontology_builder: The ontology that matches the feature names assigned to objects in this dataset
- Returns:
- Label. useful for chaining these modifying functions
-
- Note: You can now import annotations using names directly without having to lookup schema_ids
- """
- warnings.warn(
- "This method is deprecated and will be "
- "removed in a future release. Feature schema ids"
- " are no longer required for importing."
- )
- tool_lookup, classification_lookup = get_feature_schema_lookup(
- ontology_builder
- )
- for annotation in self.annotations:
- if isinstance(annotation, ClassificationAnnotation):
- self._assign_or_raise(annotation, classification_lookup)
- self._assign_option(annotation, classification_lookup)
- elif isinstance(annotation, ObjectAnnotation):
- self._assign_or_raise(annotation, tool_lookup)
- for classification in annotation.classifications:
- self._assign_or_raise(classification, classification_lookup)
- self._assign_option(classification, classification_lookup)
- else:
- raise TypeError(
- f"Unexpected type found for annotation. {type(annotation)}"
- )
- return self
-
def _assign_or_raise(self, annotation, lookup: Dict[str, str]) -> None:
if annotation.feature_schema_id is not None:
return
diff --git a/libs/labelbox/src/labelbox/data/annotation_types/metrics/scalar.py b/libs/labelbox/src/labelbox/data/annotation_types/metrics/scalar.py
index 13d0e9748..1434be427 100644
--- a/libs/labelbox/src/labelbox/data/annotation_types/metrics/scalar.py
+++ b/libs/labelbox/src/labelbox/data/annotation_types/metrics/scalar.py
@@ -1,11 +1,10 @@
-from typing import Dict, Optional, Union
-from typing_extensions import Annotated
from enum import Enum
+from typing import Annotated, Dict, Optional, Union
from pydantic import field_validator
from pydantic.types import confloat
-from .base import ConfidenceValue, BaseMetric
+from .base import BaseMetric, ConfidenceValue
ScalarMetricValue = Annotated[float, confloat(ge=0, le=100_000_000)]
ScalarMetricConfidenceValue = Dict[ConfidenceValue, ScalarMetricValue]
diff --git a/libs/labelbox/src/labelbox/data/annotation_types/types.py b/libs/labelbox/src/labelbox/data/annotation_types/types.py
index 9bb86a4b9..4a9198fea 100644
--- a/libs/labelbox/src/labelbox/data/annotation_types/types.py
+++ b/libs/labelbox/src/labelbox/data/annotation_types/types.py
@@ -1,43 +1,20 @@
-import sys
-from typing import Generic, TypeVar, Any
-
-from typing_extensions import Annotated
-from packaging import version
+from typing import Generic, TypeVar
import numpy as np
-
-from pydantic import StringConstraints, Field
+from pydantic_core import core_schema
DType = TypeVar("DType")
DShape = TypeVar("DShape")
-class _TypedArray(np.ndarray, Generic[DType, DShape]):
+class TypedArray(np.ndarray, Generic[DType, DShape]):
@classmethod
- def __get_validators__(cls):
- yield cls.validate
+ def __get_pydantic_core_schema__(
+ cls, _source_type: type, _model: type
+ ) -> core_schema.CoreSchema:
+ return core_schema.no_info_plain_validator_function(cls.validate)
@classmethod
- def validate(cls, val, field: Field):
+ def validate(cls, val):
if not isinstance(val, np.ndarray):
raise TypeError(f"Expected numpy array. Found {type(val)}")
return val
-
-
-if version.parse(np.__version__) >= version.parse("1.25.0"):
- from typing import GenericAlias
-
- TypedArray = GenericAlias(_TypedArray, (Any, DType))
-elif version.parse(np.__version__) >= version.parse("1.23.0"):
- from numpy._typing import _GenericAlias
-
- TypedArray = _GenericAlias(_TypedArray, (Any, DType))
-elif (
- version.parse("1.22.0")
- <= version.parse(np.__version__)
- < version.parse("1.23.0")
-):
- from numpy.typing import _GenericAlias
-
- TypedArray = _GenericAlias(_TypedArray, (Any, DType))
-else:
- TypedArray = _TypedArray[Any, DType]
diff --git a/libs/labelbox/src/labelbox/data/annotation_types/video.py b/libs/labelbox/src/labelbox/data/annotation_types/video.py
index cfebd7a1f..5a93704c8 100644
--- a/libs/labelbox/src/labelbox/data/annotation_types/video.py
+++ b/libs/labelbox/src/labelbox/data/annotation_types/video.py
@@ -125,7 +125,7 @@ class MaskFrame(_CamelCaseMixin, BaseModel):
def validate_args(self, values):
im_bytes = self.im_bytes
instance_uri = self.instance_uri
- if im_bytes == instance_uri == None:
+ if im_bytes == instance_uri is None:
raise ValueError("One of `instance_uri`, `im_bytes` required.")
return self
diff --git a/libs/labelbox/src/labelbox/data/metrics/confusion_matrix/calculation.py b/libs/labelbox/src/labelbox/data/metrics/confusion_matrix/calculation.py
index 938e17f65..83410a540 100644
--- a/libs/labelbox/src/labelbox/data/metrics/confusion_matrix/calculation.py
+++ b/libs/labelbox/src/labelbox/data/metrics/confusion_matrix/calculation.py
@@ -130,7 +130,7 @@ def classification_confusion_matrix(
prediction, ground_truth = predictions[0], ground_truths[0]
- if type(prediction) != type(ground_truth):
+ if type(prediction) is not type(ground_truth):
raise TypeError(
"Classification features must be the same type to compute agreement. "
f"Found `{type(prediction)}` and `{type(ground_truth)}`"
diff --git a/libs/labelbox/src/labelbox/data/metrics/group.py b/libs/labelbox/src/labelbox/data/metrics/group.py
index 88f4eae8b..9c4104c29 100644
--- a/libs/labelbox/src/labelbox/data/metrics/group.py
+++ b/libs/labelbox/src/labelbox/data/metrics/group.py
@@ -16,10 +16,10 @@
try:
from typing import Literal
except ImportError:
- from typing_extensions import Literal
+ from typing import Literal
+from ..annotation_types import ClassificationAnnotation, Label, ObjectAnnotation
from ..annotation_types.feature import FeatureSchema
-from ..annotation_types import ObjectAnnotation, ClassificationAnnotation, Label
def get_identifying_key(
diff --git a/libs/labelbox/src/labelbox/data/metrics/iou/calculation.py b/libs/labelbox/src/labelbox/data/metrics/iou/calculation.py
index 2a376d3fe..d0237963c 100644
--- a/libs/labelbox/src/labelbox/data/metrics/iou/calculation.py
+++ b/libs/labelbox/src/labelbox/data/metrics/iou/calculation.py
@@ -209,7 +209,7 @@ def classification_miou(
prediction, ground_truth = predictions[0], ground_truths[0]
- if type(prediction) != type(ground_truth):
+ if type(prediction) is not type(ground_truth):
raise TypeError(
"Classification features must be the same type to compute agreement. "
f"Found `{type(prediction)}` and `{type(ground_truth)}`"
diff --git a/libs/labelbox/src/labelbox/data/mixins.py b/libs/labelbox/src/labelbox/data/mixins.py
index 4440c8a72..fd5ecaa42 100644
--- a/libs/labelbox/src/labelbox/data/mixins.py
+++ b/libs/labelbox/src/labelbox/data/mixins.py
@@ -1,13 +1,10 @@
-from typing import Optional, List
+from typing import List, Optional
-from pydantic import BaseModel, field_validator, model_serializer
-
-from labelbox.exceptions import (
+from lbox.exceptions import (
ConfidenceNotSupportedException,
CustomMetricsNotSupportedException,
)
-
-from warnings import warn
+from pydantic import BaseModel, field_validator
class ConfidenceMixin(BaseModel):
diff --git a/libs/labelbox/src/labelbox/data/serialization/__init__.py b/libs/labelbox/src/labelbox/data/serialization/__init__.py
index 71a9b3443..38cb5edff 100644
--- a/libs/labelbox/src/labelbox/data/serialization/__init__.py
+++ b/libs/labelbox/src/labelbox/data/serialization/__init__.py
@@ -1,2 +1 @@
from .ndjson import NDJsonConverter
-from .coco import COCOConverter
diff --git a/libs/labelbox/src/labelbox/data/serialization/coco/__init__.py b/libs/labelbox/src/labelbox/data/serialization/coco/__init__.py
deleted file mode 100644
index 4511e89ee..000000000
--- a/libs/labelbox/src/labelbox/data/serialization/coco/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .converter import COCOConverter
diff --git a/libs/labelbox/src/labelbox/data/serialization/coco/annotation.py b/libs/labelbox/src/labelbox/data/serialization/coco/annotation.py
deleted file mode 100644
index e387cb7d9..000000000
--- a/libs/labelbox/src/labelbox/data/serialization/coco/annotation.py
+++ /dev/null
@@ -1,78 +0,0 @@
-from typing import Any, Tuple, List, Union
-from pathlib import Path
-from collections import defaultdict
-import warnings
-
-from ...annotation_types.relationship import RelationshipAnnotation
-from ...annotation_types.metrics.confusion_matrix import ConfusionMatrixMetric
-from ...annotation_types.metrics.scalar import ScalarMetric
-from ...annotation_types.video import VideoMaskAnnotation
-from ...annotation_types.annotation import ObjectAnnotation
-from ...annotation_types.classification.classification import (
- ClassificationAnnotation,
-)
-
-import numpy as np
-
-from .path import PathSerializerMixin
-from pydantic import BaseModel
-
-
-def rle_decoding(rle_arr: List[int], w: int, h: int) -> np.ndarray:
- indices = []
- for idx, cnt in zip(rle_arr[0::2], rle_arr[1::2]):
- indices.extend(
- list(range(idx - 1, idx + cnt - 1))
- ) # RLE is 1-based index
- mask = np.zeros(h * w, dtype=np.uint8)
- mask[indices] = 1
- return mask.reshape((w, h)).T
-
-
-def get_annotation_lookup(annotations):
- """Get annotations from Label.annotations objects
-
- Args:
- annotations (Label.annotations): Annotations attached to labelbox Label object used as private method
- """
- annotation_lookup = defaultdict(list)
- for annotation in annotations:
- # Provide a default value of None if the attribute doesn't exist
- attribute_value = getattr(annotation, "image_id", None) or getattr(
- annotation, "name", None
- )
- annotation_lookup[attribute_value].append(annotation)
- return annotation_lookup
-
-
-class SegmentInfo(BaseModel):
- id: int
- category_id: int
- area: Union[float, int]
- bbox: Tuple[float, float, float, float] # [x,y,w,h],
- iscrowd: int = 0
-
-
-class RLE(BaseModel):
- counts: List[int]
- size: Tuple[int, int] # h,w or w,h?
-
-
-class COCOObjectAnnotation(BaseModel):
- # All segmentations for a particular class in an image...
- # So each image will have one of these for each class present in the image..
- # Annotations only exist if there is data..
- id: int
- image_id: int
- category_id: int
- segmentation: Union[RLE, List[List[float]]] # [[x1,y1,x2,y2,x3,y3...]]
- area: float
- bbox: Tuple[float, float, float, float] # [x,y,w,h],
- iscrowd: int = 0
-
-
-class PanopticAnnotation(PathSerializerMixin):
- # One to one relationship between image and panoptic annotation
- image_id: int
- file_name: Path
- segments_info: List[SegmentInfo]
diff --git a/libs/labelbox/src/labelbox/data/serialization/coco/categories.py b/libs/labelbox/src/labelbox/data/serialization/coco/categories.py
deleted file mode 100644
index 60ba30fce..000000000
--- a/libs/labelbox/src/labelbox/data/serialization/coco/categories.py
+++ /dev/null
@@ -1,17 +0,0 @@
-import sys
-from hashlib import md5
-
-from pydantic import BaseModel
-
-
-class Categories(BaseModel):
- id: int
- name: str
- supercategory: str
- isthing: int = 1
-
-
-def hash_category_name(name: str) -> int:
- return int.from_bytes(
- md5(name.encode("utf-8")).hexdigest().encode("utf-8"), "little"
- )
diff --git a/libs/labelbox/src/labelbox/data/serialization/coco/converter.py b/libs/labelbox/src/labelbox/data/serialization/coco/converter.py
deleted file mode 100644
index e270b7573..000000000
--- a/libs/labelbox/src/labelbox/data/serialization/coco/converter.py
+++ /dev/null
@@ -1,170 +0,0 @@
-from typing import Dict, Any, Union
-from pathlib import Path
-import os
-import warnings
-
-from ...annotation_types.collection import LabelCollection, LabelGenerator
-from ...serialization.coco.instance_dataset import CocoInstanceDataset
-from ...serialization.coco.panoptic_dataset import CocoPanopticDataset
-
-
-def create_path_if_not_exists(
- path: Union[Path, str], ignore_existing_data=False
-):
- path = Path(path)
- if not path.exists():
- path.mkdir(parents=True, exist_ok=True)
- elif not ignore_existing_data and os.listdir(path):
- raise ValueError(
- f"Directory `{path}`` must be empty. Or set `ignore_existing_data=True`"
- )
- return path
-
-
-def validate_path(path: Union[Path, str], name: str):
- path = Path(path)
- if not path.exists():
- raise ValueError(f"{name} `{path}` must exist")
- return path
-
-
-class COCOConverter:
- """
- Class for converting between coco and labelbox formats
- Note that this class is only compatible with image data.
-
- Subclasses are currently ignored.
- To use subclasses, manually flatten them before using the converter.
- """
-
- @staticmethod
- def serialize_instances(
- labels: LabelCollection,
- image_root: Union[Path, str],
- ignore_existing_data=False,
- max_workers=8,
- ) -> Dict[str, Any]:
- """
- Convert a Labelbox LabelCollection into an mscoco dataset.
- This function will only convert masks, polygons, and rectangles.
- Masks will be converted into individual instances.
- Use deserialize_panoptic to prevent masks from being split apart.
-
- Args:
- labels: A collection of labels to convert
- image_root: Where to save images to
- ignore_existing_data: Whether or not to raise an exception if images already exist.
- This exists only to support detectons panoptic fpn model which requires two mscoco payloads for the same images.
- max_workers : Number of workers to process dataset with. A value of 0 will process all data in the main process
- Returns:
- A dictionary containing labels in the coco object format.
- """
-
- warnings.warn(
- "You are currently utilizing COCOconverter for this action, which will be deprecated in a later release.",
- DeprecationWarning,
- stacklevel=2,
- )
-
- image_root = create_path_if_not_exists(image_root, ignore_existing_data)
- return CocoInstanceDataset.from_common(
- labels=labels, image_root=image_root, max_workers=max_workers
- ).model_dump()
-
- @staticmethod
- def serialize_panoptic(
- labels: LabelCollection,
- image_root: Union[Path, str],
- mask_root: Union[Path, str],
- all_stuff: bool = False,
- ignore_existing_data=False,
- max_workers: int = 8,
- ) -> Dict[str, Any]:
- """
- Convert a Labelbox LabelCollection into an mscoco dataset.
- This function will only convert masks, polygons, and rectangles.
- Masks will be converted into individual instances.
- Use deserialize_panoptic to prevent masks from being split apart.
-
- Args:
- labels: A collection of labels to convert
- image_root: Where to save images to
- mask_root: Where to save segmentation masks to
- all_stuff: If rectangle or polygon annotations are encountered, they will be treated as instances.
- To convert them to stuff class set `all_stuff=True`.
- ignore_existing_data: Whether or not to raise an exception if images already exist.
- This exists only to support detectons panoptic fpn model which requires two mscoco payloads for the same images.
- max_workers : Number of workers to process dataset with. A value of 0 will process all data in the main process.
- Returns:
- A dictionary containing labels in the coco panoptic format.
- """
-
- warnings.warn(
- "You are currently utilizing COCOconverter for this action, which will be deprecated in a later release.",
- DeprecationWarning,
- stacklevel=2,
- )
-
- image_root = create_path_if_not_exists(image_root, ignore_existing_data)
- mask_root = create_path_if_not_exists(mask_root, ignore_existing_data)
- return CocoPanopticDataset.from_common(
- labels=labels,
- image_root=image_root,
- mask_root=mask_root,
- all_stuff=all_stuff,
- max_workers=max_workers,
- ).model_dump()
-
- @staticmethod
- def deserialize_panoptic(
- json_data: Dict[str, Any],
- image_root: Union[Path, str],
- mask_root: Union[Path, str],
- ) -> LabelGenerator:
- """
- Convert coco panoptic data into the labelbox format (as a LabelGenerator).
-
- Args:
- json_data: panoptic data as a dict
- image_root: Path to local images that are referenced by the panoptic json
- mask_root: Path to local segmentation masks that are referenced by the panoptic json
- Returns:
- LabelGenerator
- """
-
- warnings.warn(
- "You are currently utilizing COCOconverter for this action, which will be deprecated in a later release.",
- DeprecationWarning,
- stacklevel=2,
- )
-
- image_root = validate_path(image_root, "image_root")
- mask_root = validate_path(mask_root, "mask_root")
- objs = CocoPanopticDataset(**json_data)
- gen = objs.to_common(image_root, mask_root)
- return LabelGenerator(data=gen)
-
- @staticmethod
- def deserialize_instances(
- json_data: Dict[str, Any], image_root: Path
- ) -> LabelGenerator:
- """
- Convert coco object data into the labelbox format (as a LabelGenerator).
-
- Args:
- json_data: coco object data as a dict
- image_root: Path to local images that are referenced by the coco object json
- Returns:
- LabelGenerator
- """
-
- warnings.warn(
- "You are currently utilizing COCOconverter for this action, which will be deprecated in a later release.",
- DeprecationWarning,
- stacklevel=2,
- )
-
- image_root = validate_path(image_root, "image_root")
- objs = CocoInstanceDataset(**json_data)
- gen = objs.to_common(image_root)
- return LabelGenerator(data=gen)
diff --git a/libs/labelbox/src/labelbox/data/serialization/coco/image.py b/libs/labelbox/src/labelbox/data/serialization/coco/image.py
deleted file mode 100644
index cef173377..000000000
--- a/libs/labelbox/src/labelbox/data/serialization/coco/image.py
+++ /dev/null
@@ -1,52 +0,0 @@
-from pathlib import Path
-
-from typing import Optional, Tuple
-from PIL import Image
-import imagesize
-
-from .path import PathSerializerMixin
-from ...annotation_types import Label
-
-
-class CocoImage(PathSerializerMixin):
- id: int
- width: int
- height: int
- file_name: Path
- license: Optional[int] = None
- flickr_url: Optional[str] = None
- coco_url: Optional[str] = None
-
-
-def get_image_id(label: Label, idx: int) -> int:
- if label.data.file_path is not None:
- file_name = label.data.file_path.replace(".jpg", "")
- if file_name.isdecimal():
- return file_name
- return idx
-
-
-def get_image(label: Label, image_root: Path, image_id: str) -> CocoImage:
- path = Path(image_root, f"{image_id}.jpg")
- if not path.exists():
- im = Image.fromarray(label.data.value)
- im.save(path)
- w, h = im.size
- else:
- w, h = imagesize.get(str(path))
- return CocoImage(id=image_id, width=w, height=h, file_name=Path(path.name))
-
-
-def id_to_rgb(id: int) -> Tuple[int, int, int]:
- digits = []
- for _ in range(3):
- digits.append(id % 256)
- id //= 256
- return digits
-
-
-def rgb_to_id(red: int, green: int, blue: int) -> int:
- id = blue * 256 * 256
- id += green * 256
- id += red
- return id
diff --git a/libs/labelbox/src/labelbox/data/serialization/coco/instance_dataset.py b/libs/labelbox/src/labelbox/data/serialization/coco/instance_dataset.py
deleted file mode 100644
index 5241e596f..000000000
--- a/libs/labelbox/src/labelbox/data/serialization/coco/instance_dataset.py
+++ /dev/null
@@ -1,266 +0,0 @@
-# https://cocodataset.org/#format-data
-
-from concurrent.futures import ProcessPoolExecutor, as_completed
-from typing import Any, Dict, List, Tuple, Optional
-from pathlib import Path
-
-import numpy as np
-from tqdm import tqdm
-
-from ...annotation_types import (
- ImageData,
- MaskData,
- Mask,
- ObjectAnnotation,
- Label,
- Polygon,
- Point,
- Rectangle,
-)
-from ...annotation_types.collection import LabelCollection
-from .categories import Categories, hash_category_name
-from .annotation import (
- COCOObjectAnnotation,
- RLE,
- get_annotation_lookup,
- rle_decoding,
-)
-from .image import CocoImage, get_image, get_image_id
-from pydantic import BaseModel
-
-
-def mask_to_coco_object_annotation(
- annotation: ObjectAnnotation,
- annot_idx: int,
- image_id: int,
- category_id: int,
-) -> Optional[COCOObjectAnnotation]:
- # This is going to fill any holes into the multipolygon
- # If you need to support holes use the panoptic data format
- shapely = annotation.value.shapely.simplify(1).buffer(0)
- if shapely.is_empty:
- return
-
- xmin, ymin, xmax, ymax = shapely.bounds
- # Iterate over polygon once or multiple polygon for each item
- area = shapely.area
-
- return COCOObjectAnnotation(
- id=annot_idx,
- image_id=image_id,
- category_id=category_id,
- segmentation=[
- np.array(s.exterior.coords).ravel().tolist()
- for s in ([shapely] if shapely.type == "Polygon" else shapely.geoms)
- ],
- area=area,
- bbox=[xmin, ymin, xmax - xmin, ymax - ymin],
- iscrowd=0,
- )
-
-
-def vector_to_coco_object_annotation(
- annotation: ObjectAnnotation,
- annot_idx: int,
- image_id: int,
- category_id: int,
-) -> COCOObjectAnnotation:
- shapely = annotation.value.shapely
- xmin, ymin, xmax, ymax = shapely.bounds
- segmentation = []
- if isinstance(annotation.value, Polygon):
- for point in annotation.value.points:
- segmentation.extend([point.x, point.y])
- else:
- box = annotation.value
- segmentation.extend(
- [
- box.start.x,
- box.start.y,
- box.end.x,
- box.start.y,
- box.end.x,
- box.end.y,
- box.start.x,
- box.end.y,
- ]
- )
-
- return COCOObjectAnnotation(
- id=annot_idx,
- image_id=image_id,
- category_id=category_id,
- segmentation=[segmentation],
- area=shapely.area,
- bbox=[xmin, ymin, xmax - xmin, ymax - ymin],
- iscrowd=0,
- )
-
-
-def rle_to_common(
- class_annotations: COCOObjectAnnotation, class_name: str
-) -> ObjectAnnotation:
- mask = rle_decoding(
- class_annotations.segmentation.counts,
- *class_annotations.segmentation.size[::-1],
- )
- return ObjectAnnotation(
- name=class_name,
- value=Mask(mask=MaskData.from_2D_arr(mask), color=[1, 1, 1]),
- )
-
-
-def segmentations_to_common(
- class_annotations: COCOObjectAnnotation, class_name: str
-) -> List[ObjectAnnotation]:
- # Technically it is polygons. But the key in coco is called segmentations..
- annotations = []
- for points in class_annotations.segmentation:
- annotations.append(
- ObjectAnnotation(
- name=class_name,
- value=Polygon(
- points=[
- Point(x=points[i], y=points[i + 1])
- for i in range(0, len(points), 2)
- ]
- ),
- )
- )
- return annotations
-
-
-def object_annotation_to_coco(
- annotation: ObjectAnnotation,
- annot_idx: int,
- image_id: int,
- category_id: int,
-) -> Optional[COCOObjectAnnotation]:
- if isinstance(annotation.value, Mask):
- return mask_to_coco_object_annotation(
- annotation, annot_idx, image_id, category_id
- )
- elif isinstance(annotation.value, (Polygon, Rectangle)):
- return vector_to_coco_object_annotation(
- annotation, annot_idx, image_id, category_id
- )
- else:
- return None
-
-
-def process_label(
- label: Label, idx: int, image_root: str, max_annotations_per_image=10000
-) -> Tuple[np.ndarray, List[COCOObjectAnnotation], Dict[str, str]]:
- annot_idx = idx * max_annotations_per_image
- image_id = get_image_id(label, idx)
- image = get_image(label, image_root, image_id)
- coco_annotations = []
- annotation_lookup = get_annotation_lookup(label.annotations)
- categories = {}
- for class_name in annotation_lookup:
- for annotation in annotation_lookup[class_name]:
- category_id = categories.get(annotation.name) or hash_category_name(
- annotation.name
- )
- coco_annotation = object_annotation_to_coco(
- annotation, annot_idx, image_id, category_id
- )
- if coco_annotation is not None:
- coco_annotations.append(coco_annotation)
- if annotation.name not in categories:
- categories[annotation.name] = category_id
- annot_idx += 1
-
- return image, coco_annotations, categories
-
-
-class CocoInstanceDataset(BaseModel):
- info: Dict[str, Any] = {}
- images: List[CocoImage]
- annotations: List[COCOObjectAnnotation]
- categories: List[Categories]
-
- @classmethod
- def from_common(
- cls, labels: LabelCollection, image_root: Path, max_workers=8
- ):
- all_coco_annotations = []
- categories = {}
- images = []
- futures = []
- coco_categories = {}
-
- if max_workers:
- with ProcessPoolExecutor(max_workers=max_workers) as exc:
- futures = [
- exc.submit(process_label, label, idx, image_root)
- for idx, label in enumerate(labels)
- ]
- results = [
- future.result() for future in tqdm(as_completed(futures))
- ]
- else:
- results = [
- process_label(label, idx, image_root)
- for idx, label in enumerate(labels)
- ]
-
- for result in results:
- images.append(result[0])
- all_coco_annotations.extend(result[1])
- coco_categories.update(result[2])
-
- category_mapping = {
- category_id: idx + 1
- for idx, category_id in enumerate(coco_categories.values())
- }
- categories = [
- Categories(
- id=category_mapping[idx],
- name=name,
- supercategory="all",
- isthing=1,
- )
- for name, idx in coco_categories.items()
- ]
- for annot in all_coco_annotations:
- annot.category_id = category_mapping[annot.category_id]
-
- return CocoInstanceDataset(
- info={"image_root": image_root},
- images=images,
- annotations=all_coco_annotations,
- categories=categories,
- )
-
- def to_common(self, image_root):
- category_lookup = {
- category.id: category for category in self.categories
- }
- annotation_lookup = get_annotation_lookup(self.annotations)
-
- for image in self.images:
- im_path = Path(image_root, image.file_name)
- if not im_path.exists():
- raise ValueError(
- f"Cannot find file {im_path}. Make sure `image_root` is set properly"
- )
-
- data = ImageData(file_path=str(im_path))
- annotations = []
- for class_annotations in annotation_lookup[image.id]:
- if isinstance(class_annotations.segmentation, RLE):
- annotations.append(
- rle_to_common(
- class_annotations,
- category_lookup[class_annotations.category_id].name,
- )
- )
- elif isinstance(class_annotations.segmentation, list):
- annotations.extend(
- segmentations_to_common(
- class_annotations,
- category_lookup[class_annotations.category_id].name,
- )
- )
- yield Label(data=data, annotations=annotations)
diff --git a/libs/labelbox/src/labelbox/data/serialization/coco/panoptic_dataset.py b/libs/labelbox/src/labelbox/data/serialization/coco/panoptic_dataset.py
deleted file mode 100644
index cbb410548..000000000
--- a/libs/labelbox/src/labelbox/data/serialization/coco/panoptic_dataset.py
+++ /dev/null
@@ -1,242 +0,0 @@
-from concurrent.futures import ProcessPoolExecutor, as_completed
-from typing import Dict, Any, List, Union
-from pathlib import Path
-
-from tqdm import tqdm
-import numpy as np
-from PIL import Image
-
-from ...annotation_types.geometry import Polygon, Rectangle
-from ...annotation_types import Label
-from ...annotation_types.geometry.mask import Mask
-from ...annotation_types.annotation import ObjectAnnotation
-from ...annotation_types.data.raster import MaskData, ImageData
-from ...annotation_types.collection import LabelCollection
-from .categories import Categories, hash_category_name
-from .image import CocoImage, get_image, get_image_id, id_to_rgb
-from .annotation import PanopticAnnotation, SegmentInfo, get_annotation_lookup
-from pydantic import BaseModel
-
-
-def vector_to_coco_segment_info(
- canvas: np.ndarray,
- annotation: ObjectAnnotation,
- annotation_idx: int,
- image: CocoImage,
- category_id: int,
-):
- shapely = annotation.value.shapely
- if shapely.is_empty:
- return
-
- xmin, ymin, xmax, ymax = shapely.bounds
- canvas = annotation.value.draw(
- height=image.height,
- width=image.width,
- canvas=canvas,
- color=id_to_rgb(annotation_idx),
- )
-
- return SegmentInfo(
- id=annotation_idx,
- category_id=category_id,
- area=shapely.area,
- bbox=[xmin, ymin, xmax - xmin, ymax - ymin],
- ), canvas
-
-
-def mask_to_coco_segment_info(
- canvas: np.ndarray, annotation, annotation_idx: int, category_id
-):
- color = id_to_rgb(annotation_idx)
- mask = annotation.value.draw(color=color)
- shapely = annotation.value.shapely
- if shapely.is_empty:
- return
-
- xmin, ymin, xmax, ymax = shapely.bounds
- canvas = np.where(canvas == (0, 0, 0), mask, canvas)
- return SegmentInfo(
- id=annotation_idx,
- category_id=category_id,
- area=shapely.area,
- bbox=[xmin, ymin, xmax - xmin, ymax - ymin],
- ), canvas
-
-
-def process_label(
- label: Label, idx: Union[int, str], image_root, mask_root, all_stuff=False
-):
- """
- Masks become stuff
- Polygon and rectangle become thing
- """
- annotations = get_annotation_lookup(label.annotations)
- image_id = get_image_id(label, idx)
- image = get_image(label, image_root, image_id)
- canvas = np.zeros((image.height, image.width, 3))
-
- segments = []
- categories = {}
- is_thing = {}
-
- for class_idx, class_name in enumerate(annotations):
- for annotation_idx, annotation in enumerate(annotations[class_name]):
- categories[annotation.name] = hash_category_name(annotation.name)
- if isinstance(annotation.value, Mask):
- coco_segment_info = mask_to_coco_segment_info(
- canvas,
- annotation,
- class_idx + 1,
- categories[annotation.name],
- )
-
- if coco_segment_info is None:
- # Filter out empty masks
- continue
-
- segment, canvas = coco_segment_info
- segments.append(segment)
- is_thing[annotation.name] = 0
-
- elif isinstance(annotation.value, (Polygon, Rectangle)):
- coco_vector_info = vector_to_coco_segment_info(
- canvas,
- annotation,
- annotation_idx=(class_idx if all_stuff else annotation_idx)
- + 1,
- image=image,
- category_id=categories[annotation.name],
- )
-
- if coco_vector_info is None:
- # Filter out empty annotations
- continue
-
- segment, canvas = coco_vector_info
- segments.append(segment)
- is_thing[annotation.name] = 1 - int(all_stuff)
-
- mask_file = str(image.file_name).replace(".jpg", ".png")
- mask_file = Path(mask_root, mask_file)
- Image.fromarray(canvas.astype(np.uint8)).save(mask_file)
- return (
- image,
- PanopticAnnotation(
- image_id=image_id,
- file_name=Path(mask_file.name),
- segments_info=segments,
- ),
- categories,
- is_thing,
- )
-
-
-class CocoPanopticDataset(BaseModel):
- info: Dict[str, Any] = {}
- images: List[CocoImage]
- annotations: List[PanopticAnnotation]
- categories: List[Categories]
-
- @classmethod
- def from_common(
- cls,
- labels: LabelCollection,
- image_root,
- mask_root,
- all_stuff,
- max_workers=8,
- ):
- all_coco_annotations = []
- coco_categories = {}
- coco_things = {}
- images = []
-
- if max_workers:
- with ProcessPoolExecutor(max_workers=max_workers) as exc:
- futures = [
- exc.submit(
- process_label,
- label,
- idx,
- image_root,
- mask_root,
- all_stuff,
- )
- for idx, label in enumerate(labels)
- ]
- results = [
- future.result() for future in tqdm(as_completed(futures))
- ]
- else:
- results = [
- process_label(label, idx, image_root, mask_root, all_stuff)
- for idx, label in enumerate(labels)
- ]
-
- for result in results:
- images.append(result[0])
- all_coco_annotations.append(result[1])
- coco_categories.update(result[2])
- coco_things.update(result[3])
-
- category_mapping = {
- category_id: idx + 1
- for idx, category_id in enumerate(coco_categories.values())
- }
- categories = [
- Categories(
- id=category_mapping[idx],
- name=name,
- supercategory="all",
- isthing=coco_things.get(name, 1),
- )
- for name, idx in coco_categories.items()
- ]
-
- for annot in all_coco_annotations:
- for segment in annot.segments_info:
- segment.category_id = category_mapping[segment.category_id]
-
- return CocoPanopticDataset(
- info={"image_root": image_root, "mask_root": mask_root},
- images=images,
- annotations=all_coco_annotations,
- categories=categories,
- )
-
- def to_common(self, image_root: Path, mask_root: Path):
- category_lookup = {
- category.id: category for category in self.categories
- }
- annotation_lookup = {
- annotation.image_id: annotation for annotation in self.annotations
- }
- for image in self.images:
- annotations = []
- annotation = annotation_lookup[image.id]
-
- im_path = Path(image_root, image.file_name)
- if not im_path.exists():
- raise ValueError(
- f"Cannot find file {im_path}. Make sure `image_root` is set properly"
- )
- if not str(annotation.file_name).endswith(".png"):
- raise ValueError(
- f"COCO masks must be stored as png files and their extension must be `.png`. Found {annotation.file_name}"
- )
- mask = MaskData(
- file_path=str(Path(mask_root, annotation.file_name))
- )
-
- for segmentation in annotation.segments_info:
- category = category_lookup[segmentation.category_id]
- annotations.append(
- ObjectAnnotation(
- name=category.name,
- value=Mask(mask=mask, color=id_to_rgb(segmentation.id)),
- )
- )
- data = ImageData(file_path=str(im_path))
- yield Label(data=data, annotations=annotations)
- del annotation_lookup[image.id]
diff --git a/libs/labelbox/src/labelbox/data/serialization/coco/path.py b/libs/labelbox/src/labelbox/data/serialization/coco/path.py
deleted file mode 100644
index c3be84f31..000000000
--- a/libs/labelbox/src/labelbox/data/serialization/coco/path.py
+++ /dev/null
@@ -1,9 +0,0 @@
-from pathlib import Path
-from pydantic import BaseModel, model_serializer
-
-
-class PathSerializerMixin(BaseModel):
- @model_serializer(mode="wrap")
- def serialize_model(self, handler):
- res = handler(self)
- return {k: str(v) if isinstance(v, Path) else v for k, v in res.items()}
diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/base.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/base.py
index 75ebdc100..d8d8cd36f 100644
--- a/libs/labelbox/src/labelbox/data/serialization/ndjson/base.py
+++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/base.py
@@ -8,18 +8,6 @@
from ....annotated_types import Cuid
-subclass_registry = {}
-
-
-class _SubclassRegistryBase(BaseModel):
- model_config = ConfigDict(extra="allow")
-
- def __init_subclass__(cls, **kwargs):
- super().__init_subclass__(**kwargs)
- if cls.__name__ != "NDAnnotation":
- with threading.Lock():
- subclass_registry[cls.__name__] = cls
-
class DataRow(_CamelCaseMixin):
id: Optional[str] = None
diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py
index b127c4a90..fedf4d91b 100644
--- a/libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py
+++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/classification.py
@@ -1,6 +1,6 @@
from typing import Any, Dict, List, Union, Optional
-from labelbox.data.annotation_types import ImageData, TextData, VideoData
+from labelbox.data.annotation_types import GenericDataRowData
from labelbox.data.mixins import (
ConfidenceMixin,
CustomMetric,
@@ -30,7 +30,6 @@
model_serializer,
)
from pydantic.alias_generators import to_camel
-from .base import _SubclassRegistryBase
class NDAnswer(ConfidenceMixin, CustomMetricsMixin):
@@ -224,7 +223,7 @@ def from_common(
# ====== End of subclasses
-class NDText(NDAnnotation, NDTextSubclass, _SubclassRegistryBase):
+class NDText(NDAnnotation, NDTextSubclass):
@classmethod
def from_common(
cls,
@@ -233,7 +232,7 @@ def from_common(
name: str,
feature_schema_id: Cuid,
extra: Dict[str, Any],
- data: Union[TextData, ImageData],
+ data: GenericDataRowData,
message_id: str,
confidence: Optional[float] = None,
) -> "NDText":
@@ -249,9 +248,7 @@ def from_common(
)
-class NDChecklist(
- NDAnnotation, NDChecklistSubclass, VideoSupported, _SubclassRegistryBase
-):
+class NDChecklist(NDAnnotation, NDChecklistSubclass, VideoSupported):
@model_serializer(mode="wrap")
def serialize_model(self, handler):
res = handler(self)
@@ -267,7 +264,7 @@ def from_common(
name: str,
feature_schema_id: Cuid,
extra: Dict[str, Any],
- data: Union[VideoData, TextData, ImageData],
+ data: GenericDataRowData,
message_id: str,
confidence: Optional[float] = None,
custom_metrics: Optional[List[CustomMetric]] = None,
@@ -298,9 +295,7 @@ def from_common(
)
-class NDRadio(
- NDAnnotation, NDRadioSubclass, VideoSupported, _SubclassRegistryBase
-):
+class NDRadio(NDAnnotation, NDRadioSubclass, VideoSupported):
@classmethod
def from_common(
cls,
@@ -309,7 +304,7 @@ def from_common(
name: str,
feature_schema_id: Cuid,
extra: Dict[str, Any],
- data: Union[VideoData, TextData, ImageData],
+ data: GenericDataRowData,
message_id: str,
confidence: Optional[float] = None,
) -> "NDRadio":
@@ -343,7 +338,7 @@ def serialize_model(self, handler):
return res
-class NDPromptText(NDAnnotation, NDPromptTextSubclass, _SubclassRegistryBase):
+class NDPromptText(NDAnnotation, NDPromptTextSubclass):
@classmethod
def from_common(
cls,
@@ -432,7 +427,7 @@ def from_common(
annotation: Union[
ClassificationAnnotation, VideoClassificationAnnotation
],
- data: Union[VideoData, TextData, ImageData],
+ data: GenericDataRowData,
) -> Union[NDTextSubclass, NDChecklistSubclass, NDRadioSubclass]:
classify_obj = cls.lookup_classification(annotation)
if classify_obj is None:
@@ -480,7 +475,7 @@ def to_common(
def from_common(
cls,
annotation: Union[PromptClassificationAnnotation],
- data: Union[VideoData, TextData, ImageData],
+ data: GenericDataRowData,
) -> Union[NDTextSubclass, NDChecklistSubclass, NDRadioSubclass]:
return NDPromptText.from_common(
str(annotation._uuid),
diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/converter.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/converter.py
index 01ab8454a..8176d7862 100644
--- a/libs/labelbox/src/labelbox/data/serialization/ndjson/converter.py
+++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/converter.py
@@ -26,20 +26,6 @@
class NDJsonConverter:
- @staticmethod
- def deserialize(json_data: Iterable[Dict[str, Any]]) -> LabelGenerator:
- """
- Converts ndjson data (prediction import format) into the common labelbox format.
-
- Args:
- json_data: An iterable representing the ndjson data
- Returns:
- LabelGenerator containing the ndjson data.
- """
- data = NDLabel(**{"annotations": copy.copy(json_data)})
- res = data.to_common()
- return res
-
@staticmethod
def serialize(
labels: LabelCollection,
diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py
index 18134a228..ffaefb4d7 100644
--- a/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py
+++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/label.py
@@ -14,7 +14,6 @@
)
from ...annotation_types.video import VideoObjectAnnotation, VideoMaskAnnotation
from ...annotation_types.collection import LabelCollection, LabelGenerator
-from ...annotation_types.data import DicomData, ImageData, TextData, VideoData
from ...annotation_types.data.generic_data_row_data import GenericDataRowData
from ...annotation_types.label import Label
from ...annotation_types.ner import TextEntity, ConversationEntity
@@ -46,7 +45,6 @@
from .relationship import NDRelationship
from .base import DataRow
from pydantic import BaseModel, ValidationError
-from .base import subclass_registry, _SubclassRegistryBase
from pydantic_core import PydanticUndefined
from contextlib import suppress
@@ -67,68 +65,7 @@
class NDLabel(BaseModel):
- annotations: List[_SubclassRegistryBase]
-
- def __init__(self, **kwargs):
- # NOTE: Deserialization of subclasses in pydantic is difficult, see here https://blog.devgenius.io/deserialize-child-classes-with-pydantic-that-gonna-work-784230e1cf83
- # Below implements the subclass registry as mentioned in the article. The python dicts we pass in can be missing certain fields
- # we essentially have to infer the type against all sub classes that have the _SubclasssRegistryBase inheritance.
- # It works by checking if the keys of our annotations we are missing in matches any required subclass.
- # More keys are prioritized over less keys (closer match). This is used when importing json to our base models not a lot of customer workflows
- # depend on this method but this works for all our existing tests with the bonus of added validation. (no subclass found it throws an error)
-
- for index, annotation in enumerate(kwargs["annotations"]):
- if isinstance(annotation, dict):
- item_annotation_keys = annotation.keys()
- key_subclass_combos = defaultdict(list)
- for subclass in subclass_registry.values():
- # Get all required keys from subclass
- annotation_keys = []
- for k, field in subclass.model_fields.items():
- if field.default == PydanticUndefined and k != "uuid":
- if (
- hasattr(field, "alias")
- and field.alias in item_annotation_keys
- ):
- annotation_keys.append(field.alias)
- elif (
- hasattr(field, "validation_alias")
- and field.validation_alias
- in item_annotation_keys
- ):
- annotation_keys.append(field.validation_alias)
- else:
- annotation_keys.append(k)
-
- key_subclass_combos[subclass].extend(annotation_keys)
-
- # Sort by subclass that has the most keys i.e. the one with the most keys that matches is most likely our subclass
- key_subclass_combos = dict(
- sorted(
- key_subclass_combos.items(),
- key=lambda x: len(x[1]),
- reverse=True,
- )
- )
-
- for subclass, key_subclass_combo in key_subclass_combos.items():
- # Choose the keys from our dict we supplied that matches the required keys of a subclass
- check_required_keys = all(
- key in list(item_annotation_keys)
- for key in key_subclass_combo
- )
- if check_required_keys:
- # Keep trying subclasses until we find one that has valid values (does not throw an validation error)
- with suppress(ValidationError):
- annotation = subclass(**annotation)
- break
- if isinstance(annotation, dict):
- raise ValueError(
- f"Could not find subclass for fields: {item_annotation_keys}"
- )
-
- kwargs["annotations"][index] = annotation
- super().__init__(**kwargs)
+ annotations: AnnotationType
class _Relationship(BaseModel):
"""This object holds information about the relationship"""
@@ -276,46 +213,9 @@ def _generate_annotations(
yield Label(
annotations=annotations,
- data=self._infer_media_type(group.data_row, annotations),
+ data=GenericDataRowData,
)
- def _infer_media_type(
- self,
- data_row: DataRow,
- annotations: List[
- Union[
- TextEntity,
- ConversationEntity,
- VideoClassificationAnnotation,
- DICOMObjectAnnotation,
- VideoObjectAnnotation,
- ObjectAnnotation,
- ClassificationAnnotation,
- ScalarMetric,
- ConfusionMatrixMetric,
- ]
- ],
- ) -> Union[TextData, VideoData, ImageData]:
- if len(annotations) == 0:
- raise ValueError("Missing annotations while inferring media type")
-
- types = {type(annotation) for annotation in annotations}
- data = GenericDataRowData
- if (TextEntity in types) or (ConversationEntity in types):
- data = TextData
- elif (
- VideoClassificationAnnotation in types
- or VideoObjectAnnotation in types
- ):
- data = VideoData
- elif DICOMObjectAnnotation in types:
- data = DicomData
-
- if data_row.id:
- return data(uid=data_row.id)
- else:
- return data(global_key=data_row.global_key)
-
@staticmethod
def _get_consecutive_frames(
frames_indices: List[int],
diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/metric.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/metric.py
index 60d538b19..f8b522ab5 100644
--- a/libs/labelbox/src/labelbox/data/serialization/ndjson/metric.py
+++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/metric.py
@@ -1,6 +1,6 @@
from typing import Optional, Union, Type
-from labelbox.data.annotation_types.data import ImageData, TextData
+from labelbox.data.annotation_types.data import GenericDataRowData
from labelbox.data.serialization.ndjson.base import DataRow, NDJsonBase
from labelbox.data.annotation_types.metrics.scalar import (
ScalarMetric,
@@ -15,7 +15,6 @@
ConfusionMatrixMetricConfidenceValue,
)
from pydantic import ConfigDict, model_serializer
-from .base import _SubclassRegistryBase
class BaseNDMetric(NDJsonBase):
@@ -33,7 +32,7 @@ def serialize_model(self, handler):
return res
-class NDConfusionMatrixMetric(BaseNDMetric, _SubclassRegistryBase):
+class NDConfusionMatrixMetric(BaseNDMetric):
metric_value: Union[
ConfusionMatrixMetricValue, ConfusionMatrixMetricConfidenceValue
]
@@ -52,7 +51,7 @@ def to_common(self) -> ConfusionMatrixMetric:
@classmethod
def from_common(
- cls, metric: ConfusionMatrixMetric, data: Union[TextData, ImageData]
+ cls, metric: ConfusionMatrixMetric, data: GenericDataRowData
) -> "NDConfusionMatrixMetric":
return cls(
uuid=metric.extra.get("uuid"),
@@ -65,7 +64,7 @@ def from_common(
)
-class NDScalarMetric(BaseNDMetric, _SubclassRegistryBase):
+class NDScalarMetric(BaseNDMetric):
metric_value: Union[ScalarMetricValue, ScalarMetricConfidenceValue]
metric_name: Optional[str] = None
aggregation: Optional[ScalarMetricAggregation] = (
@@ -84,7 +83,7 @@ def to_common(self) -> ScalarMetric:
@classmethod
def from_common(
- cls, metric: ScalarMetric, data: Union[TextData, ImageData]
+ cls, metric: ScalarMetric, data: GenericDataRowData
) -> "NDScalarMetric":
return cls(
uuid=metric.extra.get("uuid"),
@@ -108,7 +107,7 @@ def to_common(
def from_common(
cls,
annotation: Union[ScalarMetric, ConfusionMatrixMetric],
- data: Union[TextData, ImageData],
+ data: GenericDataRowData,
) -> Union[NDScalarMetric, NDConfusionMatrixMetric]:
obj = cls.lookup_object(annotation)
return obj.from_common(annotation, data)
diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/mmc.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/mmc.py
index 4be24f683..b2dcfb5b4 100644
--- a/libs/labelbox/src/labelbox/data/serialization/ndjson/mmc.py
+++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/mmc.py
@@ -2,13 +2,14 @@
from labelbox.utils import _CamelCaseMixin
-from .base import _SubclassRegistryBase, DataRow, NDAnnotation
+from .base import DataRow, NDAnnotation
from ...annotation_types.mmc import (
MessageSingleSelectionTask,
MessageMultiSelectionTask,
MessageRankingTask,
MessageEvaluationTaskAnnotation,
)
+from ...annotation_types import GenericDataRowData
class MessageTaskData(_CamelCaseMixin):
@@ -20,7 +21,7 @@ class MessageTaskData(_CamelCaseMixin):
]
-class NDMessageTask(NDAnnotation, _SubclassRegistryBase):
+class NDMessageTask(NDAnnotation):
message_evaluation_task: MessageTaskData
def to_common(self) -> MessageEvaluationTaskAnnotation:
@@ -35,7 +36,7 @@ def to_common(self) -> MessageEvaluationTaskAnnotation:
def from_common(
cls,
annotation: MessageEvaluationTaskAnnotation,
- data: Any, # Union[ImageData, TextData],
+ data: GenericDataRowData,
) -> "NDMessageTask":
return cls(
uuid=str(annotation._uuid),
diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/objects.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/objects.py
index a1465fa06..1bcba7a89 100644
--- a/libs/labelbox/src/labelbox/data/serialization/ndjson/objects.py
+++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/objects.py
@@ -2,6 +2,7 @@
from typing import Any, Dict, List, Tuple, Union, Optional
import base64
+from labelbox.data.annotation_types.data.raster import MaskData
from labelbox.data.annotation_types.ner.conversation_entity import (
ConversationEntity,
)
@@ -21,9 +22,9 @@
from PIL import Image
from labelbox.data.annotation_types import feature
-from labelbox.data.annotation_types.data.video import VideoData
+from labelbox.data.annotation_types.data import GenericDataRowData
-from ...annotation_types.data import ImageData, TextData, MaskData
+from ...annotation_types.data import GenericDataRowData
from ...annotation_types.ner import (
DocumentEntity,
DocumentTextSelection,
@@ -52,7 +53,7 @@
NDSubclassification,
NDSubclassificationType,
)
-from .base import DataRow, NDAnnotation, NDJsonBase, _SubclassRegistryBase
+from .base import DataRow, NDAnnotation, NDJsonBase
from pydantic import BaseModel
@@ -81,9 +82,7 @@ class Bbox(BaseModel):
width: float
-class NDPoint(
- NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase
-):
+class NDPoint(NDBaseObject, ConfidenceMixin, CustomMetricsMixin):
point: _Point
def to_common(self) -> Point:
@@ -98,7 +97,7 @@ def from_common(
name: str,
feature_schema_id: Cuid,
extra: Dict[str, Any],
- data: Union[ImageData, TextData],
+ data: GenericDataRowData,
confidence: Optional[float] = None,
custom_metrics: Optional[List[CustomMetric]] = None,
) -> "NDPoint":
@@ -114,7 +113,7 @@ def from_common(
)
-class NDFramePoint(VideoSupported, _SubclassRegistryBase):
+class NDFramePoint(VideoSupported):
point: _Point
classifications: List[NDSubclassificationType] = []
@@ -148,9 +147,7 @@ def from_common(
)
-class NDLine(
- NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase
-):
+class NDLine(NDBaseObject, ConfidenceMixin, CustomMetricsMixin):
line: List[_Point]
def to_common(self) -> Line:
@@ -165,7 +162,7 @@ def from_common(
name: str,
feature_schema_id: Cuid,
extra: Dict[str, Any],
- data: Union[ImageData, TextData],
+ data: GenericDataRowData,
confidence: Optional[float] = None,
custom_metrics: Optional[List[CustomMetric]] = None,
) -> "NDLine":
@@ -181,7 +178,7 @@ def from_common(
)
-class NDFrameLine(VideoSupported, _SubclassRegistryBase):
+class NDFrameLine(VideoSupported):
line: List[_Point]
classifications: List[NDSubclassificationType] = []
@@ -215,7 +212,7 @@ def from_common(
)
-class NDDicomLine(NDFrameLine, _SubclassRegistryBase):
+class NDDicomLine(NDFrameLine):
def to_common(
self,
name: str,
@@ -234,9 +231,7 @@ def to_common(
)
-class NDPolygon(
- NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase
-):
+class NDPolygon(NDBaseObject, ConfidenceMixin, CustomMetricsMixin):
polygon: List[_Point]
def to_common(self) -> Polygon:
@@ -251,7 +246,7 @@ def from_common(
name: str,
feature_schema_id: Cuid,
extra: Dict[str, Any],
- data: Union[ImageData, TextData],
+ data: GenericDataRowData,
confidence: Optional[float] = None,
custom_metrics: Optional[List[CustomMetric]] = None,
) -> "NDPolygon":
@@ -267,9 +262,7 @@ def from_common(
)
-class NDRectangle(
- NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase
-):
+class NDRectangle(NDBaseObject, ConfidenceMixin, CustomMetricsMixin):
bbox: Bbox
def to_common(self) -> Rectangle:
@@ -290,7 +283,7 @@ def from_common(
name: str,
feature_schema_id: Cuid,
extra: Dict[str, Any],
- data: Union[ImageData, TextData],
+ data: GenericDataRowData,
confidence: Optional[float] = None,
custom_metrics: Optional[List[CustomMetric]] = None,
) -> "NDRectangle":
@@ -313,7 +306,7 @@ def from_common(
)
-class NDDocumentRectangle(NDRectangle, _SubclassRegistryBase):
+class NDDocumentRectangle(NDRectangle):
page: int
unit: str
@@ -337,7 +330,7 @@ def from_common(
name: str,
feature_schema_id: Cuid,
extra: Dict[str, Any],
- data: Union[ImageData, TextData],
+ data: GenericDataRowData,
confidence: Optional[float] = None,
custom_metrics: Optional[List[CustomMetric]] = None,
) -> "NDRectangle":
@@ -360,7 +353,7 @@ def from_common(
)
-class NDFrameRectangle(VideoSupported, _SubclassRegistryBase):
+class NDFrameRectangle(VideoSupported):
bbox: Bbox
classifications: List[NDSubclassificationType] = []
@@ -496,7 +489,7 @@ def to_common(
]
-class NDSegments(NDBaseObject, _SubclassRegistryBase):
+class NDSegments(NDBaseObject):
segments: List[NDSegment]
def to_common(self, name: str, feature_schema_id: Cuid):
@@ -516,7 +509,7 @@ def to_common(self, name: str, feature_schema_id: Cuid):
def from_common(
cls,
segments: List[VideoObjectAnnotation],
- data: VideoData,
+ data: GenericDataRowData,
name: str,
feature_schema_id: Cuid,
extra: Dict[str, Any],
@@ -532,7 +525,7 @@ def from_common(
)
-class NDDicomSegments(NDBaseObject, DicomSupported, _SubclassRegistryBase):
+class NDDicomSegments(NDBaseObject, DicomSupported):
segments: List[NDDicomSegment]
def to_common(self, name: str, feature_schema_id: Cuid):
@@ -553,7 +546,7 @@ def to_common(self, name: str, feature_schema_id: Cuid):
def from_common(
cls,
segments: List[DICOMObjectAnnotation],
- data: VideoData,
+ data: GenericDataRowData,
name: str,
feature_schema_id: Cuid,
extra: Dict[str, Any],
@@ -580,9 +573,7 @@ class _PNGMask(BaseModel):
png: str
-class NDMask(
- NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase
-):
+class NDMask(NDBaseObject, ConfidenceMixin, CustomMetricsMixin):
mask: Union[_URIMask, _PNGMask]
def to_common(self) -> Mask:
@@ -611,7 +602,7 @@ def from_common(
name: str,
feature_schema_id: Cuid,
extra: Dict[str, Any],
- data: Union[ImageData, TextData],
+ data: GenericDataRowData,
confidence: Optional[float] = None,
custom_metrics: Optional[List[CustomMetric]] = None,
) -> "NDMask":
@@ -646,7 +637,6 @@ class NDVideoMasks(
NDJsonBase,
ConfidenceMixin,
CustomMetricsNotSupportedMixin,
- _SubclassRegistryBase,
):
masks: NDVideoMasksFramesInstances
@@ -678,7 +668,7 @@ def from_common(cls, annotation, data):
)
-class NDDicomMasks(NDVideoMasks, DicomSupported, _SubclassRegistryBase):
+class NDDicomMasks(NDVideoMasks, DicomSupported):
def to_common(self) -> DICOMMaskAnnotation:
return DICOMMaskAnnotation(
frames=self.masks.frames,
@@ -702,9 +692,7 @@ class Location(BaseModel):
end: int
-class NDTextEntity(
- NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase
-):
+class NDTextEntity(NDBaseObject, ConfidenceMixin, CustomMetricsMixin):
location: Location
def to_common(self) -> TextEntity:
@@ -719,7 +707,7 @@ def from_common(
name: str,
feature_schema_id: Cuid,
extra: Dict[str, Any],
- data: Union[ImageData, TextData],
+ data: GenericDataRowData,
confidence: Optional[float] = None,
custom_metrics: Optional[List[CustomMetric]] = None,
) -> "NDTextEntity":
@@ -738,9 +726,7 @@ def from_common(
)
-class NDDocumentEntity(
- NDBaseObject, ConfidenceMixin, CustomMetricsMixin, _SubclassRegistryBase
-):
+class NDDocumentEntity(NDBaseObject, ConfidenceMixin, CustomMetricsMixin):
name: str
text_selections: List[DocumentTextSelection]
@@ -758,7 +744,7 @@ def from_common(
name: str,
feature_schema_id: Cuid,
extra: Dict[str, Any],
- data: Union[ImageData, TextData],
+ data: GenericDataRowData,
confidence: Optional[float] = None,
custom_metrics: Optional[List[CustomMetric]] = None,
) -> "NDDocumentEntity":
@@ -774,7 +760,7 @@ def from_common(
)
-class NDConversationEntity(NDTextEntity, _SubclassRegistryBase):
+class NDConversationEntity(NDTextEntity):
message_id: str
def to_common(self) -> ConversationEntity:
@@ -793,7 +779,7 @@ def from_common(
name: str,
feature_schema_id: Cuid,
extra: Dict[str, Any],
- data: Union[ImageData, TextData],
+ data: GenericDataRowData,
confidence: Optional[float] = None,
custom_metrics: Optional[List[CustomMetric]] = None,
) -> "NDConversationEntity":
@@ -851,7 +837,7 @@ def from_common(
List[List[VideoObjectAnnotation]],
VideoMaskAnnotation,
],
- data: Union[ImageData, TextData],
+ data: GenericDataRowData,
) -> Union[
NDLine,
NDPoint,
diff --git a/libs/labelbox/src/labelbox/data/serialization/ndjson/relationship.py b/libs/labelbox/src/labelbox/data/serialization/ndjson/relationship.py
index fbea7e477..d558ac244 100644
--- a/libs/labelbox/src/labelbox/data/serialization/ndjson/relationship.py
+++ b/libs/labelbox/src/labelbox/data/serialization/ndjson/relationship.py
@@ -1,11 +1,11 @@
from typing import Union
from pydantic import BaseModel
from .base import NDAnnotation, DataRow
-from ...annotation_types.data import ImageData, TextData
+from ...annotation_types.data import GenericDataRowData
from ...annotation_types.relationship import RelationshipAnnotation
from ...annotation_types.relationship import Relationship
from .objects import NDObjectType
-from .base import DataRow, _SubclassRegistryBase
+from .base import DataRow
SUPPORTED_ANNOTATIONS = NDObjectType
@@ -16,7 +16,7 @@ class _Relationship(BaseModel):
type: str
-class NDRelationship(NDAnnotation, _SubclassRegistryBase):
+class NDRelationship(NDAnnotation):
relationship: _Relationship
@staticmethod
@@ -40,7 +40,7 @@ def to_common(
def from_common(
cls,
annotation: RelationshipAnnotation,
- data: Union[ImageData, TextData],
+ data: GenericDataRowData,
) -> "NDRelationship":
relationship = annotation.value
return cls(
diff --git a/libs/labelbox/src/labelbox/orm/db_object.py b/libs/labelbox/src/labelbox/orm/db_object.py
index b210a8a5b..a1c2bde38 100644
--- a/libs/labelbox/src/labelbox/orm/db_object.py
+++ b/libs/labelbox/src/labelbox/orm/db_object.py
@@ -1,17 +1,17 @@
-from dataclasses import dataclass
+import json
+import logging
from datetime import datetime, timezone
from functools import wraps
-import logging
-import json
-from labelbox import utils
-from labelbox.exceptions import (
- InvalidQueryError,
+from lbox.exceptions import (
InvalidAttributeError,
+ InvalidQueryError,
OperationNotSupportedException,
)
+
+from labelbox import utils
from labelbox.orm import query
-from labelbox.orm.model import Field, Relationship, Entity
+from labelbox.orm.model import Entity, Field, Relationship
from labelbox.pagination import PaginatedCollection
logger = logging.getLogger(__name__)
@@ -177,7 +177,7 @@ def _to_many(self, where=None, order_by=None):
)
if rel.filter_deleted:
- not_deleted = rel.destination_type.deleted == False
+ not_deleted = rel.destination_type.deleted == False # noqa: E712 Needed for bit operator to combine comparisons
where = not_deleted if where is None else where & not_deleted
query_string, params = query.relationship(
diff --git a/libs/labelbox/src/labelbox/orm/model.py b/libs/labelbox/src/labelbox/orm/model.py
index 84dcac774..535ab0f7d 100644
--- a/libs/labelbox/src/labelbox/orm/model.py
+++ b/libs/labelbox/src/labelbox/orm/model.py
@@ -1,10 +1,11 @@
from dataclasses import dataclass
from enum import Enum, auto
-from typing import Dict, List, Union, Any, Type, TYPE_CHECKING
+from typing import TYPE_CHECKING, Any, Dict, List, Type, Union
+
+from lbox.exceptions import InvalidAttributeError
import labelbox
from labelbox import utils
-from labelbox.exceptions import InvalidAttributeError
from labelbox.orm.comparison import Comparison
""" Defines Field, Relationship and Entity. These classes are building
@@ -386,7 +387,6 @@ class Entity(metaclass=EntityMeta):
Review: Type[labelbox.Review]
User: Type[labelbox.User]
LabelingFrontend: Type[labelbox.LabelingFrontend]
- BulkImportRequest: Type[labelbox.BulkImportRequest]
Benchmark: Type[labelbox.Benchmark]
IAMIntegration: Type[labelbox.IAMIntegration]
LabelingFrontendOptions: Type[labelbox.LabelingFrontendOptions]
diff --git a/libs/labelbox/src/labelbox/orm/query.py b/libs/labelbox/src/labelbox/orm/query.py
index 8fa9fea00..8c019fb5e 100644
--- a/libs/labelbox/src/labelbox/orm/query.py
+++ b/libs/labelbox/src/labelbox/orm/query.py
@@ -1,14 +1,15 @@
from itertools import chain
from typing import Any, Dict
-from labelbox import utils
-from labelbox.exceptions import (
- InvalidQueryError,
+from lbox.exceptions import (
InvalidAttributeError,
+ InvalidQueryError,
MalformedQueryException,
)
-from labelbox.orm.comparison import LogicalExpression, Comparison
-from labelbox.orm.model import Field, Relationship, Entity
+
+from labelbox import utils
+from labelbox.orm.comparison import Comparison, LogicalExpression
+from labelbox.orm.model import Entity, Field, Relationship
""" Common query creation functionality. """
diff --git a/libs/labelbox/src/labelbox/project_validation.py b/libs/labelbox/src/labelbox/project_validation.py
new file mode 100644
index 000000000..2a6db9e2a
--- /dev/null
+++ b/libs/labelbox/src/labelbox/project_validation.py
@@ -0,0 +1,87 @@
+from typing import Annotated, Optional, Set
+
+from pydantic import BaseModel, ConfigDict, Field, model_validator
+
+from labelbox.schema.media_type import MediaType
+from labelbox.schema.ontology_kind import EditorTaskType
+from labelbox.schema.quality_mode import (
+ BENCHMARK_AUTO_AUDIT_NUMBER_OF_LABELS,
+ BENCHMARK_AUTO_AUDIT_PERCENTAGE,
+ CONSENSUS_AUTO_AUDIT_NUMBER_OF_LABELS,
+ CONSENSUS_AUTO_AUDIT_PERCENTAGE,
+ QualityMode,
+)
+
+PositiveInt = Annotated[int, Field(gt=0)]
+
+
+class _CoreProjectInput(BaseModel):
+ name: str
+ description: Optional[str] = None
+ media_type: MediaType
+ auto_audit_percentage: Optional[float] = None
+ auto_audit_number_of_labels: Optional[int] = None
+ quality_modes: Optional[Set[QualityMode]] = Field(
+ default={QualityMode.Benchmark, QualityMode.Consensus}, exclude=True
+ )
+ is_benchmark_enabled: Optional[bool] = None
+ is_consensus_enabled: Optional[bool] = None
+ dataset_name_or_id: Optional[str] = None
+ append_to_existing_dataset: Optional[bool] = None
+ data_row_count: Optional[PositiveInt] = None
+ editor_task_type: Optional[EditorTaskType] = None
+
+ model_config = ConfigDict(extra="forbid", use_enum_values=True)
+
+ @model_validator(mode="after")
+ def validate_fields(self):
+ if (
+ self.auto_audit_percentage is not None
+ or self.auto_audit_number_of_labels is not None
+ ):
+ raise ValueError(
+ "quality_modes must be set instead of auto_audit_percentage or auto_audit_number_of_labels."
+ )
+
+ if not self.name.strip():
+ raise ValueError("project name must be a valid string.")
+
+ if self.quality_modes == {
+ QualityMode.Benchmark,
+ QualityMode.Consensus,
+ }:
+ self._set_quality_mode_attributes(
+ CONSENSUS_AUTO_AUDIT_NUMBER_OF_LABELS,
+ CONSENSUS_AUTO_AUDIT_PERCENTAGE,
+ is_benchmark_enabled=True,
+ is_consensus_enabled=True,
+ )
+ elif self.quality_modes == {QualityMode.Benchmark}:
+ self._set_quality_mode_attributes(
+ BENCHMARK_AUTO_AUDIT_NUMBER_OF_LABELS,
+ BENCHMARK_AUTO_AUDIT_PERCENTAGE,
+ is_benchmark_enabled=True,
+ )
+ elif self.quality_modes == {QualityMode.Consensus}:
+ self._set_quality_mode_attributes(
+ number_of_labels=CONSENSUS_AUTO_AUDIT_NUMBER_OF_LABELS,
+ percentage=CONSENSUS_AUTO_AUDIT_PERCENTAGE,
+ is_consensus_enabled=True,
+ )
+
+ if self.data_row_count is not None and self.data_row_count < 0:
+ raise ValueError("data_row_count must be a positive integer.")
+
+ return self
+
+ def _set_quality_mode_attributes(
+ self,
+ number_of_labels,
+ percentage,
+ is_benchmark_enabled=False,
+ is_consensus_enabled=False,
+ ):
+ self.auto_audit_number_of_labels = number_of_labels
+ self.auto_audit_percentage = percentage
+ self.is_benchmark_enabled = is_benchmark_enabled
+ self.is_consensus_enabled = is_consensus_enabled
diff --git a/libs/labelbox/src/labelbox/schema/__init__.py b/libs/labelbox/src/labelbox/schema/__init__.py
index 03327e0d1..d6b74de68 100644
--- a/libs/labelbox/src/labelbox/schema/__init__.py
+++ b/libs/labelbox/src/labelbox/schema/__init__.py
@@ -1,5 +1,4 @@
import labelbox.schema.asset_attachment
-import labelbox.schema.bulk_import_request
import labelbox.schema.annotation_import
import labelbox.schema.benchmark
import labelbox.schema.data_row
diff --git a/libs/labelbox/src/labelbox/schema/annotation_import.py b/libs/labelbox/src/labelbox/schema/annotation_import.py
index df7f272a3..497ac899d 100644
--- a/libs/labelbox/src/labelbox/schema/annotation_import.py
+++ b/libs/labelbox/src/labelbox/schema/annotation_import.py
@@ -3,33 +3,34 @@
import logging
import os
import time
+from collections import defaultdict
from typing import (
+ TYPE_CHECKING,
Any,
BinaryIO,
Dict,
List,
Optional,
Union,
- TYPE_CHECKING,
cast,
)
-from collections import defaultdict
-from google.api_core import retry
-from labelbox import parser
import requests
+from google.api_core import retry
+from lbox.exceptions import ApiLimitError, NetworkError, ResourceNotFoundError
from tqdm import tqdm # type: ignore
import labelbox
+from labelbox import parser
from labelbox.orm import query
from labelbox.orm.db_object import DbObject
from labelbox.orm.model import Field, Relationship
-from labelbox.utils import is_exactly_one_set
from labelbox.schema.confidence_presence_checker import (
LabelsConfidencePresenceChecker,
)
from labelbox.schema.enums import AnnotationImportState
from labelbox.schema.serialization import serialize_labels
+from labelbox.utils import is_exactly_one_set
if TYPE_CHECKING:
from labelbox.types import Label
@@ -140,9 +141,9 @@ def wait_until_done(
@retry.Retry(
predicate=retry.if_exception_type(
- labelbox.exceptions.ApiLimitError,
- labelbox.exceptions.TimeoutError,
- labelbox.exceptions.NetworkError,
+ ApiLimitError,
+ TimeoutError,
+ NetworkError,
)
)
def __backoff_refresh(self) -> None:
@@ -435,9 +436,7 @@ def from_name(
}
response = client.execute(query_str, params)
if response is None:
- raise labelbox.exceptions.ResourceNotFoundError(
- MEAPredictionImport, params
- )
+ raise ResourceNotFoundError(MEAPredictionImport, params)
response = response["modelErrorAnalysisPredictionImport"]
if as_json:
return response
@@ -560,9 +559,7 @@ def from_name(
}
response = client.execute(query_str, params)
if response is None:
- raise labelbox.exceptions.ResourceNotFoundError(
- MALPredictionImport, params
- )
+ raise ResourceNotFoundError(MALPredictionImport, params)
response = response["meaToMalPredictionImport"]
if as_json:
return response
@@ -709,9 +706,7 @@ def from_name(
}
response = client.execute(query_str, params)
if response is None:
- raise labelbox.exceptions.ResourceNotFoundError(
- MALPredictionImport, params
- )
+ raise ResourceNotFoundError(MALPredictionImport, params)
response = response["modelAssistedLabelingPredictionImport"]
if as_json:
return response
@@ -885,7 +880,7 @@ def from_name(
}
response = client.execute(query_str, params)
if response is None:
- raise labelbox.exceptions.ResourceNotFoundError(LabelImport, params)
+ raise ResourceNotFoundError(LabelImport, params)
response = response["labelImport"]
if as_json:
return response
diff --git a/libs/labelbox/src/labelbox/schema/asset_attachment.py b/libs/labelbox/src/labelbox/schema/asset_attachment.py
index 0d5598c84..9a56dbb72 100644
--- a/libs/labelbox/src/labelbox/schema/asset_attachment.py
+++ b/libs/labelbox/src/labelbox/schema/asset_attachment.py
@@ -7,15 +7,6 @@
class AttachmentType(str, Enum):
- @classmethod
- def __missing__(cls, value: object):
- if str(value) == "TEXT":
- warnings.warn(
- "The TEXT attachment type is deprecated. Use RAW_TEXT instead."
- )
- return cls.RAW_TEXT
- return value
-
VIDEO = "VIDEO"
IMAGE = "IMAGE"
IMAGE_OVERLAY = "IMAGE_OVERLAY"
@@ -30,7 +21,7 @@ class AssetAttachment(DbObject):
"""Asset attachment provides extra context about an asset while labeling.
Attributes:
- attachment_type (str): IMAGE, VIDEO, IMAGE_OVERLAY, HTML, RAW_TEXT, TEXT_URL, or PDF_URL. TEXT attachment type is deprecated.
+ attachment_type (str): IMAGE, VIDEO, IMAGE_OVERLAY, HTML, RAW_TEXT, TEXT_URL, or PDF_URL.
attachment_value (str): URL to an external file or a string of text
attachment_name (str): The name of the attachment
"""
diff --git a/libs/labelbox/src/labelbox/schema/batch.py b/libs/labelbox/src/labelbox/schema/batch.py
index 316732a67..99e4c4908 100644
--- a/libs/labelbox/src/labelbox/schema/batch.py
+++ b/libs/labelbox/src/labelbox/schema/batch.py
@@ -1,15 +1,11 @@
-from typing import Generator, TYPE_CHECKING
+import logging
+from typing import TYPE_CHECKING
+
+from lbox.exceptions import ResourceNotFoundError
-from labelbox.orm.db_object import DbObject, experimental
from labelbox.orm import query
+from labelbox.orm.db_object import DbObject
from labelbox.orm.model import Entity, Field, Relationship
-from labelbox.exceptions import LabelboxError, ResourceNotFoundError
-from io import StringIO
-from labelbox import parser
-import requests
-import logging
-import time
-import warnings
if TYPE_CHECKING:
from labelbox import Project
diff --git a/libs/labelbox/src/labelbox/schema/bulk_import_request.py b/libs/labelbox/src/labelbox/schema/bulk_import_request.py
deleted file mode 100644
index 8e11f3261..000000000
--- a/libs/labelbox/src/labelbox/schema/bulk_import_request.py
+++ /dev/null
@@ -1,1004 +0,0 @@
-import json
-import time
-from uuid import UUID, uuid4
-import functools
-
-import logging
-from pathlib import Path
-from google.api_core import retry
-from labelbox import parser
-import requests
-from pydantic import (
- ValidationError,
- BaseModel,
- Field,
- field_validator,
- model_validator,
- ConfigDict,
- StringConstraints,
-)
-from typing_extensions import Literal, Annotated
-from typing import (
- Any,
- List,
- Optional,
- BinaryIO,
- Dict,
- Iterable,
- Tuple,
- Union,
- Type,
- Set,
- TYPE_CHECKING,
-)
-
-from labelbox import exceptions as lb_exceptions
-from labelbox import utils
-from labelbox.orm import query
-from labelbox.orm.db_object import DbObject
-from labelbox.orm.model import Relationship
-from labelbox.schema.enums import BulkImportRequestState
-from labelbox.schema.serialization import serialize_labels
-from labelbox.orm.model import Field as lb_Field
-
-if TYPE_CHECKING:
- from labelbox import Project
- from labelbox.types import Label
-
-NDJSON_MIME_TYPE = "application/x-ndjson"
-logger = logging.getLogger(__name__)
-
-# TODO: Deprecate this library in place of labelimport and malprediction import library.
-
-
-def _determinants(parent_cls: Any) -> List[str]:
- return [
- k
- for k, v in parent_cls.model_fields.items()
- if v.json_schema_extra and "determinant" in v.json_schema_extra
- ]
-
-
-def _make_file_name(project_id: str, name: str) -> str:
- return f"{project_id}__{name}.ndjson"
-
-
-# TODO(gszpak): move it to client.py
-def _make_request_data(
- project_id: str, name: str, content_length: int, file_name: str
-) -> dict:
- query_str = """mutation createBulkImportRequestFromFilePyApi(
- $projectId: ID!, $name: String!, $file: Upload!, $contentLength: Int!) {
- createBulkImportRequest(data: {
- projectId: $projectId,
- name: $name,
- filePayload: {
- file: $file,
- contentLength: $contentLength
- }
- }) {
- %s
- }
- }
- """ % query.results_query_part(BulkImportRequest)
- variables = {
- "projectId": project_id,
- "name": name,
- "file": None,
- "contentLength": content_length,
- }
- operations = json.dumps({"variables": variables, "query": query_str})
-
- return {
- "operations": operations,
- "map": (None, json.dumps({file_name: ["variables.file"]})),
- }
-
-
-def _send_create_file_command(
- client,
- request_data: dict,
- file_name: str,
- file_data: Tuple[str, Union[bytes, BinaryIO], str],
-) -> dict:
- response = client.execute(data=request_data, files={file_name: file_data})
-
- if not response.get("createBulkImportRequest", None):
- raise lb_exceptions.LabelboxError(
- "Failed to create BulkImportRequest, message: %s"
- % response.get("errors", None)
- or response.get("error", None)
- )
-
- return response
-
-
-class BulkImportRequest(DbObject):
- """Represents the import job when importing annotations.
-
- Attributes:
- name (str)
- state (Enum): FAILED, RUNNING, or FINISHED (Refers to the whole import job)
- input_file_url (str): URL to your web-hosted NDJSON file
- error_file_url (str): NDJSON that contains error messages for failed annotations
- status_file_url (str): NDJSON that contains status for each annotation
- created_at (datetime): UTC timestamp for date BulkImportRequest was created
-
- project (Relationship): `ToOne` relationship to Project
- created_by (Relationship): `ToOne` relationship to User
- """
-
- name = lb_Field.String("name")
- state = lb_Field.Enum(BulkImportRequestState, "state")
- input_file_url = lb_Field.String("input_file_url")
- error_file_url = lb_Field.String("error_file_url")
- status_file_url = lb_Field.String("status_file_url")
- created_at = lb_Field.DateTime("created_at")
-
- project = Relationship.ToOne("Project")
- created_by = Relationship.ToOne("User", False, "created_by")
-
- @property
- def inputs(self) -> List[Dict[str, Any]]:
- """
- Inputs for each individual annotation uploaded.
- This should match the ndjson annotations that you have uploaded.
-
- Returns:
- Uploaded ndjson.
-
- * This information will expire after 24 hours.
- """
- return self._fetch_remote_ndjson(self.input_file_url)
-
- @property
- def errors(self) -> List[Dict[str, Any]]:
- """
- Errors for each individual annotation uploaded. This is a subset of statuses
-
- Returns:
- List of dicts containing error messages. Empty list means there were no errors
- See `BulkImportRequest.statuses` for more details.
-
- * This information will expire after 24 hours.
- """
- self.wait_until_done()
- return self._fetch_remote_ndjson(self.error_file_url)
-
- @property
- def statuses(self) -> List[Dict[str, Any]]:
- """
- Status for each individual annotation uploaded.
-
- Returns:
- A status for each annotation if the upload is done running.
- See below table for more details
-
- .. list-table::
- :widths: 15 150
- :header-rows: 1
-
- * - Field
- - Description
- * - uuid
- - Specifies the annotation for the status row.
- * - dataRow
- - JSON object containing the Labelbox data row ID for the annotation.
- * - status
- - Indicates SUCCESS or FAILURE.
- * - errors
- - An array of error messages included when status is FAILURE. Each error has a name, message and optional (key might not exist) additional_info.
-
- * This information will expire after 24 hours.
- """
- self.wait_until_done()
- return self._fetch_remote_ndjson(self.status_file_url)
-
- @functools.lru_cache()
- def _fetch_remote_ndjson(self, url: str) -> List[Dict[str, Any]]:
- """
- Fetches the remote ndjson file and caches the results.
-
- Args:
- url (str): Can be any url pointing to an ndjson file.
- Returns:
- ndjson as a list of dicts.
- """
- response = requests.get(url)
- response.raise_for_status()
- return parser.loads(response.text)
-
- def refresh(self) -> None:
- """Synchronizes values of all fields with the database."""
- query_str, params = query.get_single(BulkImportRequest, self.uid)
- res = self.client.execute(query_str, params)
- res = res[utils.camel_case(BulkImportRequest.type_name())]
- self._set_field_values(res)
-
- def wait_till_done(self, sleep_time_seconds: int = 5) -> None:
- self.wait_until_done(sleep_time_seconds)
-
- def wait_until_done(self, sleep_time_seconds: int = 5) -> None:
- """Blocks import job until certain conditions are met.
-
- Blocks until the BulkImportRequest.state changes either to
- `BulkImportRequestState.FINISHED` or `BulkImportRequestState.FAILED`,
- periodically refreshing object's state.
-
- Args:
- sleep_time_seconds (str): a time to block between subsequent API calls
- """
- while self.state == BulkImportRequestState.RUNNING:
- logger.info(f"Sleeping for {sleep_time_seconds} seconds...")
- time.sleep(sleep_time_seconds)
- self.__exponential_backoff_refresh()
-
- @retry.Retry(
- predicate=retry.if_exception_type(
- lb_exceptions.ApiLimitError,
- lb_exceptions.TimeoutError,
- lb_exceptions.NetworkError,
- )
- )
- def __exponential_backoff_refresh(self) -> None:
- self.refresh()
-
- @classmethod
- def from_name(
- cls, client, project_id: str, name: str
- ) -> "BulkImportRequest":
- """Fetches existing BulkImportRequest.
-
- Args:
- client (Client): a Labelbox client
- project_id (str): BulkImportRequest's project id
- name (str): name of BulkImportRequest
- Returns:
- BulkImportRequest object
-
- """
- query_str = """query getBulkImportRequestPyApi(
- $projectId: ID!, $name: String!) {
- bulkImportRequest(where: {
- projectId: $projectId,
- name: $name
- }) {
- %s
- }
- }
- """ % query.results_query_part(cls)
- params = {"projectId": project_id, "name": name}
- response = client.execute(query_str, params=params)
- return cls(client, response["bulkImportRequest"])
-
- @classmethod
- def create_from_url(
- cls, client, project_id: str, name: str, url: str, validate=True
- ) -> "BulkImportRequest":
- """
- Creates a BulkImportRequest from a publicly accessible URL
- to an ndjson file with predictions.
-
- Args:
- client (Client): a Labelbox client
- project_id (str): id of project for which predictions will be imported
- name (str): name of BulkImportRequest
- url (str): publicly accessible URL pointing to ndjson file containing predictions
- validate (bool): a flag indicating if there should be a validation
- if `url` is valid ndjson
- Returns:
- BulkImportRequest object
- """
- if validate:
- logger.warn(
- "Validation is turned on. The file will be downloaded locally and processed before uploading."
- )
- res = requests.get(url)
- data = parser.loads(res.text)
- _validate_ndjson(data, client.get_project(project_id))
-
- query_str = """mutation createBulkImportRequestPyApi(
- $projectId: ID!, $name: String!, $fileUrl: String!) {
- createBulkImportRequest(data: {
- projectId: $projectId,
- name: $name,
- fileUrl: $fileUrl
- }) {
- %s
- }
- }
- """ % query.results_query_part(cls)
- params = {"projectId": project_id, "name": name, "fileUrl": url}
- bulk_import_request_response = client.execute(query_str, params=params)
- return cls(
- client, bulk_import_request_response["createBulkImportRequest"]
- )
-
- @classmethod
- def create_from_objects(
- cls,
- client,
- project_id: str,
- name: str,
- predictions: Union[Iterable[Dict], Iterable["Label"]],
- validate=True,
- ) -> "BulkImportRequest":
- """
- Creates a `BulkImportRequest` from an iterable of dictionaries.
-
- Conforms to JSON predictions format, e.g.:
- ``{
- "uuid": "9fd9a92e-2560-4e77-81d4-b2e955800092",
- "schemaId": "ckappz7d700gn0zbocmqkwd9i",
- "dataRow": {
- "id": "ck1s02fqxm8fi0757f0e6qtdc"
- },
- "bbox": {
- "top": 48,
- "left": 58,
- "height": 865,
- "width": 1512
- }
- }``
-
- Args:
- client (Client): a Labelbox client
- project_id (str): id of project for which predictions will be imported
- name (str): name of BulkImportRequest
- predictions (Iterable[dict]): iterable of dictionaries representing predictions
- validate (bool): a flag indicating if there should be a validation
- if `predictions` is valid ndjson
- Returns:
- BulkImportRequest object
- """
- if not isinstance(predictions, list):
- raise TypeError(
- f"annotations must be in a form of Iterable. Found {type(predictions)}"
- )
- ndjson_predictions = serialize_labels(predictions)
-
- if validate:
- _validate_ndjson(ndjson_predictions, client.get_project(project_id))
-
- data_str = parser.dumps(ndjson_predictions)
- if not data_str:
- raise ValueError("annotations cannot be empty")
-
- data = data_str.encode("utf-8")
- file_name = _make_file_name(project_id, name)
- request_data = _make_request_data(
- project_id, name, len(data_str), file_name
- )
- file_data = (file_name, data, NDJSON_MIME_TYPE)
- response_data = _send_create_file_command(
- client,
- request_data=request_data,
- file_name=file_name,
- file_data=file_data,
- )
-
- return cls(client, response_data["createBulkImportRequest"])
-
- @classmethod
- def create_from_local_file(
- cls, client, project_id: str, name: str, file: Path, validate_file=True
- ) -> "BulkImportRequest":
- """
- Creates a BulkImportRequest from a local ndjson file with predictions.
-
- Args:
- client (Client): a Labelbox client
- project_id (str): id of project for which predictions will be imported
- name (str): name of BulkImportRequest
- file (Path): local ndjson file with predictions
- validate_file (bool): a flag indicating if there should be a validation
- if `file` is a valid ndjson file
- Returns:
- BulkImportRequest object
-
- """
- file_name = _make_file_name(project_id, name)
- content_length = file.stat().st_size
- request_data = _make_request_data(
- project_id, name, content_length, file_name
- )
-
- with file.open("rb") as f:
- if validate_file:
- reader = parser.reader(f)
- # ensure that the underlying json load call is valid
- # https://github.com/rhgrant10/ndjson/blob/ff2f03c56b21f28f7271b27da35ca4a8bf9a05d0/ndjson/api.py#L53
- # by iterating through the file so we only store
- # each line in memory rather than the entire file
- try:
- _validate_ndjson(reader, client.get_project(project_id))
- except ValueError:
- raise ValueError(f"{file} is not a valid ndjson file")
- else:
- f.seek(0)
- file_data = (file.name, f, NDJSON_MIME_TYPE)
- response_data = _send_create_file_command(
- client, request_data, file_name, file_data
- )
- return cls(client, response_data["createBulkImportRequest"])
-
- def delete(self) -> None:
- """Deletes the import job and also any annotations created by this import.
-
- Returns:
- None
- """
- id_param = "bulk_request_id"
- query_str = """mutation deleteBulkImportRequestPyApi($%s: ID!) {
- deleteBulkImportRequest(where: {id: $%s}) {
- id
- name
- }
- }""" % (id_param, id_param)
- self.client.execute(query_str, {id_param: self.uid})
-
-
-def _validate_ndjson(
- lines: Iterable[Dict[str, Any]], project: "Project"
-) -> None:
- """
- Client side validation of an ndjson object.
-
- Does not guarentee that an upload will succeed for the following reasons:
- * We are not checking the data row types which will cause the following errors to slip through
- * Missing frame indices will not causes an error for videos
- * Uploaded annotations for the wrong data type will pass (Eg. entity on images)
- * We are not checking bounds of an asset (Eg. frame index, image height, text location)
-
- Args:
- lines (Iterable[Dict[str,Any]]): An iterable of ndjson lines
- project (Project): id of project for which predictions will be imported
-
- Raises:
- MALValidationError: Raise for invalid NDJson
- UuidError: Duplicate UUID in upload
- """
- feature_schemas_by_id, feature_schemas_by_name = get_mal_schemas(
- project.ontology()
- )
- uids: Set[str] = set()
- for idx, line in enumerate(lines):
- try:
- annotation = NDAnnotation(**line)
- annotation.validate_instance(
- feature_schemas_by_id, feature_schemas_by_name
- )
- uuid = str(annotation.uuid)
- if uuid in uids:
- raise lb_exceptions.UuidError(
- f"{uuid} already used in this import job, "
- "must be unique for the project."
- )
- uids.add(uuid)
- except (ValidationError, ValueError, TypeError, KeyError) as e:
- raise lb_exceptions.MALValidationError(
- f"Invalid NDJson on line {idx}"
- ) from e
-
-
-# The rest of this file contains objects for MAL validation
-def parse_classification(tool):
- """
- Parses a classification from an ontology. Only radio, checklist, and text are supported for mal
-
- Args:
- tool (dict)
-
- Returns:
- dict
- """
- if tool["type"] in ["radio", "checklist"]:
- option_schema_ids = [r["featureSchemaId"] for r in tool["options"]]
- option_names = [r["value"] for r in tool["options"]]
- return {
- "tool": tool["type"],
- "featureSchemaId": tool["featureSchemaId"],
- "name": tool["name"],
- "options": [*option_schema_ids, *option_names],
- }
- elif tool["type"] == "text":
- return {
- "tool": tool["type"],
- "name": tool["name"],
- "featureSchemaId": tool["featureSchemaId"],
- }
-
-
-def get_mal_schemas(ontology):
- """
- Converts a project ontology to a dict for easier lookup during ndjson validation
-
- Args:
- ontology (Ontology)
- Returns:
- Dict, Dict : Useful for looking up a tool from a given feature schema id or name
- """
-
- valid_feature_schemas_by_schema_id = {}
- valid_feature_schemas_by_name = {}
- for tool in ontology.normalized["tools"]:
- classifications = [
- parse_classification(classification_tool)
- for classification_tool in tool["classifications"]
- ]
- classifications_by_schema_id = {
- v["featureSchemaId"]: v for v in classifications
- }
- classifications_by_name = {v["name"]: v for v in classifications}
- valid_feature_schemas_by_schema_id[tool["featureSchemaId"]] = {
- "tool": tool["tool"],
- "classificationsBySchemaId": classifications_by_schema_id,
- "classificationsByName": classifications_by_name,
- "name": tool["name"],
- }
- valid_feature_schemas_by_name[tool["name"]] = {
- "tool": tool["tool"],
- "classificationsBySchemaId": classifications_by_schema_id,
- "classificationsByName": classifications_by_name,
- "name": tool["name"],
- }
- for tool in ontology.normalized["classifications"]:
- valid_feature_schemas_by_schema_id[tool["featureSchemaId"]] = (
- parse_classification(tool)
- )
- valid_feature_schemas_by_name[tool["name"]] = parse_classification(tool)
- return valid_feature_schemas_by_schema_id, valid_feature_schemas_by_name
-
-
-class Bbox(BaseModel):
- top: float
- left: float
- height: float
- width: float
-
-
-class Point(BaseModel):
- x: float
- y: float
-
-
-class FrameLocation(BaseModel):
- end: int
- start: int
-
-
-class VideoSupported(BaseModel):
- # Note that frames are only allowed as top level inferences for video
- frames: Optional[List[FrameLocation]] = None
-
-
-# Base class for a special kind of union.
-class SpecialUnion:
- def __new__(cls, **kwargs):
- return cls.build(kwargs)
-
- @classmethod
- def __get_validators__(cls):
- yield cls.build
-
- @classmethod
- def get_union_types(cls):
- if not issubclass(cls, SpecialUnion):
- raise TypeError("{} must be a subclass of SpecialUnion")
-
- union_types = [x for x in cls.__orig_bases__ if hasattr(x, "__args__")]
- if len(union_types) < 1:
- raise TypeError(
- "Class {cls} should inherit from a union of objects to build"
- )
- if len(union_types) > 1:
- raise TypeError(
- f"Class {cls} should inherit from exactly one union of objects to build. Found {union_types}"
- )
- return union_types[0].__args__[0].__args__
-
- @classmethod
- def build(cls: Any, data: Union[dict, BaseModel]) -> "NDBase":
- """
- Checks through all objects in the union to see which matches the input data.
- Args:
- data (Union[dict, BaseModel]) : The data for constructing one of the objects in the union
- raises:
- KeyError: data does not contain the determinant fields for any of the types supported by this SpecialUnion
- ValidationError: Error while trying to construct a specific object in the union
-
- """
- if isinstance(data, BaseModel):
- data = data.model_dump()
-
- top_level_fields = []
- max_match = 0
- matched = None
-
- for type_ in cls.get_union_types():
- determinate_fields = _determinants(type_)
- top_level_fields.append(determinate_fields)
- matches = sum([val in determinate_fields for val in data])
- if matches == len(determinate_fields) and matches > max_match:
- max_match = matches
- matched = type_
-
- if matched is not None:
- # These two have the exact same top level keys
- if matched in [NDRadio, NDText]:
- if isinstance(data["answer"], dict):
- matched = NDRadio
- elif isinstance(data["answer"], str):
- matched = NDText
- else:
- raise TypeError(
- f"Unexpected type for answer field. Found {data['answer']}. Expected a string or a dict"
- )
- return matched(**data)
- else:
- raise KeyError(
- f"Invalid annotation. Must have one of the following keys : {top_level_fields}. Found {data}."
- )
-
- @classmethod
- def schema(cls):
- results = {"definitions": {}}
- for cl in cls.get_union_types():
- schema = cl.schema()
- results["definitions"].update(schema.pop("definitions"))
- results[cl.__name__] = schema
- return results
-
-
-class DataRow(BaseModel):
- id: str
-
-
-class NDFeatureSchema(BaseModel):
- schemaId: Optional[str] = None
- name: Optional[str] = None
-
- @model_validator(mode="after")
- def most_set_one(self):
- if self.schemaId is None and self.name is None:
- raise ValueError(
- "Must set either schemaId or name for all feature schemas"
- )
- return self
-
-
-class NDBase(NDFeatureSchema):
- ontology_type: str
- uuid: UUID
- dataRow: DataRow
- model_config = ConfigDict(extra="forbid")
-
- def validate_feature_schemas(
- self, valid_feature_schemas_by_id, valid_feature_schemas_by_name
- ):
- if self.name:
- if self.name not in valid_feature_schemas_by_name:
- raise ValueError(
- f"Name {self.name} is not valid for the provided project's ontology."
- )
-
- if (
- self.ontology_type
- != valid_feature_schemas_by_name[self.name]["tool"]
- ):
- raise ValueError(
- f"Name {self.name} does not map to the assigned tool {valid_feature_schemas_by_name[self.name]['tool']}"
- )
-
- if self.schemaId:
- if self.schemaId not in valid_feature_schemas_by_id:
- raise ValueError(
- f"Schema id {self.schemaId} is not valid for the provided project's ontology."
- )
-
- if (
- self.ontology_type
- != valid_feature_schemas_by_id[self.schemaId]["tool"]
- ):
- raise ValueError(
- f"Schema id {self.schemaId} does not map to the assigned tool {valid_feature_schemas_by_id[self.schemaId]['tool']}"
- )
-
- def validate_instance(
- self, valid_feature_schemas_by_id, valid_feature_schemas_by_name
- ):
- self.validate_feature_schemas(
- valid_feature_schemas_by_id, valid_feature_schemas_by_name
- )
-
-
-###### Classifications ######
-
-
-class NDText(NDBase):
- ontology_type: Literal["text"] = "text"
- answer: str = Field(json_schema_extra={"determinant": True})
- # No feature schema to check
-
-
-class NDChecklist(VideoSupported, NDBase):
- ontology_type: Literal["checklist"] = "checklist"
- answers: List[NDFeatureSchema] = Field(
- json_schema_extra={"determinant": True}
- )
-
- @field_validator("answers", mode="before")
- def validate_answers(cls, value, field):
- # constr not working with mypy.
- if not len(value):
- raise ValueError("Checklist answers should not be empty")
- return value
-
- def validate_feature_schemas(
- self, valid_feature_schemas_by_id, valid_feature_schemas_by_name
- ):
- # Test top level feature schema for this tool
- super(NDChecklist, self).validate_feature_schemas(
- valid_feature_schemas_by_id, valid_feature_schemas_by_name
- )
- # Test the feature schemas provided to the answer field
- if len(
- set([answer.name or answer.schemaId for answer in self.answers])
- ) != len(self.answers):
- raise ValueError(
- f"Duplicated featureSchema found for checklist {self.uuid}"
- )
- for answer in self.answers:
- options = (
- valid_feature_schemas_by_name[self.name]["options"]
- if self.name
- else valid_feature_schemas_by_id[self.schemaId]["options"]
- )
- if answer.name not in options and answer.schemaId not in options:
- raise ValueError(
- f"Feature schema provided to {self.ontology_type} invalid. Expected on of {options}. Found {answer}"
- )
-
-
-class NDRadio(VideoSupported, NDBase):
- ontology_type: Literal["radio"] = "radio"
- answer: NDFeatureSchema = Field(json_schema_extra={"determinant": True})
-
- def validate_feature_schemas(
- self, valid_feature_schemas_by_id, valid_feature_schemas_by_name
- ):
- super(NDRadio, self).validate_feature_schemas(
- valid_feature_schemas_by_id, valid_feature_schemas_by_name
- )
- options = (
- valid_feature_schemas_by_name[self.name]["options"]
- if self.name
- else valid_feature_schemas_by_id[self.schemaId]["options"]
- )
- if (
- self.answer.name not in options
- and self.answer.schemaId not in options
- ):
- raise ValueError(
- f"Feature schema provided to {self.ontology_type} invalid. Expected on of {options}. Found {self.answer.name or self.answer.schemaId}"
- )
-
-
-# A union with custom construction logic to improve error messages
-class NDClassification(
- SpecialUnion,
- Type[Union[NDText, NDRadio, NDChecklist]], # type: ignore
-): ...
-
-
-###### Tools ######
-
-
-class NDBaseTool(NDBase):
- classifications: List[NDClassification] = []
-
- # This is indepdent of our problem
- def validate_feature_schemas(
- self, valid_feature_schemas_by_id, valid_feature_schemas_by_name
- ):
- super(NDBaseTool, self).validate_feature_schemas(
- valid_feature_schemas_by_id, valid_feature_schemas_by_name
- )
- for classification in self.classifications:
- classification.validate_feature_schemas(
- valid_feature_schemas_by_name[self.name][
- "classificationsBySchemaId"
- ]
- if self.name
- else valid_feature_schemas_by_id[self.schemaId][
- "classificationsBySchemaId"
- ],
- valid_feature_schemas_by_name[self.name][
- "classificationsByName"
- ]
- if self.name
- else valid_feature_schemas_by_id[self.schemaId][
- "classificationsByName"
- ],
- )
-
- @field_validator("classifications", mode="before")
- def validate_subclasses(cls, value, field):
- # Create uuid and datarow id so we don't have to define classification objects twice
- # This is caused by the fact that we require these ids for top level classifications but not for subclasses
- results = []
- dummy_id = "child".center(25, "_")
- for row in value:
- results.append(
- {**row, "dataRow": {"id": dummy_id}, "uuid": str(uuid4())}
- )
- return results
-
-
-class NDPolygon(NDBaseTool):
- ontology_type: Literal["polygon"] = "polygon"
- polygon: List[Point] = Field(json_schema_extra={"determinant": True})
-
- @field_validator("polygon")
- def is_geom_valid(cls, v):
- if len(v) < 3:
- raise ValueError(
- f"A polygon must have at least 3 points to be valid. Found {v}"
- )
- return v
-
-
-class NDPolyline(NDBaseTool):
- ontology_type: Literal["line"] = "line"
- line: List[Point] = Field(json_schema_extra={"determinant": True})
-
- @field_validator("line")
- def is_geom_valid(cls, v):
- if len(v) < 2:
- raise ValueError(
- f"A line must have at least 2 points to be valid. Found {v}"
- )
- return v
-
-
-class NDRectangle(NDBaseTool):
- ontology_type: Literal["rectangle"] = "rectangle"
- bbox: Bbox = Field(json_schema_extra={"determinant": True})
- # Could check if points are positive
-
-
-class NDPoint(NDBaseTool):
- ontology_type: Literal["point"] = "point"
- point: Point = Field(json_schema_extra={"determinant": True})
- # Could check if points are positive
-
-
-class EntityLocation(BaseModel):
- start: int
- end: int
-
-
-class NDTextEntity(NDBaseTool):
- ontology_type: Literal["named-entity"] = "named-entity"
- location: EntityLocation = Field(json_schema_extra={"determinant": True})
-
- @field_validator("location")
- def is_valid_location(cls, v):
- if isinstance(v, BaseModel):
- v = v.model_dump()
-
- if len(v) < 2:
- raise ValueError(
- f"A line must have at least 2 points to be valid. Found {v}"
- )
- if v["start"] < 0:
- raise ValueError(f"Text location must be positive. Found {v}")
- if v["start"] > v["end"]:
- raise ValueError(
- f"Text start location must be less or equal than end. Found {v}"
- )
- return v
-
-
-class RLEMaskFeatures(BaseModel):
- counts: List[int]
- size: List[int]
-
- @field_validator("counts")
- def validate_counts(cls, counts):
- if not all([count >= 0 for count in counts]):
- raise ValueError(
- "Found negative value for counts. They should all be zero or positive"
- )
- return counts
-
- @field_validator("size")
- def validate_size(cls, size):
- if len(size) != 2:
- raise ValueError(
- f"Mask `size` should have two ints representing height and with. Found : {size}"
- )
- if not all([count > 0 for count in size]):
- raise ValueError(
- f"Mask `size` should be a postitive int. Found : {size}"
- )
- return size
-
-
-class PNGMaskFeatures(BaseModel):
- # base64 encoded png bytes
- png: str
-
-
-class URIMaskFeatures(BaseModel):
- instanceURI: str
- colorRGB: Union[List[int], Tuple[int, int, int]]
-
- @field_validator("colorRGB")
- def validate_color(cls, colorRGB):
- # Does the dtype matter? Can it be a float?
- if not isinstance(colorRGB, (tuple, list)):
- raise ValueError(
- f"Received color that is not a list or tuple. Found : {colorRGB}"
- )
- elif len(colorRGB) != 3:
- raise ValueError(
- f"Must provide RGB values for segmentation colors. Found : {colorRGB}"
- )
- elif not all([0 <= color <= 255 for color in colorRGB]):
- raise ValueError(
- f"All rgb colors must be between 0 and 255. Found : {colorRGB}"
- )
- return colorRGB
-
-
-class NDMask(NDBaseTool):
- ontology_type: Literal["superpixel"] = "superpixel"
- mask: Union[URIMaskFeatures, PNGMaskFeatures, RLEMaskFeatures] = Field(
- json_schema_extra={"determinant": True}
- )
-
-
-# A union with custom construction logic to improve error messages
-class NDTool(
- SpecialUnion,
- Type[ # type: ignore
- Union[
- NDMask,
- NDTextEntity,
- NDPoint,
- NDRectangle,
- NDPolyline,
- NDPolygon,
- ]
- ],
-): ...
-
-
-class NDAnnotation(
- SpecialUnion,
- Type[Union[NDTool, NDClassification]], # type: ignore
-):
- @classmethod
- def build(cls: Any, data) -> "NDBase":
- if not isinstance(data, dict):
- raise ValueError("value must be dict")
- errors = []
- for cl in cls.get_union_types():
- try:
- return cl(**data)
- except KeyError as e:
- errors.append(f"{cl.__name__}: {e}")
-
- raise ValueError(
- "Unable to construct any annotation.\n{}".format("\n".join(errors))
- )
-
- @classmethod
- def schema(cls):
- data = {"definitions": {}}
- for type_ in cls.get_union_types():
- schema_ = type_.schema()
- data["definitions"].update(schema_.pop("definitions"))
- data[type_.__name__] = schema_
- return data
diff --git a/libs/labelbox/src/labelbox/schema/catalog.py b/libs/labelbox/src/labelbox/schema/catalog.py
index 567bbd777..8d9646779 100644
--- a/libs/labelbox/src/labelbox/schema/catalog.py
+++ b/libs/labelbox/src/labelbox/schema/catalog.py
@@ -1,4 +1,5 @@
from typing import Any, Dict, List, Optional, Tuple, Union
+import warnings
from labelbox.orm.db_object import experimental
from labelbox.schema.export_filters import CatalogExportFilters, build_filters
@@ -45,6 +46,13 @@ def export_v2(
>>> task.wait_till_done()
>>> task.result
"""
+
+ warnings.warn(
+ "You are currently utilizing export_v2 for this action, which will be removed in 7.0. Please refer to our docs for export alternatives. https://docs.labelbox.com/reference/export-overview#export-methods",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+
task, is_streamable = self._export(task_name, filters, params)
if is_streamable:
return ExportTask(task, True)
diff --git a/libs/labelbox/src/labelbox/schema/data_row.py b/libs/labelbox/src/labelbox/schema/data_row.py
index 8987a00f0..cb0e99b22 100644
--- a/libs/labelbox/src/labelbox/schema/data_row.py
+++ b/libs/labelbox/src/labelbox/schema/data_row.py
@@ -2,6 +2,7 @@
from enum import Enum
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, Any
import json
+import warnings
from labelbox.orm import query
from labelbox.orm.db_object import (
@@ -277,6 +278,13 @@ def export_v2(
>>> task.wait_till_done()
>>> task.result
"""
+
+ warnings.warn(
+ "You are currently utilizing export_v2 for this action, which will be removed in 7.0. Please refer to our docs for export alternatives. https://docs.labelbox.com/reference/export-overview#export-methods",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+
task, is_streamable = DataRow._export(
client, data_rows, global_keys, task_name, params
)
diff --git a/libs/labelbox/src/labelbox/schema/data_row_metadata.py b/libs/labelbox/src/labelbox/schema/data_row_metadata.py
index 288459a89..cb45ef57f 100644
--- a/libs/labelbox/src/labelbox/schema/data_row_metadata.py
+++ b/libs/labelbox/src/labelbox/schema/data_row_metadata.py
@@ -1,34 +1,35 @@
# type: ignore
-from datetime import datetime
+import warnings
from copy import deepcopy
+from datetime import datetime
from enum import Enum
from itertools import chain
-import warnings
-
from typing import (
+ Annotated,
+ Any,
+ Callable,
+ Dict,
+ Generator,
List,
Optional,
- Dict,
- Union,
- Callable,
Type,
- Any,
- Generator,
+ Union,
overload,
)
-from typing_extensions import Annotated
-from labelbox.schema.identifiables import DataRowIdentifiers, UniqueIds
-from labelbox.schema.identifiable import UniqueId, GlobalKey
from pydantic import (
BaseModel,
+ BeforeValidator,
+ ConfigDict,
Field,
StringConstraints,
conlist,
- ConfigDict,
model_serializer,
)
+from typing_extensions import Annotated
+from labelbox.schema.identifiable import GlobalKey, UniqueId
+from labelbox.schema.identifiables import DataRowIdentifiers, UniqueIds
from labelbox.schema.ontology import SchemaId
from labelbox.utils import (
_CamelCaseMixin,
@@ -36,6 +37,12 @@
format_iso_from_string,
)
+Name = Annotated[
+ str,
+ BeforeValidator(lambda x: str.strip(str(x))),
+ Field(min_length=1, max_length=100),
+]
+
class DataRowMetadataKind(Enum):
number = "CustomMetadataNumber"
@@ -49,7 +56,7 @@ class DataRowMetadataKind(Enum):
# Metadata schema
class DataRowMetadataSchema(BaseModel):
uid: SchemaId
- name: str = Field(strip_whitespace=True, min_length=1, max_length=100)
+ name: Name
reserved: bool
kind: DataRowMetadataKind
options: Optional[List["DataRowMetadataSchema"]] = None
@@ -417,7 +424,7 @@ def update_enum_option(
schema = self._validate_custom_schema_by_name(name)
if schema.kind != DataRowMetadataKind.enum:
raise ValueError(
- f"Updating Enum option is only supported for Enum metadata schema"
+ "Updating Enum option is only supported for Enum metadata schema"
)
valid_options: List[str] = [o.name for o in schema.options]
@@ -666,10 +673,8 @@ def bulk_delete(
if not len(deletes):
raise ValueError("The 'deletes' list cannot be empty.")
- passed_strings = False
for i, delete in enumerate(deletes):
if isinstance(delete.data_row_id, str):
- passed_strings = True
deletes[i] = DeleteDataRowMetadata(
data_row_id=UniqueId(delete.data_row_id),
fields=delete.fields,
@@ -683,12 +688,6 @@ def bulk_delete(
f"Invalid data row identifier type '{type(delete.data_row_id)}' for '{delete.data_row_id}'"
)
- if passed_strings:
- warnings.warn(
- "Using string for data row id will be deprecated. Please use "
- "UniqueId instead."
- )
-
def _batch_delete(
deletes: List[_DeleteBatchDataRowMetadata],
) -> List[DataRowMetadataBatchResponse]:
@@ -751,10 +750,6 @@ def bulk_export(self, data_row_ids) -> List[DataRowMetadata]:
and isinstance(data_row_ids[0], str)
):
data_row_ids = UniqueIds(data_row_ids)
- warnings.warn(
- "Using data row ids will be deprecated. Please use "
- "UniqueIds or GlobalKeys instead."
- )
def _bulk_export(
_data_row_ids: DataRowIdentifiers,
@@ -803,13 +798,13 @@ def _convert_metadata_field(metadata_field):
if isinstance(metadata_field, DataRowMetadataField):
return metadata_field
elif isinstance(metadata_field, dict):
- if not "value" in metadata_field:
+ if "value" not in metadata_field:
raise ValueError(
f"Custom metadata field '{metadata_field}' must have a 'value' key"
)
if (
- not "schema_id" in metadata_field
- and not "name" in metadata_field
+ "schema_id" not in metadata_field
+ and "name" not in metadata_field
):
raise ValueError(
f"Custom metadata field '{metadata_field}' must have either 'schema_id' or 'name' key"
@@ -954,9 +949,8 @@ def _validate_custom_schema_by_name(
def _batch_items(iterable: List[Any], size: int) -> Generator[Any, None, None]:
- l = len(iterable)
- for ndx in range(0, l, size):
- yield iterable[ndx : min(ndx + size, l)]
+ for ndx in range(0, len(iterable), size):
+ yield iterable[ndx : min(ndx + size, len(iterable))]
def _batch_operations(
diff --git a/libs/labelbox/src/labelbox/schema/dataset.py b/libs/labelbox/src/labelbox/schema/dataset.py
index 16c993dfa..107f3f50b 100644
--- a/libs/labelbox/src/labelbox/schema/dataset.py
+++ b/libs/labelbox/src/labelbox/schema/dataset.py
@@ -7,13 +7,14 @@
from string import Template
from typing import Any, Dict, List, Optional, Tuple, Union
-import labelbox.schema.internal.data_row_uploader as data_row_uploader
-from labelbox.exceptions import (
+from lbox.exceptions import (
InvalidQueryError,
LabelboxError,
ResourceCreationError,
ResourceNotFoundError,
-)
+) # type: ignore
+
+import labelbox.schema.internal.data_row_uploader as data_row_uploader
from labelbox.orm import query
from labelbox.orm.comparison import Comparison
from labelbox.orm.db_object import DbObject, Deletable, Updateable
@@ -165,49 +166,9 @@ def create_data_row(self, items=None, **kwargs) -> "DataRow":
return self.client.get_data_row(res[0]["id"])
- def create_data_rows_sync(
- self, items, file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT
- ) -> None:
- """Synchronously bulk upload data rows.
-
- Use this instead of `Dataset.create_data_rows` for smaller batches of data rows that need to be uploaded quickly.
- Cannot use this for uploads containing more than 1000 data rows.
- Each data row is also limited to 5 attachments.
-
- Args:
- items (iterable of (dict or str)):
- See the docstring for `Dataset._create_descriptor_file` for more information.
- Returns:
- None. If the function doesn't raise an exception then the import was successful.
-
- Raises:
- ResourceCreationError: If the `items` parameter does not conform to
- the specification in Dataset._create_descriptor_file or if the server did not accept the
- DataRow creation request (unknown reason).
- InvalidAttributeError: If there are fields in `items` not valid for
- a DataRow.
- ValueError: When the upload parameters are invalid
- """
- warnings.warn(
- "This method is deprecated and will be "
- "removed in a future release. Please use create_data_rows instead."
- )
-
- self._create_data_rows_sync(
- items, file_upload_thread_count=file_upload_thread_count
- )
-
- return None # Return None if no exception is raised
-
def _create_data_rows_sync(
self, items, file_upload_thread_count=FILE_UPLOAD_THREAD_COUNT
) -> "DataUpsertTask":
- max_data_rows_supported = 1000
- if len(items) > max_data_rows_supported:
- raise ValueError(
- f"Dataset.create_data_rows_sync() supports a max of {max_data_rows_supported} data rows."
- " For larger imports use the async function Dataset.create_data_rows()"
- )
if file_upload_thread_count < 1:
raise ValueError(
"file_upload_thread_count must be a positive integer"
@@ -234,8 +195,6 @@ def create_data_rows(
) -> "DataUpsertTask":
"""Asynchronously bulk upload data rows
- Use this instead of `Dataset.create_data_rows_sync` uploads for batches that contain more than 1000 data rows.
-
Args:
items (iterable of (dict or str))
@@ -310,7 +269,7 @@ def data_rows_for_external_id(
A list of `DataRow` with the given ID.
Raises:
- labelbox.exceptions.ResourceNotFoundError: If there is no `DataRow`
+ lbox.exceptions.ResourceNotFoundError: If there is no `DataRow`
in this `DataSet` with the given external ID, or if there are
multiple `DataRows` for it.
"""
@@ -336,7 +295,7 @@ def data_row_for_external_id(self, external_id) -> "DataRow":
A single `DataRow` with the given ID.
Raises:
- labelbox.exceptions.ResourceNotFoundError: If there is no `DataRow`
+ lbox.exceptions.ResourceNotFoundError: If there is no `DataRow`
in this `DataSet` with the given external ID, or if there are
multiple `DataRows` for it.
"""
@@ -399,6 +358,13 @@ def export_v2(
>>> task.wait_till_done()
>>> task.result
"""
+
+ warnings.warn(
+ "You are currently utilizing export_v2 for this action, which will be removed in 7.0. Please refer to our docs for export alternatives. https://docs.labelbox.com/reference/export-overview#export-methods",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+
task, is_streamable = self._export(task_name, filters, params)
if is_streamable:
return ExportTask(task, True)
diff --git a/libs/labelbox/src/labelbox/schema/enums.py b/libs/labelbox/src/labelbox/schema/enums.py
index 6f8aebc58..dfc87c8a4 100644
--- a/libs/labelbox/src/labelbox/schema/enums.py
+++ b/libs/labelbox/src/labelbox/schema/enums.py
@@ -1,31 +1,6 @@
from enum import Enum
-class BulkImportRequestState(Enum):
- """State of the import job when importing annotations (RUNNING, FAILED, or FINISHED).
-
- If you are not usinig MEA continue using BulkImportRequest.
- AnnotationImports are in beta and will change soon.
-
- .. list-table::
- :widths: 15 150
- :header-rows: 1
-
- * - State
- - Description
- * - RUNNING
- - Indicates that the import job is not done yet.
- * - FAILED
- - Indicates the import job failed. Check `BulkImportRequest.errors` for more information
- * - FINISHED
- - Indicates the import job is no longer running. Check `BulkImportRequest.statuses` for more information
- """
-
- RUNNING = "RUNNING"
- FAILED = "FAILED"
- FINISHED = "FINISHED"
-
-
class AnnotationImportState(Enum):
"""State of the import job when importing annotations (RUNNING, FAILED, or FINISHED).
diff --git a/libs/labelbox/src/labelbox/schema/export_filters.py b/libs/labelbox/src/labelbox/schema/export_filters.py
index 641adc011..b0d0284c4 100644
--- a/libs/labelbox/src/labelbox/schema/export_filters.py
+++ b/libs/labelbox/src/labelbox/schema/export_filters.py
@@ -1,13 +1,5 @@
-import sys
-
from datetime import datetime, timezone
-from typing import Collection, Dict, Tuple, List, Optional
-from labelbox.typing_imports import Literal
-
-if sys.version_info >= (3, 8):
- from typing import TypedDict
-else:
- from typing_extensions import TypedDict
+from typing import Collection, Dict, List, Literal, Optional, Tuple, TypedDict
SEARCH_LIMIT_PER_EXPORT_V2 = 2_000
ISO_8061_FORMAT = "%Y-%m-%dT%H:%M:%S%z"
diff --git a/libs/labelbox/src/labelbox/schema/export_params.py b/libs/labelbox/src/labelbox/schema/export_params.py
index b15bc2828..f921cdd31 100644
--- a/libs/labelbox/src/labelbox/schema/export_params.py
+++ b/libs/labelbox/src/labelbox/schema/export_params.py
@@ -1,15 +1,8 @@
-import sys
-
-from typing import Optional, List
-
-EXPORT_LIMIT = 30
+from typing import List, Optional, TypedDict
from labelbox.schema.media_type import MediaType
-if sys.version_info >= (3, 8):
- from typing import TypedDict
-else:
- from typing_extensions import TypedDict
+EXPORT_LIMIT = 30
class DataRowParams(TypedDict):
diff --git a/libs/labelbox/src/labelbox/schema/export_task.py b/libs/labelbox/src/labelbox/schema/export_task.py
index 76fd8a739..7e78fc3e9 100644
--- a/libs/labelbox/src/labelbox/schema/export_task.py
+++ b/libs/labelbox/src/labelbox/schema/export_task.py
@@ -6,20 +6,16 @@
from dataclasses import dataclass
from enum import Enum
from functools import lru_cache
-from io import TextIOWrapper
-from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generic,
Iterator,
- List,
Optional,
Tuple,
TypeVar,
Union,
- overload,
)
import requests
@@ -100,122 +96,6 @@ def convert(self, input_args: ConverterInputArgs) -> Iterator[OutputT]:
"""
-@dataclass
-class JsonConverterOutput:
- """Output with the JSON string."""
-
- current_offset: int
- current_line: int
- json_str: str
-
-
-class JsonConverter(Converter[JsonConverterOutput]): # pylint: disable=too-few-public-methods
- """Converts JSON data.
-
- Deprecated: This converter is deprecated and will be removed in a future release.
- """
-
- def __init__(self) -> None:
- warnings.warn(
- "JSON converter is deprecated and will be removed in a future release"
- )
- super().__init__()
-
- def _find_json_object_offsets(self, data: str) -> List[Tuple[int, int]]:
- object_offsets: List[Tuple[int, int]] = []
- stack = []
- current_object_start = None
-
- for index, char in enumerate(data):
- if char == "{":
- stack.append(char)
- if len(stack) == 1:
- current_object_start = index
- # we need to account for scenarios where data lands in the middle of an object
- # and the object is not the last one in the data
- if (
- index > 0
- and data[index - 1] == "\n"
- and not object_offsets
- ):
- object_offsets.append((0, index - 1))
- elif char == "}" and stack:
- stack.pop()
- # this covers cases where the last object is either followed by a newline or
- # it is missing
- if (
- len(stack) == 0
- and (len(data) == index + 1 or data[index + 1] == "\n")
- and current_object_start is not None
- ):
- object_offsets.append((current_object_start, index + 1))
- current_object_start = None
-
- # we also need to account for scenarios where data lands in the middle of the last object
- return object_offsets if object_offsets else [(0, len(data) - 1)]
-
- def convert(
- self, input_args: Converter.ConverterInputArgs
- ) -> Iterator[JsonConverterOutput]:
- current_offset, current_line, raw_data = (
- input_args.file_info.offsets.start,
- input_args.file_info.lines.start,
- input_args.raw_data,
- )
- offsets = self._find_json_object_offsets(raw_data)
- for line, (offset_start, offset_end) in enumerate(offsets):
- yield JsonConverterOutput(
- current_offset=current_offset + offset_start,
- current_line=current_line + line,
- json_str=raw_data[offset_start : offset_end + 1].strip(),
- )
-
-
-@dataclass
-class FileConverterOutput:
- """Output with statistics about the written file."""
-
- file_path: Path
- total_size: int
- total_lines: int
- current_offset: int
- current_line: int
- bytes_written: int
-
-
-class FileConverter(Converter[FileConverterOutput]):
- """Converts data to a file."""
-
- def __init__(self, file_path: str) -> None:
- super().__init__()
- self._file: Optional[TextIOWrapper] = None
- self._file_path = file_path
-
- def __enter__(self):
- self._file = open(self._file_path, "w", encoding="utf-8")
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- if self._file:
- self._file.close()
- return False
-
- def convert(
- self, input_args: Converter.ConverterInputArgs
- ) -> Iterator[FileConverterOutput]:
- # appends data to the file
- assert self._file is not None
- self._file.write(input_args.raw_data)
- yield FileConverterOutput(
- file_path=Path(self._file_path),
- total_size=input_args.ctx.metadata_header.total_size,
- total_lines=input_args.ctx.metadata_header.total_lines,
- current_offset=input_args.file_info.offsets.start,
- current_line=input_args.file_info.lines.start,
- bytes_written=len(input_args.raw_data),
- )
-
-
class FileRetrieverStrategy(ABC): # pylint: disable=too-few-public-methods
"""Abstract class for retrieving files."""
@@ -252,147 +132,6 @@ def _get_file_content(
return file_info, response.text
-class FileRetrieverByOffset(FileRetrieverStrategy): # pylint: disable=too-few-public-methods
- """Retrieves files by offset."""
-
- def __init__(
- self,
- ctx: _TaskContext,
- offset: int,
- ) -> None:
- super().__init__(ctx)
- self._current_offset = offset
- self._current_line: Optional[int] = None
- if self._current_offset >= self._ctx.metadata_header.total_size:
- raise ValueError(
- f"offset is out of range, max offset is {self._ctx.metadata_header.total_size - 1}"
- )
-
- def _find_line_at_offset(
- self, file_content: str, target_offset: int
- ) -> int:
- # TODO: Remove this, incorrect parsing of JSON to find braces
- stack = []
- line_number = 0
-
- for index, char in enumerate(file_content):
- if char == "{":
- stack.append(char)
- if len(stack) == 1 and index > 0:
- line_number += 1
- elif char == "}" and stack:
- stack.pop()
-
- if index == target_offset:
- break
-
- return line_number
-
- def get_next_chunk(self) -> Optional[Tuple[_MetadataFileInfo, str]]:
- if self._current_offset >= self._ctx.metadata_header.total_size:
- return None
- query = (
- f"query GetExportFileFromOffsetPyApi"
- f"($where: WhereUniqueIdInput, $streamType: TaskStreamType!, $offset: UInt64!)"
- f"{{task(where: $where)"
- f"{{{'exportFileFromOffset'}(streamType: $streamType, offset: $offset)"
- f"{{offsets {{start end}} lines {{start end}} file}}"
- f"}}}}"
- )
- variables = {
- "where": {"id": self._ctx.task_id},
- "streamType": self._ctx.stream_type.value,
- "offset": str(self._current_offset),
- }
- file_info, file_content = self._get_file_content(
- query, variables, "exportFileFromOffset"
- )
- if self._current_line is None:
- self._current_line = self._find_line_at_offset(
- file_content, self._current_offset - file_info.offsets.start
- )
- self._current_line += file_info.lines.start
- file_content = file_content[
- self._current_offset - file_info.offsets.start :
- ]
- file_info.offsets.start = self._current_offset
- file_info.lines.start = self._current_line
- self._current_offset = file_info.offsets.end + 1
- self._current_line = file_info.lines.end + 1
- return file_info, file_content
-
-
-class FileRetrieverByLine(FileRetrieverStrategy): # pylint: disable=too-few-public-methods
- """Retrieves files by line."""
-
- def __init__(
- self,
- ctx: _TaskContext,
- line: int,
- ) -> None:
- super().__init__(ctx)
- self._current_line = line
- self._current_offset: Optional[int] = None
- if self._current_line >= self._ctx.metadata_header.total_lines:
- raise ValueError(
- f"line is out of range, max line is {self._ctx.metadata_header.total_lines - 1}"
- )
-
- def _find_offset_of_line(self, file_content: str, target_line: int):
- # TODO: Remove this, incorrect parsing of JSON to find braces
- start_offset = None
- stack = []
- line_number = 0
-
- for index, char in enumerate(file_content):
- if char == "{":
- stack.append(char)
- if len(stack) == 1:
- if line_number == target_line:
- start_offset = index
- line_number += 1
- elif char == "}" and stack:
- stack.pop()
-
- if line_number > target_line:
- break
-
- return start_offset
-
- def get_next_chunk(self) -> Optional[Tuple[_MetadataFileInfo, str]]:
- if self._current_line >= self._ctx.metadata_header.total_lines:
- return None
- query = (
- f"query GetExportFileFromLinePyApi"
- f"($where: WhereUniqueIdInput, $streamType: TaskStreamType!, $line: UInt64!)"
- f"{{task(where: $where)"
- f"{{{'exportFileFromLine'}(streamType: $streamType, line: $line)"
- f"{{offsets {{start end}} lines {{start end}} file}}"
- f"}}}}"
- )
- variables = {
- "where": {"id": self._ctx.task_id},
- "streamType": self._ctx.stream_type.value,
- "line": self._current_line,
- }
- file_info, file_content = self._get_file_content(
- query, variables, "exportFileFromLine"
- )
- if self._current_offset is None:
- self._current_offset = self._find_offset_of_line(
- file_content, self._current_line - file_info.lines.start
- )
- self._current_offset += file_info.offsets.start
- file_content = file_content[
- self._current_offset - file_info.offsets.start :
- ]
- file_info.offsets.start = self._current_offset
- file_info.lines.start = self._current_line
- self._current_offset = file_info.offsets.end + 1
- self._current_line = file_info.lines.end + 1
- return file_info, file_content
-
-
class _Reader(ABC): # pylint: disable=too-few-public-methods
"""Abstract class for reading data from a source."""
@@ -405,94 +144,6 @@ def read(self) -> Iterator[Tuple[_MetadataFileInfo, str]]:
"""Reads data from the source."""
-class _MultiGCSFileReader(_Reader): # pylint: disable=too-few-public-methods
- """Reads data from multiple GCS files in a seamless way.
-
- Deprecated: This reader is deprecated and will be removed in a future release.
- """
-
- def __init__(self):
- warnings.warn(
- "_MultiGCSFileReader is deprecated and will be removed in a future release"
- )
- super().__init__()
- self._retrieval_strategy = None
-
- def set_retrieval_strategy(self, strategy: FileRetrieverStrategy) -> None:
- """Sets the retrieval strategy."""
- self._retrieval_strategy = strategy
-
- def read(self) -> Iterator[Tuple[_MetadataFileInfo, str]]:
- if not self._retrieval_strategy:
- raise ValueError("retrieval strategy not set")
- result = self._retrieval_strategy.get_next_chunk()
- while result:
- file_info, raw_data = result
- yield file_info, raw_data
- result = self._retrieval_strategy.get_next_chunk()
-
-
-class Stream(Generic[OutputT]):
- """Streams data from a Reader."""
-
- def __init__(
- self,
- ctx: _TaskContext,
- reader: _Reader,
- converter: Converter,
- ):
- self._ctx = ctx
- self._reader = reader
- self._converter = converter
- # default strategy is to retrieve files by offset, starting from 0
- self.with_offset(0)
-
- def __iter__(self):
- yield from self._fetch()
-
- def _fetch(
- self,
- ) -> Iterator[OutputT]:
- """Fetches the result data.
- Returns an iterator that yields the offset and the data.
- """
- if self._ctx.metadata_header.total_size is None:
- return
-
- stream = self._reader.read()
- with self._converter as converter:
- for file_info, raw_data in stream:
- for output in converter.convert(
- Converter.ConverterInputArgs(self._ctx, file_info, raw_data)
- ):
- yield output
-
- def with_offset(self, offset: int) -> "Stream[OutputT]":
- """Sets the offset for the stream."""
- self._reader.set_retrieval_strategy(
- FileRetrieverByOffset(self._ctx, offset)
- )
- return self
-
- def with_line(self, line: int) -> "Stream[OutputT]":
- """Sets the line number for the stream."""
- self._reader.set_retrieval_strategy(
- FileRetrieverByLine(self._ctx, line)
- )
- return self
-
- def start(
- self, stream_handler: Optional[Callable[[OutputT], None]] = None
- ) -> None:
- """Starts streaming the result data.
- Calls the stream_handler for each result.
- """
- # this calls the __iter__ method, which in turn calls the _fetch method
- for output in self:
- if stream_handler:
- stream_handler(output)
-
-
class _BufferedFileRetrieverByOffset(FileRetrieverStrategy): # pylint: disable=too-few-public-methods
"""Retrieves files by offset."""
@@ -926,53 +577,6 @@ def get_buffered_stream(
),
)
- @overload
- def get_stream(
- self,
- converter: JsonConverter,
- stream_type: StreamType = StreamType.RESULT,
- ) -> Stream[JsonConverterOutput]:
- """Overload for getting the right typing hints when using a JsonConverter."""
-
- @overload
- def get_stream(
- self,
- converter: FileConverter,
- stream_type: StreamType = StreamType.RESULT,
- ) -> Stream[FileConverterOutput]:
- """Overload for getting the right typing hints when using a FileConverter."""
-
- def get_stream(
- self,
- converter: Optional[Converter] = None,
- stream_type: StreamType = StreamType.RESULT,
- ) -> Stream:
- warnings.warn(
- "get_stream is deprecated and will be removed in a future release, use get_buffered_stream"
- )
- if converter is None:
- converter = JsonConverter()
- """Returns the result of the task."""
- if self._task.status == "FAILED":
- raise ExportTask.ExportTaskException("Task failed")
- if self._task.status != "COMPLETE":
- raise ExportTask.ExportTaskException("Task is not ready yet")
-
- metadata_header = self._get_metadata_header(
- self._task.client, self._task.uid, stream_type
- )
- if metadata_header is None:
- raise ValueError(
- f"Task {self._task.uid} does not have a {stream_type.value} stream"
- )
- return Stream(
- _TaskContext(
- self._task.client, self._task.uid, stream_type, metadata_header
- ),
- _MultiGCSFileReader(),
- converter,
- )
-
@staticmethod
def get_task(client, task_id):
"""Returns the task with the given id."""
diff --git a/libs/labelbox/src/labelbox/schema/foundry/foundry_client.py b/libs/labelbox/src/labelbox/schema/foundry/foundry_client.py
index 914a363c7..315fc831c 100644
--- a/libs/labelbox/src/labelbox/schema/foundry/foundry_client.py
+++ b/libs/labelbox/src/labelbox/schema/foundry/foundry_client.py
@@ -1,6 +1,8 @@
from typing import Union
-from labelbox import exceptions
-from labelbox.schema.foundry.app import App, APP_FIELD_NAMES
+
+from lbox import exceptions # type: ignore
+
+from labelbox.schema.foundry.app import APP_FIELD_NAMES, App
from labelbox.schema.identifiables import DataRowIds, GlobalKeys, IdType
from labelbox.schema.task import Task
@@ -51,7 +53,7 @@ def _get_app(self, id: str) -> App:
try:
response = self.client.execute(query_str, params)
- except exceptions.InvalidQueryError as e:
+ except exceptions.InvalidQueryError:
raise exceptions.ResourceNotFoundError(App, params)
except Exception as e:
raise exceptions.LabelboxError(f"Unable to get app with id {id}", e)
diff --git a/libs/labelbox/src/labelbox/schema/internal/descriptor_file_creator.py b/libs/labelbox/src/labelbox/schema/internal/descriptor_file_creator.py
index ce3ce4b35..a9c1c250c 100644
--- a/libs/labelbox/src/labelbox/schema/internal/descriptor_file_creator.py
+++ b/libs/labelbox/src/labelbox/schema/internal/descriptor_file_creator.py
@@ -1,25 +1,21 @@
import json
import os
-import sys
from concurrent.futures import ThreadPoolExecutor, as_completed
+from typing import TYPE_CHECKING, Generator, Iterable, List
-from typing import Iterable, List, Generator
+from lbox.exceptions import (
+ InvalidAttributeError,
+ InvalidQueryError,
+) # type: ignore
-from labelbox.exceptions import InvalidQueryError
-from labelbox.exceptions import InvalidAttributeError
-from labelbox.exceptions import MalformedQueryException
-from labelbox.orm.model import Entity
-from labelbox.orm.model import Field
+from labelbox.orm.model import Entity, Field
from labelbox.schema.embedding import EmbeddingVector
-from labelbox.schema.internal.datarow_upload_constants import (
- FILE_UPLOAD_THREAD_COUNT,
-)
from labelbox.schema.internal.data_row_upsert_item import (
DataRowItemBase,
- DataRowUpsertItem,
)
-
-from typing import TYPE_CHECKING
+from labelbox.schema.internal.datarow_upload_constants import (
+ FILE_UPLOAD_THREAD_COUNT,
+)
if TYPE_CHECKING:
from labelbox import Client
@@ -161,7 +157,7 @@ def check_message_keys(message):
]
)
for key in message.keys():
- if not key in accepted_message_keys:
+ if key not in accepted_message_keys:
raise KeyError(
f"Invalid {key} key found! Accepted keys in messages list is {accepted_message_keys}"
)
diff --git a/libs/labelbox/src/labelbox/schema/labeling_service.py b/libs/labelbox/src/labelbox/schema/labeling_service.py
index 0fbd15bd1..0b7dab6bd 100644
--- a/libs/labelbox/src/labelbox/schema/labeling_service.py
+++ b/libs/labelbox/src/labelbox/schema/labeling_service.py
@@ -1,13 +1,12 @@
+import json
from datetime import datetime
from typing import Any
-from typing_extensions import Annotated
-from pydantic import BaseModel, Field
+from lbox.exceptions import LabelboxError, ResourceNotFoundError
-from labelbox.exceptions import ResourceNotFoundError
-from labelbox.utils import _CamelCaseMixin
from labelbox.schema.labeling_service_dashboard import LabelingServiceDashboard
from labelbox.schema.labeling_service_status import LabelingServiceStatus
+from labelbox.utils import _CamelCaseMixin
from ..annotated_types import Cuid
@@ -107,12 +106,24 @@ def request(self) -> "LabelingService":
query_str,
{"projectId": self.project_id},
raise_return_resource_not_found=True,
+ error_handlers={"MALFORMED_REQUEST": self._raise_readable_errors},
)
success = result["validateAndRequestProjectBoostWorkforce"]["success"]
if not success:
raise Exception("Failed to start labeling service")
return LabelingService.get(self.client, self.project_id)
+ def _raise_readable_errors(self, response):
+ errors = response.json().get("errors", [])
+ if errors:
+ message = errors[0].get(
+ "errors", json.dumps([{"error": "Unknown error"}])
+ )
+ error_messages = [error["error"] for error in message]
+ else:
+ error_messages = ["Uknown error"]
+ raise LabelboxError(". ".join(error_messages))
+
@classmethod
def getOrCreate(cls, client, project_id: Cuid) -> "LabelingService":
"""
diff --git a/libs/labelbox/src/labelbox/schema/labeling_service_dashboard.py b/libs/labelbox/src/labelbox/schema/labeling_service_dashboard.py
index c5e1fa11e..2f91af7af 100644
--- a/libs/labelbox/src/labelbox/schema/labeling_service_dashboard.py
+++ b/libs/labelbox/src/labelbox/schema/labeling_service_dashboard.py
@@ -1,16 +1,17 @@
-from string import Template
from datetime import datetime
+from string import Template
from typing import Any, Dict, List, Optional, Union
-from labelbox.exceptions import ResourceNotFoundError
+from lbox.exceptions import ResourceNotFoundError
+from pydantic import BaseModel, Field, model_validator, model_serializer
+
from labelbox.pagination import PaginatedCollection
-from pydantic import BaseModel, model_validator, Field
+from labelbox.schema.labeling_service_status import LabelingServiceStatus
+from labelbox.schema.media_type import MediaType
from labelbox.schema.search_filters import SearchFilter, build_search_filter
-from labelbox.utils import _CamelCaseMixin
+from labelbox.utils import _CamelCaseMixin, sentence_case
+
from .ontology_kind import EditorTaskType
-from labelbox.schema.media_type import MediaType
-from labelbox.schema.labeling_service_status import LabelingServiceStatus
-from labelbox.utils import sentence_case
GRAPHQL_QUERY_SELECTIONS = """
id
@@ -49,7 +50,7 @@ class LabelingServiceDashboard(_CamelCaseMixin):
Represent labeling service data for a project
NOTE on tasks vs data rows. A task is a unit of work that is assigned to a user. A data row is a unit of data that needs to be labeled.
- In the current implementation a task reprsents a single data row. However tasks only exists when a labeler start labeling a data row.
+ In the current implementation a task represents a single data row. However tasks only exists when a labeler start labeling a data row.
So if a data row is not labeled, it will not have a task associated with it. Therefore the number of tasks can be less than the number of data rows.
Attributes:
@@ -220,8 +221,9 @@ def convert_boost_data(cls, data):
return data
- def dict(self, *args, **kwargs):
- row = super().dict(*args, **kwargs)
+ @model_serializer()
+ def ser_model(self):
+ row = self
row.pop("client")
row["service_type"] = self.service_type
return row
diff --git a/libs/labelbox/src/labelbox/schema/model_run.py b/libs/labelbox/src/labelbox/schema/model_run.py
index bc9971174..dcdbdf0e8 100644
--- a/libs/labelbox/src/labelbox/schema/model_run.py
+++ b/libs/labelbox/src/labelbox/schema/model_run.py
@@ -539,6 +539,12 @@ def export_v2(
>>> export_task = export_v2("my_export_task", params={"media_attributes": True})
"""
+
+ warnings.warn(
+ "You are currently utilizing export_v2 for this action, which will be removed in 7.0. Please refer to our docs for export alternatives. https://docs.labelbox.com/reference/export-overview#export-methods",
+ DeprecationWarning,
+ stacklevel=2,
+ )
task, is_streamable = self._export(task_name, params)
if is_streamable:
return ExportTask(task, True)
diff --git a/libs/labelbox/src/labelbox/schema/ontology.py b/libs/labelbox/src/labelbox/schema/ontology.py
index efe32611b..a3b388ef2 100644
--- a/libs/labelbox/src/labelbox/schema/ontology.py
+++ b/libs/labelbox/src/labelbox/schema/ontology.py
@@ -1,17 +1,17 @@
# type: ignore
import colorsys
+import json
+import warnings
from dataclasses import dataclass, field
from enum import Enum
-from typing import Any, Dict, List, Optional, Union, Type
-from typing_extensions import Annotated
-import warnings
+from typing import Annotated, Any, Dict, List, Optional, Type, Union
+
+from lbox.exceptions import InconsistentOntologyException
+from pydantic import StringConstraints
-from labelbox.exceptions import InconsistentOntologyException
from labelbox.orm.db_object import DbObject
from labelbox.orm.model import Field, Relationship
-import json
-from pydantic import StringConstraints
FeatureSchemaId: Type[str] = Annotated[
str, StringConstraints(min_length=25, max_length=25)
@@ -561,14 +561,18 @@ class OntologyBuilder:
There are no required instantiation arguments.
To create an ontology, use the asdict() method after fully building your
- ontology within this class, and inserting it into project.setup() as the
- "labeling_frontend_options" parameter.
+ ontology within this class, and inserting it into client.create_ontology() as the
+ "normalized" parameter.
Example:
- builder = OntologyBuilder()
- ...
- frontend = list(client.get_labeling_frontends())[0]
- project.setup(frontend, builder.asdict())
+ >>> builder = OntologyBuilder()
+ >>> ...
+ >>> ontology = client.create_ontology(
+ >>> "Ontology from new features",
+ >>> ontology_builder.asdict(),
+ >>> media_type=lb.MediaType.Image,
+ >>> )
+ >>> project.connect_ontology(ontology)
attributes:
tools: (list)
diff --git a/libs/labelbox/src/labelbox/schema/ontology_kind.py b/libs/labelbox/src/labelbox/schema/ontology_kind.py
index 3171b811e..79ef7d7a3 100644
--- a/libs/labelbox/src/labelbox/schema/ontology_kind.py
+++ b/libs/labelbox/src/labelbox/schema/ontology_kind.py
@@ -53,7 +53,7 @@ def evaluate_ontology_kind_with_media_type(
return media_type
-class EditorTaskType(Enum):
+class EditorTaskType(str, Enum):
ModelChatEvaluation = "MODEL_CHAT_EVALUATION"
ResponseCreation = "RESPONSE_CREATION"
OfflineModelChatEvaluation = "OFFLINE_MODEL_CHAT_EVALUATION"
diff --git a/libs/labelbox/src/labelbox/schema/organization.py b/libs/labelbox/src/labelbox/schema/organization.py
index 71e715f11..bd416e997 100644
--- a/libs/labelbox/src/labelbox/schema/organization.py
+++ b/libs/labelbox/src/labelbox/schema/organization.py
@@ -1,21 +1,21 @@
-import json
-from typing import TYPE_CHECKING, List, Optional, Dict
+from typing import TYPE_CHECKING, Dict, List, Optional
+
+from lbox.exceptions import LabelboxError
-from labelbox.exceptions import LabelboxError
from labelbox import utils
-from labelbox.orm.db_object import DbObject, query, Entity
+from labelbox.orm.db_object import DbObject, Entity, query
from labelbox.orm.model import Field, Relationship
from labelbox.schema.invite import InviteLimit
from labelbox.schema.resource_tag import ResourceTag
if TYPE_CHECKING:
from labelbox import (
- Role,
- User,
- ProjectRole,
+ IAMIntegration,
Invite,
InviteLimit,
- IAMIntegration,
+ ProjectRole,
+ Role,
+ User,
)
diff --git a/libs/labelbox/src/labelbox/schema/project.py b/libs/labelbox/src/labelbox/schema/project.py
index 3d5f8ca92..a6f2dfe28 100644
--- a/libs/labelbox/src/labelbox/schema/project.py
+++ b/libs/labelbox/src/labelbox/schema/project.py
@@ -4,29 +4,27 @@
import warnings
from collections import namedtuple
from datetime import datetime, timezone
-from pathlib import Path
from string import Template
from typing import (
TYPE_CHECKING,
Any,
Dict,
- Iterable,
List,
Optional,
Tuple,
Union,
- overload,
+ get_args,
)
-from urllib.parse import urlparse
-from labelbox import utils
-from labelbox.exceptions import (
+from lbox.exceptions import (
InvalidQueryError,
LabelboxError,
ProcessingWaitTimeout,
ResourceNotFoundError,
error_message_for_unparsed_graphql_error,
-)
+) # type: ignore
+
+from labelbox import utils
from labelbox.orm import query
from labelbox.orm.db_object import DbObject, Deletable, Updateable, experimental
from labelbox.orm.model import Entity, Field, Relationship
@@ -42,7 +40,11 @@
from labelbox.schema.export_task import ExportTask
from labelbox.schema.id_type import IdType
from labelbox.schema.identifiable import DataRowIdentifier, GlobalKey, UniqueId
-from labelbox.schema.identifiables import DataRowIdentifiers, UniqueIds
+from labelbox.schema.identifiables import (
+ DataRowIdentifiers,
+ GlobalKeys,
+ UniqueIds,
+)
from labelbox.schema.labeling_service import (
LabelingService,
LabelingServiceStatus,
@@ -59,19 +61,16 @@
ProjectOverview,
ProjectOverviewDetailed,
)
-from labelbox.schema.queue_mode import QueueMode
from labelbox.schema.resource_tag import ResourceTag
from labelbox.schema.task import Task
from labelbox.schema.task_queue import TaskQueue
if TYPE_CHECKING:
- from labelbox import BulkImportRequest
+ pass
DataRowPriority = int
-LabelingParameterOverrideInput = Tuple[
- Union[DataRow, DataRowIdentifier], DataRowPriority
-]
+LabelingParameterOverrideInput = Tuple[DataRowIdentifier, DataRowPriority]
logger = logging.getLogger(__name__)
MAX_SYNC_BATCH_ROW_COUNT = 1_000
@@ -81,23 +80,18 @@ def validate_labeling_parameter_overrides(
data: List[LabelingParameterOverrideInput],
) -> None:
for idx, row in enumerate(data):
- if len(row) < 2:
- raise TypeError(
- f"Data must be a list of tuples each containing two elements: a DataRow or a DataRowIdentifier and priority (int). Found {len(row)} items. Index: {idx}"
- )
data_row_identifier = row[0]
priority = row[1]
- valid_types = (Entity.DataRow, UniqueId, GlobalKey)
- if not isinstance(data_row_identifier, valid_types):
+ if not isinstance(data_row_identifier, get_args(DataRowIdentifier)):
raise TypeError(
- f"Data row identifier should be be of type DataRow, UniqueId or GlobalKey. Found {type(data_row_identifier)} for data_row_identifier {data_row_identifier}"
+ f"Data row identifier should be of type DataRowIdentifier. Found {type(data_row_identifier)}."
+ )
+ if len(row) < 2:
+ raise TypeError(
+ f"Data must be a list of tuples each containing two elements: a DataRowIdentifier and priority (int). Found {len(row)} items. Index: {idx}"
)
-
if not isinstance(priority, int):
- if isinstance(data_row_identifier, Entity.DataRow):
- id = data_row_identifier.uid
- else:
- id = data_row_identifier
+ id = data_row_identifier.key
raise TypeError(
f"Priority must be an int. Found {type(priority)} for data_row_identifier {id}"
)
@@ -114,7 +108,6 @@ class Project(DbObject, Updateable, Deletable):
created_at (datetime)
setup_complete (datetime)
last_activity_time (datetime)
- queue_mode (string)
auto_audit_number_of_labels (int)
auto_audit_percentage (float)
is_benchmark_enabled (bool)
@@ -137,7 +130,6 @@ class Project(DbObject, Updateable, Deletable):
created_at = Field.DateTime("created_at")
setup_complete = Field.DateTime("setup_complete")
last_activity_time = Field.DateTime("last_activity_time")
- queue_mode = Field.Enum(QueueMode, "queue_mode")
auto_audit_number_of_labels = Field.Int("auto_audit_number_of_labels")
auto_audit_percentage = Field.Float("auto_audit_percentage")
# Bind data_type and allowedMediaTYpe using the GraphQL type MediaType
@@ -423,6 +415,13 @@ def export_v2(
>>> task.wait_till_done()
>>> task.result
"""
+
+ warnings.warn(
+ "You are currently utilizing export_v2 for this action, which will be removed in 7.0. Please refer to our docs for export alternatives. https://docs.labelbox.com/reference/export-overview#export-methods",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+
task, is_streamable = self._export(task_name, filters, params)
if is_streamable:
return ExportTask(task, True)
@@ -659,21 +658,11 @@ def review_metrics(self, net_score) -> int:
res = self.client.execute(query_str, {id_param: self.uid})
return res["project"]["reviewMetrics"]["labelAggregate"]["count"]
- def setup_editor(self, ontology) -> None:
- """
- Sets up the project using the Pictor editor.
-
- Args:
- ontology (Ontology): The ontology to attach to the project
- """
- warnings.warn("This method is deprecated use connect_ontology instead.")
- self.connect_ontology(ontology)
-
def connect_ontology(self, ontology) -> None:
"""
Connects the ontology to the project. If an editor is not setup, it will be connected as well.
- Note: For live chat model evaluation projects, the editor setup is skipped becase it is automatically setup when the project is created.
+ Note: For live chat model evaluation projects, the editor setup is skipped because it is automatically setup when the project is created.
Args:
ontology (Ontology): The ontology to attach to the project
@@ -696,34 +685,6 @@ def connect_ontology(self, ontology) -> None:
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
self.update(setup_complete=timestamp)
- def setup(self, labeling_frontend, labeling_frontend_options) -> None:
- """This method will associate default labeling frontend with the project and create an ontology based on labeling_frontend_options.
-
- Args:
- labeling_frontend (LabelingFrontend): Do not use, this parameter is deprecated. We now associate the default labeling frontend with the project.
- labeling_frontend_options (dict or str): Labeling frontend options,
- a.k.a. project ontology. If given a `dict` it will be converted
- to `str` using `json.dumps`.
- """
-
- warnings.warn("This method is deprecated use connect_ontology instead.")
- if labeling_frontend is not None:
- warnings.warn(
- "labeling_frontend parameter will not be used to create a new labeling frontend."
- )
-
- if self.is_chat_evaluation() or self.is_prompt_response():
- warnings.warn("""
- This project is a live chat evaluation project or prompt and response generation project.
- Editor was setup automatically.
- """)
- return
-
- self._connect_default_labeling_front_end(labeling_frontend_options)
-
- timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
- self.update(setup_complete=timestamp)
-
def _connect_default_labeling_front_end(self, ontology_as_dict: dict):
labeling_frontend = self.labeling_frontend()
if (
@@ -775,11 +736,8 @@ def create_batch(
Returns: the created batch
Raises:
- labelbox.exceptions.ValueError if a project is not batch mode, if the project is auto data generation, if the batch exceeds 100k data rows
+ lbox.exceptions.ValueError if a project is not batch mode, if the project is auto data generation, if the batch exceeds 100k data rows
"""
- # @TODO: make this automatic?
- if self.queue_mode != QueueMode.Batch:
- raise ValueError("Project must be in batch mode")
if (
self.is_auto_data_generation() and not self.is_chat_evaluation()
@@ -861,9 +819,6 @@ def create_batches(
Returns: a task for the created batches
"""
- if self.queue_mode != QueueMode.Batch:
- raise ValueError("Project must be in batch mode")
-
dr_ids = []
if data_rows is not None:
for dr in data_rows:
@@ -946,9 +901,6 @@ def create_batches_from_dataset(
Returns: a task for the created batches
"""
- if self.queue_mode != QueueMode.Batch:
- raise ValueError("Project must be in batch mode")
-
if consensus_settings:
consensus_settings = ConsensusSettings(
**consensus_settings
@@ -1088,57 +1040,6 @@ def _create_batch_async(
return self.client.get_batch(self.uid, batch_id)
- def _update_queue_mode(self, mode: "QueueMode") -> "QueueMode":
- """
- Updates the queueing mode of this project.
-
- Deprecation notice: This method is deprecated. Going forward, projects must
- go through a migration to have the queue mode changed. Users should specify the
- queue mode for a project during creation if a non-default mode is desired.
-
- For more information, visit https://docs.labelbox.com/reference/migrating-to-workflows#upcoming-changes
-
- Args:
- mode: the specified queue mode
-
- Returns: the updated queueing mode of this project
-
- """
-
- logger.warning(
- "Updating the queue_mode for a project will soon no longer be supported."
- )
-
- if self.queue_mode == mode:
- return mode
-
- if mode == QueueMode.Batch:
- status = "ENABLED"
- elif mode == QueueMode.Dataset:
- status = "DISABLED"
- else:
- raise ValueError(
- "Must provide either `BATCH` or `DATASET` as a mode"
- )
-
- query_str = (
- """mutation %s($projectId: ID!, $status: TagSetStatusInput!) {
- project(where: {id: $projectId}) {
- setTagSetStatus(input: {tagSetStatus: $status}) {
- tagSetStatus
- }
- }
- }
- """
- % "setTagSetStatusPyApi"
- )
-
- self.client.execute(
- query_str, {"projectId": self.uid, "status": status}
- )
-
- return mode
-
def get_label_count(self) -> int:
"""
Returns: the total number of labels in this project.
@@ -1153,46 +1054,6 @@ def get_label_count(self) -> int:
res = self.client.execute(query_str, {"projectId": self.uid})
return res["project"]["labelCount"]
- def get_queue_mode(self) -> "QueueMode":
- """
- Provides the queue mode used for this project.
-
- Deprecation notice: This method is deprecated and will be removed in
- a future version. To obtain the queue mode of a project, simply refer
- to the queue_mode attribute of a Project.
-
- For more information, visit https://docs.labelbox.com/reference/migrating-to-workflows#upcoming-changes
-
- Returns: the QueueMode for this project
-
- """
-
- logger.warning(
- "Obtaining the queue_mode for a project through this method will soon"
- " no longer be supported."
- )
-
- query_str = (
- """query %s($projectId: ID!) {
- project(where: {id: $projectId}) {
- tagSetStatus
- }
- }
- """
- % "GetTagSetStatusPyApi"
- )
-
- status = self.client.execute(query_str, {"projectId": self.uid})[
- "project"
- ]["tagSetStatus"]
-
- if status == "ENABLED":
- return QueueMode.Batch
- elif status == "DISABLED":
- return QueueMode.Dataset
- else:
- raise ValueError("Status not known")
-
def add_model_config(self, model_config_id: str) -> str:
"""Adds a model config to this project.
@@ -1285,18 +1146,13 @@ def set_labeling_parameter_overrides(
See information on priority here:
https://docs.labelbox.com/en/configure-editor/queue-system#reservation-system
- >>> project.set_labeling_parameter_overrides([
- >>> (data_row_id1, 2), (data_row_id2, 1)])
- or
>>> project.set_labeling_parameter_overrides([
>>> (data_row_gk1, 2), (data_row_gk2, 1)])
Args:
data (iterable): An iterable of tuples. Each tuple must contain
- either (DataRow, DataRowPriority)
- or (DataRowIdentifier, priority) for the new override.
+ (DataRowIdentifier, priority) for the new override.
DataRowIdentifier is an object representing a data row id or a global key. A DataIdentifier object can be a UniqueIds or GlobalKeys class.
- NOTE - passing whole DatRow is deprecated. Please use a DataRowIdentifier instead.
Priority:
* Data will be labeled in priority order.
@@ -1325,16 +1181,7 @@ def set_labeling_parameter_overrides(
data_rows_with_identifiers = ""
for data_row, priority in data:
- if isinstance(data_row, DataRow):
- data_rows_with_identifiers += f'{{dataRowIdentifier: {{id: "{data_row.uid}", idType: {IdType.DataRowId}}}, priority: {priority}}},'
- elif isinstance(data_row, UniqueId) or isinstance(
- data_row, GlobalKey
- ):
- data_rows_with_identifiers += f'{{dataRowIdentifier: {{id: "{data_row.key}", idType: {data_row.id_type}}}, priority: {priority}}},'
- else:
- raise TypeError(
- f"Data row identifier should be be of type DataRow or Data Row Identifier. Found {type(data_row)}."
- )
+ data_rows_with_identifiers += f'{{dataRowIdentifier: {{id: "{data_row.key}", idType: {data_row.id_type}}}, priority: {priority}}},'
query_str = template.substitute(
dataWithDataRowIdentifiers=data_rows_with_identifiers
@@ -1342,26 +1189,10 @@ def set_labeling_parameter_overrides(
res = self.client.execute(query_str, {"projectId": self.uid})
return res["project"]["setLabelingParameterOverrides"]["success"]
- @overload
def update_data_row_labeling_priority(
self,
data_rows: DataRowIdentifiers,
priority: int,
- ) -> bool:
- pass
-
- @overload
- def update_data_row_labeling_priority(
- self,
- data_rows: List[str],
- priority: int,
- ) -> bool:
- pass
-
- def update_data_row_labeling_priority(
- self,
- data_rows,
- priority: int,
) -> bool:
"""
Updates labeling parameter overrides to this project in bulk. This method allows up to 1 million data rows to be
@@ -1371,7 +1202,7 @@ def update_data_row_labeling_priority(
https://docs.labelbox.com/en/configure-editor/queue-system#reservation-system
Args:
- data_rows: a list of data row ids to update priorities for. This can be a list of strings or a DataRowIdentifiers object
+ data_rows: data row identifiers object to update priorities.
DataRowIdentifier objects are lists of ids or global keys. A DataIdentifier object can be a UniqueIds or GlobalKeys class.
priority (int): Priority for the new override. See above for more information.
@@ -1379,12 +1210,8 @@ def update_data_row_labeling_priority(
bool, indicates if the operation was a success.
"""
- if isinstance(data_rows, list):
- data_rows = UniqueIds(data_rows)
- warnings.warn(
- "Using data row ids will be deprecated. Please use "
- "UniqueIds or GlobalKeys instead."
- )
+ if not isinstance(data_rows, get_args(DataRowIdentifiers)):
+ raise TypeError("data_rows must be a DataRowIdentifiers object")
method = "createQueuePriorityUpdateTask"
priority_param = "priority"
@@ -1482,33 +1309,6 @@ def enable_model_assisted_labeling(self, toggle: bool = True) -> bool:
"showingPredictionsToLabelers"
]
- def bulk_import_requests(self) -> PaginatedCollection:
- """Returns bulk import request objects which are used in model-assisted labeling.
- These are returned with the oldest first, and most recent last.
- """
-
- id_param = "project_id"
- query_str = """query ListAllImportRequestsPyApi($%s: ID!) {
- bulkImportRequests (
- where: { projectId: $%s }
- skip: %%d
- first: %%d
- ) {
- %s
- }
- }""" % (
- id_param,
- id_param,
- query.results_query_part(Entity.BulkImportRequest),
- )
- return PaginatedCollection(
- self.client,
- query_str,
- {id_param: str(self.uid)},
- ["bulkImportRequests"],
- Entity.BulkImportRequest,
- )
-
def batches(self) -> PaginatedCollection:
"""Fetch all batches that belong to this project
@@ -1554,25 +1354,15 @@ def task_queues(self) -> List[TaskQueue]:
for field_values in task_queue_values
]
- @overload
def move_data_rows_to_task_queue(
self, data_row_ids: DataRowIdentifiers, task_queue_id: str
):
- pass
-
- @overload
- def move_data_rows_to_task_queue(
- self, data_row_ids: List[str], task_queue_id: str
- ):
- pass
-
- def move_data_rows_to_task_queue(self, data_row_ids, task_queue_id: str):
"""
Moves data rows to the specified task queue.
Args:
- data_row_ids: a list of data row ids to be moved. This can be a list of strings or a DataRowIdentifiers object
+ data_row_ids: a list of data row ids to be moved. This should be a DataRowIdentifiers object
DataRowIdentifier objects are lists of ids or global keys. A DataIdentifier object can be a UniqueIds or GlobalKeys class.
task_queue_id: the task queue id to be moved to, or None to specify the "Done" queue
@@ -1580,12 +1370,9 @@ def move_data_rows_to_task_queue(self, data_row_ids, task_queue_id: str):
None if successful, or a raised error on failure
"""
- if isinstance(data_row_ids, list):
- data_row_ids = UniqueIds(data_row_ids)
- warnings.warn(
- "Using data row ids will be deprecated. Please use "
- "UniqueIds or GlobalKeys instead."
- )
+
+ if not isinstance(data_row_ids, get_args(DataRowIdentifiers)):
+ raise TypeError("data_rows must be a DataRowIdentifiers object")
method = "createBulkAddRowsToQueueTask"
query_str = (
@@ -1633,77 +1420,6 @@ def _wait_for_task(self, task_id: str) -> Task:
return task
- def upload_annotations(
- self,
- name: str,
- annotations: Union[str, Path, Iterable[Dict]],
- validate: bool = False,
- ) -> "BulkImportRequest": # type: ignore
- """Uploads annotations to a new Editor project.
-
- Args:
- name (str): name of the BulkImportRequest job
- annotations (str or Path or Iterable):
- url that is publicly accessible by Labelbox containing an
- ndjson file
- OR local path to an ndjson file
- OR iterable of annotation rows
- validate (bool):
- Whether or not to validate the payload before uploading.
- Returns:
- BulkImportRequest
- """
-
- if isinstance(annotations, str) or isinstance(annotations, Path):
-
- def _is_url_valid(url: Union[str, Path]) -> bool:
- """Verifies that the given string is a valid url.
-
- Args:
- url: string to be checked
- Returns:
- True if the given url is valid otherwise False
-
- """
- if isinstance(url, Path):
- return False
- parsed = urlparse(url)
- return bool(parsed.scheme) and bool(parsed.netloc)
-
- if _is_url_valid(annotations):
- return Entity.BulkImportRequest.create_from_url(
- client=self.client,
- project_id=self.uid,
- name=name,
- url=str(annotations),
- validate=validate,
- )
- else:
- path = Path(annotations)
- if not path.exists():
- raise FileNotFoundError(
- f"{annotations} is not a valid url nor existing local file"
- )
- return Entity.BulkImportRequest.create_from_local_file(
- client=self.client,
- project_id=self.uid,
- name=name,
- file=path,
- validate_file=validate,
- )
- elif isinstance(annotations, Iterable):
- return Entity.BulkImportRequest.create_from_objects(
- client=self.client,
- project_id=self.uid,
- name=name,
- predictions=annotations, # type: ignore
- validate=validate,
- )
- else:
- raise ValueError(
- f"Invalid annotations given of type: {type(annotations)}"
- )
-
def _wait_until_data_rows_are_processed(
self,
data_row_ids: Optional[List[str]] = None,
diff --git a/libs/labelbox/src/labelbox/schema/project_model_config.py b/libs/labelbox/src/labelbox/schema/project_model_config.py
index 9b6d8a0bb..c8773abf9 100644
--- a/libs/labelbox/src/labelbox/schema/project_model_config.py
+++ b/libs/labelbox/src/labelbox/schema/project_model_config.py
@@ -1,10 +1,11 @@
-from labelbox.orm.db_object import DbObject
-from labelbox.orm.model import Field, Relationship
-from labelbox.exceptions import (
+from lbox.exceptions import (
LabelboxError,
error_message_for_unparsed_graphql_error,
)
+from labelbox.orm.db_object import DbObject
+from labelbox.orm.model import Field, Relationship
+
class ProjectModelConfig(DbObject):
"""A ProjectModelConfig represents an association between a project and a single model config.
diff --git a/libs/labelbox/src/labelbox/schema/project_overview.py b/libs/labelbox/src/labelbox/schema/project_overview.py
index cee195c10..8be20fbda 100644
--- a/libs/labelbox/src/labelbox/schema/project_overview.py
+++ b/libs/labelbox/src/labelbox/schema/project_overview.py
@@ -1,6 +1,7 @@
from typing import Dict, List
-from typing_extensions import TypedDict
+
from pydantic import BaseModel
+from typing_extensions import TypedDict
class ProjectOverview(BaseModel):
diff --git a/libs/labelbox/src/labelbox/schema/queue_mode.py b/libs/labelbox/src/labelbox/schema/queue_mode.py
deleted file mode 100644
index 333e92987..000000000
--- a/libs/labelbox/src/labelbox/schema/queue_mode.py
+++ /dev/null
@@ -1,12 +0,0 @@
-from enum import Enum
-
-
-class QueueMode(str, Enum):
- Batch = "BATCH"
- Dataset = "DATA_SET"
-
- @classmethod
- def _missing_(cls, value):
- # Parses the deprecated "CATALOG" value back to QueueMode.Batch.
- if value == "CATALOG":
- return QueueMode.Batch
diff --git a/libs/labelbox/src/labelbox/schema/search_filters.py b/libs/labelbox/src/labelbox/schema/search_filters.py
index 13b158678..52330492d 100644
--- a/libs/labelbox/src/labelbox/schema/search_filters.py
+++ b/libs/labelbox/src/labelbox/schema/search_filters.py
@@ -1,11 +1,15 @@
import datetime
from enum import Enum
-from typing import List, Union
-from pydantic import PlainSerializer, BaseModel, Field
+from typing import Annotated, List, Union
-from typing_extensions import Annotated
+from pydantic import (
+ BaseModel,
+ ConfigDict,
+ Field,
+ PlainSerializer,
+ field_validator,
+)
-from pydantic import BaseModel, Field, field_validator
from labelbox.schema.labeling_service_status import LabelingServiceStatus
from labelbox.utils import format_iso_datetime
@@ -15,8 +19,7 @@ class BaseSearchFilter(BaseModel):
Shared code for all search filters
"""
- class Config:
- use_enum_values = True
+ model_config = ConfigDict(use_enum_values=True)
class OperationTypeEnum(Enum):
diff --git a/libs/labelbox/src/labelbox/schema/send_to_annotate_params.py b/libs/labelbox/src/labelbox/schema/send_to_annotate_params.py
index 18bd26637..fe09cf4c0 100644
--- a/libs/labelbox/src/labelbox/schema/send_to_annotate_params.py
+++ b/libs/labelbox/src/labelbox/schema/send_to_annotate_params.py
@@ -1,18 +1,11 @@
-import sys
+from typing import Dict, Optional, TypedDict
-from typing import Optional, Dict
+from pydantic import BaseModel, model_validator
from labelbox.schema.conflict_resolution_strategy import (
ConflictResolutionStrategy,
)
-if sys.version_info >= (3, 8):
- from typing import TypedDict
-else:
- from typing_extensions import TypedDict
-
-from pydantic import BaseModel, model_validator
-
class SendToAnnotateFromCatalogParams(BaseModel):
"""
diff --git a/libs/labelbox/src/labelbox/schema/slice.py b/libs/labelbox/src/labelbox/schema/slice.py
index 624731024..a640ebc1d 100644
--- a/libs/labelbox/src/labelbox/schema/slice.py
+++ b/libs/labelbox/src/labelbox/schema/slice.py
@@ -53,43 +53,6 @@ class CatalogSlice(Slice):
Represents a Slice used for filtering data rows in Catalog.
"""
- def get_data_row_ids(self) -> PaginatedCollection:
- """
- Fetches all data row ids that match this Slice
-
- Returns:
- A PaginatedCollection of mapping of data row ids to global keys
- """
-
- warnings.warn(
- "get_data_row_ids will be deprecated. Use get_data_row_identifiers instead"
- )
-
- query_str = """
- query getDataRowIdsBySavedQueryPyApi($id: ID!, $from: String, $first: Int!) {
- getDataRowIdsBySavedQuery(input: {
- savedQueryId: $id,
- after: $from
- first: $first
- }) {
- totalCount
- nodes
- pageInfo {
- endCursor
- hasNextPage
- }
- }
- }
- """
- return PaginatedCollection(
- client=self.client,
- query=query_str,
- params={"id": str(self.uid)},
- dereferencing=["getDataRowIdsBySavedQuery", "nodes"],
- obj_class=lambda _, data_row_id: data_row_id,
- cursor_path=["getDataRowIdsBySavedQuery", "pageInfo", "endCursor"],
- )
-
def get_data_row_identifiers(self) -> PaginatedCollection:
"""
Fetches all data row ids and global keys (where defined) that match this Slice
@@ -164,6 +127,13 @@ def export_v2(
>>> task.wait_till_done()
>>> task.result
"""
+
+ warnings.warn(
+ "You are currently utilizing export_v2 for this action, which will be removed in 7.0. Please refer to our docs for export alternatives. https://docs.labelbox.com/reference/export-overview#export-methods",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+
task, is_streamable = self._export(task_name, params)
if is_streamable:
return ExportTask(task, True)
diff --git a/libs/labelbox/src/labelbox/schema/task.py b/libs/labelbox/src/labelbox/schema/task.py
index 9d7a26e1d..f996ae05d 100644
--- a/libs/labelbox/src/labelbox/schema/task.py
+++ b/libs/labelbox/src/labelbox/schema/task.py
@@ -1,14 +1,14 @@
import json
import logging
-import requests
import time
-from typing import TYPE_CHECKING, Callable, Optional, Dict, Any, List, Union
-from labelbox import parser
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
-from labelbox.exceptions import ResourceNotFoundError
-from labelbox.orm.db_object import DbObject
-from labelbox.orm.model import Field, Relationship, Entity
+import requests
+from lbox.exceptions import ResourceNotFoundError
+from labelbox import parser
+from labelbox.orm.db_object import DbObject
+from labelbox.orm.model import Entity, Field, Relationship
from labelbox.pagination import PaginatedCollection
from labelbox.schema.internal.datarow_upload_constants import (
DOWNLOAD_RESULT_PAGE_SIZE,
diff --git a/libs/labelbox/src/labelbox/schema/user_group.py b/libs/labelbox/src/labelbox/schema/user_group.py
index 9d506bf92..2e93b4376 100644
--- a/libs/labelbox/src/labelbox/schema/user_group.py
+++ b/libs/labelbox/src/labelbox/schema/user_group.py
@@ -1,21 +1,21 @@
-from enum import Enum
-from typing import Set, Iterator
from collections import defaultdict
+from enum import Enum
+from typing import Iterator, Set
-from labelbox import Client
-from labelbox.exceptions import ResourceCreationError
-from labelbox.schema.user import User
-from labelbox.schema.project import Project
-from labelbox.exceptions import (
- UnprocessableEntityError,
+from lbox.exceptions import (
MalformedQueryException,
+ ResourceCreationError,
ResourceNotFoundError,
+ UnprocessableEntityError,
)
-from labelbox.schema.queue_mode import QueueMode
-from labelbox.schema.ontology_kind import EditorTaskType
-from labelbox.schema.media_type import MediaType
from pydantic import BaseModel, ConfigDict
+from labelbox import Client
+from labelbox.schema.media_type import MediaType
+from labelbox.schema.ontology_kind import EditorTaskType
+from labelbox.schema.project import Project
+from labelbox.schema.user import User
+
class UserGroupColor(Enum):
"""
@@ -92,9 +92,6 @@ def __init__(
color (UserGroupColor, optional): The color of the user group. Defaults to UserGroupColor.BLUE.
users (Set[User], optional): The set of users in the user group. Defaults to an empty set.
projects (Set[Project], optional): The set of projects associated with the user group. Defaults to an empty set.
-
- Raises:
- RuntimeError: If the experimental feature is not enabled in the client.
"""
super().__init__(
client=client,
@@ -104,10 +101,6 @@ def __init__(
users=users,
projects=projects,
)
- if not self.client.enable_experimental:
- raise RuntimeError(
- "Please enable experimental in client to use UserGroups"
- )
def get(self) -> "UserGroup":
"""
@@ -284,7 +277,7 @@ def create(self) -> "UserGroup":
except Exception as e:
error = e
if not result or error:
- # this is client side only, server doesn't have an equivalent error
+ # This is client side only, server doesn't have an equivalent error
raise ResourceCreationError(
f"Failed to create user group, either user group name is in use currently, or provided user or projects don't exist server error: {error}"
)
@@ -417,7 +410,6 @@ def _get_projects_set(self, project_nodes):
project_values = defaultdict(lambda: None)
project_values["id"] = project["id"]
project_values["name"] = project["name"]
- project_values["queueMode"] = QueueMode.Batch.value
project_values["editorTaskType"] = EditorTaskType.Missing.value
project_values["mediaType"] = MediaType.Image.value
projects.add(Project(self.client, project_values))
diff --git a/libs/labelbox/src/labelbox/typing_imports.py b/libs/labelbox/src/labelbox/typing_imports.py
deleted file mode 100644
index 6edfb9bef..000000000
--- a/libs/labelbox/src/labelbox/typing_imports.py
+++ /dev/null
@@ -1,11 +0,0 @@
-"""
-This module imports types that differ across python versions, so other modules
-don't have to worry about where they should be imported from.
-"""
-
-import sys
-
-if sys.version_info >= (3, 8):
- from typing import Literal
-else:
- from typing_extensions import Literal
diff --git a/libs/labelbox/src/labelbox/utils.py b/libs/labelbox/src/labelbox/utils.py
index c76ce188f..dcf51be82 100644
--- a/libs/labelbox/src/labelbox/utils.py
+++ b/libs/labelbox/src/labelbox/utils.py
@@ -87,8 +87,8 @@ class _NoCoercionMixin:
when serializing the object.
Example:
- class ConversationData(BaseData, _NoCoercionMixin):
- class_name: Literal["ConversationData"] = "ConversationData"
+ class GenericDataRowData(BaseData, _NoCoercionMixin):
+ class_name: Literal["GenericDataRowData"] = "GenericDataRowData"
"""
diff --git a/libs/labelbox/tests/conftest.py b/libs/labelbox/tests/conftest.py
index 49eab165d..3e9f0b491 100644
--- a/libs/labelbox/tests/conftest.py
+++ b/libs/labelbox/tests/conftest.py
@@ -12,6 +12,7 @@
import pytest
import requests
+from lbox.exceptions import LabelboxError
from labelbox import (
Classification,
@@ -24,7 +25,6 @@
Option,
Tool,
)
-from labelbox.exceptions import LabelboxError
from labelbox.orm import query
from labelbox.pagination import PaginatedCollection
from labelbox.schema.annotation_import import LabelImport
@@ -33,7 +33,6 @@
from labelbox.schema.ontology import Ontology
from labelbox.schema.project import Project
from labelbox.schema.quality_mode import QualityMode
-from labelbox.schema.queue_mode import QueueMode
IMG_URL = "https://picsum.photos/200/300.jpg"
MASKABLE_IMG_URL = "https://storage.googleapis.com/labelbox-datasets/image_sample_data/2560px-Kitano_Street_Kobe01s5s4110.jpeg"
@@ -445,7 +444,6 @@ def conversation_entity_data_row(client, rand_gen):
def project(client, rand_gen):
project = client.create_project(
name=rand_gen(str),
- queue_mode=QueueMode.Batch,
media_type=MediaType.Image,
)
yield project
@@ -456,8 +454,7 @@ def project(client, rand_gen):
def consensus_project(client, rand_gen):
project = client.create_project(
name=rand_gen(str),
- quality_mode=QualityMode.Consensus,
- queue_mode=QueueMode.Batch,
+ quality_modes={QualityMode.Consensus},
media_type=MediaType.Image,
)
yield project
@@ -647,7 +644,6 @@ def configured_project_with_label(
"""
project = client.create_project(
name=rand_gen(str),
- queue_mode=QueueMode.Batch,
media_type=MediaType.Image,
)
project._wait_until_data_rows_are_processed(
@@ -661,7 +657,7 @@ def configured_project_with_label(
[data_row.uid], # sample of data row objects
5, # priority between 1(Highest) - 5(lowest)
)
- ontology = _setup_ontology(project)
+ ontology = _setup_ontology(project, client)
label = _create_label(
project, data_row, ontology, wait_for_label_processing
)
@@ -704,20 +700,19 @@ def create_label():
return label
-def _setup_ontology(project):
- editor = list(
- project.client.get_labeling_frontends(
- where=LabelingFrontend.name == "editor"
- )
- )[0]
+def _setup_ontology(project: Project, client: Client):
ontology_builder = OntologyBuilder(
tools=[
Tool(tool=Tool.Type.BBOX, name="test-bbox-class"),
]
)
- project.setup(editor, ontology_builder.asdict())
- # TODO: ontology may not be synchronous after setup. remove sleep when api is more consistent
- time.sleep(2)
+ ontology = client.create_ontology(
+ name="ontology with features",
+ media_type=MediaType.Image,
+ normalized=ontology_builder.asdict(),
+ )
+ project.connect_ontology(ontology)
+
return OntologyBuilder.from_project(project)
@@ -750,7 +745,6 @@ def configured_batch_project_with_label(
"""
project = client.create_project(
name=rand_gen(str),
- queue_mode=QueueMode.Batch,
media_type=MediaType.Image,
)
data_rows = [dr.uid for dr in list(dataset.data_rows())]
@@ -760,7 +754,7 @@ def configured_batch_project_with_label(
project.create_batch("test-batch", data_rows)
project.data_row_ids = data_rows
- ontology = _setup_ontology(project)
+ ontology = _setup_ontology(project, client)
label = _create_label(
project, data_row, ontology, wait_for_label_processing
)
@@ -785,7 +779,6 @@ def configured_batch_project_with_multiple_datarows(
"""
project = client.create_project(
name=rand_gen(str),
- queue_mode=QueueMode.Batch,
media_type=MediaType.Image,
)
global_keys = [dr.global_key for dr in data_rows]
@@ -793,7 +786,7 @@ def configured_batch_project_with_multiple_datarows(
batch_name = f"batch {uuid.uuid4()}"
project.create_batch(batch_name, global_keys=global_keys)
- ontology = _setup_ontology(project)
+ ontology = _setup_ontology(project, client)
for datarow in data_rows:
_create_label(project, datarow, ontology, wait_for_label_processing)
@@ -1030,11 +1023,11 @@ def _upload_invalid_data_rows_for_dataset(dataset: Dataset):
@pytest.fixture
def configured_project(
- project_with_empty_ontology, initial_dataset, rand_gen, image_url
+ project_with_one_feature_ontology, initial_dataset, rand_gen, image_url
):
dataset = initial_dataset
data_row_id = dataset.create_data_row(row_data=image_url).uid
- project = project_with_empty_ontology
+ project = project_with_one_feature_ontology
batch = project.create_batch(
rand_gen(str),
@@ -1049,24 +1042,24 @@ def configured_project(
@pytest.fixture
-def project_with_empty_ontology(project):
- editor = list(
- project.client.get_labeling_frontends(
- where=LabelingFrontend.name == "editor"
- )
- )[0]
- empty_ontology = {"tools": [], "classifications": []}
- project.setup(editor, empty_ontology)
+def project_with_one_feature_ontology(project, client: Client):
+ tools = [
+ Tool(tool=Tool.Type.BBOX, name="test-bbox-class").asdict(),
+ ]
+ empty_ontology = {"tools": tools, "classifications": []}
+ ontology = client.create_ontology(
+ "empty ontology", empty_ontology, MediaType.Image
+ )
+ project.connect_ontology(ontology)
yield project
@pytest.fixture
def configured_project_with_complex_ontology(
- client, initial_dataset, rand_gen, image_url, teardown_helpers
+ client: Client, initial_dataset, rand_gen, image_url, teardown_helpers
):
project = client.create_project(
name=rand_gen(str),
- queue_mode=QueueMode.Batch,
media_type=MediaType.Image,
)
dataset = initial_dataset
@@ -1080,19 +1073,12 @@ def configured_project_with_complex_ontology(
)
project.data_row_ids = data_row_ids
- editor = list(
- project.client.get_labeling_frontends(
- where=LabelingFrontend.name == "editor"
- )
- )[0]
-
ontology = OntologyBuilder()
tools = [
Tool(tool=Tool.Type.BBOX, name="test-bbox-class"),
Tool(tool=Tool.Type.LINE, name="test-line-class"),
Tool(tool=Tool.Type.POINT, name="test-point-class"),
Tool(tool=Tool.Type.POLYGON, name="test-polygon-class"),
- Tool(tool=Tool.Type.NER, name="test-ner-class"),
]
options = [
@@ -1124,7 +1110,11 @@ def configured_project_with_complex_ontology(
for c in classifications:
ontology.add_classification(c)
- project.setup(editor, ontology.asdict())
+ ontology = client.create_ontology(
+ "complex image ontology", ontology.asdict(), MediaType.Image
+ )
+
+ project.connect_ontology(ontology)
yield [project, data_row]
teardown_helpers.teardown_project_labels_ontology_feature_schemas(project)
diff --git a/libs/labelbox/tests/data/annotation_import/test_bulk_import_request.py b/libs/labelbox/tests/data/annotation_import/test_bulk_import_request.py
deleted file mode 100644
index 9abae1422..000000000
--- a/libs/labelbox/tests/data/annotation_import/test_bulk_import_request.py
+++ /dev/null
@@ -1,258 +0,0 @@
-from unittest.mock import patch
-import uuid
-from labelbox import parser, Project
-from labelbox.data.annotation_types.data.generic_data_row_data import (
- GenericDataRowData,
-)
-import pytest
-import random
-from labelbox.data.annotation_types.annotation import ObjectAnnotation
-from labelbox.data.annotation_types.classification.classification import (
- Checklist,
- ClassificationAnnotation,
- ClassificationAnswer,
- Radio,
-)
-from labelbox.data.annotation_types.data.video import VideoData
-from labelbox.data.annotation_types.geometry.point import Point
-from labelbox.data.annotation_types.geometry.rectangle import (
- Rectangle,
- RectangleUnit,
-)
-from labelbox.data.annotation_types.label import Label
-from labelbox.data.annotation_types.data.text import TextData
-from labelbox.data.annotation_types.ner import (
- DocumentEntity,
- DocumentTextSelection,
-)
-from labelbox.data.annotation_types.video import VideoObjectAnnotation
-
-from labelbox.data.serialization import NDJsonConverter
-from labelbox.exceptions import MALValidationError, UuidError
-from labelbox.schema.bulk_import_request import BulkImportRequest
-from labelbox.schema.enums import BulkImportRequestState
-from labelbox.schema.annotation_import import LabelImport, MALPredictionImport
-from labelbox.schema.media_type import MediaType
-
-"""
-- Here we only want to check that the uploads are calling the validation
-- Then with unit tests we can check the types of errors raised
-"""
-# TODO: remove library once bulk import requests are removed
-
-
-@pytest.mark.order(1)
-def test_create_from_url(module_project):
- name = str(uuid.uuid4())
- url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson"
-
- bulk_import_request = module_project.upload_annotations(
- name=name, annotations=url, validate=False
- )
-
- assert bulk_import_request.project() == module_project
- assert bulk_import_request.name == name
- assert bulk_import_request.input_file_url == url
- assert bulk_import_request.error_file_url is None
- assert bulk_import_request.status_file_url is None
- assert bulk_import_request.state == BulkImportRequestState.RUNNING
-
-
-def test_validate_file(module_project):
- name = str(uuid.uuid4())
- url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson"
- with pytest.raises(MALValidationError):
- module_project.upload_annotations(
- name=name, annotations=url, validate=True
- )
- # Schema ids shouldn't match
-
-
-def test_create_from_objects(
- module_project: Project, predictions, annotation_import_test_helpers
-):
- name = str(uuid.uuid4())
-
- bulk_import_request = module_project.upload_annotations(
- name=name, annotations=predictions
- )
-
- assert bulk_import_request.project() == module_project
- assert bulk_import_request.name == name
- assert bulk_import_request.error_file_url is None
- assert bulk_import_request.status_file_url is None
- assert bulk_import_request.state == BulkImportRequestState.RUNNING
- annotation_import_test_helpers.assert_file_content(
- bulk_import_request.input_file_url, predictions
- )
-
-
-def test_create_from_label_objects(
- module_project, predictions, annotation_import_test_helpers
-):
- name = str(uuid.uuid4())
-
- labels = list(NDJsonConverter.deserialize(predictions))
- bulk_import_request = module_project.upload_annotations(
- name=name, annotations=labels
- )
-
- assert bulk_import_request.project() == module_project
- assert bulk_import_request.name == name
- assert bulk_import_request.error_file_url is None
- assert bulk_import_request.status_file_url is None
- assert bulk_import_request.state == BulkImportRequestState.RUNNING
- normalized_predictions = list(NDJsonConverter.serialize(labels))
- annotation_import_test_helpers.assert_file_content(
- bulk_import_request.input_file_url, normalized_predictions
- )
-
-
-def test_create_from_local_file(
- tmp_path, predictions, module_project, annotation_import_test_helpers
-):
- name = str(uuid.uuid4())
- file_name = f"{name}.ndjson"
- file_path = tmp_path / file_name
- with file_path.open("w") as f:
- parser.dump(predictions, f)
-
- bulk_import_request = module_project.upload_annotations(
- name=name, annotations=str(file_path), validate=False
- )
-
- assert bulk_import_request.project() == module_project
- assert bulk_import_request.name == name
- assert bulk_import_request.error_file_url is None
- assert bulk_import_request.status_file_url is None
- assert bulk_import_request.state == BulkImportRequestState.RUNNING
- annotation_import_test_helpers.assert_file_content(
- bulk_import_request.input_file_url, predictions
- )
-
-
-def test_get(client, module_project):
- name = str(uuid.uuid4())
- url = "https://storage.googleapis.com/labelbox-public-bucket/predictions_test_v2.ndjson"
- module_project.upload_annotations(
- name=name, annotations=url, validate=False
- )
-
- bulk_import_request = BulkImportRequest.from_name(
- client, project_id=module_project.uid, name=name
- )
-
- assert bulk_import_request.project() == module_project
- assert bulk_import_request.name == name
- assert bulk_import_request.input_file_url == url
- assert bulk_import_request.error_file_url is None
- assert bulk_import_request.status_file_url is None
- assert bulk_import_request.state == BulkImportRequestState.RUNNING
-
-
-def test_validate_ndjson(tmp_path, module_project):
- file_name = f"broken.ndjson"
- file_path = tmp_path / file_name
- with file_path.open("w") as f:
- f.write("test")
-
- with pytest.raises(ValueError):
- module_project.upload_annotations(
- name="name", validate=True, annotations=str(file_path)
- )
-
-
-def test_validate_ndjson_uuid(tmp_path, module_project, predictions):
- file_name = f"repeat_uuid.ndjson"
- file_path = tmp_path / file_name
- repeat_uuid = predictions.copy()
- uid = str(uuid.uuid4())
- repeat_uuid[0]["uuid"] = uid
- repeat_uuid[1]["uuid"] = uid
-
- with file_path.open("w") as f:
- parser.dump(repeat_uuid, f)
-
- with pytest.raises(UuidError):
- module_project.upload_annotations(
- name="name", validate=True, annotations=str(file_path)
- )
-
- with pytest.raises(UuidError):
- module_project.upload_annotations(
- name="name", validate=True, annotations=repeat_uuid
- )
-
-
-@pytest.mark.skip(
- "Slow test and uses a deprecated api endpoint for annotation imports"
-)
-def test_wait_till_done(rectangle_inference, project):
- name = str(uuid.uuid4())
- url = project.client.upload_data(
- content=parser.dumps(rectangle_inference), sign=True
- )
- bulk_import_request = project.upload_annotations(
- name=name, annotations=url, validate=False
- )
-
- assert len(bulk_import_request.inputs) == 1
- bulk_import_request.wait_until_done()
- assert bulk_import_request.state == BulkImportRequestState.FINISHED
-
- # Check that the status files are being returned as expected
- assert len(bulk_import_request.errors) == 0
- assert len(bulk_import_request.inputs) == 1
- assert bulk_import_request.inputs[0]["uuid"] == rectangle_inference["uuid"]
- assert len(bulk_import_request.statuses) == 1
- assert bulk_import_request.statuses[0]["status"] == "SUCCESS"
- assert (
- bulk_import_request.statuses[0]["uuid"] == rectangle_inference["uuid"]
- )
-
-
-def test_project_bulk_import_requests(module_project, predictions):
- result = module_project.bulk_import_requests()
- assert len(list(result)) == 0
-
- name = str(uuid.uuid4())
- bulk_import_request = module_project.upload_annotations(
- name=name, annotations=predictions
- )
- bulk_import_request.wait_until_done()
-
- name = str(uuid.uuid4())
- bulk_import_request = module_project.upload_annotations(
- name=name, annotations=predictions
- )
- bulk_import_request.wait_until_done()
-
- name = str(uuid.uuid4())
- bulk_import_request = module_project.upload_annotations(
- name=name, annotations=predictions
- )
- bulk_import_request.wait_until_done()
-
- result = module_project.bulk_import_requests()
- assert len(list(result)) == 3
-
-
-def test_delete(module_project, predictions):
- name = str(uuid.uuid4())
-
- bulk_import_requests = module_project.bulk_import_requests()
- [
- bulk_import_request.delete()
- for bulk_import_request in bulk_import_requests
- ]
-
- bulk_import_request = module_project.upload_annotations(
- name=name, annotations=predictions
- )
- bulk_import_request.wait_until_done()
- all_import_requests = module_project.bulk_import_requests()
- assert len(list(all_import_requests)) == 1
-
- bulk_import_request.delete()
- all_import_requests = module_project.bulk_import_requests()
- assert len(list(all_import_requests)) == 0
diff --git a/libs/labelbox/tests/data/annotation_import/test_data_types.py b/libs/labelbox/tests/data/annotation_import/test_data_types.py
deleted file mode 100644
index 1e45295ef..000000000
--- a/libs/labelbox/tests/data/annotation_import/test_data_types.py
+++ /dev/null
@@ -1,83 +0,0 @@
-import pytest
-
-from labelbox.data.annotation_types.data import (
- AudioData,
- ConversationData,
- DocumentData,
- HTMLData,
- ImageData,
- TextData,
-)
-from labelbox.data.serialization import NDJsonConverter
-from labelbox.data.annotation_types.data.video import VideoData
-
-import labelbox.types as lb_types
-from labelbox.schema.media_type import MediaType
-
-# Unit test for label based on data type.
-# TODO: Dicom removed it is unstable when you deserialize and serialize on label import. If we intend to keep this library this needs add generic data types tests work with this data type.
-# TODO: add MediaType.LLMPromptResponseCreation(data gen) once supported and llm human preference once media type is added
-
-
-@pytest.mark.parametrize(
- "media_type, data_type_class",
- [
- (MediaType.Audio, AudioData),
- (MediaType.Html, HTMLData),
- (MediaType.Image, ImageData),
- (MediaType.Text, TextData),
- (MediaType.Video, VideoData),
- (MediaType.Conversational, ConversationData),
- (MediaType.Document, DocumentData),
- ],
-)
-def test_data_row_type_by_data_row_id(
- media_type,
- data_type_class,
- annotations_by_media_type,
- hardcoded_datarow_id,
-):
- annotations_ndjson = annotations_by_media_type[media_type]
- annotations_ndjson = [annotation[0] for annotation in annotations_ndjson]
-
- label = list(NDJsonConverter.deserialize(annotations_ndjson))[0]
-
- data_label = lb_types.Label(
- data=data_type_class(uid=hardcoded_datarow_id()),
- annotations=label.annotations,
- )
-
- assert data_label.data.uid == label.data.uid
- assert label.annotations == data_label.annotations
-
-
-@pytest.mark.parametrize(
- "media_type, data_type_class",
- [
- (MediaType.Audio, AudioData),
- (MediaType.Html, HTMLData),
- (MediaType.Image, ImageData),
- (MediaType.Text, TextData),
- (MediaType.Video, VideoData),
- (MediaType.Conversational, ConversationData),
- (MediaType.Document, DocumentData),
- ],
-)
-def test_data_row_type_by_global_key(
- media_type,
- data_type_class,
- annotations_by_media_type,
- hardcoded_global_key,
-):
- annotations_ndjson = annotations_by_media_type[media_type]
- annotations_ndjson = [annotation[0] for annotation in annotations_ndjson]
-
- label = list(NDJsonConverter.deserialize(annotations_ndjson))[0]
-
- data_label = lb_types.Label(
- data=data_type_class(global_key=hardcoded_global_key()),
- annotations=label.annotations,
- )
-
- assert data_label.data.global_key == label.data.global_key
- assert label.annotations == data_label.annotations
diff --git a/libs/labelbox/tests/data/annotation_import/test_generic_data_types.py b/libs/labelbox/tests/data/annotation_import/test_generic_data_types.py
index 1cc5538d9..921e98c9d 100644
--- a/libs/labelbox/tests/data/annotation_import/test_generic_data_types.py
+++ b/libs/labelbox/tests/data/annotation_import/test_generic_data_types.py
@@ -29,78 +29,6 @@ def validate_iso_format(date_string: str):
assert parsed_t.second is not None
-@pytest.mark.parametrize(
- "media_type, data_type_class",
- [
- (MediaType.Audio, GenericDataRowData),
- (MediaType.Html, GenericDataRowData),
- (MediaType.Image, GenericDataRowData),
- (MediaType.Text, GenericDataRowData),
- (MediaType.Video, GenericDataRowData),
- (MediaType.Conversational, GenericDataRowData),
- (MediaType.Document, GenericDataRowData),
- (MediaType.LLMPromptResponseCreation, GenericDataRowData),
- (MediaType.LLMPromptCreation, GenericDataRowData),
- (OntologyKind.ResponseCreation, GenericDataRowData),
- (OntologyKind.ModelEvaluation, GenericDataRowData),
- ],
-)
-def test_generic_data_row_type_by_data_row_id(
- media_type,
- data_type_class,
- annotations_by_media_type,
- hardcoded_datarow_id,
-):
- annotations_ndjson = annotations_by_media_type[media_type]
- annotations_ndjson = [annotation[0] for annotation in annotations_ndjson]
-
- label = list(NDJsonConverter.deserialize(annotations_ndjson))[0]
-
- data_label = Label(
- data=data_type_class(uid=hardcoded_datarow_id()),
- annotations=label.annotations,
- )
-
- assert data_label.data.uid == label.data.uid
- assert label.annotations == data_label.annotations
-
-
-@pytest.mark.parametrize(
- "media_type, data_type_class",
- [
- (MediaType.Audio, GenericDataRowData),
- (MediaType.Html, GenericDataRowData),
- (MediaType.Image, GenericDataRowData),
- (MediaType.Text, GenericDataRowData),
- (MediaType.Video, GenericDataRowData),
- (MediaType.Conversational, GenericDataRowData),
- (MediaType.Document, GenericDataRowData),
- # (MediaType.LLMPromptResponseCreation, GenericDataRowData),
- # (MediaType.LLMPromptCreation, GenericDataRowData),
- (OntologyKind.ResponseCreation, GenericDataRowData),
- (OntologyKind.ModelEvaluation, GenericDataRowData),
- ],
-)
-def test_generic_data_row_type_by_global_key(
- media_type,
- data_type_class,
- annotations_by_media_type,
- hardcoded_global_key,
-):
- annotations_ndjson = annotations_by_media_type[media_type]
- annotations_ndjson = [annotation[0] for annotation in annotations_ndjson]
-
- label = list(NDJsonConverter.deserialize(annotations_ndjson))[0]
-
- data_label = Label(
- data=data_type_class(global_key=hardcoded_global_key()),
- annotations=label.annotations,
- )
-
- assert data_label.data.global_key == label.data.global_key
- assert label.annotations == data_label.annotations
-
-
@pytest.mark.parametrize(
"configured_project, media_type",
[
diff --git a/libs/labelbox/tests/data/annotation_import/test_mea_prediction_import.py b/libs/labelbox/tests/data/annotation_import/test_mea_prediction_import.py
index fccca2a3f..5f47975ad 100644
--- a/libs/labelbox/tests/data/annotation_import/test_mea_prediction_import.py
+++ b/libs/labelbox/tests/data/annotation_import/test_mea_prediction_import.py
@@ -1,5 +1,19 @@
import uuid
from labelbox import parser
+from labelbox.data.annotation_types.annotation import ObjectAnnotation
+from labelbox.data.annotation_types.classification.classification import (
+ ClassificationAnnotation,
+ ClassificationAnswer,
+ Radio,
+)
+from labelbox.data.annotation_types.data.generic_data_row_data import (
+ GenericDataRowData,
+)
+from labelbox.data.annotation_types.geometry.line import Line
+from labelbox.data.annotation_types.geometry.point import Point
+from labelbox.data.annotation_types.geometry.polygon import Polygon
+from labelbox.data.annotation_types.geometry.rectangle import Rectangle
+from labelbox.data.annotation_types.label import Label
import pytest
from labelbox import ModelRun
@@ -193,14 +207,60 @@ def test_create_from_label_objects(
annotation_import_test_helpers,
):
name = str(uuid.uuid4())
- use_data_row_ids = [
+ use_data_row_id = [
p["dataRow"]["id"] for p in object_predictions_for_annotation_import
]
- model_run_with_data_rows.upsert_data_rows(use_data_row_ids)
- predictions = list(
- NDJsonConverter.deserialize(object_predictions_for_annotation_import)
- )
+ model_run_with_data_rows.upsert_data_rows(use_data_row_id)
+
+ predictions = []
+ for data_row_id in use_data_row_id:
+ predictions.append(
+ Label(
+ data=GenericDataRowData(
+ uid=data_row_id,
+ ),
+ annotations=[
+ ObjectAnnotation(
+ name="polygon",
+ extra={
+ "uuid": "6d10fa30-3ea0-4e6c-bbb1-63f5c29fe3e4",
+ },
+ value=Polygon(
+ points=[
+ Point(x=147.692, y=118.154),
+ Point(x=142.769, y=104.923),
+ Point(x=57.846, y=118.769),
+ Point(x=28.308, y=169.846),
+ Point(x=147.692, y=118.154),
+ ],
+ ),
+ ),
+ ObjectAnnotation(
+ name="bbox",
+ extra={
+ "uuid": "15b7138f-4bbc-42c5-ae79-45d87b0a3b2a",
+ },
+ value=Rectangle(
+ start=Point(x=58.0, y=48.0),
+ end=Point(x=70.0, y=113.0),
+ ),
+ ),
+ ObjectAnnotation(
+ name="polyline",
+ extra={
+ "uuid": "cf4c6df9-c39c-4fbc-9541-470f6622978a",
+ },
+ value=Line(
+ points=[
+ Point(x=147.692, y=118.154),
+ Point(x=150.692, y=160.154),
+ ],
+ ),
+ ),
+ ],
+ ),
+ )
annotation_import = model_run_with_data_rows.add_predictions(
name=name, predictions=predictions
diff --git a/libs/labelbox/tests/data/annotation_import/test_model.py b/libs/labelbox/tests/data/annotation_import/test_model.py
index dcfe9ef2c..56b2e07c4 100644
--- a/libs/labelbox/tests/data/annotation_import/test_model.py
+++ b/libs/labelbox/tests/data/annotation_import/test_model.py
@@ -1,7 +1,7 @@
import pytest
+from lbox.exceptions import ResourceNotFoundError
from labelbox import Model
-from labelbox.exceptions import ResourceNotFoundError
def test_model(client, configured_project, rand_gen):
diff --git a/libs/labelbox/tests/data/annotation_import/test_ndjson_validation.py b/libs/labelbox/tests/data/annotation_import/test_ndjson_validation.py
deleted file mode 100644
index a0df559fc..000000000
--- a/libs/labelbox/tests/data/annotation_import/test_ndjson_validation.py
+++ /dev/null
@@ -1,230 +0,0 @@
-from labelbox.schema.media_type import MediaType
-from labelbox.schema.project import Project
-import pytest
-
-from labelbox import parser
-from pytest_cases import parametrize, fixture_ref
-
-from labelbox.exceptions import MALValidationError
-from labelbox.schema.bulk_import_request import (
- NDChecklist,
- NDClassification,
- NDMask,
- NDPolygon,
- NDPolyline,
- NDRadio,
- NDRectangle,
- NDText,
- NDTextEntity,
- NDTool,
- _validate_ndjson,
-)
-
-"""
-- These NDlabels are apart of bulkImportReqeust and should be removed once bulk import request is removed
-"""
-
-
-def test_classification_construction(checklist_inference, text_inference):
- checklist = NDClassification.build(checklist_inference[0])
- assert isinstance(checklist, NDChecklist)
- text = NDClassification.build(text_inference[0])
- assert isinstance(text, NDText)
-
-
-@parametrize(
- "inference, expected_type",
- [
- (fixture_ref("polygon_inference"), NDPolygon),
- (fixture_ref("rectangle_inference"), NDRectangle),
- (fixture_ref("line_inference"), NDPolyline),
- (fixture_ref("entity_inference"), NDTextEntity),
- (fixture_ref("segmentation_inference"), NDMask),
- (fixture_ref("segmentation_inference_rle"), NDMask),
- (fixture_ref("segmentation_inference_png"), NDMask),
- ],
-)
-def test_tool_construction(inference, expected_type):
- assert isinstance(NDTool.build(inference[0]), expected_type)
-
-
-def no_tool(text_inference, module_project):
- pred = text_inference[0].copy()
- # Missing key
- del pred["answer"]
- with pytest.raises(MALValidationError):
- _validate_ndjson([pred], module_project)
-
-
-@pytest.mark.parametrize("configured_project", [MediaType.Text], indirect=True)
-def test_invalid_text(text_inference, configured_project):
- # and if it is not a string
- pred = text_inference[0].copy()
- # Extra and wrong key
- del pred["answer"]
- pred["answers"] = []
- with pytest.raises(MALValidationError):
- _validate_ndjson([pred], configured_project)
- del pred["answers"]
-
- # Invalid type
- pred["answer"] = []
- with pytest.raises(MALValidationError):
- _validate_ndjson([pred], configured_project)
-
- # Invalid type
- pred["answer"] = None
- with pytest.raises(MALValidationError):
- _validate_ndjson([pred], configured_project)
-
-
-def test_invalid_checklist_item(checklist_inference, module_project):
- # Only two points
- pred = checklist_inference[0].copy()
- pred["answers"] = [pred["answers"][0], pred["answers"][0]]
- # Duplicate schema ids
- with pytest.raises(MALValidationError):
- _validate_ndjson([pred], module_project)
-
- pred["answers"] = [{"name": "asdfg"}]
- with pytest.raises(MALValidationError):
- _validate_ndjson([pred], module_project)
-
- pred["answers"] = [{"schemaId": "1232132132"}]
- with pytest.raises(MALValidationError):
- _validate_ndjson([pred], module_project)
-
- pred["answers"] = [{}]
- with pytest.raises(MALValidationError):
- _validate_ndjson([pred], module_project)
-
- pred["answers"] = []
- with pytest.raises(MALValidationError):
- _validate_ndjson([pred], module_project)
-
- del pred["answers"]
- with pytest.raises(MALValidationError):
- _validate_ndjson([pred], module_project)
-
-
-def test_invalid_polygon(polygon_inference, module_project):
- # Only two points
- pred = polygon_inference[0].copy()
- pred["polygon"] = [{"x": 100, "y": 100}, {"x": 200, "y": 200}]
- with pytest.raises(MALValidationError):
- _validate_ndjson([pred], module_project)
-
-
-@pytest.mark.parametrize("configured_project", [MediaType.Text], indirect=True)
-def test_incorrect_entity(entity_inference, configured_project):
- entity = entity_inference[0].copy()
- # Location cannot be a list
- entity["location"] = [0, 10]
- with pytest.raises(MALValidationError):
- _validate_ndjson([entity], configured_project)
-
- entity["location"] = {"start": -1, "end": 5}
- with pytest.raises(MALValidationError):
- _validate_ndjson([entity], configured_project)
-
- entity["location"] = {"start": 15, "end": 5}
- with pytest.raises(MALValidationError):
- _validate_ndjson([entity], configured_project)
-
-
-@pytest.mark.skip(
- "Test wont work/fails randomly since projects have to have a media type and could be missing features from prediction list"
-)
-def test_all_validate_json(module_project, predictions):
- # Predictions contains one of each type of prediction.
- # These should be properly formatted and pass.
- _validate_ndjson(predictions[0], module_project)
-
-
-def test_incorrect_line(line_inference, module_project):
- line = line_inference[0].copy()
- line["line"] = [line["line"][0]] # Just one point
- with pytest.raises(MALValidationError):
- _validate_ndjson([line], module_project)
-
-
-def test_incorrect_rectangle(rectangle_inference, module_project):
- del rectangle_inference[0]["bbox"]["top"]
- with pytest.raises(MALValidationError):
- _validate_ndjson([rectangle_inference], module_project)
-
-
-def test_duplicate_tools(rectangle_inference, module_project):
- pred = rectangle_inference[0].copy()
- pred["polygon"] = [{"x": 100, "y": 100}, {"x": 200, "y": 200}]
- with pytest.raises(MALValidationError):
- _validate_ndjson([pred], module_project)
-
-
-def test_invalid_feature_schema(module_project, rectangle_inference):
- pred = rectangle_inference[0].copy()
- pred["schemaId"] = "blahblah"
- with pytest.raises(MALValidationError):
- _validate_ndjson([pred], module_project)
-
-
-def test_name_only_feature_schema(module_project, rectangle_inference):
- pred = rectangle_inference[0].copy()
- _validate_ndjson([pred], module_project)
-
-
-def test_schema_id_only_feature_schema(module_project, rectangle_inference):
- pred = rectangle_inference[0].copy()
- del pred["name"]
- ontology = module_project.ontology().normalized["tools"]
- for tool in ontology:
- if tool["name"] == "bbox":
- feature_schema_id = tool["featureSchemaId"]
- pred["schemaId"] = feature_schema_id
- _validate_ndjson([pred], module_project)
-
-
-def test_missing_feature_schema(module_project, rectangle_inference):
- pred = rectangle_inference[0].copy()
- del pred["name"]
- with pytest.raises(MALValidationError):
- _validate_ndjson([pred], module_project)
-
-
-def test_validate_ndjson(tmp_path, configured_project):
- file_name = f"broken.ndjson"
- file_path = tmp_path / file_name
- with file_path.open("w") as f:
- f.write("test")
-
- with pytest.raises(ValueError):
- configured_project.upload_annotations(
- name="name", annotations=str(file_path), validate=True
- )
-
-
-def test_validate_ndjson_uuid(tmp_path, configured_project, predictions):
- file_name = f"repeat_uuid.ndjson"
- file_path = tmp_path / file_name
- repeat_uuid = predictions.copy()
- repeat_uuid[0]["uuid"] = "test_uuid"
- repeat_uuid[1]["uuid"] = "test_uuid"
-
- with file_path.open("w") as f:
- parser.dump(repeat_uuid, f)
-
- with pytest.raises(MALValidationError):
- configured_project.upload_annotations(
- name="name", validate=True, annotations=str(file_path)
- )
-
- with pytest.raises(MALValidationError):
- configured_project.upload_annotations(
- name="name", validate=True, annotations=repeat_uuid
- )
-
-
-@pytest.mark.parametrize("configured_project", [MediaType.Video], indirect=True)
-def test_video_upload(video_checklist_inference, configured_project):
- pred = video_checklist_inference[0].copy()
- _validate_ndjson([pred], configured_project)
diff --git a/libs/labelbox/tests/data/annotation_types/data/test_raster.py b/libs/labelbox/tests/data/annotation_types/data/test_raster.py
index 6bc8f2bbf..209419aed 100644
--- a/libs/labelbox/tests/data/annotation_types/data/test_raster.py
+++ b/libs/labelbox/tests/data/annotation_types/data/test_raster.py
@@ -5,34 +5,28 @@
import pytest
from PIL import Image
-from labelbox.data.annotation_types.data import ImageData
+from labelbox.data.annotation_types.data import GenericDataRowData, MaskData
from pydantic import ValidationError
def test_validate_schema():
with pytest.raises(ValidationError):
- data = ImageData()
+ MaskData()
def test_im_bytes():
data = (np.random.random((32, 32, 3)) * 255).astype(np.uint8)
im_bytes = BytesIO()
Image.fromarray(data).save(im_bytes, format="PNG")
- raster_data = ImageData(im_bytes=im_bytes.getvalue())
+ raster_data = MaskData(im_bytes=im_bytes.getvalue())
data_ = raster_data.value
assert np.all(data == data_)
def test_im_url():
- raster_data = ImageData(url="https://picsum.photos/id/829/200/300")
- data_ = raster_data.value
- assert data_.shape == (300, 200, 3)
-
-
-def test_im_path():
- img_path = "/tmp/img.jpg"
- urllib.request.urlretrieve("https://picsum.photos/id/829/200/300", img_path)
- raster_data = ImageData(file_path=img_path)
+ raster_data = MaskData(
+ uid="test", url="https://picsum.photos/id/829/200/300"
+ )
data_ = raster_data.value
assert data_.shape == (300, 200, 3)
@@ -42,14 +36,11 @@ def test_ref():
uid = "uid"
metadata = []
media_attributes = {}
- data = ImageData(
- im_bytes=b"",
- external_id=external_id,
+ data = GenericDataRowData(
uid=uid,
metadata=metadata,
media_attributes=media_attributes,
)
- assert data.external_id == external_id
assert data.uid == uid
assert data.media_attributes == media_attributes
assert data.metadata == metadata
diff --git a/libs/labelbox/tests/data/annotation_types/data/test_text.py b/libs/labelbox/tests/data/annotation_types/data/test_text.py
deleted file mode 100644
index 865f93e65..000000000
--- a/libs/labelbox/tests/data/annotation_types/data/test_text.py
+++ /dev/null
@@ -1,55 +0,0 @@
-import os
-
-import pytest
-
-from labelbox.data.annotation_types import TextData
-from pydantic import ValidationError
-
-
-def test_validate_schema():
- with pytest.raises(ValidationError):
- data = TextData()
-
-
-def test_text():
- text = "hello world"
- metadata = []
- media_attributes = {}
- text_data = TextData(
- text=text, metadata=metadata, media_attributes=media_attributes
- )
- assert text_data.text == text
-
-
-def test_url():
- url = "https://storage.googleapis.com/lb-artifacts-testing-public/sdk_integration_test/sample3.txt"
- text_data = TextData(url=url)
- text = text_data.value
- assert len(text) == 3541
-
-
-def test_file(tmpdir):
- content = "foo bar baz"
- file = "hello.txt"
- dir = tmpdir.mkdir("data")
- dir.join(file).write(content)
- text_data = TextData(file_path=os.path.join(dir.strpath, file))
- assert len(text_data.value) == len(content)
-
-
-def test_ref():
- external_id = "external_id"
- uid = "uid"
- metadata = []
- media_attributes = {}
- data = TextData(
- text="hello world",
- external_id=external_id,
- uid=uid,
- metadata=metadata,
- media_attributes=media_attributes,
- )
- assert data.external_id == external_id
- assert data.uid == uid
- assert data.media_attributes == media_attributes
- assert data.metadata == metadata
diff --git a/libs/labelbox/tests/data/annotation_types/data/test_video.py b/libs/labelbox/tests/data/annotation_types/data/test_video.py
deleted file mode 100644
index 5fd77c2c8..000000000
--- a/libs/labelbox/tests/data/annotation_types/data/test_video.py
+++ /dev/null
@@ -1,73 +0,0 @@
-import numpy as np
-import pytest
-
-from labelbox.data.annotation_types import VideoData
-from pydantic import ValidationError
-
-
-def test_validate_schema():
- with pytest.raises(ValidationError):
- data = VideoData()
-
-
-def test_frames():
- data = {
- x: (np.random.random((32, 32, 3)) * 255).astype(np.uint8)
- for x in range(5)
- }
- video_data = VideoData(frames=data)
- for idx, frame in video_data.frame_generator():
- assert idx in data
- assert np.all(frame == data[idx])
-
-
-def test_file_path():
- path = "tests/integration/media/cat.mp4"
- raster_data = VideoData(file_path=path)
-
- with pytest.raises(ValueError):
- raster_data[0]
-
- raster_data.load_frames()
- raster_data[0]
-
- frame_indices = list(raster_data.frames.keys())
- # 29 frames
- assert set(frame_indices) == set(list(range(28)))
-
-
-def test_file_url():
- url = "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/ForBiggerMeltdowns.mp4"
- raster_data = VideoData(url=url)
-
- with pytest.raises(ValueError):
- raster_data[0]
-
- raster_data.load_frames()
- raster_data[0]
-
- frame_indices = list(raster_data.frames.keys())
- # 362 frames
- assert set(frame_indices) == set(list(range(361)))
-
-
-def test_ref():
- external_id = "external_id"
- uid = "uid"
- data = {
- x: (np.random.random((32, 32, 3)) * 255).astype(np.uint8)
- for x in range(5)
- }
- metadata = []
- media_attributes = {}
- data = VideoData(
- frames=data,
- external_id=external_id,
- uid=uid,
- metadata=metadata,
- media_attributes=media_attributes,
- )
- assert data.external_id == external_id
- assert data.uid == uid
- assert data.media_attributes == media_attributes
- assert data.metadata == metadata
diff --git a/libs/labelbox/tests/data/annotation_types/test_annotation.py b/libs/labelbox/tests/data/annotation_types/test_annotation.py
index 8cdeac9ba..01547bd56 100644
--- a/libs/labelbox/tests/data/annotation_types/test_annotation.py
+++ b/libs/labelbox/tests/data/annotation_types/test_annotation.py
@@ -1,18 +1,20 @@
import pytest
+from lbox.exceptions import ConfidenceNotSupportedException
+from pydantic import ValidationError
from labelbox.data.annotation_types import (
- Text,
- Point,
- Line,
ClassificationAnnotation,
+ Line,
ObjectAnnotation,
+ Point,
+ Text,
TextEntity,
)
-from labelbox.data.annotation_types.video import VideoObjectAnnotation
from labelbox.data.annotation_types.geometry.rectangle import Rectangle
-from labelbox.data.annotation_types.video import VideoClassificationAnnotation
-from labelbox.exceptions import ConfidenceNotSupportedException
-from pydantic import ValidationError
+from labelbox.data.annotation_types.video import (
+ VideoClassificationAnnotation,
+ VideoObjectAnnotation,
+)
def test_annotation():
diff --git a/libs/labelbox/tests/data/annotation_types/test_collection.py b/libs/labelbox/tests/data/annotation_types/test_collection.py
index 9deddc3c8..e0fa7bd53 100644
--- a/libs/labelbox/tests/data/annotation_types/test_collection.py
+++ b/libs/labelbox/tests/data/annotation_types/test_collection.py
@@ -7,19 +7,21 @@
from labelbox.data.annotation_types import (
LabelGenerator,
ObjectAnnotation,
- ImageData,
- MaskData,
Line,
Mask,
Point,
Label,
+ GenericDataRowData,
+ MaskData,
)
from labelbox import OntologyBuilder, Tool
@pytest.fixture
def list_of_labels():
- return [Label(data=ImageData(url="http://someurl")) for _ in range(5)]
+ return [
+ Label(data=GenericDataRowData(uid="http://someurl")) for _ in range(5)
+ ]
@pytest.fixture
@@ -70,58 +72,9 @@ def test_conversion(list_of_labels):
assert [x for x in label_collection] == list_of_labels
-def test_adding_schema_ids():
- name = "line_feature"
- label = Label(
- data=ImageData(arr=np.ones((32, 32, 3), dtype=np.uint8)),
- annotations=[
- ObjectAnnotation(
- value=Line(points=[Point(x=1, y=2), Point(x=2, y=2)]),
- name=name,
- )
- ],
- )
- feature_schema_id = "expected_id"
- ontology = OntologyBuilder(
- tools=[
- Tool(Tool.Type.LINE, name=name, feature_schema_id=feature_schema_id)
- ]
- )
- generator = LabelGenerator([label]).assign_feature_schema_ids(ontology)
- assert next(generator).annotations[0].feature_schema_id == feature_schema_id
-
-
-def test_adding_urls(signer):
- label = Label(
- data=ImageData(arr=np.random.random((32, 32, 3)).astype(np.uint8)),
- annotations=[],
- )
- uuid = str(uuid4())
- generator = LabelGenerator([label]).add_url_to_data(signer(uuid))
- assert label.data.url != uuid
- assert next(generator).data.url == uuid
- assert label.data.url == uuid
-
-
-def test_adding_to_dataset(signer):
- dataset = FakeDataset()
- label = Label(
- data=ImageData(arr=np.random.random((32, 32, 3)).astype(np.uint8)),
- annotations=[],
- )
- uuid = str(uuid4())
- generator = LabelGenerator([label]).add_to_dataset(dataset, signer(uuid))
- assert label.data.url != uuid
- generated_label = next(generator)
- assert generated_label.data.url == uuid
- assert generated_label.data.external_id != None
- assert generated_label.data.uid == dataset.uid
- assert label.data.url == uuid
-
-
def test_adding_to_masks(signer):
label = Label(
- data=ImageData(arr=np.random.random((32, 32, 3)).astype(np.uint8)),
+ data=GenericDataRowData(uid="12345"),
annotations=[
ObjectAnnotation(
name="1234",
diff --git a/libs/labelbox/tests/data/annotation_types/test_label.py b/libs/labelbox/tests/data/annotation_types/test_label.py
index 5bdfb6bde..9cd992b0c 100644
--- a/libs/labelbox/tests/data/annotation_types/test_label.py
+++ b/libs/labelbox/tests/data/annotation_types/test_label.py
@@ -17,220 +17,16 @@
ObjectAnnotation,
Point,
Line,
- ImageData,
+ MaskData,
Label,
)
import pytest
-def test_schema_assignment_geometry():
- name = "line_feature"
- label = Label(
- data=ImageData(arr=np.ones((32, 32, 3), dtype=np.uint8)),
- annotations=[
- ObjectAnnotation(
- value=Line(points=[Point(x=1, y=2), Point(x=2, y=2)]),
- name=name,
- )
- ],
- )
- feature_schema_id = "expected_id"
- ontology = OntologyBuilder(
- tools=[
- Tool(Tool.Type.LINE, name=name, feature_schema_id=feature_schema_id)
- ]
- )
- label.assign_feature_schema_ids(ontology)
-
- assert label.annotations[0].feature_schema_id == feature_schema_id
-
-
-def test_schema_assignment_classification():
- radio_name = "radio_name"
- text_name = "text_name"
- option_name = "my_option"
-
- label = Label(
- data=ImageData(arr=np.ones((32, 32, 3), dtype=np.uint8)),
- annotations=[
- ClassificationAnnotation(
- value=Radio(answer=ClassificationAnswer(name=option_name)),
- name=radio_name,
- ),
- ClassificationAnnotation(
- value=Text(answer="some text"), name=text_name
- ),
- ],
- )
- radio_schema_id = "radio_schema_id"
- text_schema_id = "text_schema_id"
- option_schema_id = "option_schema_id"
- ontology = OntologyBuilder(
- tools=[],
- classifications=[
- OClassification(
- class_type=OClassification.Type.RADIO,
- name=radio_name,
- feature_schema_id=radio_schema_id,
- options=[
- Option(
- value=option_name, feature_schema_id=option_schema_id
- )
- ],
- ),
- OClassification(
- class_type=OClassification.Type.TEXT,
- name=text_name,
- feature_schema_id=text_schema_id,
- ),
- ],
- )
- label.assign_feature_schema_ids(ontology)
- assert label.annotations[0].feature_schema_id == radio_schema_id
- assert label.annotations[1].feature_schema_id == text_schema_id
- assert (
- label.annotations[0].value.answer.feature_schema_id == option_schema_id
- )
-
-
-def test_schema_assignment_subclass():
- name = "line_feature"
- radio_name = "radio_name"
- option_name = "my_option"
- classification = ClassificationAnnotation(
- name=radio_name,
- value=Radio(answer=ClassificationAnswer(name=option_name)),
- )
- label = Label(
- data=ImageData(arr=np.ones((32, 32, 3), dtype=np.uint8)),
- annotations=[
- ObjectAnnotation(
- value=Line(points=[Point(x=1, y=2), Point(x=2, y=2)]),
- name=name,
- classifications=[classification],
- )
- ],
- )
- feature_schema_id = "expected_id"
- classification_schema_id = "classification_id"
- option_schema_id = "option_schema_id"
- ontology = OntologyBuilder(
- tools=[
- Tool(
- Tool.Type.LINE,
- name=name,
- feature_schema_id=feature_schema_id,
- classifications=[
- OClassification(
- class_type=OClassification.Type.RADIO,
- name=radio_name,
- feature_schema_id=classification_schema_id,
- options=[
- Option(
- value=option_name,
- feature_schema_id=option_schema_id,
- )
- ],
- )
- ],
- )
- ]
- )
- label.assign_feature_schema_ids(ontology)
- assert label.annotations[0].feature_schema_id == feature_schema_id
- assert (
- label.annotations[0].classifications[0].feature_schema_id
- == classification_schema_id
- )
- assert (
- label.annotations[0].classifications[0].value.answer.feature_schema_id
- == option_schema_id
- )
-
-
-def test_highly_nested():
- name = "line_feature"
- radio_name = "radio_name"
- nested_name = "nested_name"
- option_name = "my_option"
- nested_option_name = "nested_option_name"
- classification = ClassificationAnnotation(
- name=radio_name,
- value=Radio(answer=ClassificationAnswer(name=option_name)),
- classifications=[
- ClassificationAnnotation(
- value=Radio(
- answer=ClassificationAnswer(name=nested_option_name)
- ),
- name=nested_name,
- )
- ],
- )
- label = Label(
- data=ImageData(arr=np.ones((32, 32, 3), dtype=np.uint8)),
- annotations=[
- ObjectAnnotation(
- value=Line(points=[Point(x=1, y=2), Point(x=2, y=2)]),
- name=name,
- classifications=[classification],
- )
- ],
- )
- feature_schema_id = "expected_id"
- classification_schema_id = "classification_id"
- nested_classification_schema_id = "nested_classification_schema_id"
- option_schema_id = "option_schema_id"
- ontology = OntologyBuilder(
- tools=[
- Tool(
- Tool.Type.LINE,
- name=name,
- feature_schema_id=feature_schema_id,
- classifications=[
- OClassification(
- class_type=OClassification.Type.RADIO,
- name=radio_name,
- feature_schema_id=classification_schema_id,
- options=[
- Option(
- value=option_name,
- feature_schema_id=option_schema_id,
- options=[
- OClassification(
- class_type=OClassification.Type.RADIO,
- name=nested_name,
- feature_schema_id=nested_classification_schema_id,
- options=[
- Option(
- value=nested_option_name,
- feature_schema_id=nested_classification_schema_id,
- )
- ],
- )
- ],
- )
- ],
- )
- ],
- )
- ]
- )
- label.assign_feature_schema_ids(ontology)
- assert label.annotations[0].feature_schema_id == feature_schema_id
- assert (
- label.annotations[0].classifications[0].feature_schema_id
- == classification_schema_id
- )
- assert (
- label.annotations[0].classifications[0].value.answer.feature_schema_id
- == option_schema_id
- )
-
-
def test_schema_assignment_confidence():
name = "line_feature"
label = Label(
- data=ImageData(arr=np.ones((32, 32, 3), dtype=np.uint8)),
+ data=MaskData(arr=np.ones((32, 32, 3), dtype=np.uint8), uid="test"),
annotations=[
ObjectAnnotation(
value=Line(
@@ -252,10 +48,10 @@ def test_initialize_label_no_coercion():
value=lb_types.ConversationEntity(start=0, end=8, message_id="4"),
)
label = Label(
- data=lb_types.ConversationData(global_key=global_key),
+ data=lb_types.GenericDataRowData(global_key=global_key),
annotations=[ner_annotation],
)
- assert isinstance(label.data, lb_types.ConversationData)
+ assert isinstance(label.data, lb_types.GenericDataRowData)
assert label.data.global_key == global_key
diff --git a/libs/labelbox/tests/data/annotation_types/test_metrics.py b/libs/labelbox/tests/data/annotation_types/test_metrics.py
index 94c9521a5..4e9355573 100644
--- a/libs/labelbox/tests/data/annotation_types/test_metrics.py
+++ b/libs/labelbox/tests/data/annotation_types/test_metrics.py
@@ -8,7 +8,11 @@
ConfusionMatrixMetric,
ScalarMetric,
)
-from labelbox.data.annotation_types import ScalarMetric, Label, ImageData
+from labelbox.data.annotation_types import (
+ ScalarMetric,
+ Label,
+ GenericDataRowData,
+)
from labelbox.data.annotation_types.metrics.scalar import RESERVED_METRIC_NAMES
from pydantic import ValidationError
@@ -19,7 +23,8 @@ def test_legacy_scalar_metric():
assert metric.value == value
label = Label(
- data=ImageData(uid="ckrmd9q8g000009mg6vej7hzg"), annotations=[metric]
+ data=GenericDataRowData(uid="ckrmd9q8g000009mg6vej7hzg"),
+ annotations=[metric],
)
expected = {
"data": {
@@ -72,7 +77,8 @@ def test_custom_scalar_metric(feature_name, subclass_name, aggregation, value):
assert metric.value == value
label = Label(
- data=ImageData(uid="ckrmd9q8g000009mg6vej7hzg"), annotations=[metric]
+ data=GenericDataRowData(uid="ckrmd9q8g000009mg6vej7hzg"),
+ annotations=[metric],
)
expected = {
"data": {
@@ -134,7 +140,8 @@ def test_custom_confusison_matrix_metric(
assert metric.value == value
label = Label(
- data=ImageData(uid="ckrmd9q8g000009mg6vej7hzg"), annotations=[metric]
+ data=GenericDataRowData(uid="ckrmd9q8g000009mg6vej7hzg"),
+ annotations=[metric],
)
expected = {
"data": {
diff --git a/libs/labelbox/tests/data/assets/ndjson/classification_import_global_key.json b/libs/labelbox/tests/data/assets/ndjson/classification_import_global_key.json
deleted file mode 100644
index 4de15e217..000000000
--- a/libs/labelbox/tests/data/assets/ndjson/classification_import_global_key.json
+++ /dev/null
@@ -1,54 +0,0 @@
-[
- {
- "answer": {
- "schemaId": "ckrb1sfl8099g0y91cxbd5ftb",
- "confidence": 0.8,
- "customMetrics": [
- {
- "name": "customMetric1",
- "value": 0.5
- },
- {
- "name": "customMetric2",
- "value": 0.3
- }
- ]
- },
- "schemaId": "ckrb1sfl8099g0y91cxbd5ftb",
- "dataRow": {
- "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d"
- },
- "uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673"
- },
- {
- "answer": [
- {
- "schemaId": "ckrb1sfl8099e0y919v260awv",
- "confidence": 0.82,
- "customMetrics": [
- {
- "name": "customMetric1",
- "value": 0.5
- },
- {
- "name": "customMetric2",
- "value": 0.3
- }
- ]
- }
- ],
- "schemaId": "ckrb1sfkn099c0y910wbo0p1a",
- "dataRow": {
- "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d"
- },
- "uuid": "d009925d-91a3-4f67-abd9-753453f5a584"
- },
- {
- "answer": "a value",
- "schemaId": "ckrb1sfkn099c0y910wbo0p1a",
- "dataRow": {
- "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d"
- },
- "uuid": "ee70fd88-9f88-48dd-b760-7469ff479b71"
- }
-]
\ No newline at end of file
diff --git a/libs/labelbox/tests/data/assets/ndjson/conversation_entity_import_global_key.json b/libs/labelbox/tests/data/assets/ndjson/conversation_entity_import_global_key.json
deleted file mode 100644
index 83a95e5bf..000000000
--- a/libs/labelbox/tests/data/assets/ndjson/conversation_entity_import_global_key.json
+++ /dev/null
@@ -1,25 +0,0 @@
-[{
- "location": {
- "start": 67,
- "end": 128
- },
- "messageId": "some-message-id",
- "uuid": "5ad9c52f-058d-49c8-a749-3f20b84f8cd4",
- "dataRow": {
- "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d"
- },
- "name": "some-text-entity",
- "schemaId": "cl6xnuwt95lqq07330tbb3mfd",
- "classifications": [],
- "confidence": 0.53,
- "customMetrics": [
- {
- "name": "customMetric1",
- "value": 0.5
- },
- {
- "name": "customMetric2",
- "value": 0.3
- }
- ]
-}]
diff --git a/libs/labelbox/tests/data/assets/ndjson/image_import.json b/libs/labelbox/tests/data/assets/ndjson/image_import.json
index 91563b8ae..75fe36e44 100644
--- a/libs/labelbox/tests/data/assets/ndjson/image_import.json
+++ b/libs/labelbox/tests/data/assets/ndjson/image_import.json
@@ -8,16 +8,17 @@
"confidence": 0.851,
"customMetrics": [
{
- "name": "customMetric1",
- "value": 0.4
+ "name": "customMetric1",
+ "value": 0.4
}
],
"bbox": {
- "top": 1352,
- "left": 2275,
- "height": 350,
- "width": 139
- }
+ "top": 1352.0,
+ "left": 2275.0,
+ "height": 350.0,
+ "width": 139.0
+ },
+ "classifications": []
},
{
"uuid": "751fc725-f7b6-48ed-89b0-dd7d94d08af6",
@@ -28,20 +29,17 @@
"confidence": 0.834,
"customMetrics": [
{
- "name": "customMetric1",
- "value": 0.3
+ "name": "customMetric1",
+ "value": 0.3
}
],
"mask": {
- "instanceURI": "https://storage.labelbox.com/ckqcx1czn06830y61gh9v02cs%2F3e729327-f038-f66c-186e-45e921ef9717-1?Expires=1626806874672&KeyName=labelbox-assets-key-3&Signature=YsUOGKrsqmAZ68vT9BlPJOaRyLY",
- "colorRGB": [
- 255,
- 0,
- 0
- ]
- }
+ "instanceURI": "https://storage.labelbox.com/ckqcx1czn06830y61gh9v02cs%2F3e729327-f038-f66c-186e-45e921ef9717-1?Expires=1626806874672&KeyName=labelbox-assets-key-3&Signature=YsUOGKrsqmAZ68vT9BlPJOaRyLY"
+ },
+ "classifications": []
},
{
+ "classifications": [],
"uuid": "43d719ac-5d7f-4aea-be00-2ebfca0900fd",
"schemaId": "ckrazcuec16oi0z66dzrd8pfl",
"dataRow": {
@@ -50,762 +48,39 @@
"confidence": 0.986,
"customMetrics": [
{
- "name": "customMetric1",
- "value": 0.9
+ "name": "customMetric1",
+ "value": 0.9
}
],
"polygon": [
{
- "x": 1118,
- "y": 935
- },
- {
- "x": 1117,
- "y": 935
- },
- {
- "x": 1116,
- "y": 935
- },
- {
- "x": 1115,
- "y": 935
- },
- {
- "x": 1114,
- "y": 935
- },
- {
- "x": 1113,
- "y": 935
- },
- {
- "x": 1112,
- "y": 935
- },
- {
- "x": 1111,
- "y": 935
- },
- {
- "x": 1110,
- "y": 935
- },
- {
- "x": 1109,
- "y": 935
- },
- {
- "x": 1108,
- "y": 935
- },
- {
- "x": 1108,
- "y": 934
- },
- {
- "x": 1107,
- "y": 934
- },
- {
- "x": 1106,
- "y": 934
- },
- {
- "x": 1105,
- "y": 934
- },
- {
- "x": 1105,
- "y": 933
- },
- {
- "x": 1104,
- "y": 933
- },
- {
- "x": 1103,
- "y": 933
- },
- {
- "x": 1103,
- "y": 932
- },
- {
- "x": 1102,
- "y": 932
- },
- {
- "x": 1101,
- "y": 932
- },
- {
- "x": 1100,
- "y": 932
- },
- {
- "x": 1099,
- "y": 932
- },
- {
- "x": 1098,
- "y": 932
- },
- {
- "x": 1097,
- "y": 932
- },
- {
- "x": 1097,
- "y": 931
- },
- {
- "x": 1096,
- "y": 931
- },
- {
- "x": 1095,
- "y": 931
- },
- {
- "x": 1094,
- "y": 931
- },
- {
- "x": 1093,
- "y": 931
- },
- {
- "x": 1092,
- "y": 931
- },
- {
- "x": 1091,
- "y": 931
- },
- {
- "x": 1090,
- "y": 931
- },
- {
- "x": 1090,
- "y": 930
- },
- {
- "x": 1089,
- "y": 930
- },
- {
- "x": 1088,
- "y": 930
- },
- {
- "x": 1087,
- "y": 930
- },
- {
- "x": 1087,
- "y": 929
- },
- {
- "x": 1086,
- "y": 929
- },
- {
- "x": 1085,
- "y": 929
- },
- {
- "x": 1084,
- "y": 929
- },
- {
- "x": 1084,
- "y": 928
- },
- {
- "x": 1083,
- "y": 928
- },
- {
- "x": 1083,
- "y": 927
- },
- {
- "x": 1082,
- "y": 927
- },
- {
- "x": 1081,
- "y": 927
- },
- {
- "x": 1081,
- "y": 926
- },
- {
- "x": 1080,
- "y": 926
- },
- {
- "x": 1080,
- "y": 925
- },
- {
- "x": 1079,
- "y": 925
- },
- {
- "x": 1078,
- "y": 925
- },
- {
- "x": 1078,
- "y": 924
- },
- {
- "x": 1077,
- "y": 924
- },
- {
- "x": 1076,
- "y": 924
- },
- {
- "x": 1076,
- "y": 923
- },
- {
- "x": 1075,
- "y": 923
- },
- {
- "x": 1074,
- "y": 923
- },
- {
- "x": 1073,
- "y": 923
- },
- {
- "x": 1073,
- "y": 922
- },
- {
- "x": 1072,
- "y": 922
- },
- {
- "x": 1071,
- "y": 922
- },
- {
- "x": 1070,
- "y": 922
- },
- {
- "x": 1070,
- "y": 921
- },
- {
- "x": 1069,
- "y": 921
- },
- {
- "x": 1068,
- "y": 921
- },
- {
- "x": 1067,
- "y": 921
- },
- {
- "x": 1066,
- "y": 921
- },
- {
- "x": 1065,
- "y": 921
- },
- {
- "x": 1064,
- "y": 921
- },
- {
- "x": 1063,
- "y": 921
- },
- {
- "x": 1062,
- "y": 921
- },
- {
- "x": 1061,
- "y": 921
- },
- {
- "x": 1060,
- "y": 921
- },
- {
- "x": 1059,
- "y": 921
- },
- {
- "x": 1058,
- "y": 921
- },
- {
- "x": 1058,
- "y": 920
- },
- {
- "x": 1057,
- "y": 920
- },
- {
- "x": 1057,
- "y": 919
- },
- {
- "x": 1056,
- "y": 919
- },
- {
- "x": 1057,
- "y": 918
- },
- {
- "x": 1057,
- "y": 918
- },
- {
- "x": 1057,
- "y": 917
- },
- {
- "x": 1058,
- "y": 916
- },
- {
- "x": 1058,
- "y": 916
- },
- {
- "x": 1059,
- "y": 915
- },
- {
- "x": 1059,
- "y": 915
- },
- {
- "x": 1060,
- "y": 914
- },
- {
- "x": 1060,
- "y": 914
- },
- {
- "x": 1061,
- "y": 913
- },
- {
- "x": 1061,
- "y": 913
- },
- {
- "x": 1062,
- "y": 912
- },
- {
- "x": 1063,
- "y": 912
- },
- {
- "x": 1063,
- "y": 912
- },
- {
- "x": 1064,
- "y": 911
- },
- {
- "x": 1064,
- "y": 911
- },
- {
- "x": 1065,
- "y": 910
- },
- {
- "x": 1066,
- "y": 910
- },
- {
- "x": 1066,
- "y": 910
- },
- {
- "x": 1067,
- "y": 909
- },
- {
- "x": 1068,
- "y": 909
- },
- {
- "x": 1068,
- "y": 909
- },
- {
- "x": 1069,
- "y": 908
- },
- {
- "x": 1070,
- "y": 908
- },
- {
- "x": 1071,
- "y": 908
- },
- {
- "x": 1072,
- "y": 908
- },
- {
- "x": 1072,
- "y": 908
- },
- {
- "x": 1073,
- "y": 907
- },
- {
- "x": 1074,
- "y": 907
- },
- {
- "x": 1075,
- "y": 907
- },
- {
- "x": 1076,
- "y": 907
- },
- {
- "x": 1077,
- "y": 907
- },
- {
- "x": 1078,
- "y": 907
- },
- {
- "x": 1079,
- "y": 907
- },
- {
- "x": 1080,
- "y": 907
- },
- {
- "x": 1081,
- "y": 907
- },
- {
- "x": 1082,
- "y": 907
- },
- {
- "x": 1083,
- "y": 907
- },
- {
- "x": 1084,
- "y": 907
- },
- {
- "x": 1085,
- "y": 907
- },
- {
- "x": 1086,
- "y": 907
- },
- {
- "x": 1087,
- "y": 907
- },
- {
- "x": 1088,
- "y": 907
- },
- {
- "x": 1089,
- "y": 907
- },
- {
- "x": 1090,
- "y": 907
- },
- {
- "x": 1091,
- "y": 907
- },
- {
- "x": 1091,
- "y": 908
- },
- {
- "x": 1092,
- "y": 908
- },
- {
- "x": 1093,
- "y": 908
- },
- {
- "x": 1094,
- "y": 908
- },
- {
- "x": 1095,
- "y": 908
- },
- {
- "x": 1095,
- "y": 909
- },
- {
- "x": 1096,
- "y": 909
- },
- {
- "x": 1097,
- "y": 909
- },
- {
- "x": 1097,
- "y": 910
- },
- {
- "x": 1098,
- "y": 910
- },
- {
- "x": 1099,
- "y": 910
- },
- {
- "x": 1099,
- "y": 911
- },
- {
- "x": 1100,
- "y": 911
- },
- {
- "x": 1101,
- "y": 911
- },
- {
- "x": 1101,
- "y": 912
- },
- {
- "x": 1102,
- "y": 912
- },
- {
- "x": 1103,
- "y": 912
- },
- {
- "x": 1103,
- "y": 913
- },
- {
- "x": 1104,
- "y": 913
- },
- {
- "x": 1104,
- "y": 914
- },
- {
- "x": 1105,
- "y": 914
- },
- {
- "x": 1105,
- "y": 915
- },
- {
- "x": 1106,
- "y": 915
- },
- {
- "x": 1107,
- "y": 915
- },
- {
- "x": 1107,
- "y": 916
- },
- {
- "x": 1108,
- "y": 916
- },
- {
- "x": 1108,
- "y": 917
- },
- {
- "x": 1109,
- "y": 917
- },
- {
- "x": 1109,
- "y": 918
- },
- {
- "x": 1110,
- "y": 918
- },
- {
- "x": 1110,
- "y": 919
- },
- {
- "x": 1111,
- "y": 919
- },
- {
- "x": 1111,
- "y": 920
- },
- {
- "x": 1112,
- "y": 920
- },
- {
- "x": 1112,
- "y": 921
- },
- {
- "x": 1113,
- "y": 921
- },
- {
- "x": 1113,
- "y": 922
- },
- {
- "x": 1114,
- "y": 922
- },
- {
- "x": 1114,
- "y": 923
- },
- {
- "x": 1115,
- "y": 923
- },
- {
- "x": 1115,
- "y": 924
- },
- {
- "x": 1115,
- "y": 925
- },
- {
- "x": 1116,
- "y": 925
- },
- {
- "x": 1116,
- "y": 926
- },
- {
- "x": 1117,
- "y": 926
- },
- {
- "x": 1117,
- "y": 927
- },
- {
- "x": 1117,
- "y": 928
- },
- {
- "x": 1118,
- "y": 928
- },
- {
- "x": 1118,
- "y": 929
- },
- {
- "x": 1119,
- "y": 929
- },
- {
- "x": 1119,
- "y": 930
- },
- {
- "x": 1120,
- "y": 930
- },
- {
- "x": 1120,
- "y": 931
- },
- {
- "x": 1120,
- "y": 932
- },
- {
- "x": 1120,
- "y": 932
- },
- {
- "x": 1119,
- "y": 933
- },
- {
- "x": 1119,
- "y": 934
+ "x": 10.0,
+ "y": 20.0
},
{
- "x": 1119,
- "y": 934
+ "x": 15.0,
+ "y": 20.0
},
{
- "x": 1118,
- "y": 935
+ "x": 20.0,
+ "y": 25.0
},
{
- "x": 1118,
- "y": 935
+ "x": 10.0,
+ "y": 20.0
}
]
},
{
+ "classifications": [],
"uuid": "b98f3a45-3328-41a0-9077-373a8177ebf2",
"schemaId": "ckrazcuec16om0z66bhhh4tp7",
"dataRow": {
"id": "ckrazctum0z8a0ybc0b0o0g0v"
},
"point": {
- "x": 2122,
- "y": 1457
+ "x": 2122.0,
+ "y": 1457.0
}
}
]
\ No newline at end of file
diff --git a/libs/labelbox/tests/data/assets/ndjson/image_import_global_key.json b/libs/labelbox/tests/data/assets/ndjson/image_import_global_key.json
deleted file mode 100644
index 591e40cf6..000000000
--- a/libs/labelbox/tests/data/assets/ndjson/image_import_global_key.json
+++ /dev/null
@@ -1,823 +0,0 @@
-[
- {
- "uuid": "b862c586-8614-483c-b5e6-82810f70cac0",
- "schemaId": "ckrazcueb16og0z6609jj7y3y",
- "dataRow": {
- "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d"
- },
- "confidence": 0.851,
- "bbox": {
- "top": 1352,
- "left": 2275,
- "height": 350,
- "width": 139
- },
- "customMetrics": [
- {
- "name": "customMetric1",
- "value": 0.5
- },
- {
- "name": "customMetric2",
- "value": 0.3
- }
- ]
- },
- {
- "uuid": "751fc725-f7b6-48ed-89b0-dd7d94d08af6",
- "schemaId": "ckrazcuec16ok0z66f956apb7",
- "dataRow": {
- "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d"
- },
- "confidence": 0.834,
- "customMetrics": [
- {
- "name": "customMetric1",
- "value": 0.5
- },
- {
- "name": "customMetric2",
- "value": 0.3
- }
- ],
- "mask": {
- "instanceURI": "https://storage.labelbox.com/ckqcx1czn06830y61gh9v02cs%2F3e729327-f038-f66c-186e-45e921ef9717-1?Expires=1626806874672&KeyName=labelbox-assets-key-3&Signature=YsUOGKrsqmAZ68vT9BlPJOaRyLY",
- "colorRGB": [
- 255,
- 0,
- 0
- ]
- }
- },
- {
- "uuid": "43d719ac-5d7f-4aea-be00-2ebfca0900fd",
- "schemaId": "ckrazcuec16oi0z66dzrd8pfl",
- "dataRow": {
- "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d"
- },
- "confidence": 0.986,
- "customMetrics": [
- {
- "name": "customMetric1",
- "value": 0.5
- },
- {
- "name": "customMetric2",
- "value": 0.3
- }
- ],
- "polygon": [
- {
- "x": 1118,
- "y": 935
- },
- {
- "x": 1117,
- "y": 935
- },
- {
- "x": 1116,
- "y": 935
- },
- {
- "x": 1115,
- "y": 935
- },
- {
- "x": 1114,
- "y": 935
- },
- {
- "x": 1113,
- "y": 935
- },
- {
- "x": 1112,
- "y": 935
- },
- {
- "x": 1111,
- "y": 935
- },
- {
- "x": 1110,
- "y": 935
- },
- {
- "x": 1109,
- "y": 935
- },
- {
- "x": 1108,
- "y": 935
- },
- {
- "x": 1108,
- "y": 934
- },
- {
- "x": 1107,
- "y": 934
- },
- {
- "x": 1106,
- "y": 934
- },
- {
- "x": 1105,
- "y": 934
- },
- {
- "x": 1105,
- "y": 933
- },
- {
- "x": 1104,
- "y": 933
- },
- {
- "x": 1103,
- "y": 933
- },
- {
- "x": 1103,
- "y": 932
- },
- {
- "x": 1102,
- "y": 932
- },
- {
- "x": 1101,
- "y": 932
- },
- {
- "x": 1100,
- "y": 932
- },
- {
- "x": 1099,
- "y": 932
- },
- {
- "x": 1098,
- "y": 932
- },
- {
- "x": 1097,
- "y": 932
- },
- {
- "x": 1097,
- "y": 931
- },
- {
- "x": 1096,
- "y": 931
- },
- {
- "x": 1095,
- "y": 931
- },
- {
- "x": 1094,
- "y": 931
- },
- {
- "x": 1093,
- "y": 931
- },
- {
- "x": 1092,
- "y": 931
- },
- {
- "x": 1091,
- "y": 931
- },
- {
- "x": 1090,
- "y": 931
- },
- {
- "x": 1090,
- "y": 930
- },
- {
- "x": 1089,
- "y": 930
- },
- {
- "x": 1088,
- "y": 930
- },
- {
- "x": 1087,
- "y": 930
- },
- {
- "x": 1087,
- "y": 929
- },
- {
- "x": 1086,
- "y": 929
- },
- {
- "x": 1085,
- "y": 929
- },
- {
- "x": 1084,
- "y": 929
- },
- {
- "x": 1084,
- "y": 928
- },
- {
- "x": 1083,
- "y": 928
- },
- {
- "x": 1083,
- "y": 927
- },
- {
- "x": 1082,
- "y": 927
- },
- {
- "x": 1081,
- "y": 927
- },
- {
- "x": 1081,
- "y": 926
- },
- {
- "x": 1080,
- "y": 926
- },
- {
- "x": 1080,
- "y": 925
- },
- {
- "x": 1079,
- "y": 925
- },
- {
- "x": 1078,
- "y": 925
- },
- {
- "x": 1078,
- "y": 924
- },
- {
- "x": 1077,
- "y": 924
- },
- {
- "x": 1076,
- "y": 924
- },
- {
- "x": 1076,
- "y": 923
- },
- {
- "x": 1075,
- "y": 923
- },
- {
- "x": 1074,
- "y": 923
- },
- {
- "x": 1073,
- "y": 923
- },
- {
- "x": 1073,
- "y": 922
- },
- {
- "x": 1072,
- "y": 922
- },
- {
- "x": 1071,
- "y": 922
- },
- {
- "x": 1070,
- "y": 922
- },
- {
- "x": 1070,
- "y": 921
- },
- {
- "x": 1069,
- "y": 921
- },
- {
- "x": 1068,
- "y": 921
- },
- {
- "x": 1067,
- "y": 921
- },
- {
- "x": 1066,
- "y": 921
- },
- {
- "x": 1065,
- "y": 921
- },
- {
- "x": 1064,
- "y": 921
- },
- {
- "x": 1063,
- "y": 921
- },
- {
- "x": 1062,
- "y": 921
- },
- {
- "x": 1061,
- "y": 921
- },
- {
- "x": 1060,
- "y": 921
- },
- {
- "x": 1059,
- "y": 921
- },
- {
- "x": 1058,
- "y": 921
- },
- {
- "x": 1058,
- "y": 920
- },
- {
- "x": 1057,
- "y": 920
- },
- {
- "x": 1057,
- "y": 919
- },
- {
- "x": 1056,
- "y": 919
- },
- {
- "x": 1057,
- "y": 918
- },
- {
- "x": 1057,
- "y": 918
- },
- {
- "x": 1057,
- "y": 917
- },
- {
- "x": 1058,
- "y": 916
- },
- {
- "x": 1058,
- "y": 916
- },
- {
- "x": 1059,
- "y": 915
- },
- {
- "x": 1059,
- "y": 915
- },
- {
- "x": 1060,
- "y": 914
- },
- {
- "x": 1060,
- "y": 914
- },
- {
- "x": 1061,
- "y": 913
- },
- {
- "x": 1061,
- "y": 913
- },
- {
- "x": 1062,
- "y": 912
- },
- {
- "x": 1063,
- "y": 912
- },
- {
- "x": 1063,
- "y": 912
- },
- {
- "x": 1064,
- "y": 911
- },
- {
- "x": 1064,
- "y": 911
- },
- {
- "x": 1065,
- "y": 910
- },
- {
- "x": 1066,
- "y": 910
- },
- {
- "x": 1066,
- "y": 910
- },
- {
- "x": 1067,
- "y": 909
- },
- {
- "x": 1068,
- "y": 909
- },
- {
- "x": 1068,
- "y": 909
- },
- {
- "x": 1069,
- "y": 908
- },
- {
- "x": 1070,
- "y": 908
- },
- {
- "x": 1071,
- "y": 908
- },
- {
- "x": 1072,
- "y": 908
- },
- {
- "x": 1072,
- "y": 908
- },
- {
- "x": 1073,
- "y": 907
- },
- {
- "x": 1074,
- "y": 907
- },
- {
- "x": 1075,
- "y": 907
- },
- {
- "x": 1076,
- "y": 907
- },
- {
- "x": 1077,
- "y": 907
- },
- {
- "x": 1078,
- "y": 907
- },
- {
- "x": 1079,
- "y": 907
- },
- {
- "x": 1080,
- "y": 907
- },
- {
- "x": 1081,
- "y": 907
- },
- {
- "x": 1082,
- "y": 907
- },
- {
- "x": 1083,
- "y": 907
- },
- {
- "x": 1084,
- "y": 907
- },
- {
- "x": 1085,
- "y": 907
- },
- {
- "x": 1086,
- "y": 907
- },
- {
- "x": 1087,
- "y": 907
- },
- {
- "x": 1088,
- "y": 907
- },
- {
- "x": 1089,
- "y": 907
- },
- {
- "x": 1090,
- "y": 907
- },
- {
- "x": 1091,
- "y": 907
- },
- {
- "x": 1091,
- "y": 908
- },
- {
- "x": 1092,
- "y": 908
- },
- {
- "x": 1093,
- "y": 908
- },
- {
- "x": 1094,
- "y": 908
- },
- {
- "x": 1095,
- "y": 908
- },
- {
- "x": 1095,
- "y": 909
- },
- {
- "x": 1096,
- "y": 909
- },
- {
- "x": 1097,
- "y": 909
- },
- {
- "x": 1097,
- "y": 910
- },
- {
- "x": 1098,
- "y": 910
- },
- {
- "x": 1099,
- "y": 910
- },
- {
- "x": 1099,
- "y": 911
- },
- {
- "x": 1100,
- "y": 911
- },
- {
- "x": 1101,
- "y": 911
- },
- {
- "x": 1101,
- "y": 912
- },
- {
- "x": 1102,
- "y": 912
- },
- {
- "x": 1103,
- "y": 912
- },
- {
- "x": 1103,
- "y": 913
- },
- {
- "x": 1104,
- "y": 913
- },
- {
- "x": 1104,
- "y": 914
- },
- {
- "x": 1105,
- "y": 914
- },
- {
- "x": 1105,
- "y": 915
- },
- {
- "x": 1106,
- "y": 915
- },
- {
- "x": 1107,
- "y": 915
- },
- {
- "x": 1107,
- "y": 916
- },
- {
- "x": 1108,
- "y": 916
- },
- {
- "x": 1108,
- "y": 917
- },
- {
- "x": 1109,
- "y": 917
- },
- {
- "x": 1109,
- "y": 918
- },
- {
- "x": 1110,
- "y": 918
- },
- {
- "x": 1110,
- "y": 919
- },
- {
- "x": 1111,
- "y": 919
- },
- {
- "x": 1111,
- "y": 920
- },
- {
- "x": 1112,
- "y": 920
- },
- {
- "x": 1112,
- "y": 921
- },
- {
- "x": 1113,
- "y": 921
- },
- {
- "x": 1113,
- "y": 922
- },
- {
- "x": 1114,
- "y": 922
- },
- {
- "x": 1114,
- "y": 923
- },
- {
- "x": 1115,
- "y": 923
- },
- {
- "x": 1115,
- "y": 924
- },
- {
- "x": 1115,
- "y": 925
- },
- {
- "x": 1116,
- "y": 925
- },
- {
- "x": 1116,
- "y": 926
- },
- {
- "x": 1117,
- "y": 926
- },
- {
- "x": 1117,
- "y": 927
- },
- {
- "x": 1117,
- "y": 928
- },
- {
- "x": 1118,
- "y": 928
- },
- {
- "x": 1118,
- "y": 929
- },
- {
- "x": 1119,
- "y": 929
- },
- {
- "x": 1119,
- "y": 930
- },
- {
- "x": 1120,
- "y": 930
- },
- {
- "x": 1120,
- "y": 931
- },
- {
- "x": 1120,
- "y": 932
- },
- {
- "x": 1120,
- "y": 932
- },
- {
- "x": 1119,
- "y": 933
- },
- {
- "x": 1119,
- "y": 934
- },
- {
- "x": 1119,
- "y": 934
- },
- {
- "x": 1118,
- "y": 935
- },
- {
- "x": 1118,
- "y": 935
- }
- ]
- },
- {
- "uuid": "b98f3a45-3328-41a0-9077-373a8177ebf2",
- "schemaId": "ckrazcuec16om0z66bhhh4tp7",
- "dataRow": {
- "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d"
- },
- "point": {
- "x": 2122,
- "y": 1457
- }
- }
-]
\ No newline at end of file
diff --git a/libs/labelbox/tests/data/assets/ndjson/image_import_name_only.json b/libs/labelbox/tests/data/assets/ndjson/image_import_name_only.json
index 82be4cdab..466a03594 100644
--- a/libs/labelbox/tests/data/assets/ndjson/image_import_name_only.json
+++ b/libs/labelbox/tests/data/assets/ndjson/image_import_name_only.json
@@ -1,826 +1,86 @@
[
{
"uuid": "b862c586-8614-483c-b5e6-82810f70cac0",
- "name": "box a",
+ "name": "ckrazcueb16og0z6609jj7y3y",
"dataRow": {
"id": "ckrazctum0z8a0ybc0b0o0g0v"
},
- "bbox": {
- "top": 1352,
- "left": 2275,
- "height": 350,
- "width": 139
- },
- "confidence": 0.854,
+ "classifications": [],
+ "confidence": 0.851,
"customMetrics": [
{
"name": "customMetric1",
- "value": 0.5
- },
- {
- "name": "customMetric2",
- "value": 0.7
+ "value": 0.4
}
- ]
+ ],
+ "bbox": {
+ "top": 1352.0,
+ "left": 2275.0,
+ "height": 350.0,
+ "width": 139.0
+ }
},
{
"uuid": "751fc725-f7b6-48ed-89b0-dd7d94d08af6",
- "name": "mask a",
+ "name": "ckrazcuec16ok0z66f956apb7",
"dataRow": {
"id": "ckrazctum0z8a0ybc0b0o0g0v"
},
- "mask": {
- "instanceURI": "https://storage.labelbox.com/ckqcx1czn06830y61gh9v02cs%2F3e729327-f038-f66c-186e-45e921ef9717-1?Expires=1626806874672&KeyName=labelbox-assets-key-3&Signature=YsUOGKrsqmAZ68vT9BlPJOaRyLY",
- "colorRGB": [
- 255,
- 0,
- 0
- ]
- },
- "confidence": 0.685,
+ "classifications": [],
+ "confidence": 0.834,
"customMetrics": [
{
"name": "customMetric1",
- "value": 0.4
- },
- {
- "name": "customMetric2",
- "value": 0.9
+ "value": 0.3
}
- ]
+ ],
+ "mask": {
+ "instanceURI": "https://storage.labelbox.com/ckqcx1czn06830y61gh9v02cs%2F3e729327-f038-f66c-186e-45e921ef9717-1?Expires=1626806874672&KeyName=labelbox-assets-key-3&Signature=YsUOGKrsqmAZ68vT9BlPJOaRyLY"
+ }
},
{
+ "classifications": [],
"uuid": "43d719ac-5d7f-4aea-be00-2ebfca0900fd",
- "name": "polygon a",
+ "name": "ckrazcuec16oi0z66dzrd8pfl",
"dataRow": {
"id": "ckrazctum0z8a0ybc0b0o0g0v"
},
- "confidence": 0.71,
+ "confidence": 0.986,
"customMetrics": [
{
"name": "customMetric1",
- "value": 0.1
+ "value": 0.9
}
],
"polygon": [
{
- "x": 1118,
- "y": 935
- },
- {
- "x": 1117,
- "y": 935
- },
- {
- "x": 1116,
- "y": 935
- },
- {
- "x": 1115,
- "y": 935
- },
- {
- "x": 1114,
- "y": 935
- },
- {
- "x": 1113,
- "y": 935
- },
- {
- "x": 1112,
- "y": 935
- },
- {
- "x": 1111,
- "y": 935
- },
- {
- "x": 1110,
- "y": 935
- },
- {
- "x": 1109,
- "y": 935
- },
- {
- "x": 1108,
- "y": 935
- },
- {
- "x": 1108,
- "y": 934
- },
- {
- "x": 1107,
- "y": 934
- },
- {
- "x": 1106,
- "y": 934
- },
- {
- "x": 1105,
- "y": 934
- },
- {
- "x": 1105,
- "y": 933
- },
- {
- "x": 1104,
- "y": 933
- },
- {
- "x": 1103,
- "y": 933
- },
- {
- "x": 1103,
- "y": 932
- },
- {
- "x": 1102,
- "y": 932
- },
- {
- "x": 1101,
- "y": 932
- },
- {
- "x": 1100,
- "y": 932
- },
- {
- "x": 1099,
- "y": 932
- },
- {
- "x": 1098,
- "y": 932
- },
- {
- "x": 1097,
- "y": 932
- },
- {
- "x": 1097,
- "y": 931
- },
- {
- "x": 1096,
- "y": 931
- },
- {
- "x": 1095,
- "y": 931
- },
- {
- "x": 1094,
- "y": 931
- },
- {
- "x": 1093,
- "y": 931
- },
- {
- "x": 1092,
- "y": 931
- },
- {
- "x": 1091,
- "y": 931
- },
- {
- "x": 1090,
- "y": 931
- },
- {
- "x": 1090,
- "y": 930
- },
- {
- "x": 1089,
- "y": 930
- },
- {
- "x": 1088,
- "y": 930
- },
- {
- "x": 1087,
- "y": 930
- },
- {
- "x": 1087,
- "y": 929
- },
- {
- "x": 1086,
- "y": 929
- },
- {
- "x": 1085,
- "y": 929
- },
- {
- "x": 1084,
- "y": 929
- },
- {
- "x": 1084,
- "y": 928
- },
- {
- "x": 1083,
- "y": 928
- },
- {
- "x": 1083,
- "y": 927
- },
- {
- "x": 1082,
- "y": 927
- },
- {
- "x": 1081,
- "y": 927
- },
- {
- "x": 1081,
- "y": 926
- },
- {
- "x": 1080,
- "y": 926
- },
- {
- "x": 1080,
- "y": 925
- },
- {
- "x": 1079,
- "y": 925
- },
- {
- "x": 1078,
- "y": 925
- },
- {
- "x": 1078,
- "y": 924
- },
- {
- "x": 1077,
- "y": 924
- },
- {
- "x": 1076,
- "y": 924
- },
- {
- "x": 1076,
- "y": 923
- },
- {
- "x": 1075,
- "y": 923
- },
- {
- "x": 1074,
- "y": 923
- },
- {
- "x": 1073,
- "y": 923
- },
- {
- "x": 1073,
- "y": 922
- },
- {
- "x": 1072,
- "y": 922
- },
- {
- "x": 1071,
- "y": 922
- },
- {
- "x": 1070,
- "y": 922
- },
- {
- "x": 1070,
- "y": 921
- },
- {
- "x": 1069,
- "y": 921
- },
- {
- "x": 1068,
- "y": 921
- },
- {
- "x": 1067,
- "y": 921
- },
- {
- "x": 1066,
- "y": 921
- },
- {
- "x": 1065,
- "y": 921
- },
- {
- "x": 1064,
- "y": 921
- },
- {
- "x": 1063,
- "y": 921
- },
- {
- "x": 1062,
- "y": 921
- },
- {
- "x": 1061,
- "y": 921
- },
- {
- "x": 1060,
- "y": 921
- },
- {
- "x": 1059,
- "y": 921
- },
- {
- "x": 1058,
- "y": 921
- },
- {
- "x": 1058,
- "y": 920
- },
- {
- "x": 1057,
- "y": 920
- },
- {
- "x": 1057,
- "y": 919
- },
- {
- "x": 1056,
- "y": 919
- },
- {
- "x": 1057,
- "y": 918
- },
- {
- "x": 1057,
- "y": 918
- },
- {
- "x": 1057,
- "y": 917
- },
- {
- "x": 1058,
- "y": 916
- },
- {
- "x": 1058,
- "y": 916
- },
- {
- "x": 1059,
- "y": 915
- },
- {
- "x": 1059,
- "y": 915
- },
- {
- "x": 1060,
- "y": 914
- },
- {
- "x": 1060,
- "y": 914
- },
- {
- "x": 1061,
- "y": 913
- },
- {
- "x": 1061,
- "y": 913
- },
- {
- "x": 1062,
- "y": 912
- },
- {
- "x": 1063,
- "y": 912
- },
- {
- "x": 1063,
- "y": 912
- },
- {
- "x": 1064,
- "y": 911
- },
- {
- "x": 1064,
- "y": 911
- },
- {
- "x": 1065,
- "y": 910
- },
- {
- "x": 1066,
- "y": 910
- },
- {
- "x": 1066,
- "y": 910
- },
- {
- "x": 1067,
- "y": 909
- },
- {
- "x": 1068,
- "y": 909
- },
- {
- "x": 1068,
- "y": 909
- },
- {
- "x": 1069,
- "y": 908
- },
- {
- "x": 1070,
- "y": 908
- },
- {
- "x": 1071,
- "y": 908
- },
- {
- "x": 1072,
- "y": 908
- },
- {
- "x": 1072,
- "y": 908
- },
- {
- "x": 1073,
- "y": 907
- },
- {
- "x": 1074,
- "y": 907
- },
- {
- "x": 1075,
- "y": 907
- },
- {
- "x": 1076,
- "y": 907
- },
- {
- "x": 1077,
- "y": 907
- },
- {
- "x": 1078,
- "y": 907
- },
- {
- "x": 1079,
- "y": 907
- },
- {
- "x": 1080,
- "y": 907
- },
- {
- "x": 1081,
- "y": 907
- },
- {
- "x": 1082,
- "y": 907
- },
- {
- "x": 1083,
- "y": 907
- },
- {
- "x": 1084,
- "y": 907
- },
- {
- "x": 1085,
- "y": 907
- },
- {
- "x": 1086,
- "y": 907
- },
- {
- "x": 1087,
- "y": 907
- },
- {
- "x": 1088,
- "y": 907
- },
- {
- "x": 1089,
- "y": 907
- },
- {
- "x": 1090,
- "y": 907
- },
- {
- "x": 1091,
- "y": 907
- },
- {
- "x": 1091,
- "y": 908
- },
- {
- "x": 1092,
- "y": 908
- },
- {
- "x": 1093,
- "y": 908
- },
- {
- "x": 1094,
- "y": 908
- },
- {
- "x": 1095,
- "y": 908
- },
- {
- "x": 1095,
- "y": 909
- },
- {
- "x": 1096,
- "y": 909
- },
- {
- "x": 1097,
- "y": 909
- },
- {
- "x": 1097,
- "y": 910
- },
- {
- "x": 1098,
- "y": 910
- },
- {
- "x": 1099,
- "y": 910
+ "x": 10.0,
+ "y": 20.0
},
{
- "x": 1099,
- "y": 911
+ "x": 15.0,
+ "y": 20.0
},
{
- "x": 1100,
- "y": 911
+ "x": 20.0,
+ "y": 25.0
},
{
- "x": 1101,
- "y": 911
- },
- {
- "x": 1101,
- "y": 912
- },
- {
- "x": 1102,
- "y": 912
- },
- {
- "x": 1103,
- "y": 912
- },
- {
- "x": 1103,
- "y": 913
- },
- {
- "x": 1104,
- "y": 913
- },
- {
- "x": 1104,
- "y": 914
- },
- {
- "x": 1105,
- "y": 914
- },
- {
- "x": 1105,
- "y": 915
- },
- {
- "x": 1106,
- "y": 915
- },
- {
- "x": 1107,
- "y": 915
- },
- {
- "x": 1107,
- "y": 916
- },
- {
- "x": 1108,
- "y": 916
- },
- {
- "x": 1108,
- "y": 917
- },
- {
- "x": 1109,
- "y": 917
- },
- {
- "x": 1109,
- "y": 918
- },
- {
- "x": 1110,
- "y": 918
- },
- {
- "x": 1110,
- "y": 919
- },
- {
- "x": 1111,
- "y": 919
- },
- {
- "x": 1111,
- "y": 920
- },
- {
- "x": 1112,
- "y": 920
- },
- {
- "x": 1112,
- "y": 921
- },
- {
- "x": 1113,
- "y": 921
- },
- {
- "x": 1113,
- "y": 922
- },
- {
- "x": 1114,
- "y": 922
- },
- {
- "x": 1114,
- "y": 923
- },
- {
- "x": 1115,
- "y": 923
- },
- {
- "x": 1115,
- "y": 924
- },
- {
- "x": 1115,
- "y": 925
- },
- {
- "x": 1116,
- "y": 925
- },
- {
- "x": 1116,
- "y": 926
- },
- {
- "x": 1117,
- "y": 926
- },
- {
- "x": 1117,
- "y": 927
- },
- {
- "x": 1117,
- "y": 928
- },
- {
- "x": 1118,
- "y": 928
- },
- {
- "x": 1118,
- "y": 929
- },
- {
- "x": 1119,
- "y": 929
- },
- {
- "x": 1119,
- "y": 930
- },
- {
- "x": 1120,
- "y": 930
- },
- {
- "x": 1120,
- "y": 931
- },
- {
- "x": 1120,
- "y": 932
- },
- {
- "x": 1120,
- "y": 932
- },
- {
- "x": 1119,
- "y": 933
- },
- {
- "x": 1119,
- "y": 934
- },
- {
- "x": 1119,
- "y": 934
- },
- {
- "x": 1118,
- "y": 935
- },
- {
- "x": 1118,
- "y": 935
+ "x": 10.0,
+ "y": 20.0
}
]
},
{
+ "classifications": [],
"uuid": "b98f3a45-3328-41a0-9077-373a8177ebf2",
- "name": "point a",
+ "name": "ckrazcuec16om0z66bhhh4tp7",
"dataRow": {
"id": "ckrazctum0z8a0ybc0b0o0g0v"
},
- "confidence": 0.77,
- "customMetrics": [
- {
- "name": "customMetric2",
- "value": 1.2
- }
- ],
"point": {
- "x": 2122,
- "y": 1457
+ "x": 2122.0,
+ "y": 1457.0
}
}
]
\ No newline at end of file
diff --git a/libs/labelbox/tests/data/assets/ndjson/metric_import_global_key.json b/libs/labelbox/tests/data/assets/ndjson/metric_import_global_key.json
deleted file mode 100644
index 31be5a4c7..000000000
--- a/libs/labelbox/tests/data/assets/ndjson/metric_import_global_key.json
+++ /dev/null
@@ -1,10 +0,0 @@
-[
- {
- "uuid": "a22bbf6e-b2da-4abe-9a11-df84759f7672",
- "aggregation": "ARITHMETIC_MEAN",
- "dataRow": {
- "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d"
- },
- "metricValue": 0.1
- }
-]
\ No newline at end of file
diff --git a/libs/labelbox/tests/data/assets/ndjson/pdf_import_global_key.json b/libs/labelbox/tests/data/assets/ndjson/pdf_import_global_key.json
deleted file mode 100644
index f4b4894f6..000000000
--- a/libs/labelbox/tests/data/assets/ndjson/pdf_import_global_key.json
+++ /dev/null
@@ -1,155 +0,0 @@
-[{
- "uuid": "5ad9c52f-058d-49c8-a749-3f20b84f8cd4",
- "dataRow": {
- "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d"
- },
- "name": "boxy",
- "schemaId": "cl6xnuwt95lqq07330tbb3mfd",
- "classifications": [],
- "page": 4,
- "unit": "POINTS",
- "confidence": 0.53,
- "customMetrics": [
- {
- "name": "customMetric1",
- "value": 0.5
- },
- {
- "name": "customMetric2",
- "value": 0.3
- }
- ],
- "bbox": {
- "top": 162.73,
- "left": 32.45,
- "height": 388.16999999999996,
- "width": 101.66000000000001
- }
-}, {
- "uuid": "20eeef88-0294-49b4-a815-86588476bc6f",
- "dataRow": {
- "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d"
- },
- "name": "boxy",
- "schemaId": "cl6xnuwt95lqq07330tbb3mfd",
- "classifications": [],
- "page": 7,
- "unit": "POINTS",
- "bbox": {
- "top": 223.26,
- "left": 251.42,
- "height": 457.03999999999996,
- "width": 186.78
- }
-}, {
- "uuid": "641a8944-3938-409c-b4eb-dea354ed06e5",
- "dataRow": {
- "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d"
- },
- "name": "boxy",
- "schemaId": "cl6xnuwt95lqq07330tbb3mfd",
- "classifications": [],
- "page": 6,
- "unit": "POINTS",
- "confidence": 0.99,
- "customMetrics": [
- {
- "name": "customMetric1",
- "value": 0.5
- },
- {
- "name": "customMetric2",
- "value": 0.3
- }
- ],
- "bbox": {
- "top": 32.52,
- "left": 218.17,
- "height": 231.73,
- "width": 110.56000000000003
- }
-}, {
- "uuid": "ebe4da7d-08b3-480a-8d15-26552b7f011c",
- "dataRow": {
- "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d"
- },
- "name": "boxy",
- "schemaId": "cl6xnuwt95lqq07330tbb3mfd",
- "classifications": [],
- "page": 7,
- "unit": "POINTS",
- "confidence": 0.89,
- "customMetrics": [
- {
- "name": "customMetric1",
- "value": 0.5
- },
- {
- "name": "customMetric2",
- "value": 0.3
- }
- ],
- "bbox": {
- "top": 117.39,
- "left": 4.25,
- "height": 456.9200000000001,
- "width": 164.83
- }
-}, {
- "uuid": "35c41855-575f-42cc-a2f9-1f06237e9b63",
- "dataRow": {
- "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d"
- },
- "name": "boxy",
- "schemaId": "cl6xnuwt95lqq07330tbb3mfd",
- "classifications": [],
- "page": 8,
- "unit": "POINTS",
- "bbox": {
- "top": 82.13,
- "left": 217.28,
- "height": 279.76,
- "width": 82.43000000000004
- }
-}, {
- "uuid": "1b009654-bc17-42a2-8a71-160e7808c403",
- "dataRow": {
- "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d"
- },
- "name": "boxy",
- "schemaId": "cl6xnuwt95lqq07330tbb3mfd",
- "classifications": [],
- "page": 3,
- "unit": "POINTS",
- "bbox": {
- "top": 298.12,
- "left": 83.34,
- "height": 203.83000000000004,
- "width": 0.37999999999999545
- }
-},
-{
- "uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673",
- "dataRow": {
- "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d"
- },
- "name": "named_entity",
- "classifications": [],
- "textSelections": [
- {
- "groupId": "2f4336f4-a07e-4e0a-a9e1-5629b03b719b",
- "tokenIds": [
- "3f984bf3-1d61-44f5-b59a-9658a2e3440f",
- "3bf00b56-ff12-4e52-8cc1-08dbddb3c3b8",
- "6e1c3420-d4b7-4c5a-8fd6-ead43bf73d80",
- "87a43d32-af76-4a1d-b262-5c5f4d5ace3a",
- "e8606e8a-dfd9-4c49-a635-ad5c879c75d0",
- "67c7c19e-4654-425d-bf17-2adb8cf02c30",
- "149c5e80-3e07-49a7-ab2d-29ddfe6a38fa",
- "b0e94071-2187-461e-8e76-96c58738a52c"
- ],
- "page": 1
- }
- ]
-}
-]
\ No newline at end of file
diff --git a/libs/labelbox/tests/data/assets/ndjson/polyline_import_global_key.json b/libs/labelbox/tests/data/assets/ndjson/polyline_import_global_key.json
deleted file mode 100644
index d6a9eecbd..000000000
--- a/libs/labelbox/tests/data/assets/ndjson/polyline_import_global_key.json
+++ /dev/null
@@ -1,36 +0,0 @@
-[
- {
- "line": [
- {
- "x": 2534.353,
- "y": 249.471
- },
- {
- "x": 2429.492,
- "y": 182.092
- },
- {
- "x": 2294.322,
- "y": 221.962
- }
- ],
- "uuid": "5ad9c52f-058d-49c8-a749-3f20b84f8cd4",
- "dataRow": {
- "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d"
- },
- "name": "some-line",
- "schemaId": "cl6xnuwt95lqq07330tbb3mfd",
- "classifications": [],
- "confidence": 0.58,
- "customMetrics": [
- {
- "name": "customMetric1",
- "value": 0.5
- },
- {
- "name": "customMetric2",
- "value": 0.3
- }
- ]
- }
-]
\ No newline at end of file
diff --git a/libs/labelbox/tests/data/assets/ndjson/text_entity_import_global_key.json b/libs/labelbox/tests/data/assets/ndjson/text_entity_import_global_key.json
deleted file mode 100644
index 1f26d8dc8..000000000
--- a/libs/labelbox/tests/data/assets/ndjson/text_entity_import_global_key.json
+++ /dev/null
@@ -1,26 +0,0 @@
-[
- {
- "location": {
- "start": 67,
- "end": 128
- },
- "uuid": "5ad9c52f-058d-49c8-a749-3f20b84f8cd4",
- "dataRow": {
- "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d"
- },
- "name": "some-text-entity",
- "schemaId": "cl6xnuwt95lqq07330tbb3mfd",
- "classifications": [],
- "confidence": 0.53,
- "customMetrics": [
- {
- "name": "customMetric1",
- "value": 0.5
- },
- {
- "name": "customMetric2",
- "value": 0.3
- }
- ]
- }
-]
\ No newline at end of file
diff --git a/libs/labelbox/tests/data/assets/ndjson/video_import_global_key.json b/libs/labelbox/tests/data/assets/ndjson/video_import_global_key.json
deleted file mode 100644
index 11e0753d9..000000000
--- a/libs/labelbox/tests/data/assets/ndjson/video_import_global_key.json
+++ /dev/null
@@ -1,166 +0,0 @@
-[{
- "answer": {
- "schemaId": "ckrb1sfl8099g0y91cxbd5ftb"
- },
- "schemaId": "ckrb1sfjx099a0y914hl319ie",
- "dataRow": {
- "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d"
- },
- "uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673",
- "frames": [{
- "start": 30,
- "end": 35
- }, {
- "start": 50,
- "end": 51
- }]
-}, {
- "answer": [{
- "schemaId": "ckrb1sfl8099e0y919v260awv"
- }],
- "schemaId": "ckrb1sfkn099c0y910wbo0p1a",
- "dataRow": {
- "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d"
- },
- "uuid": "d009925d-91a3-4f67-abd9-753453f5a584",
- "frames": [{
- "start": 0,
- "end": 5
- }]
-}, {
- "answer": "a value",
- "schemaId": "ckrb1sfkn099c0y910wbo0p1a",
- "dataRow": {
- "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d"
- },
- "uuid": "3b302706-37ec-4f72-ab2e-757d8bd302b9"
-}, {
- "classifications": [],
- "schemaId":
- "cl5islwg200gfci6g0oitaypu",
- "dataRow": {
- "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d"
- },
- "uuid":
- "6f7c835a-0139-4896-b73f-66a6baa89e94",
- "segments": [{
- "keyframes": [{
- "frame": 1,
- "line": [{
- "x": 10.0,
- "y": 10.0
- }, {
- "x": 100.0,
- "y": 100.0
- }, {
- "x": 50.0,
- "y": 30.0
- }],
- "classifications": []
- }, {
- "frame": 5,
- "line": [{
- "x": 15.0,
- "y": 10.0
- }, {
- "x": 50.0,
- "y": 100.0
- }, {
- "x": 50.0,
- "y": 30.0
- }],
- "classifications": []
- }]
- }, {
- "keyframes": [{
- "frame": 8,
- "line": [{
- "x": 100.0,
- "y": 10.0
- }, {
- "x": 50.0,
- "y": 100.0
- }, {
- "x": 50.0,
- "y": 30.0
- }],
- "classifications": []
- }]
- }]
-}, {
- "classifications": [],
- "schemaId":
- "cl5it7ktp00i5ci6gf80b1ysd",
- "dataRow": {
- "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d"
- },
- "uuid":
- "f963be22-227b-4efe-9be4-2738ed822216",
- "segments": [{
- "keyframes": [{
- "frame": 1,
- "point": {
- "x": 10.0,
- "y": 10.0
- },
- "classifications": []
- }]
- }, {
- "keyframes": [{
- "frame": 5,
- "point": {
- "x": 50.0,
- "y": 50.0
- },
- "classifications": []
- }, {
- "frame": 10,
- "point": {
- "x": 10.0,
- "y": 50.0
- },
- "classifications": []
- }]
- }]
-}, {
- "classifications": [],
- "schemaId":
- "cl5iw0roz00lwci6g5jni62vs",
- "dataRow": {
- "globalKey": "05e8ee85-072e-4eb2-b30a-501dee9b0d9d"
- },
- "uuid":
- "13b2ee0e-2355-4336-8b83-d74d09e3b1e7",
- "segments": [{
- "keyframes": [{
- "frame": 1,
- "bbox": {
- "top": 10.0,
- "left": 5.0,
- "height": 100.0,
- "width": 150.0
- },
- "classifications": []
- }, {
- "frame": 5,
- "bbox": {
- "top": 30.0,
- "left": 5.0,
- "height": 50.0,
- "width": 150.0
- },
- "classifications": []
- }]
- }, {
- "keyframes": [{
- "frame": 10,
- "bbox": {
- "top": 300.0,
- "left": 200.0,
- "height": 400.0,
- "width": 150.0
- },
- "classifications": []
- }]
- }]
-}]
diff --git a/libs/labelbox/tests/data/export/conftest.py b/libs/labelbox/tests/data/export/conftest.py
index 0a62f39c8..4a59b6966 100644
--- a/libs/labelbox/tests/data/export/conftest.py
+++ b/libs/labelbox/tests/data/export/conftest.py
@@ -1,13 +1,16 @@
-import uuid
import time
+import uuid
+
import pytest
-from labelbox.schema.queue_mode import QueueMode
+
+from labelbox import Client, MediaType
+from labelbox.schema.annotation_import import AnnotationImportState, LabelImport
from labelbox.schema.labeling_frontend import LabelingFrontend
-from labelbox.schema.annotation_import import LabelImport, AnnotationImportState
+from labelbox.schema.media_type import MediaType
@pytest.fixture
-def ontology():
+def ontology(client: Client):
bbox_tool_with_nested_text = {
"required": False,
"name": "bbox_tool_with_nested_text",
@@ -116,18 +119,174 @@ def ontology():
"color": "#008941",
"classifications": [],
}
- entity_tool = {
+ raster_segmentation_tool = {
"required": False,
- "name": "entity--",
- "tool": "named-entity",
- "color": "#006FA6",
+ "name": "segmentation_mask",
+ "tool": "raster-segmentation",
+ "color": "#ff0000",
"classifications": [],
}
- segmentation_tool = {
+ checklist = {
+ "required": False,
+ "instructions": "checklist",
+ "name": "checklist",
+ "type": "checklist",
+ "options": [
+ {"label": "option1", "value": "option1"},
+ {"label": "option2", "value": "option2"},
+ {"label": "optionN", "value": "optionn"},
+ ],
+ }
+
+ free_form_text = {
+ "required": False,
+ "instructions": "text",
+ "name": "text",
+ "type": "text",
+ "options": [],
+ }
+
+ radio = {
"required": False,
- "name": "segmentation--",
- "tool": "superpixel",
- "color": "#A30059",
+ "instructions": "radio",
+ "name": "radio",
+ "type": "radio",
+ "options": [
+ {
+ "label": "first_radio_answer",
+ "value": "first_radio_answer",
+ "options": [],
+ },
+ {
+ "label": "second_radio_answer",
+ "value": "second_radio_answer",
+ "options": [],
+ },
+ ],
+ }
+
+ tools = [
+ bbox_tool,
+ bbox_tool_with_nested_text,
+ polygon_tool,
+ polyline_tool,
+ point_tool,
+ raster_segmentation_tool,
+ ]
+ classifications = [
+ checklist,
+ free_form_text,
+ radio,
+ ]
+ ontology = client.create_ontology(
+ "image ontology",
+ {"tools": tools, "classifications": classifications},
+ MediaType.Image,
+ )
+ return ontology
+
+
+@pytest.fixture
+def video_ontology(client: Client):
+ bbox_tool_with_nested_text = {
+ "required": False,
+ "name": "bbox_tool_with_nested_text",
+ "tool": "rectangle",
+ "color": "#a23030",
+ "classifications": [
+ {
+ "required": False,
+ "instructions": "nested",
+ "name": "nested",
+ "type": "radio",
+ "options": [
+ {
+ "label": "radio_option_1",
+ "value": "radio_value_1",
+ "options": [
+ {
+ "required": False,
+ "instructions": "nested_checkbox",
+ "name": "nested_checkbox",
+ "type": "checklist",
+ "options": [
+ {
+ "label": "nested_checkbox_option_1",
+ "value": "nested_checkbox_value_1",
+ "options": [],
+ },
+ {
+ "label": "nested_checkbox_option_2",
+ "value": "nested_checkbox_value_2",
+ },
+ ],
+ },
+ {
+ "required": False,
+ "instructions": "nested_text",
+ "name": "nested_text",
+ "type": "text",
+ "options": [],
+ },
+ ],
+ },
+ ],
+ }
+ ],
+ }
+
+ bbox_tool = {
+ "required": False,
+ "name": "bbox",
+ "tool": "rectangle",
+ "color": "#a23030",
+ "classifications": [
+ {
+ "required": False,
+ "instructions": "nested",
+ "name": "nested",
+ "type": "radio",
+ "options": [
+ {
+ "label": "radio_option_1",
+ "value": "radio_value_1",
+ "options": [
+ {
+ "required": False,
+ "instructions": "nested_checkbox",
+ "name": "nested_checkbox",
+ "type": "checklist",
+ "options": [
+ {
+ "label": "nested_checkbox_option_1",
+ "value": "nested_checkbox_value_1",
+ "options": [],
+ },
+ {
+ "label": "nested_checkbox_option_2",
+ "value": "nested_checkbox_value_2",
+ },
+ ],
+ }
+ ],
+ },
+ ],
+ }
+ ],
+ }
+
+ polyline_tool = {
+ "required": False,
+ "name": "polyline",
+ "tool": "line",
+ "color": "#FF4A46",
+ "classifications": [],
+ }
+ point_tool = {
+ "required": False,
+ "name": "point--",
+ "tool": "point",
+ "color": "#008941",
"classifications": [],
}
raster_segmentation_tool = {
@@ -160,6 +319,7 @@ def ontology():
{"label": "optionN_index", "value": "optionn_index"},
],
}
+
free_form_text = {
"required": False,
"instructions": "text",
@@ -167,14 +327,7 @@ def ontology():
"type": "text",
"options": [],
}
- free_form_text_index = {
- "required": False,
- "instructions": "text_index",
- "name": "text_index",
- "type": "text",
- "scope": "index",
- "options": [],
- }
+
radio = {
"required": False,
"instructions": "radio",
@@ -193,33 +346,26 @@ def ontology():
},
],
}
- named_entity = {
- "tool": "named-entity",
- "name": "named-entity",
- "required": False,
- "color": "#A30059",
- "classifications": [],
- }
tools = [
bbox_tool,
bbox_tool_with_nested_text,
- polygon_tool,
polyline_tool,
point_tool,
- entity_tool,
- segmentation_tool,
raster_segmentation_tool,
- named_entity,
]
classifications = [
- checklist,
checklist_index,
+ checklist,
free_form_text,
- free_form_text_index,
radio,
]
- return {"tools": tools, "classifications": classifications}
+ ontology = client.create_ontology(
+ "image ontology",
+ {"tools": tools, "classifications": classifications},
+ MediaType.Video,
+ )
+ return ontology
@pytest.fixture
@@ -246,15 +392,17 @@ def configured_project_with_ontology(
dataset = initial_dataset
project = client.create_project(
name=rand_gen(str),
- queue_mode=QueueMode.Batch,
+ media_type=MediaType.Image,
)
- editor = list(
- client.get_labeling_frontends(where=LabelingFrontend.name == "editor")
- )[0]
- project.setup(editor, ontology)
+ project.connect_ontology(ontology)
data_row_ids = []
- for _ in range(len(ontology["tools"]) + len(ontology["classifications"])):
+ normalized_ontology = ontology.normalized
+
+ for _ in range(
+ len(normalized_ontology["tools"])
+ + len(normalized_ontology["classifications"])
+ ):
data_row_ids.append(dataset.create_data_row(row_data=image_url).uid)
project.create_batch(
rand_gen(str),
@@ -273,12 +421,25 @@ def configured_project_without_data_rows(
project = client.create_project(
name=rand_gen(str),
description=rand_gen(str),
- queue_mode=QueueMode.Batch,
+ media_type=MediaType.Image,
)
- editor = list(
- client.get_labeling_frontends(where=LabelingFrontend.name == "editor")
- )[0]
- project.setup(editor, ontology)
+
+ project.connect_ontology(ontology)
+ yield project
+ teardown_helpers.teardown_project_labels_ontology_feature_schemas(project)
+
+
+@pytest.fixture
+def configured_video_project_without_data_rows(
+ client, video_ontology, rand_gen, teardown_helpers
+):
+ project = client.create_project(
+ name=rand_gen(str),
+ description=rand_gen(str),
+ media_type=MediaType.Video,
+ )
+
+ project.connect_ontology(video_ontology)
yield project
teardown_helpers.teardown_project_labels_ontology_feature_schemas(project)
diff --git a/libs/labelbox/tests/data/export/streamable/test_export_data_rows_streamable.py b/libs/labelbox/tests/data/export/streamable/test_export_data_rows_streamable.py
index 3e4efbc46..8dfa1df72 100644
--- a/libs/labelbox/tests/data/export/streamable/test_export_data_rows_streamable.py
+++ b/libs/labelbox/tests/data/export/streamable/test_export_data_rows_streamable.py
@@ -1,7 +1,5 @@
-import json
import time
-import pytest
from labelbox import DataRow, ExportTask, StreamType
@@ -27,9 +25,7 @@ def test_with_data_row_object(
)
assert export_task.get_total_lines(stream_type=StreamType.RESULT) == 1
assert (
- json.loads(list(export_task.get_stream())[0].json_str)["data_row"][
- "id"
- ]
+ list(export_task.get_buffered_stream())[0].json["data_row"]["id"]
== data_row.uid
)
@@ -75,9 +71,7 @@ def test_with_id(self, client, data_row, wait_for_data_row_processing):
)
assert export_task.get_total_lines(stream_type=StreamType.RESULT) == 1
assert (
- json.loads(list(export_task.get_stream())[0].json_str)["data_row"][
- "id"
- ]
+ list(export_task.get_buffered_stream())[0].json["data_row"]["id"]
== data_row.uid
)
@@ -101,9 +95,7 @@ def test_with_global_key(
)
assert export_task.get_total_lines(stream_type=StreamType.RESULT) == 1
assert (
- json.loads(list(export_task.get_stream())[0].json_str)["data_row"][
- "id"
- ]
+ list(export_task.get_buffered_stream())[0].json["data_row"]["id"]
== data_row.uid
)
diff --git a/libs/labelbox/tests/data/export/streamable/test_export_dataset_streamable.py b/libs/labelbox/tests/data/export/streamable/test_export_dataset_streamable.py
index 57f617a00..0d34a40b1 100644
--- a/libs/labelbox/tests/data/export/streamable/test_export_dataset_streamable.py
+++ b/libs/labelbox/tests/data/export/streamable/test_export_dataset_streamable.py
@@ -25,8 +25,8 @@ def test_export(self, dataset, data_rows):
) == len(expected_data_row_ids)
data_row_ids = list(
map(
- lambda x: json.loads(x.json_str)["data_row"]["id"],
- export_task.get_stream(),
+ lambda x: x.json["data_row"]["id"],
+ export_task.get_buffered_stream(),
)
)
assert data_row_ids.sort() == expected_data_row_ids.sort()
@@ -58,8 +58,8 @@ def test_with_data_row_filter(self, dataset, data_rows):
)
data_row_ids = list(
map(
- lambda x: json.loads(x.json_str)["data_row"]["id"],
- export_task.get_stream(),
+ lambda x: x.json["data_row"]["id"],
+ export_task.get_buffered_stream(),
)
)
assert data_row_ids.sort() == expected_data_row_ids.sort()
@@ -91,8 +91,8 @@ def test_with_global_key_filter(self, dataset, data_rows):
)
global_keys = list(
map(
- lambda x: json.loads(x.json_str)["data_row"]["global_key"],
- export_task.get_stream(),
+ lambda x: x.json["data_row"]["global_key"],
+ export_task.get_buffered_stream(),
)
)
assert global_keys.sort() == expected_global_keys.sort()
diff --git a/libs/labelbox/tests/data/export/streamable/test_export_embeddings_streamable.py b/libs/labelbox/tests/data/export/streamable/test_export_embeddings_streamable.py
index 071acbb5b..803b5994a 100644
--- a/libs/labelbox/tests/data/export/streamable/test_export_embeddings_streamable.py
+++ b/libs/labelbox/tests/data/export/streamable/test_export_embeddings_streamable.py
@@ -1,7 +1,7 @@
import json
import random
-from labelbox import StreamType, JsonConverter
+from labelbox import StreamType
class TestExportEmbeddings:
@@ -23,12 +23,8 @@ def test_export_embeddings_precomputed(
assert export_task.has_errors() is False
results = []
- export_task.get_stream(
- converter=JsonConverter(), stream_type=StreamType.RESULT
- ).start(
- stream_handler=lambda output: results.append(
- json.loads(output.json_str)
- )
+ export_task.get_buffered_stream(stream_type=StreamType.RESULT).start(
+ stream_handler=lambda output: results.append(output.json)
)
assert len(results) == len(data_row_specs)
@@ -69,12 +65,8 @@ def test_export_embeddings_custom(
assert export_task.has_errors() is False
results = []
- export_task.get_stream(
- converter=JsonConverter(), stream_type=StreamType.RESULT
- ).start(
- stream_handler=lambda output: results.append(
- json.loads(output.json_str)
- )
+ export_task.get_buffered_stream(stream_type=StreamType.RESULT).start(
+ stream_handler=lambda output: results.append(output.json)
)
assert len(results) == 1
@@ -86,6 +78,6 @@ def test_export_embeddings_custom(
if emb["id"] == embedding.id:
assert emb["name"] == embedding.name
assert emb["dimensions"] == embedding.dims
- assert emb["is_custom"] == True
+ assert emb["is_custom"] is True
assert len(emb["values"]) == 1
assert emb["values"][0]["value"] == vector
diff --git a/libs/labelbox/tests/data/export/streamable/test_export_model_run_streamable.py b/libs/labelbox/tests/data/export/streamable/test_export_model_run_streamable.py
index ada493fc3..7a583198b 100644
--- a/libs/labelbox/tests/data/export/streamable/test_export_model_run_streamable.py
+++ b/libs/labelbox/tests/data/export/streamable/test_export_model_run_streamable.py
@@ -27,8 +27,8 @@ def test_export(self, model_run_with_data_rows):
stream_type=StreamType.RESULT
) == len(expected_data_rows)
- for data in export_task.get_stream():
- obj = json.loads(data.json_str)
+ for data in export_task.get_buffered_stream():
+ obj = data.json
assert (
"media_attributes" in obj
and obj["media_attributes"] is not None
diff --git a/libs/labelbox/tests/data/export/streamable/test_export_project_streamable.py b/libs/labelbox/tests/data/export/streamable/test_export_project_streamable.py
index 818a0178c..597c529aa 100644
--- a/libs/labelbox/tests/data/export/streamable/test_export_project_streamable.py
+++ b/libs/labelbox/tests/data/export/streamable/test_export_project_streamable.py
@@ -9,6 +9,7 @@
from labelbox import Project, Dataset
from labelbox.schema.data_row import DataRow
from labelbox.schema.label import Label
+from labelbox import UniqueIds
IMAGE_URL = "https://storage.googleapis.com/lb-artifacts-testing-public/sdk_integration_test/potato.jpeg"
@@ -56,8 +57,8 @@ def test_export(
)
assert export_task.get_total_lines(stream_type=StreamType.RESULT) > 0
- for data in export_task.get_stream():
- obj = json.loads(data.json_str)
+ for data in export_task.get_buffered_stream():
+ obj = data.json
task_media_attributes = obj["media_attributes"]
task_project = obj["projects"][project.uid]
task_project_label_ids_set = set(
@@ -128,7 +129,9 @@ def test_with_date_filters(
review_queue = next(
tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE"
)
- project.move_data_rows_to_task_queue([data_row.uid], review_queue.uid)
+ project.move_data_rows_to_task_queue(
+ UniqueIds([data_row.uid]), review_queue.uid
+ )
export_task = project_export(
project, task_name, filters=filters, params=params
)
@@ -139,8 +142,8 @@ def test_with_date_filters(
)
assert export_task.get_total_lines(stream_type=StreamType.RESULT) > 0
- for data in export_task.get_stream():
- obj = json.loads(data.json_str)
+ for data in export_task.get_buffered_stream():
+ obj = data.json
task_project = obj["projects"][project.uid]
task_project_label_ids_set = set(
map(lambda prediction: prediction["id"], task_project["labels"])
@@ -181,9 +184,9 @@ def test_with_iso_date_filters(
assert export_task.get_total_lines(stream_type=StreamType.RESULT) > 0
assert (
label_id
- == json.loads(list(export_task.get_stream())[0].json_str)[
- "projects"
- ][project.uid]["labels"][0]["id"]
+ == list(export_task.get_buffered_stream())[0].json["projects"][
+ project.uid
+ ]["labels"][0]["id"]
)
def test_with_iso_date_filters_no_start_date(
@@ -207,9 +210,9 @@ def test_with_iso_date_filters_no_start_date(
assert export_task.get_total_lines(stream_type=StreamType.RESULT) > 0
assert (
label_id
- == json.loads(list(export_task.get_stream())[0].json_str)[
- "projects"
- ][project.uid]["labels"][0]["id"]
+ == list(export_task.get_buffered_stream())[0].json["projects"][
+ project.uid
+ ]["labels"][0]["id"]
)
def test_with_iso_date_filters_and_future_start_date(
@@ -270,8 +273,8 @@ def test_with_data_row_filter(
)
data_row_ids = list(
map(
- lambda x: json.loads(x.json_str)["data_row"]["id"],
- export_task.get_stream(),
+ lambda x: x.json["data_row"]["id"],
+ export_task.get_buffered_stream(),
)
)
assert data_row_ids.sort() == expected_data_row_ids.sort()
@@ -310,8 +313,8 @@ def test_with_global_key_filter(
)
global_keys = list(
map(
- lambda x: json.loads(x.json_str)["data_row"]["global_key"],
- export_task.get_stream(),
+ lambda x: x.json["data_row"]["global_key"],
+ export_task.get_buffered_stream(),
)
)
assert global_keys.sort() == expected_global_keys.sort()
diff --git a/libs/labelbox/tests/data/export/streamable/test_export_video_streamable.py b/libs/labelbox/tests/data/export/streamable/test_export_video_streamable.py
index 115194a58..da99414ee 100644
--- a/libs/labelbox/tests/data/export/streamable/test_export_video_streamable.py
+++ b/libs/labelbox/tests/data/export/streamable/test_export_video_streamable.py
@@ -4,7 +4,7 @@
import labelbox as lb
import labelbox.types as lb_types
-from labelbox.data.annotation_types.data.video import VideoData
+from labelbox.data.annotation_types.data import GenericDataRowData
from labelbox.schema.annotation_import import AnnotationImportState
from labelbox.schema.export_task import ExportTask, StreamType
@@ -21,13 +21,13 @@ def org_id(self, client):
def test_export(
self,
client,
- configured_project_without_data_rows,
+ configured_video_project_without_data_rows,
video_data,
video_data_row,
bbox_video_annotation_objects,
rand_gen,
):
- project = configured_project_without_data_rows
+ project = configured_video_project_without_data_rows
project_id = project.uid
labels = []
@@ -41,7 +41,7 @@ def test_export(
for data_row_uid in data_row_uids:
labels = [
lb_types.Label(
- data=VideoData(uid=data_row_uid),
+ data=GenericDataRowData(uid=data_row_uid),
annotations=bbox_video_annotation_objects,
)
]
@@ -71,7 +71,8 @@ def test_export(
export_task.get_total_file_size(stream_type=StreamType.RESULT) > 0
)
- export_data = json.loads(list(export_task.get_stream())[0].json_str)
+ export_data = list(export_task.get_buffered_stream())[0].json
+
data_row_export = export_data["data_row"]
assert data_row_export["global_key"] == video_data_row["global_key"]
assert data_row_export["row_data"] == video_data_row["row_data"]
diff --git a/libs/labelbox/tests/data/serialization/coco/test_coco.py b/libs/labelbox/tests/data/serialization/coco/test_coco.py
deleted file mode 100644
index a7c733ce5..000000000
--- a/libs/labelbox/tests/data/serialization/coco/test_coco.py
+++ /dev/null
@@ -1,38 +0,0 @@
-import json
-from pathlib import Path
-
-from labelbox.data.serialization.coco import COCOConverter
-
-COCO_ASSETS_DIR = "tests/data/assets/coco"
-
-
-def run_instances(tmpdir):
- instance_json = json.load(open(Path(COCO_ASSETS_DIR, "instances.json")))
- res = COCOConverter.deserialize_instances(
- instance_json, Path(COCO_ASSETS_DIR, "images")
- )
- back = COCOConverter.serialize_instances(
- res,
- Path(tmpdir),
- )
-
-
-def test_rle_objects(tmpdir):
- rle_json = json.load(open(Path(COCO_ASSETS_DIR, "rle.json")))
- res = COCOConverter.deserialize_instances(
- rle_json, Path(COCO_ASSETS_DIR, "images")
- )
- back = COCOConverter.serialize_instances(res, tmpdir)
-
-
-def test_panoptic(tmpdir):
- panoptic_json = json.load(open(Path(COCO_ASSETS_DIR, "panoptic.json")))
- image_dir, mask_dir = [
- Path(COCO_ASSETS_DIR, dir_name) for dir_name in ["images", "masks"]
- ]
- res = COCOConverter.deserialize_panoptic(panoptic_json, image_dir, mask_dir)
- back = COCOConverter.serialize_panoptic(
- res,
- Path(f"/{tmpdir}/images_panoptic"),
- Path(f"/{tmpdir}/masks_panoptic"),
- )
diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_checklist.py b/libs/labelbox/tests/data/serialization/ndjson/test_checklist.py
index 0bc3c8924..fb78916f4 100644
--- a/libs/labelbox/tests/data/serialization/ndjson/test_checklist.py
+++ b/libs/labelbox/tests/data/serialization/ndjson/test_checklist.py
@@ -4,7 +4,7 @@
ClassificationAnswer,
Radio,
)
-from labelbox.data.annotation_types.data.text import TextData
+from labelbox.data.annotation_types.data import GenericDataRowData
from labelbox.data.annotation_types.label import Label
from labelbox.data.serialization.ndjson.converter import NDJsonConverter
@@ -13,9 +13,8 @@
def test_serialization_min():
label = Label(
uid="ckj7z2q0b0000jx6x0q2q7q0d",
- data=TextData(
+ data=GenericDataRowData(
uid="bkj7z2q0b0000jx6x0q2q7q0d",
- text="This is a test",
),
annotations=[
ClassificationAnnotation(
@@ -37,20 +36,12 @@ def test_serialization_min():
res.pop("uuid")
assert res == expected
- deserialized = NDJsonConverter.deserialize([res])
- res = next(deserialized)
- for i, annotation in enumerate(res.annotations):
- annotation.extra.pop("uuid")
- assert annotation.value == label.annotations[i].value
- assert annotation.name == label.annotations[i].name
-
def test_serialization_with_classification():
label = Label(
uid="ckj7z2q0b0000jx6x0q2q7q0d",
- data=TextData(
+ data=GenericDataRowData(
uid="bkj7z2q0b0000jx6x0q2q7q0d",
- text="This is a test",
),
annotations=[
ClassificationAnnotation(
@@ -134,19 +125,12 @@ def test_serialization_with_classification():
res.pop("uuid")
assert res == expected
- deserialized = NDJsonConverter.deserialize([res])
- res = next(deserialized)
- assert label.model_dump(exclude_none=True) == label.model_dump(
- exclude_none=True
- )
-
def test_serialization_with_classification_double_nested():
label = Label(
uid="ckj7z2q0b0000jx6x0q2q7q0d",
- data=TextData(
+ data=GenericDataRowData(
uid="bkj7z2q0b0000jx6x0q2q7q0d",
- text="This is a test",
),
annotations=[
ClassificationAnnotation(
@@ -233,20 +217,12 @@ def test_serialization_with_classification_double_nested():
res.pop("uuid")
assert res == expected
- deserialized = NDJsonConverter.deserialize([res])
- res = next(deserialized)
- res.annotations[0].extra.pop("uuid")
- assert label.model_dump(exclude_none=True) == label.model_dump(
- exclude_none=True
- )
-
def test_serialization_with_classification_double_nested_2():
label = Label(
uid="ckj7z2q0b0000jx6x0q2q7q0d",
- data=TextData(
+ data=GenericDataRowData(
uid="bkj7z2q0b0000jx6x0q2q7q0d",
- text="This is a test",
),
annotations=[
ClassificationAnnotation(
@@ -330,9 +306,3 @@ def test_serialization_with_classification_double_nested_2():
res = next(serialized)
res.pop("uuid")
assert res == expected
-
- deserialized = NDJsonConverter.deserialize([res])
- res = next(deserialized)
- assert label.model_dump(exclude_none=True) == label.model_dump(
- exclude_none=True
- )
diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_classification.py b/libs/labelbox/tests/data/serialization/ndjson/test_classification.py
index 8dcb17f0b..82adce99c 100644
--- a/libs/labelbox/tests/data/serialization/ndjson/test_classification.py
+++ b/libs/labelbox/tests/data/serialization/ndjson/test_classification.py
@@ -1,15 +1,73 @@
import json
+from labelbox.data.annotation_types.classification.classification import (
+ Checklist,
+ Radio,
+ Text,
+)
+from labelbox.data.annotation_types.data.generic_data_row_data import (
+ GenericDataRowData,
+)
from labelbox.data.serialization.ndjson.converter import NDJsonConverter
+from labelbox.types import (
+ Label,
+ ClassificationAnnotation,
+ ClassificationAnswer,
+)
+from labelbox.data.mixins import CustomMetric
+
def test_classification():
with open(
"tests/data/assets/ndjson/classification_import.json", "r"
) as file:
data = json.load(file)
- res = list(NDJsonConverter.deserialize(data))
- res = list(NDJsonConverter.serialize(res))
+
+ label = Label(
+ data=GenericDataRowData(
+ uid="ckrb1sf1i1g7i0ybcdc6oc8ct",
+ ),
+ annotations=[
+ ClassificationAnnotation(
+ feature_schema_id="ckrb1sfjx099a0y914hl319ie",
+ extra={"uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673"},
+ value=Radio(
+ answer=ClassificationAnswer(
+ custom_metrics=[
+ CustomMetric(name="customMetric1", value=0.5),
+ CustomMetric(name="customMetric2", value=0.3),
+ ],
+ confidence=0.8,
+ feature_schema_id="ckrb1sfl8099g0y91cxbd5ftb",
+ ),
+ ),
+ ),
+ ClassificationAnnotation(
+ feature_schema_id="ckrb1sfkn099c0y910wbo0p1a",
+ extra={"uuid": "d009925d-91a3-4f67-abd9-753453f5a584"},
+ value=Checklist(
+ answer=[
+ ClassificationAnswer(
+ custom_metrics=[
+ CustomMetric(name="customMetric1", value=0.5),
+ CustomMetric(name="customMetric2", value=0.3),
+ ],
+ confidence=0.82,
+ feature_schema_id="ckrb1sfl8099e0y919v260awv",
+ )
+ ],
+ ),
+ ),
+ ClassificationAnnotation(
+ feature_schema_id="ckrb1sfkn099c0y910wbo0p1a",
+ extra={"uuid": "78ff6a23-bebe-475c-8f67-4c456909648f"},
+ value=Text(answer="a value"),
+ ),
+ ],
+ )
+
+ res = list(NDJsonConverter.serialize([label]))
assert res == data
@@ -18,6 +76,48 @@ def test_classification_with_name():
"tests/data/assets/ndjson/classification_import_name_only.json", "r"
) as file:
data = json.load(file)
- res = list(NDJsonConverter.deserialize(data))
- res = list(NDJsonConverter.serialize(res))
+ label = Label(
+ data=GenericDataRowData(
+ uid="ckrb1sf1i1g7i0ybcdc6oc8ct",
+ ),
+ annotations=[
+ ClassificationAnnotation(
+ name="classification a",
+ extra={"uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673"},
+ value=Radio(
+ answer=ClassificationAnswer(
+ custom_metrics=[
+ CustomMetric(name="customMetric1", value=0.5),
+ CustomMetric(name="customMetric2", value=0.3),
+ ],
+ confidence=0.99,
+ name="choice 1",
+ ),
+ ),
+ ),
+ ClassificationAnnotation(
+ name="classification b",
+ extra={"uuid": "d009925d-91a3-4f67-abd9-753453f5a584"},
+ value=Checklist(
+ answer=[
+ ClassificationAnswer(
+ custom_metrics=[
+ CustomMetric(name="customMetric1", value=0.5),
+ CustomMetric(name="customMetric2", value=0.3),
+ ],
+ confidence=0.945,
+ name="choice 2",
+ )
+ ],
+ ),
+ ),
+ ClassificationAnnotation(
+ name="classification c",
+ extra={"uuid": "150d60de-30af-44e4-be20-55201c533312"},
+ value=Text(answer="a value"),
+ ),
+ ],
+ )
+
+ res = list(NDJsonConverter.serialize([label]))
assert res == data
diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_conversation.py b/libs/labelbox/tests/data/serialization/ndjson/test_conversation.py
index f7da9181b..5aa7285e2 100644
--- a/libs/labelbox/tests/data/serialization/ndjson/test_conversation.py
+++ b/libs/labelbox/tests/data/serialization/ndjson/test_conversation.py
@@ -1,8 +1,12 @@
import json
+from labelbox.data.annotation_types.data.generic_data_row_data import (
+ GenericDataRowData,
+)
import pytest
import labelbox.types as lb_types
from labelbox.data.serialization.ndjson.converter import NDJsonConverter
+from labelbox.data.mixins import CustomMetric
radio_ndjson = [
{
@@ -15,7 +19,7 @@
radio_label = [
lb_types.Label(
- data=lb_types.ConversationData(global_key="my_global_key"),
+ data=lb_types.GenericDataRowData(global_key="my_global_key"),
annotations=[
lb_types.ClassificationAnnotation(
name="radio",
@@ -44,7 +48,7 @@
checklist_label = [
lb_types.Label(
- data=lb_types.ConversationData(global_key="my_global_key"),
+ data=lb_types.GenericDataRowData(global_key="my_global_key"),
annotations=[
lb_types.ClassificationAnnotation(
name="checklist",
@@ -74,7 +78,7 @@
]
free_text_label = [
lb_types.Label(
- data=lb_types.ConversationData(global_key="my_global_key"),
+ data=lb_types.GenericDataRowData(global_key="my_global_key"),
annotations=[
lb_types.ClassificationAnnotation(
name="free_text",
@@ -99,31 +103,68 @@ def test_message_based_radio_classification(label, ndjson):
serialized_label[0].pop("uuid")
assert serialized_label == ndjson
- deserialized_label = list(NDJsonConverter().deserialize(ndjson))
- deserialized_label[0].annotations[0].extra.pop("uuid")
- assert deserialized_label[0].model_dump(exclude_none=True) == label[
- 0
- ].model_dump(exclude_none=True)
+def test_conversation_entity_import():
+ with open(
+ "tests/data/assets/ndjson/conversation_entity_import.json", "r"
+ ) as file:
+ data = json.load(file)
-@pytest.mark.parametrize(
- "filename",
- [
- "tests/data/assets/ndjson/conversation_entity_import.json",
+ label = lb_types.Label(
+ data=GenericDataRowData(
+ uid="cl6xnv9h61fv0085yhtoq06ht",
+ ),
+ annotations=[
+ lb_types.ObjectAnnotation(
+ custom_metrics=[
+ CustomMetric(name="customMetric1", value=0.5),
+ CustomMetric(name="customMetric2", value=0.3),
+ ],
+ confidence=0.53,
+ name="some-text-entity",
+ feature_schema_id="cl6xnuwt95lqq07330tbb3mfd",
+ extra={"uuid": "5ad9c52f-058d-49c8-a749-3f20b84f8cd4"},
+ value=lb_types.ConversationEntity(
+ start=67, end=128, message_id="some-message-id"
+ ),
+ )
+ ],
+ )
+
+ res = list(NDJsonConverter.serialize([label]))
+ assert res == data
+
+
+def test_conversation_entity_import_without_confidence():
+ with open(
"tests/data/assets/ndjson/conversation_entity_without_confidence_import.json",
- ],
-)
-def test_conversation_entity_import(filename: str):
- with open(filename, "r") as file:
+ "r",
+ ) as file:
data = json.load(file)
- res = list(NDJsonConverter.deserialize(data))
- res = list(NDJsonConverter.serialize(res))
+ label = lb_types.Label(
+ uid=None,
+ data=GenericDataRowData(
+ uid="cl6xnv9h61fv0085yhtoq06ht",
+ ),
+ annotations=[
+ lb_types.ObjectAnnotation(
+ name="some-text-entity",
+ feature_schema_id="cl6xnuwt95lqq07330tbb3mfd",
+ extra={"uuid": "5ad9c52f-058d-49c8-a749-3f20b84f8cd4"},
+ value=lb_types.ConversationEntity(
+ start=67, end=128, extra={}, message_id="some-message-id"
+ ),
+ )
+ ],
+ )
+
+ res = list(NDJsonConverter.serialize([label]))
assert res == data
def test_benchmark_reference_label_flag_enabled():
label = lb_types.Label(
- data=lb_types.ConversationData(global_key="my_global_key"),
+ data=lb_types.GenericDataRowData(global_key="my_global_key"),
annotations=[
lb_types.ClassificationAnnotation(
name="free_text",
@@ -140,7 +181,7 @@ def test_benchmark_reference_label_flag_enabled():
def test_benchmark_reference_label_flag_disabled():
label = lb_types.Label(
- data=lb_types.ConversationData(global_key="my_global_key"),
+ data=lb_types.GenericDataRowData(global_key="my_global_key"),
annotations=[
lb_types.ClassificationAnnotation(
name="free_text",
diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_data_gen.py b/libs/labelbox/tests/data/serialization/ndjson/test_data_gen.py
index 333c00250..999e1bda5 100644
--- a/libs/labelbox/tests/data/serialization/ndjson/test_data_gen.py
+++ b/libs/labelbox/tests/data/serialization/ndjson/test_data_gen.py
@@ -1,67 +1,29 @@
-from copy import copy
-import pytest
import labelbox.types as lb_types
from labelbox.data.serialization import NDJsonConverter
-from labelbox.data.serialization.ndjson.objects import (
- NDDicomSegments,
- NDDicomSegment,
- NDDicomLine,
-)
-
-"""
-Data gen prompt test data
-"""
-
-prompt_text_annotation = lb_types.PromptClassificationAnnotation(
- feature_schema_id="ckrb1sfkn099c0y910wbo0p1a",
- name="test",
- value=lb_types.PromptText(
- answer="the answer to the text questions right here"
- ),
-)
-
-prompt_text_ndjson = {
- "answer": "the answer to the text questions right here",
- "name": "test",
- "schemaId": "ckrb1sfkn099c0y910wbo0p1a",
- "dataRow": {"id": "ckrb1sf1i1g7i0ybcdc6oc8ct"},
-}
-
-data_gen_label = lb_types.Label(
- data={"uid": "ckrb1sf1i1g7i0ybcdc6oc8ct"},
- annotations=[prompt_text_annotation],
-)
-
-"""
-Prompt annotation test
-"""
def test_serialize_label():
- serialized_label = next(NDJsonConverter().serialize([data_gen_label]))
- # Remove uuid field since this is a random value that can not be specified also meant for relationships
- del serialized_label["uuid"]
- assert serialized_label == prompt_text_ndjson
-
-
-def test_deserialize_label():
- deserialized_label = next(
- NDJsonConverter().deserialize([prompt_text_ndjson])
+ prompt_text_annotation = lb_types.PromptClassificationAnnotation(
+ feature_schema_id="ckrb1sfkn099c0y910wbo0p1a",
+ name="test",
+ extra={"uuid": "test"},
+ value=lb_types.PromptText(
+ answer="the answer to the text questions right here"
+ ),
)
- if hasattr(deserialized_label.annotations[0], "extra"):
- # Extra fields are added to deserialized label by default need removed to match
- deserialized_label.annotations[0].extra = {}
- assert deserialized_label.model_dump(
- exclude_none=True
- ) == data_gen_label.model_dump(exclude_none=True)
+ prompt_text_ndjson = {
+ "answer": "the answer to the text questions right here",
+ "name": "test",
+ "schemaId": "ckrb1sfkn099c0y910wbo0p1a",
+ "dataRow": {"id": "ckrb1sf1i1g7i0ybcdc6oc8ct"},
+ "uuid": "test",
+ }
+
+ data_gen_label = lb_types.Label(
+ data={"uid": "ckrb1sf1i1g7i0ybcdc6oc8ct"},
+ annotations=[prompt_text_annotation],
+ )
+ serialized_label = next(NDJsonConverter().serialize([data_gen_label]))
-def test_serialize_deserialize_label():
- serialized = list(NDJsonConverter.serialize([data_gen_label]))
- deserialized = next(NDJsonConverter.deserialize(serialized))
- if hasattr(deserialized.annotations[0], "extra"):
- # Extra fields are added to deserialized label by default need removed to match
- deserialized.annotations[0].extra = {}
- assert deserialized.model_dump(
- exclude_none=True
- ) == data_gen_label.model_dump(exclude_none=True)
+ assert serialized_label == prompt_text_ndjson
diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_dicom.py b/libs/labelbox/tests/data/serialization/ndjson/test_dicom.py
index 633214367..6a00fa871 100644
--- a/libs/labelbox/tests/data/serialization/ndjson/test_dicom.py
+++ b/libs/labelbox/tests/data/serialization/ndjson/test_dicom.py
@@ -1,6 +1,5 @@
from copy import copy
import pytest
-import base64
import labelbox.types as lb_types
from labelbox.data.serialization import NDJsonConverter
from labelbox.data.serialization.ndjson.objects import (
@@ -32,7 +31,7 @@
]
polyline_label = lb_types.Label(
- data=lb_types.DicomData(uid="test-uid"),
+ data=lb_types.GenericDataRowData(uid="test-uid"),
annotations=dicom_polyline_annotations,
)
@@ -59,7 +58,7 @@
}
polyline_with_global_key = lb_types.Label(
- data=lb_types.DicomData(global_key="test-global-key"),
+ data=lb_types.GenericDataRowData(global_key="test-global-key"),
annotations=dicom_polyline_annotations,
)
@@ -110,11 +109,12 @@
}
video_mask_label = lb_types.Label(
- data=lb_types.VideoData(uid="test-uid"), annotations=[video_mask_annotation]
+ data=lb_types.GenericDataRowData(uid="test-uid"),
+ annotations=[video_mask_annotation],
)
video_mask_label_with_global_key = lb_types.Label(
- data=lb_types.VideoData(global_key="test-global-key"),
+ data=lb_types.GenericDataRowData(global_key="test-global-key"),
annotations=[video_mask_annotation],
)
"""
@@ -129,11 +129,12 @@
)
dicom_mask_label = lb_types.Label(
- data=lb_types.DicomData(uid="test-uid"), annotations=[dicom_mask_annotation]
+ data=lb_types.GenericDataRowData(uid="test-uid"),
+ annotations=[dicom_mask_annotation],
)
dicom_mask_label_with_global_key = lb_types.Label(
- data=lb_types.DicomData(global_key="test-global-key"),
+ data=lb_types.GenericDataRowData(global_key="test-global-key"),
annotations=[dicom_mask_annotation],
)
@@ -181,28 +182,3 @@ def test_serialize_label(label, ndjson):
if "uuid" in serialized_label:
serialized_label.pop("uuid")
assert serialized_label == ndjson
-
-
-@pytest.mark.parametrize("label, ndjson", labels_ndjsons)
-def test_deserialize_label(label, ndjson):
- deserialized_label = next(NDJsonConverter().deserialize([ndjson]))
- if hasattr(deserialized_label.annotations[0], "extra"):
- deserialized_label.annotations[0].extra = {}
- for i, annotation in enumerate(deserialized_label.annotations):
- if hasattr(annotation, "frames"):
- assert annotation.frames == label.annotations[i].frames
- if hasattr(annotation, "value"):
- assert annotation.value == label.annotations[i].value
-
-
-@pytest.mark.parametrize("label", labels)
-def test_serialize_deserialize_label(label):
- serialized = list(NDJsonConverter.serialize([label]))
- deserialized = list(NDJsonConverter.deserialize(serialized))
- if hasattr(deserialized[0].annotations[0], "extra"):
- deserialized[0].annotations[0].extra = {}
- for i, annotation in enumerate(deserialized[0].annotations):
- if hasattr(annotation, "frames"):
- assert annotation.frames == label.annotations[i].frames
- if hasattr(annotation, "value"):
- assert annotation.value == label.annotations[i].value
diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_document.py b/libs/labelbox/tests/data/serialization/ndjson/test_document.py
index 5fe6a9789..fcdf4368b 100644
--- a/libs/labelbox/tests/data/serialization/ndjson/test_document.py
+++ b/libs/labelbox/tests/data/serialization/ndjson/test_document.py
@@ -1,6 +1,19 @@
import json
+from labelbox.data.annotation_types.data.generic_data_row_data import (
+ GenericDataRowData,
+)
+from labelbox.data.mixins import CustomMetric
import labelbox.types as lb_types
from labelbox.data.serialization.ndjson.converter import NDJsonConverter
+from labelbox.types import (
+ Label,
+ ObjectAnnotation,
+ RectangleUnit,
+ Point,
+ DocumentRectangle,
+ DocumentEntity,
+ DocumentTextSelection,
+)
bbox_annotation = lb_types.ObjectAnnotation(
name="bounding_box", # must match your ontology feature's name
@@ -13,7 +26,7 @@
)
bbox_labels = [
lb_types.Label(
- data=lb_types.DocumentData(global_key="test-global-key"),
+ data=lb_types.GenericDataRowData(global_key="test-global-key"),
annotations=[bbox_annotation],
)
]
@@ -53,10 +66,144 @@ def test_pdf():
"""
with open("tests/data/assets/ndjson/pdf_import.json", "r") as f:
data = json.load(f)
- res = list(NDJsonConverter.deserialize(data))
- res = list(NDJsonConverter.serialize(res))
+ labels = [
+ Label(
+ uid=None,
+ data=GenericDataRowData(
+ uid="cl6xnv9h61fv0085yhtoq06ht",
+ ),
+ annotations=[
+ ObjectAnnotation(
+ custom_metrics=[
+ CustomMetric(name="customMetric1", value=0.5),
+ CustomMetric(name="customMetric2", value=0.3),
+ ],
+ confidence=0.53,
+ name="boxy",
+ feature_schema_id="cl6xnuwt95lqq07330tbb3mfd",
+ extra={
+ "uuid": "5ad9c52f-058d-49c8-a749-3f20b84f8cd4",
+ },
+ value=DocumentRectangle(
+ start=Point(x=32.45, y=162.73),
+ end=Point(x=134.11, y=550.9),
+ page=4,
+ unit=RectangleUnit.POINTS,
+ ),
+ ),
+ ObjectAnnotation(
+ name="boxy",
+ feature_schema_id="cl6xnuwt95lqq07330tbb3mfd",
+ extra={
+ "uuid": "20eeef88-0294-49b4-a815-86588476bc6f",
+ },
+ value=DocumentRectangle(
+ start=Point(x=251.42, y=223.26),
+ end=Point(x=438.2, y=680.3),
+ page=7,
+ unit=RectangleUnit.POINTS,
+ ),
+ ),
+ ObjectAnnotation(
+ custom_metrics=[
+ CustomMetric(name="customMetric1", value=0.5),
+ CustomMetric(name="customMetric2", value=0.3),
+ ],
+ confidence=0.99,
+ name="boxy",
+ feature_schema_id="cl6xnuwt95lqq07330tbb3mfd",
+ extra={
+ "uuid": "641a8944-3938-409c-b4eb-dea354ed06e5",
+ },
+ value=DocumentRectangle(
+ start=Point(x=218.17, y=32.52),
+ end=Point(x=328.73, y=264.25),
+ page=6,
+ unit=RectangleUnit.POINTS,
+ ),
+ ),
+ ObjectAnnotation(
+ custom_metrics=[
+ CustomMetric(name="customMetric1", value=0.5),
+ CustomMetric(name="customMetric2", value=0.3),
+ ],
+ confidence=0.89,
+ name="boxy",
+ feature_schema_id="cl6xnuwt95lqq07330tbb3mfd",
+ extra={
+ "uuid": "ebe4da7d-08b3-480a-8d15-26552b7f011c",
+ },
+ value=DocumentRectangle(
+ start=Point(x=4.25, y=117.39),
+ end=Point(x=169.08, y=574.3100000000001),
+ page=7,
+ unit=RectangleUnit.POINTS,
+ ),
+ ),
+ ObjectAnnotation(
+ name="boxy",
+ feature_schema_id="cl6xnuwt95lqq07330tbb3mfd",
+ extra={
+ "uuid": "35c41855-575f-42cc-a2f9-1f06237e9b63",
+ },
+ value=DocumentRectangle(
+ start=Point(x=217.28, y=82.13),
+ end=Point(x=299.71000000000004, y=361.89),
+ page=8,
+ unit=RectangleUnit.POINTS,
+ ),
+ ),
+ ObjectAnnotation(
+ name="boxy",
+ feature_schema_id="cl6xnuwt95lqq07330tbb3mfd",
+ extra={
+ "uuid": "1b009654-bc17-42a2-8a71-160e7808c403",
+ },
+ value=DocumentRectangle(
+ start=Point(x=83.34, y=298.12),
+ end=Point(x=83.72, y=501.95000000000005),
+ page=3,
+ unit=RectangleUnit.POINTS,
+ ),
+ ),
+ ],
+ ),
+ Label(
+ data=GenericDataRowData(
+ uid="ckrb1sf1i1g7i0ybcdc6oc8ct",
+ ),
+ annotations=[
+ ObjectAnnotation(
+ name="named_entity",
+ feature_schema_id="cl6xnuwt95lqq07330tbb3mfd",
+ extra={
+ "uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673",
+ },
+ value=DocumentEntity(
+ text_selections=[
+ DocumentTextSelection(
+ token_ids=[
+ "3f984bf3-1d61-44f5-b59a-9658a2e3440f",
+ "3bf00b56-ff12-4e52-8cc1-08dbddb3c3b8",
+ "6e1c3420-d4b7-4c5a-8fd6-ead43bf73d80",
+ "87a43d32-af76-4a1d-b262-5c5f4d5ace3a",
+ "e8606e8a-dfd9-4c49-a635-ad5c879c75d0",
+ "67c7c19e-4654-425d-bf17-2adb8cf02c30",
+ "149c5e80-3e07-49a7-ab2d-29ddfe6a38fa",
+ "b0e94071-2187-461e-8e76-96c58738a52c",
+ ],
+ group_id="2f4336f4-a07e-4e0a-a9e1-5629b03b719b",
+ page=1,
+ )
+ ]
+ ),
+ )
+ ],
+ ),
+ ]
+
+ res = list(NDJsonConverter.serialize(labels))
assert [round_dict(x) for x in res] == [round_dict(x) for x in data]
- f.close()
def test_pdf_with_name_only():
@@ -65,26 +212,135 @@ def test_pdf_with_name_only():
"""
with open("tests/data/assets/ndjson/pdf_import_name_only.json", "r") as f:
data = json.load(f)
- res = list(NDJsonConverter.deserialize(data))
- res = list(NDJsonConverter.serialize(res))
+
+ labels = [
+ Label(
+ data=GenericDataRowData(
+ uid="cl6xnv9h61fv0085yhtoq06ht",
+ ),
+ annotations=[
+ ObjectAnnotation(
+ custom_metrics=[
+ CustomMetric(name="customMetric1", value=0.5),
+ CustomMetric(name="customMetric2", value=0.3),
+ ],
+ confidence=0.99,
+ name="boxy",
+ feature_schema_id=None,
+ extra={
+ "uuid": "5ad9c52f-058d-49c8-a749-3f20b84f8cd4",
+ },
+ value=DocumentRectangle(
+ start=Point(x=32.45, y=162.73),
+ end=Point(x=134.11, y=550.9),
+ page=4,
+ unit=RectangleUnit.POINTS,
+ ),
+ ),
+ ObjectAnnotation(
+ name="boxy",
+ extra={
+ "uuid": "20eeef88-0294-49b4-a815-86588476bc6f",
+ },
+ value=DocumentRectangle(
+ start=Point(x=251.42, y=223.26),
+ end=Point(x=438.2, y=680.3),
+ page=7,
+ unit=RectangleUnit.POINTS,
+ ),
+ ),
+ ObjectAnnotation(
+ name="boxy",
+ extra={
+ "uuid": "641a8944-3938-409c-b4eb-dea354ed06e5",
+ },
+ value=DocumentRectangle(
+ start=Point(x=218.17, y=32.52),
+ end=Point(x=328.73, y=264.25),
+ page=6,
+ unit=RectangleUnit.POINTS,
+ ),
+ ),
+ ObjectAnnotation(
+ custom_metrics=[
+ CustomMetric(name="customMetric1", value=0.5),
+ CustomMetric(name="customMetric2", value=0.3),
+ ],
+ confidence=0.74,
+ name="boxy",
+ extra={
+ "uuid": "ebe4da7d-08b3-480a-8d15-26552b7f011c",
+ },
+ value=DocumentRectangle(
+ start=Point(x=4.25, y=117.39),
+ end=Point(x=169.08, y=574.3100000000001),
+ page=7,
+ unit=RectangleUnit.POINTS,
+ ),
+ ),
+ ObjectAnnotation(
+ name="boxy",
+ extra={
+ "uuid": "35c41855-575f-42cc-a2f9-1f06237e9b63",
+ },
+ value=DocumentRectangle(
+ start=Point(x=217.28, y=82.13),
+ end=Point(x=299.71000000000004, y=361.89),
+ page=8,
+ unit=RectangleUnit.POINTS,
+ ),
+ ),
+ ObjectAnnotation(
+ name="boxy",
+ extra={
+ "uuid": "1b009654-bc17-42a2-8a71-160e7808c403",
+ },
+ value=DocumentRectangle(
+ start=Point(x=83.34, y=298.12),
+ end=Point(x=83.72, y=501.95000000000005),
+ page=3,
+ unit=RectangleUnit.POINTS,
+ ),
+ ),
+ ],
+ ),
+ Label(
+ data=GenericDataRowData(
+ uid="ckrb1sf1i1g7i0ybcdc6oc8ct",
+ ),
+ annotations=[
+ ObjectAnnotation(
+ name="named_entity",
+ extra={
+ "uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673",
+ },
+ value=DocumentEntity(
+ text_selections=[
+ DocumentTextSelection(
+ token_ids=[
+ "3f984bf3-1d61-44f5-b59a-9658a2e3440f",
+ "3bf00b56-ff12-4e52-8cc1-08dbddb3c3b8",
+ "6e1c3420-d4b7-4c5a-8fd6-ead43bf73d80",
+ "87a43d32-af76-4a1d-b262-5c5f4d5ace3a",
+ "e8606e8a-dfd9-4c49-a635-ad5c879c75d0",
+ "67c7c19e-4654-425d-bf17-2adb8cf02c30",
+ "149c5e80-3e07-49a7-ab2d-29ddfe6a38fa",
+ "b0e94071-2187-461e-8e76-96c58738a52c",
+ ],
+ group_id="2f4336f4-a07e-4e0a-a9e1-5629b03b719b",
+ page=1,
+ )
+ ]
+ ),
+ )
+ ],
+ ),
+ ]
+ res = list(NDJsonConverter.serialize(labels))
assert [round_dict(x) for x in res] == [round_dict(x) for x in data]
- f.close()
def test_pdf_bbox_serialize():
serialized = list(NDJsonConverter.serialize(bbox_labels))
serialized[0].pop("uuid")
assert serialized == bbox_ndjson
-
-
-def test_pdf_bbox_deserialize():
- deserialized = list(NDJsonConverter.deserialize(bbox_ndjson))
- deserialized[0].annotations[0].extra = {}
- assert (
- deserialized[0].annotations[0].value
- == bbox_labels[0].annotations[0].value
- )
- assert (
- deserialized[0].annotations[0].name
- == bbox_labels[0].annotations[0].name
- )
diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_export_video_objects.py b/libs/labelbox/tests/data/serialization/ndjson/test_export_video_objects.py
index 4adcd9935..a0cd13e81 100644
--- a/libs/labelbox/tests/data/serialization/ndjson/test_export_video_objects.py
+++ b/libs/labelbox/tests/data/serialization/ndjson/test_export_video_objects.py
@@ -1,16 +1,14 @@
from labelbox.data.annotation_types import Label, VideoObjectAnnotation
from labelbox.data.serialization.ndjson.converter import NDJsonConverter
from labelbox.data.annotation_types.geometry import Rectangle, Point
-from labelbox.data.annotation_types import VideoData
+from labelbox.data.annotation_types.data import GenericDataRowData
def video_bbox_label():
return Label(
uid="cl1z52xwh00050fhcmfgczqvn",
- data=VideoData(
+ data=GenericDataRowData(
uid="cklr9mr4m5iao0rb6cvxu4qbn",
- file_path=None,
- frames=None,
url="https://storage.labelbox.com/ckcz6bubudyfi0855o1dt1g9s%2F26403a22-604a-a38c-eeff-c2ed481fb40a-cat.mp4?Expires=1651677421050&KeyName=labelbox-assets-key-3&Signature=vF7gMyfHzgZdfbB8BHgd88Ws-Ms",
),
annotations=[
@@ -22,6 +20,7 @@ def video_bbox_label():
"instanceURI": None,
"color": "#1CE6FF",
"feature_id": "cl1z52xw700000fhcayaqy0ev",
+ "uuid": "b24e672b-8f79-4d96-bf5e-b552ca0820d5",
},
value=Rectangle(
extra={},
@@ -588,31 +587,4 @@ def test_serialize_video_objects():
serialized_labels = NDJsonConverter.serialize([label])
label = next(serialized_labels)
- manual_label = video_serialized_bbox_label()
-
- for key in label.keys():
- # ignore uuid because we randomize if there was none
- if key != "uuid":
- assert label[key] == manual_label[key]
-
- assert len(label["segments"]) == 2
- assert len(label["segments"][0]["keyframes"]) == 2
- assert len(label["segments"][1]["keyframes"]) == 4
-
- # #converts back only the keyframes. should be the sum of all prev segments
- deserialized_labels = NDJsonConverter.deserialize([label])
- label = next(deserialized_labels)
- assert len(label.annotations) == 6
-
-
-def test_confidence_is_ignored():
- label = video_bbox_label()
- serialized_labels = NDJsonConverter.serialize([label])
- label = next(serialized_labels)
- label["confidence"] = 0.453
- label["segments"][0]["confidence"] = 0.453
-
- deserialized_labels = NDJsonConverter.deserialize([label])
- label = next(deserialized_labels)
- for annotation in label.annotations:
- assert annotation.confidence is None
+ assert label == video_serialized_bbox_label()
diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_free_text.py b/libs/labelbox/tests/data/serialization/ndjson/test_free_text.py
index 84c017497..7b03a8447 100644
--- a/libs/labelbox/tests/data/serialization/ndjson/test_free_text.py
+++ b/libs/labelbox/tests/data/serialization/ndjson/test_free_text.py
@@ -5,7 +5,7 @@
Radio,
Text,
)
-from labelbox.data.annotation_types.data.text import TextData
+from labelbox.data.annotation_types.data import GenericDataRowData
from labelbox.data.annotation_types.label import Label
from labelbox.data.serialization.ndjson.converter import NDJsonConverter
@@ -14,7 +14,7 @@
def test_serialization():
label = Label(
uid="ckj7z2q0b0000jx6x0q2q7q0d",
- data=TextData(
+ data=GenericDataRowData(
uid="bkj7z2q0b0000jx6x0q2q7q0d",
text="This is a test",
),
@@ -34,21 +34,11 @@ def test_serialization():
assert res["answer"] == "text_answer"
assert res["dataRow"]["id"] == "bkj7z2q0b0000jx6x0q2q7q0d"
- deserialized = NDJsonConverter.deserialize([res])
- res = next(deserialized)
-
- annotation = res.annotations[0]
-
- annotation_value = annotation.value
- assert type(annotation_value) is Text
- assert annotation_value.answer == "text_answer"
- assert annotation_value.confidence == 0.5
-
def test_nested_serialization():
label = Label(
uid="ckj7z2q0b0000jx6x0q2q7q0d",
- data=TextData(
+ data=GenericDataRowData(
uid="bkj7z2q0b0000jx6x0q2q7q0d",
text="This is a test",
),
@@ -102,19 +92,3 @@ def test_nested_serialization():
assert sub_classification["name"] == "nested answer"
assert sub_classification["answer"] == "nested answer"
assert sub_classification["confidence"] == 0.7
-
- deserialized = NDJsonConverter.deserialize([res])
- res = next(deserialized)
- annotation = res.annotations[0]
- answer = annotation.value.answer[0]
- assert answer.confidence == 0.9
- assert answer.name == "first_answer"
-
- classification_answer = answer.classifications[0].value.answer
- assert classification_answer.confidence == 0.8
- assert classification_answer.name == "first_sub_radio_answer"
-
- sub_classification_answer = classification_answer.classifications[0].value
- assert type(sub_classification_answer) is Text
- assert sub_classification_answer.answer == "nested answer"
- assert sub_classification_answer.confidence == 0.7
diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_generic_data_row_data.py b/libs/labelbox/tests/data/serialization/ndjson/test_generic_data_row_data.py
new file mode 100644
index 000000000..0dc4c21c0
--- /dev/null
+++ b/libs/labelbox/tests/data/serialization/ndjson/test_generic_data_row_data.py
@@ -0,0 +1,79 @@
+from labelbox.data.annotation_types.data.generic_data_row_data import (
+ GenericDataRowData,
+)
+from labelbox.data.serialization.ndjson.converter import NDJsonConverter
+from labelbox.types import Label, ClassificationAnnotation, Text
+
+
+def test_generic_data_row_global_key():
+ label_1 = Label(
+ data=GenericDataRowData(global_key="test"),
+ annotations=[
+ ClassificationAnnotation(
+ name="free_text",
+ value=Text(answer="sample text"),
+ extra={"uuid": "141c3592-e5f0-4866-9943-d4a21fd47eb0"},
+ )
+ ],
+ )
+ label_2 = Label(
+ data={"global_key": "test"},
+ annotations=[
+ ClassificationAnnotation(
+ name="free_text",
+ value=Text(answer="sample text"),
+ extra={"uuid": "141c3592-e5f0-4866-9943-d4a21fd47eb0"},
+ )
+ ],
+ )
+
+ expected_result = [
+ {
+ "answer": "sample text",
+ "dataRow": {"globalKey": "test"},
+ "name": "free_text",
+ "uuid": "141c3592-e5f0-4866-9943-d4a21fd47eb0",
+ }
+ ]
+ assert (
+ list(NDJsonConverter.serialize([label_1]))
+ == list(NDJsonConverter.serialize([label_2]))
+ == expected_result
+ )
+
+
+def test_generic_data_row_id():
+ label_1 = Label(
+ data=GenericDataRowData(uid="test"),
+ annotations=[
+ ClassificationAnnotation(
+ name="free_text",
+ value=Text(answer="sample text"),
+ extra={"uuid": "141c3592-e5f0-4866-9943-d4a21fd47eb0"},
+ )
+ ],
+ )
+ label_2 = Label(
+ data={"uid": "test"},
+ annotations=[
+ ClassificationAnnotation(
+ name="free_text",
+ value=Text(answer="sample text"),
+ extra={"uuid": "141c3592-e5f0-4866-9943-d4a21fd47eb0"},
+ )
+ ],
+ )
+
+ expected_result = [
+ {
+ "answer": "sample text",
+ "dataRow": {"id": "test"},
+ "name": "free_text",
+ "uuid": "141c3592-e5f0-4866-9943-d4a21fd47eb0",
+ }
+ ]
+ assert (
+ list(NDJsonConverter.serialize([label_1]))
+ == list(NDJsonConverter.serialize([label_2]))
+ == expected_result
+ )
diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_global_key.py b/libs/labelbox/tests/data/serialization/ndjson/test_global_key.py
index 2b3fa7f8c..d104a691e 100644
--- a/libs/labelbox/tests/data/serialization/ndjson/test_global_key.py
+++ b/libs/labelbox/tests/data/serialization/ndjson/test_global_key.py
@@ -1,73 +1,74 @@
-import json
-import pytest
-
-from labelbox.data.serialization.ndjson.classification import NDRadio
-
+from labelbox.data.annotation_types.data.generic_data_row_data import (
+ GenericDataRowData,
+)
from labelbox.data.serialization.ndjson.converter import NDJsonConverter
-from labelbox.data.serialization.ndjson.objects import NDLine
-
-
-def round_dict(data):
- if isinstance(data, dict):
- for key in data:
- if isinstance(data[key], float):
- data[key] = int(data[key])
- elif isinstance(data[key], dict):
- data[key] = round_dict(data[key])
- elif isinstance(data[key], (list, tuple)):
- data[key] = [round_dict(r) for r in data[key]]
+from labelbox.types import (
+ Label,
+ ClassificationAnnotation,
+ Radio,
+ ClassificationAnswer,
+)
- return data
+def test_generic_data_row_global_key_included():
+ expected = [
+ {
+ "answer": {"schemaId": "ckrb1sfl8099g0y91cxbd5ftb"},
+ "dataRow": {"globalKey": "ckrb1sf1i1g7i0ybcdc6oc8ct"},
+ "schemaId": "ckrb1sfjx099a0y914hl319ie",
+ "uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673",
+ }
+ ]
-@pytest.mark.parametrize(
- "filename",
- [
- "tests/data/assets/ndjson/classification_import_global_key.json",
- "tests/data/assets/ndjson/metric_import_global_key.json",
- "tests/data/assets/ndjson/polyline_import_global_key.json",
- "tests/data/assets/ndjson/text_entity_import_global_key.json",
- "tests/data/assets/ndjson/conversation_entity_import_global_key.json",
- ],
-)
-def test_many_types(filename: str):
- with open(filename, "r") as f:
- data = json.load(f)
- res = list(NDJsonConverter.deserialize(data))
- res = list(NDJsonConverter.serialize(res))
- assert res == data
- f.close()
+ label = Label(
+ data=GenericDataRowData(
+ global_key="ckrb1sf1i1g7i0ybcdc6oc8ct",
+ ),
+ annotations=[
+ ClassificationAnnotation(
+ feature_schema_id="ckrb1sfjx099a0y914hl319ie",
+ extra={"uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673"},
+ value=Radio(
+ answer=ClassificationAnswer(
+ feature_schema_id="ckrb1sfl8099g0y91cxbd5ftb",
+ ),
+ ),
+ )
+ ],
+ )
+ res = list(NDJsonConverter.serialize([label]))
-def test_image():
- with open(
- "tests/data/assets/ndjson/image_import_global_key.json", "r"
- ) as f:
- data = json.load(f)
- res = list(NDJsonConverter.deserialize(data))
- res = list(NDJsonConverter.serialize(res))
- for r in res:
- r.pop("classifications", None)
- assert [round_dict(x) for x in res] == [round_dict(x) for x in data]
- f.close()
+ assert res == expected
-def test_pdf():
- with open("tests/data/assets/ndjson/pdf_import_global_key.json", "r") as f:
- data = json.load(f)
- res = list(NDJsonConverter.deserialize(data))
- res = list(NDJsonConverter.serialize(res))
- assert [round_dict(x) for x in res] == [round_dict(x) for x in data]
- f.close()
+def test_dict_data_row_global_key_included():
+ expected = [
+ {
+ "answer": {"schemaId": "ckrb1sfl8099g0y91cxbd5ftb"},
+ "dataRow": {"globalKey": "ckrb1sf1i1g7i0ybcdc6oc8ct"},
+ "schemaId": "ckrb1sfjx099a0y914hl319ie",
+ "uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673",
+ }
+ ]
+ label = Label(
+ data={
+ "global_key": "ckrb1sf1i1g7i0ybcdc6oc8ct",
+ },
+ annotations=[
+ ClassificationAnnotation(
+ feature_schema_id="ckrb1sfjx099a0y914hl319ie",
+ extra={"uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673"},
+ value=Radio(
+ answer=ClassificationAnswer(
+ feature_schema_id="ckrb1sfl8099g0y91cxbd5ftb",
+ ),
+ ),
+ )
+ ],
+ )
-def test_video():
- with open(
- "tests/data/assets/ndjson/video_import_global_key.json", "r"
- ) as f:
- data = json.load(f)
+ res = list(NDJsonConverter.serialize([label]))
- res = list(NDJsonConverter.deserialize(data))
- res = list(NDJsonConverter.serialize(res))
- assert res == [data[2], data[0], data[1], data[3], data[4], data[5]]
- f.close()
+ assert res == expected
diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_image.py b/libs/labelbox/tests/data/serialization/ndjson/test_image.py
index 1729e1f46..4d615658c 100644
--- a/libs/labelbox/tests/data/serialization/ndjson/test_image.py
+++ b/libs/labelbox/tests/data/serialization/ndjson/test_image.py
@@ -1,4 +1,8 @@
import json
+from labelbox.data.annotation_types.data.generic_data_row_data import (
+ GenericDataRowData,
+)
+from labelbox.data.mixins import CustomMetric
import numpy as np
import cv2
@@ -7,9 +11,9 @@
Mask,
Label,
ObjectAnnotation,
- ImageData,
MaskData,
)
+from labelbox.types import Rectangle, Polygon, Point
def round_dict(data):
@@ -29,12 +33,74 @@ def test_image():
with open("tests/data/assets/ndjson/image_import.json", "r") as file:
data = json.load(file)
- res = list(NDJsonConverter.deserialize(data))
- res = list(NDJsonConverter.serialize(res))
+ labels = [
+ Label(
+ data=GenericDataRowData(
+ uid="ckrazctum0z8a0ybc0b0o0g0v",
+ ),
+ annotations=[
+ ObjectAnnotation(
+ custom_metrics=[
+ CustomMetric(name="customMetric1", value=0.4)
+ ],
+ confidence=0.851,
+ feature_schema_id="ckrazcueb16og0z6609jj7y3y",
+ extra={
+ "uuid": "b862c586-8614-483c-b5e6-82810f70cac0",
+ },
+ value=Rectangle(
+ start=Point(extra={}, x=2275.0, y=1352.0),
+ end=Point(extra={}, x=2414.0, y=1702.0),
+ ),
+ ),
+ ObjectAnnotation(
+ custom_metrics=[
+ CustomMetric(name="customMetric1", value=0.3)
+ ],
+ confidence=0.834,
+ feature_schema_id="ckrazcuec16ok0z66f956apb7",
+ extra={
+ "uuid": "751fc725-f7b6-48ed-89b0-dd7d94d08af6",
+ },
+ value=Mask(
+ mask=MaskData(
+ url="https://storage.labelbox.com/ckqcx1czn06830y61gh9v02cs%2F3e729327-f038-f66c-186e-45e921ef9717-1?Expires=1626806874672&KeyName=labelbox-assets-key-3&Signature=YsUOGKrsqmAZ68vT9BlPJOaRyLY",
+ ),
+ color=[255, 0, 0],
+ ),
+ ),
+ ObjectAnnotation(
+ custom_metrics=[
+ CustomMetric(name="customMetric1", value=0.9)
+ ],
+ confidence=0.986,
+ feature_schema_id="ckrazcuec16oi0z66dzrd8pfl",
+ extra={
+ "uuid": "43d719ac-5d7f-4aea-be00-2ebfca0900fd",
+ },
+ value=Polygon(
+ points=[
+ Point(x=10.0, y=20.0),
+ Point(x=15.0, y=20.0),
+ Point(x=20.0, y=25.0),
+ Point(x=10.0, y=20.0),
+ ],
+ ),
+ ),
+ ObjectAnnotation(
+ feature_schema_id="ckrazcuec16om0z66bhhh4tp7",
+ extra={
+ "uuid": "b98f3a45-3328-41a0-9077-373a8177ebf2",
+ },
+ value=Point(x=2122.0, y=1457.0),
+ ),
+ ],
+ )
+ ]
- for r in res:
- r.pop("classifications", None)
- assert [round_dict(x) for x in res] == [round_dict(x) for x in data]
+ res = list(NDJsonConverter.serialize(labels))
+ del res[1]["mask"]["colorRGB"] # JSON does not support tuples
+ assert res == data
def test_image_with_name_only():
@@ -43,11 +109,74 @@ def test_image_with_name_only():
) as file:
data = json.load(file)
- res = list(NDJsonConverter.deserialize(data))
- res = list(NDJsonConverter.serialize(res))
- for r in res:
- r.pop("classifications", None)
- assert [round_dict(x) for x in res] == [round_dict(x) for x in data]
+ labels = [
+ Label(
+ data=GenericDataRowData(
+ uid="ckrazctum0z8a0ybc0b0o0g0v",
+ ),
+ annotations=[
+ ObjectAnnotation(
+ custom_metrics=[
+ CustomMetric(name="customMetric1", value=0.4)
+ ],
+ confidence=0.851,
+ name="ckrazcueb16og0z6609jj7y3y",
+ extra={
+ "uuid": "b862c586-8614-483c-b5e6-82810f70cac0",
+ },
+ value=Rectangle(
+ start=Point(extra={}, x=2275.0, y=1352.0),
+ end=Point(extra={}, x=2414.0, y=1702.0),
+ ),
+ ),
+ ObjectAnnotation(
+ custom_metrics=[
+ CustomMetric(name="customMetric1", value=0.3)
+ ],
+ confidence=0.834,
+ name="ckrazcuec16ok0z66f956apb7",
+ extra={
+ "uuid": "751fc725-f7b6-48ed-89b0-dd7d94d08af6",
+ },
+ value=Mask(
+ mask=MaskData(
+ url="https://storage.labelbox.com/ckqcx1czn06830y61gh9v02cs%2F3e729327-f038-f66c-186e-45e921ef9717-1?Expires=1626806874672&KeyName=labelbox-assets-key-3&Signature=YsUOGKrsqmAZ68vT9BlPJOaRyLY",
+ ),
+ color=[255, 0, 0],
+ ),
+ ),
+ ObjectAnnotation(
+ custom_metrics=[
+ CustomMetric(name="customMetric1", value=0.9)
+ ],
+ confidence=0.986,
+ name="ckrazcuec16oi0z66dzrd8pfl",
+ extra={
+ "uuid": "43d719ac-5d7f-4aea-be00-2ebfca0900fd",
+ },
+ value=Polygon(
+ points=[
+ Point(x=10.0, y=20.0),
+ Point(x=15.0, y=20.0),
+ Point(x=20.0, y=25.0),
+ Point(x=10.0, y=20.0),
+ ],
+ ),
+ ),
+ ObjectAnnotation(
+ name="ckrazcuec16om0z66bhhh4tp7",
+ extra={
+ "uuid": "b98f3a45-3328-41a0-9077-373a8177ebf2",
+ },
+ value=Point(x=2122.0, y=1457.0),
+ ),
+ ],
+ )
+ ]
+
+ res = list(NDJsonConverter.serialize(labels))
+ del res[1]["mask"]["colorRGB"] # JSON does not support tuples
+ assert res == data
def test_mask():
@@ -57,10 +186,11 @@ def test_mask():
"schemaId": "ckrazcueb16og0z6609jj7y3y",
"dataRow": {"id": "ckrazctum0z8a0ybc0b0o0g0v"},
"mask": {
- "png": "iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAAAAACoWZBhAAAAMklEQVR4nD3MuQ3AQADDMOqQ/Vd2ijytaSiZLAcYuyLEYYYl9cvrlGftTHvsYl+u/3EDv0QLI8Z7FlwAAAAASUVORK5CYII="
+ "png": "iVBORw0KGgoAAAANSUhEUgAAAAIAAAACCAAAAABX3VL4AAAADklEQVR4nGNgYGBkZAAAAAsAA+RRQXwAAAAASUVORK5CYII="
},
"confidence": 0.8,
"customMetrics": [{"name": "customMetric1", "value": 0.4}],
+ "classifications": [],
},
{
"uuid": "751fc725-f7b6-48ed-89b0-dd7d94d08af6",
@@ -68,16 +198,54 @@ def test_mask():
"dataRow": {"id": "ckrazctum0z8a0ybc0b0o0g0v"},
"mask": {
"instanceURI": "https://storage.labelbox.com/ckqcx1czn06830y61gh9v02cs%2F3e729327-f038-f66c-186e-45e921ef9717-1?Expires=1626806874672&KeyName=labelbox-assets-key-3&Signature=YsUOGKrsqmAZ68vT9BlPJOaRyLY",
- "colorRGB": [255, 0, 0],
+ "colorRGB": (255, 0, 0),
},
+ "classifications": [],
},
]
- res = list(NDJsonConverter.deserialize(data))
- res = list(NDJsonConverter.serialize(res))
- for r in res:
- r.pop("classifications", None)
- assert [round_dict(x) for x in res] == [round_dict(x) for x in data]
+ mask_numpy = np.array([[[1, 1, 0], [1, 0, 1]], [[1, 1, 1], [1, 1, 1]]])
+ mask_numpy = mask_numpy.astype(np.uint8)
+
+ labels = [
+ Label(
+ data=GenericDataRowData(
+ uid="ckrazctum0z8a0ybc0b0o0g0v",
+ ),
+ annotations=[
+ ObjectAnnotation(
+ custom_metrics=[
+ CustomMetric(name="customMetric1", value=0.4)
+ ],
+ confidence=0.8,
+ feature_schema_id="ckrazcueb16og0z6609jj7y3y",
+ extra={
+ "uuid": "b862c586-8614-483c-b5e6-82810f70cac0",
+ },
+ value=Mask(
+ mask=MaskData(arr=mask_numpy),
+ color=(1, 1, 1),
+ ),
+ ),
+ ObjectAnnotation(
+ feature_schema_id="ckrazcuec16ok0z66f956apb7",
+ extra={
+ "uuid": "751fc725-f7b6-48ed-89b0-dd7d94d08af6",
+ },
+ value=Mask(
+ extra={},
+ mask=MaskData(
+ url="https://storage.labelbox.com/ckqcx1czn06830y61gh9v02cs%2F3e729327-f038-f66c-186e-45e921ef9717-1?Expires=1626806874672&KeyName=labelbox-assets-key-3&Signature=YsUOGKrsqmAZ68vT9BlPJOaRyLY",
+ ),
+ color=(255, 0, 0),
+ ),
+ ),
+ ],
+ )
+ ]
+ res = list(NDJsonConverter.serialize(labels))
+
+ assert res == data
def test_mask_from_arr():
@@ -93,7 +261,7 @@ def test_mask_from_arr():
),
)
],
- data=ImageData(uid="0" * 25),
+ data=GenericDataRowData(uid="0" * 25),
)
res = next(NDJsonConverter.serialize([label]))
res.pop("uuid")
diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_metric.py b/libs/labelbox/tests/data/serialization/ndjson/test_metric.py
index 45c5c67bf..40e098405 100644
--- a/libs/labelbox/tests/data/serialization/ndjson/test_metric.py
+++ b/libs/labelbox/tests/data/serialization/ndjson/test_metric.py
@@ -1,38 +1,166 @@
import json
+from labelbox.data.annotation_types.data.generic_data_row_data import (
+ GenericDataRowData,
+)
+from labelbox.data.annotation_types.metrics.confusion_matrix import (
+ ConfusionMatrixMetric,
+)
from labelbox.data.serialization.ndjson.converter import NDJsonConverter
+from labelbox.types import (
+ Label,
+ ScalarMetric,
+ ScalarMetricAggregation,
+ ConfusionMatrixAggregation,
+)
def test_metric():
with open("tests/data/assets/ndjson/metric_import.json", "r") as file:
data = json.load(file)
- label_list = list(NDJsonConverter.deserialize(data))
- reserialized = list(NDJsonConverter.serialize(label_list))
- assert reserialized == data
+ labels = [
+ Label(
+ data=GenericDataRowData(
+ uid="ckrmdnqj4000007msh9p2a27r",
+ ),
+ annotations=[
+ ScalarMetric(
+ value=0.1,
+ extra={"uuid": "a22bbf6e-b2da-4abe-9a11-df84759f7672"},
+ aggregation=ScalarMetricAggregation.ARITHMETIC_MEAN,
+ )
+ ],
+ )
+ ]
+
+ res = list(NDJsonConverter.serialize(labels))
+ assert res == data
def test_custom_scalar_metric():
- with open(
- "tests/data/assets/ndjson/custom_scalar_import.json", "r"
- ) as file:
- data = json.load(file)
+ data = [
+ {
+ "uuid": "a22bbf6e-b2da-4abe-9a11-df84759f7672",
+ "dataRow": {"id": "ckrmdnqj4000007msh9p2a27r"},
+ "metricValue": 0.1,
+ "metricName": "custom_iou",
+ "featureName": "sample_class",
+ "subclassName": "sample_subclass",
+ "aggregation": "SUM",
+ },
+ {
+ "uuid": "a22bbf6e-b2da-4abe-9a11-df84759f7673",
+ "dataRow": {"id": "ckrmdnqj4000007msh9p2a27r"},
+ "metricValue": 0.1,
+ "metricName": "custom_iou",
+ "featureName": "sample_class",
+ "aggregation": "SUM",
+ },
+ {
+ "uuid": "a22bbf6e-b2da-4abe-9a11-df84759f7674",
+ "dataRow": {"id": "ckrmdnqj4000007msh9p2a27r"},
+ "metricValue": {0.1: 0.1, 0.2: 0.5},
+ "metricName": "custom_iou",
+ "aggregation": "SUM",
+ },
+ ]
+
+ labels = [
+ Label(
+ data=GenericDataRowData(
+ uid="ckrmdnqj4000007msh9p2a27r",
+ ),
+ annotations=[
+ ScalarMetric(
+ value=0.1,
+ feature_name="sample_class",
+ subclass_name="sample_subclass",
+ extra={"uuid": "a22bbf6e-b2da-4abe-9a11-df84759f7672"},
+ metric_name="custom_iou",
+ aggregation=ScalarMetricAggregation.SUM,
+ ),
+ ScalarMetric(
+ value=0.1,
+ feature_name="sample_class",
+ extra={"uuid": "a22bbf6e-b2da-4abe-9a11-df84759f7673"},
+ metric_name="custom_iou",
+ aggregation=ScalarMetricAggregation.SUM,
+ ),
+ ScalarMetric(
+ value={"0.1": 0.1, "0.2": 0.5},
+ extra={"uuid": "a22bbf6e-b2da-4abe-9a11-df84759f7674"},
+ metric_name="custom_iou",
+ aggregation=ScalarMetricAggregation.SUM,
+ ),
+ ],
+ )
+ ]
+
+ res = list(NDJsonConverter.serialize(labels))
- label_list = list(NDJsonConverter.deserialize(data))
- reserialized = list(NDJsonConverter.serialize(label_list))
- assert json.dumps(reserialized, sort_keys=True) == json.dumps(
- data, sort_keys=True
- )
+ assert res == data
def test_custom_confusion_matrix_metric():
- with open(
- "tests/data/assets/ndjson/custom_confusion_matrix_import.json", "r"
- ) as file:
- data = json.load(file)
+ data = [
+ {
+ "uuid": "a22bbf6e-b2da-4abe-9a11-df84759f7672",
+ "dataRow": {"id": "ckrmdnqj4000007msh9p2a27r"},
+ "metricValue": (1, 1, 2, 3),
+ "metricName": "50%_iou",
+ "featureName": "sample_class",
+ "subclassName": "sample_subclass",
+ "aggregation": "CONFUSION_MATRIX",
+ },
+ {
+ "uuid": "a22bbf6e-b2da-4abe-9a11-df84759f7673",
+ "dataRow": {"id": "ckrmdnqj4000007msh9p2a27r"},
+ "metricValue": (0, 1, 2, 5),
+ "metricName": "50%_iou",
+ "featureName": "sample_class",
+ "aggregation": "CONFUSION_MATRIX",
+ },
+ {
+ "uuid": "a22bbf6e-b2da-4abe-9a11-df84759f7674",
+ "dataRow": {"id": "ckrmdnqj4000007msh9p2a27r"},
+ "metricValue": {0.1: (0, 1, 2, 3), 0.2: (5, 3, 4, 3)},
+ "metricName": "50%_iou",
+ "aggregation": "CONFUSION_MATRIX",
+ },
+ ]
+
+ labels = [
+ Label(
+ data=GenericDataRowData(
+ uid="ckrmdnqj4000007msh9p2a27r",
+ ),
+ annotations=[
+ ConfusionMatrixMetric(
+ value=(1, 1, 2, 3),
+ feature_name="sample_class",
+ subclass_name="sample_subclass",
+ extra={"uuid": "a22bbf6e-b2da-4abe-9a11-df84759f7672"},
+ metric_name="50%_iou",
+ aggregation=ConfusionMatrixAggregation.CONFUSION_MATRIX,
+ ),
+ ConfusionMatrixMetric(
+ value=(0, 1, 2, 5),
+ feature_name="sample_class",
+ extra={"uuid": "a22bbf6e-b2da-4abe-9a11-df84759f7673"},
+ metric_name="50%_iou",
+ aggregation=ConfusionMatrixAggregation.CONFUSION_MATRIX,
+ ),
+ ConfusionMatrixMetric(
+ value={0.1: (0, 1, 2, 3), 0.2: (5, 3, 4, 3)},
+ extra={"uuid": "a22bbf6e-b2da-4abe-9a11-df84759f7674"},
+ metric_name="50%_iou",
+ aggregation=ConfusionMatrixAggregation.CONFUSION_MATRIX,
+ ),
+ ],
+ )
+ ]
+
+ res = list(NDJsonConverter.serialize(labels))
- label_list = list(NDJsonConverter.deserialize(data))
- reserialized = list(NDJsonConverter.serialize(label_list))
- assert json.dumps(reserialized, sort_keys=True) == json.dumps(
- data, sort_keys=True
- )
+ assert data == res
diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_mmc.py b/libs/labelbox/tests/data/serialization/ndjson/test_mmc.py
index 69594ff73..202f793fe 100644
--- a/libs/labelbox/tests/data/serialization/ndjson/test_mmc.py
+++ b/libs/labelbox/tests/data/serialization/ndjson/test_mmc.py
@@ -1,32 +1,125 @@
import json
+from labelbox.data.annotation_types.data.generic_data_row_data import (
+ GenericDataRowData,
+)
import pytest
from labelbox.data.serialization import NDJsonConverter
+from labelbox.types import (
+ Label,
+ MessageEvaluationTaskAnnotation,
+ MessageSingleSelectionTask,
+ MessageMultiSelectionTask,
+ MessageInfo,
+ OrderedMessageInfo,
+ MessageRankingTask,
+)
def test_message_task_annotation_serialization():
with open("tests/data/assets/ndjson/mmc_import.json", "r") as file:
data = json.load(file)
- deserialized = list(NDJsonConverter.deserialize(data))
- reserialized = list(NDJsonConverter.serialize(deserialized))
+ labels = [
+ Label(
+ data=GenericDataRowData(
+ uid="cnjencjencjfencvj",
+ ),
+ annotations=[
+ MessageEvaluationTaskAnnotation(
+ name="single-selection",
+ extra={"uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72"},
+ value=MessageSingleSelectionTask(
+ message_id="clxfzocbm00083b6v8vczsept",
+ model_config_name="GPT 5",
+ parent_message_id="clxfznjb800073b6v43ppx9ca",
+ ),
+ )
+ ],
+ ),
+ Label(
+ data=GenericDataRowData(
+ uid="cfcerfvergerfefj",
+ ),
+ annotations=[
+ MessageEvaluationTaskAnnotation(
+ name="multi-selection",
+ extra={"uuid": "gferf3a57-597e-48cb-8d8d-a8526fefe72"},
+ value=MessageMultiSelectionTask(
+ parent_message_id="clxfznjb800073b6v43ppx9ca",
+ selected_messages=[
+ MessageInfo(
+ message_id="clxfzocbm00083b6v8vczsept",
+ model_config_name="GPT 5",
+ )
+ ],
+ ),
+ )
+ ],
+ ),
+ Label(
+ data=GenericDataRowData(
+ uid="cwefgtrgrthveferfferffr",
+ ),
+ annotations=[
+ MessageEvaluationTaskAnnotation(
+ name="ranking",
+ extra={"uuid": "hybe3a57-5gt7e-48tgrb-8d8d-a852dswqde72"},
+ value=MessageRankingTask(
+ parent_message_id="clxfznjb800073b6v43ppx9ca",
+ ranked_messages=[
+ OrderedMessageInfo(
+ message_id="clxfzocbm00083b6v8vczsept",
+ model_config_name="GPT 4 with temperature 0.7",
+ order=1,
+ ),
+ OrderedMessageInfo(
+ message_id="clxfzocbm00093b6vx4ndisub",
+ model_config_name="GPT 5",
+ order=2,
+ ),
+ ],
+ ),
+ )
+ ],
+ ),
+ ]
- assert data == reserialized
+ res = list(NDJsonConverter.serialize(labels))
+ assert res == data
-def test_mesage_ranking_task_wrong_order_serialization():
- with open("tests/data/assets/ndjson/mmc_import.json", "r") as file:
- data = json.load(file)
-
- some_ranking_task = next(
- task
- for task in data
- if task["messageEvaluationTask"]["format"] == "message-ranking"
- )
- some_ranking_task["messageEvaluationTask"]["data"]["rankedMessages"][0][
- "order"
- ] = 3
+def test_mesage_ranking_task_wrong_order_serialization():
with pytest.raises(ValueError):
- list(NDJsonConverter.deserialize([some_ranking_task]))
+ (
+ Label(
+ data=GenericDataRowData(
+ uid="cwefgtrgrthveferfferffr",
+ ),
+ annotations=[
+ MessageEvaluationTaskAnnotation(
+ name="ranking",
+ extra={
+ "uuid": "hybe3a57-5gt7e-48tgrb-8d8d-a852dswqde72"
+ },
+ value=MessageRankingTask(
+ parent_message_id="clxfznjb800073b6v43ppx9ca",
+ ranked_messages=[
+ OrderedMessageInfo(
+ message_id="clxfzocbm00093b6vx4ndisub",
+ model_config_name="GPT 5",
+ order=1,
+ ),
+ OrderedMessageInfo(
+ message_id="clxfzocbm00083b6v8vczsept",
+ model_config_name="GPT 4 with temperature 0.7",
+ order=1,
+ ),
+ ],
+ ),
+ )
+ ],
+ ),
+ )
diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_ndlabel_subclass_matching.py b/libs/labelbox/tests/data/serialization/ndjson/test_ndlabel_subclass_matching.py
deleted file mode 100644
index 790bd87b3..000000000
--- a/libs/labelbox/tests/data/serialization/ndjson/test_ndlabel_subclass_matching.py
+++ /dev/null
@@ -1,19 +0,0 @@
-import json
-from labelbox.data.serialization.ndjson.label import NDLabel
-from labelbox.data.serialization.ndjson.objects import NDDocumentRectangle
-import pytest
-
-
-def test_bad_annotation_input():
- data = [{"test": 3}]
- with pytest.raises(ValueError):
- NDLabel(**{"annotations": data})
-
-
-def test_correct_annotation_input():
- with open("tests/data/assets/ndjson/pdf_import_name_only.json", "r") as f:
- data = json.load(f)
- assert isinstance(
- NDLabel(**{"annotations": [data[0]]}).annotations[0],
- NDDocumentRectangle,
- )
diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_nested.py b/libs/labelbox/tests/data/serialization/ndjson/test_nested.py
index e0f0df0e6..3633c9cbe 100644
--- a/libs/labelbox/tests/data/serialization/ndjson/test_nested.py
+++ b/libs/labelbox/tests/data/serialization/ndjson/test_nested.py
@@ -1,13 +1,135 @@
import json
+from labelbox.data.annotation_types.data.generic_data_row_data import (
+ GenericDataRowData,
+)
+from labelbox.data.mixins import CustomMetric
from labelbox.data.serialization.ndjson.converter import NDJsonConverter
+from labelbox.types import (
+ Label,
+ ObjectAnnotation,
+ Rectangle,
+ Point,
+ ClassificationAnnotation,
+ Radio,
+ ClassificationAnswer,
+ Text,
+ Checklist,
+)
def test_nested():
with open("tests/data/assets/ndjson/nested_import.json", "r") as file:
data = json.load(file)
- res = list(NDJsonConverter.deserialize(data))
- res = list(NDJsonConverter.serialize(res))
+ labels = [
+ Label(
+ data=GenericDataRowData(
+ uid="ckrb1sf1i1g7i0ybcdc6oc8ct",
+ ),
+ annotations=[
+ ObjectAnnotation(
+ feature_schema_id="ckrb1sfjx099a0y914hl319ie",
+ extra={
+ "uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673",
+ },
+ value=Rectangle(
+ start=Point(x=2275.0, y=1352.0),
+ end=Point(x=2414.0, y=1702.0),
+ ),
+ classifications=[
+ ClassificationAnnotation(
+ feature_schema_id="ckrb1sfkn099c0y910wbo0p1a",
+ value=Radio(
+ answer=ClassificationAnswer(
+ custom_metrics=[
+ CustomMetric(
+ name="customMetric1", value=0.5
+ ),
+ CustomMetric(
+ name="customMetric2", value=0.3
+ ),
+ ],
+ confidence=0.34,
+ feature_schema_id="ckrb1sfl8099g0y91cxbd5ftb",
+ ),
+ ),
+ )
+ ],
+ ),
+ ObjectAnnotation(
+ feature_schema_id="ckrb1sfjx099a0y914hl319ie",
+ extra={
+ "uuid": "d009925d-91a3-4f67-abd9-753453f5a584",
+ },
+ value=Rectangle(
+ start=Point(x=2089.0, y=1251.0),
+ end=Point(x=2247.0, y=1679.0),
+ ),
+ classifications=[
+ ClassificationAnnotation(
+ feature_schema_id="ckrb1sfkn099c0y910wbo0p1a",
+ value=Radio(
+ answer=ClassificationAnswer(
+ feature_schema_id="ckrb1sfl8099e0y919v260awv",
+ ),
+ ),
+ )
+ ],
+ ),
+ ObjectAnnotation(
+ feature_schema_id="ckrb1sfjx099a0y914hl319ie",
+ extra={
+ "uuid": "5d03213e-4408-456c-9eca-cf0723202961",
+ },
+ value=Rectangle(
+ start=Point(x=2089.0, y=1251.0),
+ end=Point(x=2247.0, y=1679.0),
+ ),
+ classifications=[
+ ClassificationAnnotation(
+ feature_schema_id="ckrb1sfkn099c0y910wbo0p1a",
+ value=Checklist(
+ answer=[
+ ClassificationAnswer(
+ custom_metrics=[
+ CustomMetric(
+ name="customMetric1", value=0.5
+ ),
+ CustomMetric(
+ name="customMetric2", value=0.3
+ ),
+ ],
+ confidence=0.894,
+ feature_schema_id="ckrb1sfl8099e0y919v260awv",
+ )
+ ],
+ ),
+ )
+ ],
+ ),
+ ObjectAnnotation(
+ feature_schema_id="ckrb1sfjx099a0y914hl319ie",
+ extra={
+ "uuid": "d50812f6-34eb-4f12-b3cb-bbde51a31d83",
+ },
+ value=Rectangle(
+ start=Point(x=2089.0, y=1251.0),
+ end=Point(x=2247.0, y=1679.0),
+ ),
+ classifications=[
+ ClassificationAnnotation(
+ feature_schema_id="ckrb1sfkn099c0y910wbo0p1a",
+ extra={},
+ value=Text(
+ answer="a string",
+ ),
+ )
+ ],
+ ),
+ ],
+ )
+ ]
+ res = list(NDJsonConverter.serialize(labels))
assert res == data
@@ -16,6 +138,112 @@ def test_nested_name_only():
"tests/data/assets/ndjson/nested_import_name_only.json", "r"
) as file:
data = json.load(file)
- res = list(NDJsonConverter.deserialize(data))
- res = list(NDJsonConverter.serialize(res))
+ labels = [
+ Label(
+ data=GenericDataRowData(
+ uid="ckrb1sf1i1g7i0ybcdc6oc8ct",
+ ),
+ annotations=[
+ ObjectAnnotation(
+ name="box a",
+ extra={
+ "uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673",
+ },
+ value=Rectangle(
+ start=Point(x=2275.0, y=1352.0),
+ end=Point(x=2414.0, y=1702.0),
+ ),
+ classifications=[
+ ClassificationAnnotation(
+ name="classification a",
+ value=Radio(
+ answer=ClassificationAnswer(
+ custom_metrics=[
+ CustomMetric(
+ name="customMetric1", value=0.5
+ ),
+ CustomMetric(
+ name="customMetric2", value=0.3
+ ),
+ ],
+ confidence=0.811,
+ name="first answer",
+ ),
+ ),
+ )
+ ],
+ ),
+ ObjectAnnotation(
+ name="box b",
+ extra={
+ "uuid": "d009925d-91a3-4f67-abd9-753453f5a584",
+ },
+ value=Rectangle(
+ start=Point(x=2089.0, y=1251.0),
+ end=Point(x=2247.0, y=1679.0),
+ ),
+ classifications=[
+ ClassificationAnnotation(
+ name="classification b",
+ value=Radio(
+ answer=ClassificationAnswer(
+ custom_metrics=[
+ CustomMetric(
+ name="customMetric1", value=0.5
+ ),
+ CustomMetric(
+ name="customMetric2", value=0.3
+ ),
+ ],
+ confidence=0.815,
+ name="second answer",
+ ),
+ ),
+ )
+ ],
+ ),
+ ObjectAnnotation(
+ name="box c",
+ extra={
+ "uuid": "8a2b2c43-f0a1-4763-ba96-e322d986ced6",
+ },
+ value=Rectangle(
+ start=Point(x=2089.0, y=1251.0),
+ end=Point(x=2247.0, y=1679.0),
+ ),
+ classifications=[
+ ClassificationAnnotation(
+ name="classification c",
+ value=Checklist(
+ answer=[
+ ClassificationAnswer(
+ name="third answer",
+ )
+ ],
+ ),
+ )
+ ],
+ ),
+ ObjectAnnotation(
+ name="box c",
+ extra={
+ "uuid": "456dd2c6-9fa0-42f9-9809-acc27b9886a7",
+ },
+ value=Rectangle(
+ start=Point(x=2089.0, y=1251.0),
+ end=Point(x=2247.0, y=1679.0),
+ ),
+ classifications=[
+ ClassificationAnnotation(
+ name="a string",
+ value=Text(
+ answer="a string",
+ ),
+ )
+ ],
+ ),
+ ],
+ )
+ ]
+ res = list(NDJsonConverter.serialize(labels))
assert res == data
diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_polyline.py b/libs/labelbox/tests/data/serialization/ndjson/test_polyline.py
index 97d48a14e..cd11d97fe 100644
--- a/libs/labelbox/tests/data/serialization/ndjson/test_polyline.py
+++ b/libs/labelbox/tests/data/serialization/ndjson/test_polyline.py
@@ -1,18 +1,76 @@
import json
-import pytest
+from labelbox.data.annotation_types.data.generic_data_row_data import (
+ GenericDataRowData,
+)
+from labelbox.data.mixins import CustomMetric
from labelbox.data.serialization.ndjson.converter import NDJsonConverter
+from labelbox.types import ObjectAnnotation, Point, Line, Label
-@pytest.mark.parametrize(
- "filename",
- [
- "tests/data/assets/ndjson/polyline_without_confidence_import.json",
- "tests/data/assets/ndjson/polyline_import.json",
- ],
-)
-def test_polyline_import(filename: str):
- with open(filename, "r") as file:
+def test_polyline_import_with_confidence():
+ with open(
+ "tests/data/assets/ndjson/polyline_without_confidence_import.json", "r"
+ ) as file:
+ data = json.load(file)
+ labels = [
+ Label(
+ data=GenericDataRowData(
+ uid="cl6xnv9h61fv0085yhtoq06ht",
+ ),
+ annotations=[
+ ObjectAnnotation(
+ name="some-line",
+ feature_schema_id="cl6xnuwt95lqq07330tbb3mfd",
+ extra={
+ "uuid": "5ad9c52f-058d-49c8-a749-3f20b84f8cd4",
+ },
+ value=Line(
+ points=[
+ Point(x=2534.353, y=249.471),
+ Point(x=2429.492, y=182.092),
+ Point(x=2294.322, y=221.962),
+ ],
+ ),
+ )
+ ],
+ )
+ ]
+ res = list(NDJsonConverter.serialize(labels))
+ assert res == data
+
+
+def test_polyline_import_without_confidence():
+ with open("tests/data/assets/ndjson/polyline_import.json", "r") as file:
data = json.load(file)
- res = list(NDJsonConverter.deserialize(data))
- res = list(NDJsonConverter.serialize(res))
+
+ labels = [
+ Label(
+ data=GenericDataRowData(
+ uid="cl6xnv9h61fv0085yhtoq06ht",
+ ),
+ annotations=[
+ ObjectAnnotation(
+ custom_metrics=[
+ CustomMetric(name="customMetric1", value=0.5),
+ CustomMetric(name="customMetric2", value=0.3),
+ ],
+ confidence=0.58,
+ name="some-line",
+ feature_schema_id="cl6xnuwt95lqq07330tbb3mfd",
+ extra={
+ "uuid": "5ad9c52f-058d-49c8-a749-3f20b84f8cd4",
+ },
+ value=Line(
+ points=[
+ Point(x=2534.353, y=249.471),
+ Point(x=2429.492, y=182.092),
+ Point(x=2294.322, y=221.962),
+ ],
+ ),
+ )
+ ],
+ )
+ ]
+
+ res = list(NDJsonConverter.serialize(labels))
assert res == data
diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_radio.py b/libs/labelbox/tests/data/serialization/ndjson/test_radio.py
index bd80f9267..ec57f0528 100644
--- a/libs/labelbox/tests/data/serialization/ndjson/test_radio.py
+++ b/libs/labelbox/tests/data/serialization/ndjson/test_radio.py
@@ -1,10 +1,9 @@
-import json
from labelbox.data.annotation_types.annotation import ClassificationAnnotation
from labelbox.data.annotation_types.classification.classification import (
ClassificationAnswer,
)
from labelbox.data.annotation_types.classification.classification import Radio
-from labelbox.data.annotation_types.data.text import TextData
+from labelbox.data.annotation_types.data import GenericDataRowData
from labelbox.data.annotation_types.label import Label
from labelbox.data.serialization.ndjson.converter import NDJsonConverter
@@ -13,9 +12,8 @@
def test_serialization_with_radio_min():
label = Label(
uid="ckj7z2q0b0000jx6x0q2q7q0d",
- data=TextData(
+ data=GenericDataRowData(
uid="bkj7z2q0b0000jx6x0q2q7q0d",
- text="This is a test",
),
annotations=[
ClassificationAnnotation(
@@ -40,21 +38,12 @@ def test_serialization_with_radio_min():
res.pop("uuid")
assert res == expected
- deserialized = NDJsonConverter.deserialize([res])
- res = next(deserialized)
-
- for i, annotation in enumerate(res.annotations):
- annotation.extra.pop("uuid")
- assert annotation.value == label.annotations[i].value
- assert annotation.name == label.annotations[i].name
-
def test_serialization_with_radio_classification():
label = Label(
uid="ckj7z2q0b0000jx6x0q2q7q0d",
- data=TextData(
+ data=GenericDataRowData(
uid="bkj7z2q0b0000jx6x0q2q7q0d",
- text="This is a test",
),
annotations=[
ClassificationAnnotation(
@@ -101,10 +90,3 @@ def test_serialization_with_radio_classification():
res = next(serialized)
res.pop("uuid")
assert res == expected
-
- deserialized = NDJsonConverter.deserialize([res])
- res = next(deserialized)
- res.annotations[0].extra.pop("uuid")
- assert res.annotations[0].model_dump(
- exclude_none=True
- ) == label.annotations[0].model_dump(exclude_none=True)
diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_rectangle.py b/libs/labelbox/tests/data/serialization/ndjson/test_rectangle.py
index 66630dbb5..0e42ab152 100644
--- a/libs/labelbox/tests/data/serialization/ndjson/test_rectangle.py
+++ b/libs/labelbox/tests/data/serialization/ndjson/test_rectangle.py
@@ -1,6 +1,10 @@
import json
+from labelbox.data.annotation_types.data.generic_data_row_data import (
+ GenericDataRowData,
+)
import labelbox.types as lb_types
from labelbox.data.serialization.ndjson.converter import NDJsonConverter
+from labelbox.types import Label, ObjectAnnotation, Rectangle, Point
DATAROW_ID = "ckrb1sf1i1g7i0ybcdc6oc8ct"
@@ -8,8 +12,26 @@
def test_rectangle():
with open("tests/data/assets/ndjson/rectangle_import.json", "r") as file:
data = json.load(file)
- res = list(NDJsonConverter.deserialize(data))
- res = list(NDJsonConverter.serialize(res))
+ labels = [
+ Label(
+ data=GenericDataRowData(
+ uid="ckrb1sf1i1g7i0ybcdc6oc8ct",
+ ),
+ annotations=[
+ ObjectAnnotation(
+ name="bbox",
+ extra={
+ "uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72",
+ },
+ value=Rectangle(
+ start=Point(x=38.0, y=28.0),
+ end=Point(x=81.0, y=69.0),
+ ),
+ )
+ ],
+ )
+ ]
+ res = list(NDJsonConverter.serialize(labels))
assert res == data
@@ -39,8 +61,6 @@ def test_rectangle_inverted_start_end_points():
),
extra={
"uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72",
- "page": None,
- "unit": None,
},
)
@@ -48,8 +68,9 @@ def test_rectangle_inverted_start_end_points():
data={"uid": DATAROW_ID}, annotations=[expected_bbox]
)
- res = list(NDJsonConverter.deserialize(res))
- assert res == [label]
+ data = list(NDJsonConverter.serialize([label]))
+
+ assert res == data
def test_rectangle_mixed_start_end_points():
@@ -76,17 +97,13 @@ def test_rectangle_mixed_start_end_points():
start=lb_types.Point(x=38, y=28),
end=lb_types.Point(x=81, y=69),
),
- extra={
- "uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72",
- "page": None,
- "unit": None,
- },
+ extra={"uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72"},
)
label = lb_types.Label(data={"uid": DATAROW_ID}, annotations=[bbox])
- res = list(NDJsonConverter.deserialize(res))
- assert res == [label]
+ data = list(NDJsonConverter.serialize([label]))
+ assert res == data
def test_benchmark_reference_label_flag_enabled():
diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_relationship.py b/libs/labelbox/tests/data/serialization/ndjson/test_relationship.py
index f33719035..235b66957 100644
--- a/libs/labelbox/tests/data/serialization/ndjson/test_relationship.py
+++ b/libs/labelbox/tests/data/serialization/ndjson/test_relationship.py
@@ -1,16 +1,135 @@
import json
-from uuid import uuid4
-import pytest
+from labelbox.data.annotation_types.data.generic_data_row_data import (
+ GenericDataRowData,
+)
from labelbox.data.serialization.ndjson.converter import NDJsonConverter
+from labelbox.types import (
+ Label,
+ ObjectAnnotation,
+ Point,
+ Rectangle,
+ RelationshipAnnotation,
+ Relationship,
+)
def test_relationship():
with open("tests/data/assets/ndjson/relationship_import.json", "r") as file:
data = json.load(file)
- res = list(NDJsonConverter.deserialize(data))
+ res = [
+ Label(
+ data=GenericDataRowData(
+ uid="clf98gj90000qp38ka34yhptl",
+ ),
+ annotations=[
+ ObjectAnnotation(
+ name="cat",
+ extra={
+ "uuid": "d8813907-b15d-4374-bbe6-b9877fb42ccd",
+ },
+ value=Rectangle(
+ start=Point(x=100.0, y=200.0),
+ end=Point(x=200.0, y=300.0),
+ ),
+ ),
+ ObjectAnnotation(
+ name="dog",
+ extra={
+ "uuid": "9b1e1249-36b4-4665-b60a-9060e0d18660",
+ },
+ value=Rectangle(
+ start=Point(x=400.0, y=500.0),
+ end=Point(x=600.0, y=700.0),
+ ),
+ ),
+ RelationshipAnnotation(
+ name="is chasing",
+ extra={"uuid": "0e6354eb-9adb-47e5-8e52-217ed016d948"},
+ value=Relationship(
+ source=ObjectAnnotation(
+ name="dog",
+ extra={
+ "uuid": "9b1e1249-36b4-4665-b60a-9060e0d18660",
+ },
+ value=Rectangle(
+ start=Point(x=400.0, y=500.0),
+ end=Point(x=600.0, y=700.0),
+ ),
+ ),
+ target=ObjectAnnotation(
+ name="cat",
+ extra={
+ "uuid": "d8813907-b15d-4374-bbe6-b9877fb42ccd",
+ },
+ value=Rectangle(
+ extra={},
+ start=Point(x=100.0, y=200.0),
+ end=Point(x=200.0, y=300.0),
+ ),
+ ),
+ type=Relationship.Type.UNIDIRECTIONAL,
+ ),
+ ),
+ ],
+ ),
+ Label(
+ data=GenericDataRowData(
+ uid="clf98gj90000qp38ka34yhptl-DIFFERENT",
+ ),
+ annotations=[
+ ObjectAnnotation(
+ name="cat",
+ extra={
+ "uuid": "d8813907-b15d-4374-bbe6-b9877fb42ccd",
+ },
+ value=Rectangle(
+ start=Point(x=100.0, y=200.0),
+ end=Point(x=200.0, y=300.0),
+ ),
+ ),
+ ObjectAnnotation(
+ name="dog",
+ extra={
+ "uuid": "9b1e1249-36b4-4665-b60a-9060e0d18660",
+ },
+ value=Rectangle(
+ start=Point(x=400.0, y=500.0),
+ end=Point(x=600.0, y=700.0),
+ ),
+ ),
+ RelationshipAnnotation(
+ name="is chasing",
+ extra={"uuid": "0e6354eb-9adb-47e5-8e52-217ed016d948"},
+ value=Relationship(
+ source=ObjectAnnotation(
+ name="dog",
+ extra={
+ "uuid": "9b1e1249-36b4-4665-b60a-9060e0d18660",
+ },
+ value=Rectangle(
+ start=Point(x=400.0, y=500.0),
+ end=Point(x=600.0, y=700.0),
+ ),
+ ),
+ target=ObjectAnnotation(
+ name="cat",
+ extra={
+ "uuid": "d8813907-b15d-4374-bbe6-b9877fb42ccd",
+ },
+ value=Rectangle(
+ start=Point(x=100.0, y=200.0),
+ end=Point(x=200.0, y=300.0),
+ ),
+ ),
+ type=Relationship.Type.UNIDIRECTIONAL,
+ ),
+ ),
+ ],
+ ),
+ ]
res = list(NDJsonConverter.serialize(res))
assert len(res) == len(data)
@@ -44,29 +163,3 @@ def test_relationship():
assert res_relationship_second_annotation["relationship"]["target"] in [
annot["uuid"] for annot in res_source_and_target
]
-
-
-def test_relationship_nonexistent_object():
- with open("tests/data/assets/ndjson/relationship_import.json", "r") as file:
- data = json.load(file)
-
- relationship_annotation = data[2]
- source_uuid = relationship_annotation["relationship"]["source"]
- target_uuid = str(uuid4())
- relationship_annotation["relationship"]["target"] = target_uuid
- error_msg = f"Relationship object refers to nonexistent object with UUID '{source_uuid}' and/or '{target_uuid}'"
-
- with pytest.raises(ValueError, match=error_msg):
- list(NDJsonConverter.deserialize(data))
-
-
-def test_relationship_duplicate_uuids():
- with open("tests/data/assets/ndjson/relationship_import.json", "r") as file:
- data = json.load(file)
-
- source, target = data[0], data[1]
- target["uuid"] = source["uuid"]
- error_msg = f"UUID '{source['uuid']}' is not unique"
-
- with pytest.raises(AssertionError, match=error_msg):
- list(NDJsonConverter.deserialize(data))
diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_text.py b/libs/labelbox/tests/data/serialization/ndjson/test_text.py
index d5e81c51a..28eba07bd 100644
--- a/libs/labelbox/tests/data/serialization/ndjson/test_text.py
+++ b/libs/labelbox/tests/data/serialization/ndjson/test_text.py
@@ -1,10 +1,8 @@
from labelbox.data.annotation_types.annotation import ClassificationAnnotation
from labelbox.data.annotation_types.classification.classification import (
- ClassificationAnswer,
- Radio,
Text,
)
-from labelbox.data.annotation_types.data.text import TextData
+from labelbox.data.annotation_types.data import GenericDataRowData
from labelbox.data.annotation_types.label import Label
from labelbox.data.serialization.ndjson.converter import NDJsonConverter
@@ -13,9 +11,8 @@
def test_serialization():
label = Label(
uid="ckj7z2q0b0000jx6x0q2q7q0d",
- data=TextData(
+ data=GenericDataRowData(
uid="bkj7z2q0b0000jx6x0q2q7q0d",
- text="This is a test",
),
annotations=[
ClassificationAnnotation(
@@ -34,11 +31,3 @@ def test_serialization():
assert res["name"] == "radio_question_geo"
assert res["answer"] == "first_radio_answer"
assert res["dataRow"]["id"] == "bkj7z2q0b0000jx6x0q2q7q0d"
-
- deserialized = NDJsonConverter.deserialize([res])
- res = next(deserialized)
- annotation = res.annotations[0]
-
- annotation_value = annotation.value
- assert type(annotation_value) is Text
- assert annotation_value.answer == "first_radio_answer"
diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_text_entity.py b/libs/labelbox/tests/data/serialization/ndjson/test_text_entity.py
index 3e856f001..fb93f15d4 100644
--- a/libs/labelbox/tests/data/serialization/ndjson/test_text_entity.py
+++ b/libs/labelbox/tests/data/serialization/ndjson/test_text_entity.py
@@ -1,21 +1,68 @@
import json
-import pytest
+from labelbox.data.annotation_types.data.generic_data_row_data import (
+ GenericDataRowData,
+)
+from labelbox.data.mixins import CustomMetric
from labelbox.data.serialization.ndjson.converter import NDJsonConverter
+from labelbox.types import Label, ObjectAnnotation, TextEntity
+
+
+def test_text_entity_import():
+ with open("tests/data/assets/ndjson/text_entity_import.json", "r") as file:
+ data = json.load(file)
+
+ labels = [
+ Label(
+ data=GenericDataRowData(
+ uid="cl6xnv9h61fv0085yhtoq06ht",
+ ),
+ annotations=[
+ ObjectAnnotation(
+ custom_metrics=[
+ CustomMetric(name="customMetric1", value=0.5),
+ CustomMetric(name="customMetric2", value=0.3),
+ ],
+ confidence=0.53,
+ name="some-text-entity",
+ feature_schema_id="cl6xnuwt95lqq07330tbb3mfd",
+ extra={
+ "uuid": "5ad9c52f-058d-49c8-a749-3f20b84f8cd4",
+ },
+ value=TextEntity(start=67, end=128, extra={}),
+ )
+ ],
+ )
+ ]
+ res = list(NDJsonConverter.serialize(labels))
+ assert res == data
-@pytest.mark.parametrize(
- "filename",
- [
- "tests/data/assets/ndjson/text_entity_import.json",
+def test_text_entity_import_without_confidence():
+ with open(
"tests/data/assets/ndjson/text_entity_without_confidence_import.json",
- ],
-)
-def test_text_entity_import(filename: str):
- with open(filename, "r") as file:
+ "r",
+ ) as file:
data = json.load(file)
- res = list(NDJsonConverter.deserialize(data))
- res = list(NDJsonConverter.serialize(res))
+ labels = [
+ Label(
+ data=GenericDataRowData(
+ uid="cl6xnv9h61fv0085yhtoq06ht",
+ ),
+ annotations=[
+ ObjectAnnotation(
+ name="some-text-entity",
+ feature_schema_id="cl6xnuwt95lqq07330tbb3mfd",
+ extra={
+ "uuid": "5ad9c52f-058d-49c8-a749-3f20b84f8cd4",
+ },
+ value=TextEntity(start=67, end=128, extra={}),
+ )
+ ],
+ )
+ ]
+
+ res = list(NDJsonConverter.serialize(labels))
assert res == data
diff --git a/libs/labelbox/tests/data/serialization/ndjson/test_video.py b/libs/labelbox/tests/data/serialization/ndjson/test_video.py
index c7a6535c4..6c14343a4 100644
--- a/libs/labelbox/tests/data/serialization/ndjson/test_video.py
+++ b/libs/labelbox/tests/data/serialization/ndjson/test_video.py
@@ -1,20 +1,21 @@
import json
-from labelbox.client import Client
from labelbox.data.annotation_types.classification.classification import (
Checklist,
ClassificationAnnotation,
ClassificationAnswer,
Radio,
+ Text,
)
-from labelbox.data.annotation_types.data.video import VideoData
+from labelbox.data.annotation_types.data import GenericDataRowData
from labelbox.data.annotation_types.geometry.line import Line
from labelbox.data.annotation_types.geometry.point import Point
from labelbox.data.annotation_types.geometry.rectangle import Rectangle
-from labelbox.data.annotation_types.geometry.point import Point
from labelbox.data.annotation_types.label import Label
-from labelbox.data.annotation_types.video import VideoObjectAnnotation
-from labelbox import parser
+from labelbox.data.annotation_types.video import (
+ VideoClassificationAnnotation,
+ VideoObjectAnnotation,
+)
from labelbox.data.serialization.ndjson.converter import NDJsonConverter
from operator import itemgetter
@@ -24,15 +25,275 @@ def test_video():
with open("tests/data/assets/ndjson/video_import.json", "r") as file:
data = json.load(file)
- res = list(NDJsonConverter.deserialize(data))
- res = list(NDJsonConverter.serialize(res))
+ labels = [
+ Label(
+ data=GenericDataRowData(uid="ckrb1sf1i1g7i0ybcdc6oc8ct"),
+ annotations=[
+ VideoClassificationAnnotation(
+ feature_schema_id="ckrb1sfjx099a0y914hl319ie",
+ extra={"uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673"},
+ value=Radio(
+ answer=ClassificationAnswer(
+ feature_schema_id="ckrb1sfl8099g0y91cxbd5ftb",
+ ),
+ ),
+ frame=30,
+ ),
+ VideoClassificationAnnotation(
+ feature_schema_id="ckrb1sfjx099a0y914hl319ie",
+ extra={"uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673"},
+ value=Radio(
+ answer=ClassificationAnswer(
+ feature_schema_id="ckrb1sfl8099g0y91cxbd5ftb",
+ ),
+ ),
+ frame=31,
+ ),
+ VideoClassificationAnnotation(
+ feature_schema_id="ckrb1sfjx099a0y914hl319ie",
+ extra={"uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673"},
+ value=Radio(
+ answer=ClassificationAnswer(
+ feature_schema_id="ckrb1sfl8099g0y91cxbd5ftb",
+ ),
+ ),
+ frame=32,
+ ),
+ VideoClassificationAnnotation(
+ feature_schema_id="ckrb1sfjx099a0y914hl319ie",
+ extra={"uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673"},
+ value=Radio(
+ answer=ClassificationAnswer(
+ feature_schema_id="ckrb1sfl8099g0y91cxbd5ftb",
+ ),
+ ),
+ frame=33,
+ ),
+ VideoClassificationAnnotation(
+ feature_schema_id="ckrb1sfjx099a0y914hl319ie",
+ extra={"uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673"},
+ value=Radio(
+ answer=ClassificationAnswer(
+ feature_schema_id="ckrb1sfl8099g0y91cxbd5ftb",
+ ),
+ ),
+ frame=34,
+ ),
+ VideoClassificationAnnotation(
+ feature_schema_id="ckrb1sfjx099a0y914hl319ie",
+ extra={"uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673"},
+ value=Radio(
+ answer=ClassificationAnswer(
+ feature_schema_id="ckrb1sfl8099g0y91cxbd5ftb",
+ ),
+ ),
+ frame=35,
+ ),
+ VideoClassificationAnnotation(
+ feature_schema_id="ckrb1sfjx099a0y914hl319ie",
+ extra={"uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673"},
+ value=Radio(
+ answer=ClassificationAnswer(
+ feature_schema_id="ckrb1sfl8099g0y91cxbd5ftb",
+ ),
+ ),
+ frame=50,
+ ),
+ VideoClassificationAnnotation(
+ feature_schema_id="ckrb1sfjx099a0y914hl319ie",
+ extra={"uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673"},
+ value=Radio(
+ answer=ClassificationAnswer(
+ feature_schema_id="ckrb1sfl8099g0y91cxbd5ftb",
+ ),
+ ),
+ frame=51,
+ ),
+ VideoClassificationAnnotation(
+ feature_schema_id="ckrb1sfkn099c0y910wbo0p1a",
+ extra={"uuid": "d009925d-91a3-4f67-abd9-753453f5a584"},
+ value=Checklist(
+ answer=[
+ ClassificationAnswer(
+ feature_schema_id="ckrb1sfl8099e0y919v260awv",
+ )
+ ],
+ ),
+ frame=0,
+ ),
+ VideoClassificationAnnotation(
+ feature_schema_id="ckrb1sfkn099c0y910wbo0p1a",
+ extra={"uuid": "d009925d-91a3-4f67-abd9-753453f5a584"},
+ value=Checklist(
+ answer=[
+ ClassificationAnswer(
+ feature_schema_id="ckrb1sfl8099e0y919v260awv",
+ )
+ ],
+ ),
+ frame=1,
+ ),
+ VideoClassificationAnnotation(
+ feature_schema_id="ckrb1sfkn099c0y910wbo0p1a",
+ extra={"uuid": "d009925d-91a3-4f67-abd9-753453f5a584"},
+ value=Checklist(
+ answer=[
+ ClassificationAnswer(
+ feature_schema_id="ckrb1sfl8099e0y919v260awv",
+ )
+ ],
+ ),
+ frame=2,
+ ),
+ VideoClassificationAnnotation(
+ feature_schema_id="ckrb1sfkn099c0y910wbo0p1a",
+ extra={"uuid": "d009925d-91a3-4f67-abd9-753453f5a584"},
+ value=Checklist(
+ answer=[
+ ClassificationAnswer(
+ feature_schema_id="ckrb1sfl8099e0y919v260awv",
+ )
+ ],
+ ),
+ frame=3,
+ ),
+ VideoClassificationAnnotation(
+ feature_schema_id="ckrb1sfkn099c0y910wbo0p1a",
+ extra={"uuid": "d009925d-91a3-4f67-abd9-753453f5a584"},
+ value=Checklist(
+ answer=[
+ ClassificationAnswer(
+ feature_schema_id="ckrb1sfl8099e0y919v260awv",
+ )
+ ],
+ ),
+ frame=4,
+ ),
+ VideoClassificationAnnotation(
+ feature_schema_id="ckrb1sfkn099c0y910wbo0p1a",
+ extra={"uuid": "d009925d-91a3-4f67-abd9-753453f5a584"},
+ value=Checklist(
+ answer=[
+ ClassificationAnswer(
+ feature_schema_id="ckrb1sfl8099e0y919v260awv",
+ )
+ ],
+ ),
+ frame=5,
+ ),
+ ClassificationAnnotation(
+ feature_schema_id="ckrb1sfkn099c0y910wbo0p1a",
+ extra={"uuid": "90e2ecf7-c19c-47e6-8cdb-8867e1b9d88c"},
+ value=Text(answer="a value"),
+ ),
+ VideoObjectAnnotation(
+ feature_schema_id="cl5islwg200gfci6g0oitaypu",
+ extra={"uuid": "6f7c835a-0139-4896-b73f-66a6baa89e94"},
+ value=Line(
+ points=[
+ Point(x=10.0, y=10.0),
+ Point(x=100.0, y=100.0),
+ Point(x=50.0, y=30.0),
+ ],
+ ),
+ frame=1,
+ keyframe=True,
+ segment_index=0,
+ ),
+ VideoObjectAnnotation(
+ feature_schema_id="cl5islwg200gfci6g0oitaypu",
+ extra={"uuid": "6f7c835a-0139-4896-b73f-66a6baa89e94"},
+ value=Line(
+ points=[
+ Point(x=15.0, y=10.0),
+ Point(x=50.0, y=100.0),
+ Point(x=50.0, y=30.0),
+ ],
+ ),
+ frame=5,
+ keyframe=True,
+ segment_index=0,
+ ),
+ VideoObjectAnnotation(
+ feature_schema_id="cl5islwg200gfci6g0oitaypu",
+ extra={"uuid": "6f7c835a-0139-4896-b73f-66a6baa89e94"},
+ value=Line(
+ points=[
+ Point(x=100.0, y=10.0),
+ Point(x=50.0, y=100.0),
+ Point(x=50.0, y=30.0),
+ ],
+ ),
+ frame=8,
+ keyframe=True,
+ segment_index=1,
+ ),
+ VideoObjectAnnotation(
+ feature_schema_id="cl5it7ktp00i5ci6gf80b1ysd",
+ extra={"uuid": "f963be22-227b-4efe-9be4-2738ed822216"},
+ value=Point(x=10.0, y=10.0),
+ frame=1,
+ keyframe=True,
+ segment_index=0,
+ ),
+ VideoObjectAnnotation(
+ feature_schema_id="cl5it7ktp00i5ci6gf80b1ysd",
+ extra={"uuid": "f963be22-227b-4efe-9be4-2738ed822216"},
+ value=Point(x=50.0, y=50.0),
+ frame=5,
+ keyframe=True,
+ segment_index=1,
+ ),
+ VideoObjectAnnotation(
+ feature_schema_id="cl5it7ktp00i5ci6gf80b1ysd",
+ extra={"uuid": "f963be22-227b-4efe-9be4-2738ed822216"},
+ value=Point(x=10.0, y=50.0),
+ frame=10,
+ keyframe=True,
+ segment_index=1,
+ ),
+ VideoObjectAnnotation(
+ feature_schema_id="cl5iw0roz00lwci6g5jni62vs",
+ extra={"uuid": "13b2ee0e-2355-4336-8b83-d74d09e3b1e7"},
+ value=Rectangle(
+ start=Point(x=5.0, y=10.0),
+ end=Point(x=155.0, y=110.0),
+ ),
+ frame=1,
+ keyframe=True,
+ segment_index=0,
+ ),
+ VideoObjectAnnotation(
+ feature_schema_id="cl5iw0roz00lwci6g5jni62vs",
+ extra={"uuid": "13b2ee0e-2355-4336-8b83-d74d09e3b1e7"},
+ value=Rectangle(
+ start=Point(x=5.0, y=30.0),
+ end=Point(x=155.0, y=80.0),
+ ),
+ frame=5,
+ keyframe=True,
+ segment_index=0,
+ ),
+ VideoObjectAnnotation(
+ feature_schema_id="cl5iw0roz00lwci6g5jni62vs",
+ extra={"uuid": "13b2ee0e-2355-4336-8b83-d74d09e3b1e7"},
+ value=Rectangle(
+ start=Point(x=200.0, y=300.0),
+ end=Point(x=350.0, y=700.0),
+ ),
+ frame=10,
+ keyframe=True,
+ segment_index=1,
+ ),
+ ],
+ )
+ ]
+
+ res = list(NDJsonConverter.serialize(labels))
data = sorted(data, key=itemgetter("uuid"))
res = sorted(res, key=itemgetter("uuid"))
-
- pairs = zip(data, res)
- for data, res in pairs:
- assert data == res
+ assert data == res
def test_video_name_only():
@@ -40,21 +301,279 @@ def test_video_name_only():
"tests/data/assets/ndjson/video_import_name_only.json", "r"
) as file:
data = json.load(file)
-
- res = list(NDJsonConverter.deserialize(data))
- res = list(NDJsonConverter.serialize(res))
-
+ labels = [
+ Label(
+ data=GenericDataRowData(uid="ckrb1sf1i1g7i0ybcdc6oc8ct"),
+ annotations=[
+ VideoClassificationAnnotation(
+ name="question 1",
+ extra={"uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673"},
+ value=Radio(
+ answer=ClassificationAnswer(
+ name="answer 1",
+ ),
+ ),
+ frame=30,
+ ),
+ VideoClassificationAnnotation(
+ name="question 1",
+ extra={"uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673"},
+ value=Radio(
+ answer=ClassificationAnswer(
+ name="answer 1",
+ ),
+ ),
+ frame=31,
+ ),
+ VideoClassificationAnnotation(
+ name="question 1",
+ extra={"uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673"},
+ value=Radio(
+ answer=ClassificationAnswer(
+ name="answer 1",
+ ),
+ ),
+ frame=32,
+ ),
+ VideoClassificationAnnotation(
+ name="question 1",
+ extra={"uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673"},
+ value=Radio(
+ answer=ClassificationAnswer(
+ name="answer 1",
+ ),
+ ),
+ frame=33,
+ ),
+ VideoClassificationAnnotation(
+ name="question 1",
+ extra={"uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673"},
+ value=Radio(
+ answer=ClassificationAnswer(
+ name="answer 1",
+ ),
+ ),
+ frame=34,
+ ),
+ VideoClassificationAnnotation(
+ name="question 1",
+ extra={"uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673"},
+ value=Radio(
+ answer=ClassificationAnswer(
+ name="answer 1",
+ ),
+ ),
+ frame=35,
+ ),
+ VideoClassificationAnnotation(
+ name="question 1",
+ extra={"uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673"},
+ value=Radio(
+ answer=ClassificationAnswer(
+ name="answer 1",
+ ),
+ ),
+ frame=50,
+ ),
+ VideoClassificationAnnotation(
+ name="question 1",
+ extra={"uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673"},
+ value=Radio(
+ answer=ClassificationAnswer(
+ name="answer 1",
+ ),
+ ),
+ frame=51,
+ ),
+ VideoClassificationAnnotation(
+ name="question 2",
+ extra={"uuid": "d009925d-91a3-4f67-abd9-753453f5a584"},
+ value=Checklist(
+ answer=[
+ ClassificationAnswer(
+ name="answer 2",
+ )
+ ],
+ ),
+ frame=0,
+ ),
+ VideoClassificationAnnotation(
+ name="question 2",
+ extra={"uuid": "d009925d-91a3-4f67-abd9-753453f5a584"},
+ value=Checklist(
+ answer=[
+ ClassificationAnswer(
+ name="answer 2",
+ )
+ ],
+ ),
+ frame=1,
+ ),
+ VideoClassificationAnnotation(
+ name="question 2",
+ extra={"uuid": "d009925d-91a3-4f67-abd9-753453f5a584"},
+ value=Checklist(
+ answer=[
+ ClassificationAnswer(
+ name="answer 2",
+ )
+ ],
+ ),
+ frame=2,
+ ),
+ VideoClassificationAnnotation(
+ name="question 2",
+ extra={"uuid": "d009925d-91a3-4f67-abd9-753453f5a584"},
+ value=Checklist(
+ answer=[
+ ClassificationAnswer(
+ name="answer 2",
+ )
+ ],
+ ),
+ frame=3,
+ ),
+ VideoClassificationAnnotation(
+ name="question 2",
+ extra={"uuid": "d009925d-91a3-4f67-abd9-753453f5a584"},
+ value=Checklist(
+ answer=[
+ ClassificationAnswer(
+ name="answer 2",
+ )
+ ],
+ ),
+ frame=4,
+ ),
+ VideoClassificationAnnotation(
+ name="question 2",
+ extra={"uuid": "d009925d-91a3-4f67-abd9-753453f5a584"},
+ value=Checklist(
+ answer=[
+ ClassificationAnswer(
+ name="answer 2",
+ )
+ ],
+ ),
+ frame=5,
+ ),
+ ClassificationAnnotation(
+ name="question 3",
+ extra={"uuid": "e5f32456-bd67-4520-8d3b-cbeb2204bad3"},
+ value=Text(answer="a value"),
+ ),
+ VideoObjectAnnotation(
+ name="segment 1",
+ extra={"uuid": "6f7c835a-0139-4896-b73f-66a6baa89e94"},
+ value=Line(
+ points=[
+ Point(x=10.0, y=10.0),
+ Point(x=100.0, y=100.0),
+ Point(x=50.0, y=30.0),
+ ],
+ ),
+ frame=1,
+ keyframe=True,
+ segment_index=0,
+ ),
+ VideoObjectAnnotation(
+ name="segment 1",
+ extra={"uuid": "6f7c835a-0139-4896-b73f-66a6baa89e94"},
+ value=Line(
+ points=[
+ Point(x=15.0, y=10.0),
+ Point(x=50.0, y=100.0),
+ Point(x=50.0, y=30.0),
+ ],
+ ),
+ frame=5,
+ keyframe=True,
+ segment_index=0,
+ ),
+ VideoObjectAnnotation(
+ name="segment 1",
+ extra={"uuid": "6f7c835a-0139-4896-b73f-66a6baa89e94"},
+ value=Line(
+ points=[
+ Point(x=100.0, y=10.0),
+ Point(x=50.0, y=100.0),
+ Point(x=50.0, y=30.0),
+ ],
+ ),
+ frame=8,
+ keyframe=True,
+ segment_index=1,
+ ),
+ VideoObjectAnnotation(
+ name="segment 2",
+ extra={"uuid": "f963be22-227b-4efe-9be4-2738ed822216"},
+ value=Point(x=10.0, y=10.0),
+ frame=1,
+ keyframe=True,
+ segment_index=0,
+ ),
+ VideoObjectAnnotation(
+ name="segment 2",
+ extra={"uuid": "f963be22-227b-4efe-9be4-2738ed822216"},
+ value=Point(x=50.0, y=50.0),
+ frame=5,
+ keyframe=True,
+ segment_index=1,
+ ),
+ VideoObjectAnnotation(
+ name="segment 2",
+ extra={"uuid": "f963be22-227b-4efe-9be4-2738ed822216"},
+ value=Point(x=10.0, y=50.0),
+ frame=10,
+ keyframe=True,
+ segment_index=1,
+ ),
+ VideoObjectAnnotation(
+ name="segment 3",
+ extra={"uuid": "13b2ee0e-2355-4336-8b83-d74d09e3b1e7"},
+ value=Rectangle(
+ start=Point(x=5.0, y=10.0),
+ end=Point(x=155.0, y=110.0),
+ ),
+ frame=1,
+ keyframe=True,
+ segment_index=0,
+ ),
+ VideoObjectAnnotation(
+ name="segment 3",
+ extra={"uuid": "13b2ee0e-2355-4336-8b83-d74d09e3b1e7"},
+ value=Rectangle(
+ start=Point(x=5.0, y=30.0),
+ end=Point(x=155.0, y=80.0),
+ ),
+ frame=5,
+ keyframe=True,
+ segment_index=0,
+ ),
+ VideoObjectAnnotation(
+ name="segment 3",
+ extra={"uuid": "13b2ee0e-2355-4336-8b83-d74d09e3b1e7"},
+ value=Rectangle(
+ start=Point(x=200.0, y=300.0),
+ end=Point(x=350.0, y=700.0),
+ ),
+ frame=10,
+ keyframe=True,
+ segment_index=1,
+ ),
+ ],
+ )
+ ]
+ res = list(NDJsonConverter.serialize(labels))
data = sorted(data, key=itemgetter("uuid"))
res = sorted(res, key=itemgetter("uuid"))
- pairs = zip(data, res)
- for data, res in pairs:
- assert data == res
+ assert data == res
def test_video_classification_global_subclassifications():
label = Label(
- data=VideoData(
+ data=GenericDataRowData(
global_key="sample-video-4.mp4",
),
annotations=[
@@ -67,7 +586,6 @@ def test_video_classification_global_subclassifications():
ClassificationAnnotation(
name="nested_checklist_question",
value=Checklist(
- name="checklist",
answer=[
ClassificationAnswer(
name="first_checklist_answer",
@@ -94,7 +612,7 @@ def test_video_classification_global_subclassifications():
"dataRow": {"globalKey": "sample-video-4.mp4"},
}
- expected_second_annotation = nested_checklist_annotation_ndjson = {
+ expected_second_annotation = {
"name": "nested_checklist_question",
"answer": [
{
@@ -116,12 +634,6 @@ def test_video_classification_global_subclassifications():
annotations.pop("uuid")
assert res == [expected_first_annotation, expected_second_annotation]
- deserialized = NDJsonConverter.deserialize(res)
- res = next(deserialized)
- annotations = res.annotations
- for i, annotation in enumerate(annotations):
- assert annotation.name == label.annotations[i].name
-
def test_video_classification_nesting_bbox():
bbox_annotation = [
@@ -277,7 +789,7 @@ def test_video_classification_nesting_bbox():
]
label = Label(
- data=VideoData(
+ data=GenericDataRowData(
global_key="sample-video-4.mp4",
),
annotations=bbox_annotation,
@@ -287,14 +799,6 @@ def test_video_classification_nesting_bbox():
res = [x for x in serialized]
assert res == expected
- deserialized = NDJsonConverter.deserialize(res)
- res = next(deserialized)
- annotations = res.annotations
- for i, annotation in enumerate(annotations):
- annotation.extra.pop("uuid")
- assert annotation.value == label.annotations[i].value
- assert annotation.name == label.annotations[i].name
-
def test_video_classification_point():
bbox_annotation = [
@@ -435,7 +939,7 @@ def test_video_classification_point():
]
label = Label(
- data=VideoData(
+ data=GenericDataRowData(
global_key="sample-video-4.mp4",
),
annotations=bbox_annotation,
@@ -445,13 +949,6 @@ def test_video_classification_point():
res = [x for x in serialized]
assert res == expected
- deserialized = NDJsonConverter.deserialize(res)
- res = next(deserialized)
- annotations = res.annotations
- for i, annotation in enumerate(annotations):
- annotation.extra.pop("uuid")
- assert annotation.value == label.annotations[i].value
-
def test_video_classification_frameline():
bbox_annotation = [
@@ -610,7 +1107,7 @@ def test_video_classification_frameline():
]
label = Label(
- data=VideoData(
+ data=GenericDataRowData(
global_key="sample-video-4.mp4",
),
annotations=bbox_annotation,
@@ -619,9 +1116,289 @@ def test_video_classification_frameline():
res = [x for x in serialized]
assert res == expected
- deserialized = NDJsonConverter.deserialize(res)
- res = next(deserialized)
- annotations = res.annotations
- for i, annotation in enumerate(annotations):
- annotation.extra.pop("uuid")
- assert annotation.value == label.annotations[i].value
+
+[
+ {
+ "answer": "a value",
+ "dataRow": {"id": "ckrb1sf1i1g7i0ybcdc6oc8ct"},
+ "schemaId": "ckrb1sfkn099c0y910wbo0p1a",
+ "uuid": "90e2ecf7-c19c-47e6-8cdb-8867e1b9d88c",
+ },
+ {
+ "answer": {"schemaId": "ckrb1sfl8099g0y91cxbd5ftb"},
+ "dataRow": {"id": "ckrb1sf1i1g7i0ybcdc6oc8ct"},
+ "frames": [{"end": 35, "start": 30}, {"end": 51, "start": 50}],
+ "schemaId": "ckrb1sfjx099a0y914hl319ie",
+ "uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673",
+ },
+ {
+ "answer": [{"schemaId": "ckrb1sfl8099e0y919v260awv"}],
+ "dataRow": {"id": "ckrb1sf1i1g7i0ybcdc6oc8ct"},
+ "frames": [{"end": 5, "start": 0}],
+ "schemaId": "ckrb1sfkn099c0y910wbo0p1a",
+ "uuid": "d009925d-91a3-4f67-abd9-753453f5a584",
+ },
+ {
+ "classifications": [],
+ "dataRow": {"id": "ckrb1sf1i1g7i0ybcdc6oc8ct"},
+ "schemaId": "cl5islwg200gfci6g0oitaypu",
+ "segments": [
+ {
+ "keyframes": [
+ {
+ "classifications": [],
+ "frame": 1,
+ "line": [
+ {"x": 10.0, "y": 10.0},
+ {"x": 100.0, "y": 100.0},
+ {"x": 50.0, "y": 30.0},
+ ],
+ },
+ {
+ "classifications": [],
+ "frame": 5,
+ "line": [
+ {"x": 15.0, "y": 10.0},
+ {"x": 50.0, "y": 100.0},
+ {"x": 50.0, "y": 30.0},
+ ],
+ },
+ ]
+ },
+ {
+ "keyframes": [
+ {
+ "classifications": [],
+ "frame": 8,
+ "line": [
+ {"x": 100.0, "y": 10.0},
+ {"x": 50.0, "y": 100.0},
+ {"x": 50.0, "y": 30.0},
+ ],
+ }
+ ]
+ },
+ ],
+ "uuid": "6f7c835a-0139-4896-b73f-66a6baa89e94",
+ },
+ {
+ "classifications": [],
+ "dataRow": {"id": "ckrb1sf1i1g7i0ybcdc6oc8ct"},
+ "schemaId": "cl5it7ktp00i5ci6gf80b1ysd",
+ "segments": [
+ {
+ "keyframes": [
+ {
+ "classifications": [],
+ "frame": 1,
+ "point": {"x": 10.0, "y": 10.0},
+ }
+ ]
+ },
+ {
+ "keyframes": [
+ {
+ "classifications": [],
+ "frame": 5,
+ "point": {"x": 50.0, "y": 50.0},
+ },
+ {
+ "classifications": [],
+ "frame": 10,
+ "point": {"x": 10.0, "y": 50.0},
+ },
+ ]
+ },
+ ],
+ "uuid": "f963be22-227b-4efe-9be4-2738ed822216",
+ },
+ {
+ "classifications": [],
+ "dataRow": {"id": "ckrb1sf1i1g7i0ybcdc6oc8ct"},
+ "schemaId": "cl5iw0roz00lwci6g5jni62vs",
+ "segments": [
+ {
+ "keyframes": [
+ {
+ "bbox": {
+ "height": 100.0,
+ "left": 5.0,
+ "top": 10.0,
+ "width": 150.0,
+ },
+ "classifications": [],
+ "frame": 1,
+ },
+ {
+ "bbox": {
+ "height": 50.0,
+ "left": 5.0,
+ "top": 30.0,
+ "width": 150.0,
+ },
+ "classifications": [],
+ "frame": 5,
+ },
+ ]
+ },
+ {
+ "keyframes": [
+ {
+ "bbox": {
+ "height": 400.0,
+ "left": 200.0,
+ "top": 300.0,
+ "width": 150.0,
+ },
+ "classifications": [],
+ "frame": 10,
+ }
+ ]
+ },
+ ],
+ "uuid": "13b2ee0e-2355-4336-8b83-d74d09e3b1e7",
+ },
+]
+
+[
+ {
+ "answer": {"schemaId": "ckrb1sfl8099g0y91cxbd5ftb"},
+ "schemaId": "ckrb1sfjx099a0y914hl319ie",
+ "dataRow": {"id": "ckrb1sf1i1g7i0ybcdc6oc8ct"},
+ "uuid": "f6879f59-d2b5-49c2-aceb-d9e8dc478673",
+ "frames": [{"start": 30, "end": 35}, {"start": 50, "end": 51}],
+ },
+ {
+ "answer": [{"schemaId": "ckrb1sfl8099e0y919v260awv"}],
+ "schemaId": "ckrb1sfkn099c0y910wbo0p1a",
+ "dataRow": {"id": "ckrb1sf1i1g7i0ybcdc6oc8ct"},
+ "uuid": "d009925d-91a3-4f67-abd9-753453f5a584",
+ "frames": [{"start": 0, "end": 5}],
+ },
+ {
+ "answer": "a value",
+ "schemaId": "ckrb1sfkn099c0y910wbo0p1a",
+ "dataRow": {"id": "ckrb1sf1i1g7i0ybcdc6oc8ct"},
+ "uuid": "90e2ecf7-c19c-47e6-8cdb-8867e1b9d88c",
+ },
+ {
+ "classifications": [],
+ "schemaId": "cl5islwg200gfci6g0oitaypu",
+ "dataRow": {"id": "ckrb1sf1i1g7i0ybcdc6oc8ct"},
+ "uuid": "6f7c835a-0139-4896-b73f-66a6baa89e94",
+ "segments": [
+ {
+ "keyframes": [
+ {
+ "frame": 1,
+ "line": [
+ {"x": 10.0, "y": 10.0},
+ {"x": 100.0, "y": 100.0},
+ {"x": 50.0, "y": 30.0},
+ ],
+ "classifications": [],
+ },
+ {
+ "frame": 5,
+ "line": [
+ {"x": 15.0, "y": 10.0},
+ {"x": 50.0, "y": 100.0},
+ {"x": 50.0, "y": 30.0},
+ ],
+ "classifications": [],
+ },
+ ]
+ },
+ {
+ "keyframes": [
+ {
+ "frame": 8,
+ "line": [
+ {"x": 100.0, "y": 10.0},
+ {"x": 50.0, "y": 100.0},
+ {"x": 50.0, "y": 30.0},
+ ],
+ "classifications": [],
+ }
+ ]
+ },
+ ],
+ },
+ {
+ "classifications": [],
+ "schemaId": "cl5it7ktp00i5ci6gf80b1ysd",
+ "dataRow": {"id": "ckrb1sf1i1g7i0ybcdc6oc8ct"},
+ "uuid": "f963be22-227b-4efe-9be4-2738ed822216",
+ "segments": [
+ {
+ "keyframes": [
+ {
+ "frame": 1,
+ "point": {"x": 10.0, "y": 10.0},
+ "classifications": [],
+ }
+ ]
+ },
+ {
+ "keyframes": [
+ {
+ "frame": 5,
+ "point": {"x": 50.0, "y": 50.0},
+ "classifications": [],
+ },
+ {
+ "frame": 10,
+ "point": {"x": 10.0, "y": 50.0},
+ "classifications": [],
+ },
+ ]
+ },
+ ],
+ },
+ {
+ "classifications": [],
+ "schemaId": "cl5iw0roz00lwci6g5jni62vs",
+ "dataRow": {"id": "ckrb1sf1i1g7i0ybcdc6oc8ct"},
+ "uuid": "13b2ee0e-2355-4336-8b83-d74d09e3b1e7",
+ "segments": [
+ {
+ "keyframes": [
+ {
+ "frame": 1,
+ "bbox": {
+ "top": 10.0,
+ "left": 5.0,
+ "height": 100.0,
+ "width": 150.0,
+ },
+ "classifications": [],
+ },
+ {
+ "frame": 5,
+ "bbox": {
+ "top": 30.0,
+ "left": 5.0,
+ "height": 50.0,
+ "width": 150.0,
+ },
+ "classifications": [],
+ },
+ ]
+ },
+ {
+ "keyframes": [
+ {
+ "frame": 10,
+ "bbox": {
+ "top": 300.0,
+ "left": 200.0,
+ "height": 400.0,
+ "width": 150.0,
+ },
+ "classifications": [],
+ }
+ ]
+ },
+ ],
+ },
+]
diff --git a/libs/labelbox/tests/data/test_data_row_metadata.py b/libs/labelbox/tests/data/test_data_row_metadata.py
index 891cab9be..2a455efce 100644
--- a/libs/labelbox/tests/data/test_data_row_metadata.py
+++ b/libs/labelbox/tests/data/test_data_row_metadata.py
@@ -1,18 +1,18 @@
-from datetime import datetime
+import uuid
+from datetime import datetime, timezone
import pytest
-import uuid
+from lbox.exceptions import MalformedQueryException
from labelbox import Dataset
-from labelbox.exceptions import MalformedQueryException
-from labelbox.schema.identifiables import GlobalKeys, UniqueIds
from labelbox.schema.data_row_metadata import (
- DataRowMetadataField,
DataRowMetadata,
+ DataRowMetadataField,
DataRowMetadataKind,
DataRowMetadataOntology,
_parse_metadata_schema,
)
+from labelbox.schema.identifiables import GlobalKeys, UniqueIds
INVALID_SCHEMA_ID = "1" * 25
FAKE_SCHEMA_ID = "0" * 25
@@ -61,7 +61,7 @@ def big_dataset(dataset: Dataset, image_url):
def make_metadata(dr_id: str = None, gk: str = None) -> DataRowMetadata:
msg = "A message"
- time = datetime.utcnow()
+ time = datetime.now(timezone.utc)
metadata = DataRowMetadata(
global_key=gk,
@@ -79,7 +79,7 @@ def make_metadata(dr_id: str = None, gk: str = None) -> DataRowMetadata:
def make_named_metadata(dr_id) -> DataRowMetadata:
msg = "A message"
- time = datetime.utcnow()
+ time = datetime.now(timezone.utc)
metadata = DataRowMetadata(
data_row_id=dr_id,
@@ -124,7 +124,7 @@ def test_get_datarow_metadata_ontology(mdo):
fields=[
DataRowMetadataField(
schema_id=mdo.reserved_by_name["captureDateTime"].uid,
- value=datetime.utcnow(),
+ value=datetime.now(timezone.utc),
),
DataRowMetadataField(schema_id=split.parent, value=split.uid),
DataRowMetadataField(
diff --git a/libs/labelbox/tests/integration/conftest.py b/libs/labelbox/tests/integration/conftest.py
index e73fef920..a0c2499a5 100644
--- a/libs/labelbox/tests/integration/conftest.py
+++ b/libs/labelbox/tests/integration/conftest.py
@@ -1,16 +1,22 @@
+import json
import os
+import re
import sys
import time
+import uuid
from collections import defaultdict
from datetime import datetime, timezone
+from enum import Enum
from itertools import islice
-from typing import Type
+from types import SimpleNamespace
+from typing import List, Tuple, Type
import pytest
from labelbox import (
Classification,
Client,
+ DataRow,
Dataset,
LabelingFrontend,
MediaType,
@@ -20,9 +26,14 @@
ResponseOption,
Tool,
)
+from labelbox.orm import query
+from labelbox.pagination import PaginatedCollection
+from labelbox.schema.annotation_import import LabelImport
+from labelbox.schema.catalog import Catalog
from labelbox.schema.data_row import DataRowMetadataField
+from labelbox.schema.enums import AnnotationImportState
+from labelbox.schema.invite import Invite
from labelbox.schema.ontology_kind import OntologyKind
-from labelbox.schema.queue_mode import QueueMode
from labelbox.schema.user import User
@@ -80,7 +91,6 @@ def project_pack(client):
projects = [
client.create_project(
name=f"user-proj-{idx}",
- queue_mode=QueueMode.Batch,
media_type=MediaType.Image,
)
for idx in range(2)
@@ -91,24 +101,27 @@ def project_pack(client):
@pytest.fixture
-def project_with_empty_ontology(project):
- editor = list(
- project.client.get_labeling_frontends(
- where=LabelingFrontend.name == "editor"
- )
- )[0]
- empty_ontology = {"tools": [], "classifications": []}
- project.setup(editor, empty_ontology)
+def project_with_one_feature_ontology(project, client: Client):
+ tools = [
+ Tool(tool=Tool.Type.BBOX, name="test-bbox-class").asdict(),
+ ]
+ empty_ontology = {"tools": tools, "classifications": []}
+ ontology = client.create_ontology(
+ "empty ontology",
+ empty_ontology,
+ MediaType.Image,
+ )
+ project.connect_ontology(ontology)
yield project
@pytest.fixture
def configured_project(
- project_with_empty_ontology, initial_dataset, rand_gen, image_url
+ project_with_one_feature_ontology, initial_dataset, rand_gen, image_url
):
dataset = initial_dataset
data_row_id = dataset.create_data_row(row_data=image_url).uid
- project = project_with_empty_ontology
+ project = project_with_one_feature_ontology
batch = project.create_batch(
rand_gen(str),
@@ -124,11 +137,10 @@ def configured_project(
@pytest.fixture
def configured_project_with_complex_ontology(
- client, initial_dataset, rand_gen, image_url, teardown_helpers
+ client: Client, initial_dataset, rand_gen, image_url, teardown_helpers
):
project = client.create_project(
name=rand_gen(str),
- queue_mode=QueueMode.Batch,
media_type=MediaType.Image,
)
dataset = initial_dataset
@@ -142,19 +154,12 @@ def configured_project_with_complex_ontology(
)
project.data_row_ids = data_row_ids
- editor = list(
- project.client.get_labeling_frontends(
- where=LabelingFrontend.name == "editor"
- )
- )[0]
-
ontology = OntologyBuilder()
tools = [
Tool(tool=Tool.Type.BBOX, name="test-bbox-class"),
Tool(tool=Tool.Type.LINE, name="test-line-class"),
Tool(tool=Tool.Type.POINT, name="test-point-class"),
Tool(tool=Tool.Type.POLYGON, name="test-polygon-class"),
- Tool(tool=Tool.Type.NER, name="test-ner-class"),
]
options = [
@@ -186,7 +191,12 @@ def configured_project_with_complex_ontology(
for c in classifications:
ontology.add_classification(c)
- project.setup(editor, ontology.asdict())
+ ontology = client.create_ontology(
+ "image ontology",
+ ontology.asdict(),
+ MediaType.Image,
+ )
+ project.connect_ontology(ontology)
yield [project, data_row]
teardown_helpers.teardown_project_labels_ontology_feature_schemas(project)
diff --git a/libs/labelbox/tests/integration/schema/test_user_group.py b/libs/labelbox/tests/integration/schema/test_user_group.py
index 50aaad4a7..b1443b7e7 100644
--- a/libs/labelbox/tests/integration/schema/test_user_group.py
+++ b/libs/labelbox/tests/integration/schema/test_user_group.py
@@ -2,11 +2,11 @@
import faker
import pytest
-
-from labelbox.exceptions import (
+from lbox.exceptions import (
ResourceCreationError,
ResourceNotFoundError,
)
+
from labelbox.schema.user_group import UserGroup, UserGroupColor
data = faker.Faker()
diff --git a/libs/labelbox/tests/integration/test_batch.py b/libs/labelbox/tests/integration/test_batch.py
index 8a8707175..f63b4c3d9 100644
--- a/libs/labelbox/tests/integration/test_batch.py
+++ b/libs/labelbox/tests/integration/test_batch.py
@@ -1,16 +1,16 @@
-import time
from typing import List
from uuid import uuid4
-import pytest
-from labelbox import Dataset, Project
-from labelbox.exceptions import (
- ProcessingWaitTimeout,
+import pytest
+from lbox.exceptions import (
+ LabelboxError,
MalformedQueryException,
+ ProcessingWaitTimeout,
ResourceConflict,
- LabelboxError,
)
+from labelbox import Dataset, Project
+
def get_data_row_ids(ds: Dataset):
return [dr.uid for dr in list(ds.data_rows())]
diff --git a/libs/labelbox/tests/integration/test_benchmark.py b/libs/labelbox/tests/integration/test_benchmark.py
index c10542bda..661f83bd7 100644
--- a/libs/labelbox/tests/integration/test_benchmark.py
+++ b/libs/labelbox/tests/integration/test_benchmark.py
@@ -1,17 +1,17 @@
def test_benchmark(configured_project_with_label):
project, _, data_row, label = configured_project_with_label
assert set(project.benchmarks()) == set()
- assert label.is_benchmark_reference == False
+ assert label.is_benchmark_reference is False
benchmark = label.create_benchmark()
assert set(project.benchmarks()) == {benchmark}
assert benchmark.reference_label() == label
# Refresh label data to check it's benchmark reference
label = list(data_row.labels())[0]
- assert label.is_benchmark_reference == True
+ assert label.is_benchmark_reference is True
benchmark.delete()
assert set(project.benchmarks()) == set()
# Refresh label data to check it's benchmark reference
label = list(data_row.labels())[0]
- assert label.is_benchmark_reference == False
+ assert label.is_benchmark_reference is False
diff --git a/libs/labelbox/tests/integration/test_chat_evaluation_ontology_project.py b/libs/labelbox/tests/integration/test_chat_evaluation_ontology_project.py
index 2c02b77ac..796cd9859 100644
--- a/libs/labelbox/tests/integration/test_chat_evaluation_ontology_project.py
+++ b/libs/labelbox/tests/integration/test_chat_evaluation_ontology_project.py
@@ -1,3 +1,5 @@
+from unittest.mock import patch
+
import pytest
from labelbox import MediaType
diff --git a/libs/labelbox/tests/integration/test_client_errors.py b/libs/labelbox/tests/integration/test_client_errors.py
index 64b8fb626..38022d1d2 100644
--- a/libs/labelbox/tests/integration/test_client_errors.py
+++ b/libs/labelbox/tests/integration/test_client_errors.py
@@ -1,44 +1,46 @@
-from multiprocessing.dummy import Pool
import os
import time
+from multiprocessing.dummy import Pool
+
+import lbox.exceptions
import pytest
from google.api_core.exceptions import RetryError
-from labelbox import Project, Dataset, User
import labelbox.client
-import labelbox.exceptions
+from labelbox import Project, User
+from labelbox.schema.media_type import MediaType
def test_missing_api_key():
- key = os.environ.get(labelbox.client._LABELBOX_API_KEY, None)
+ key = os.environ.get(lbox.request_client._LABELBOX_API_KEY, None)
if key is not None:
- del os.environ[labelbox.client._LABELBOX_API_KEY]
+ del os.environ[lbox.request_client._LABELBOX_API_KEY]
- with pytest.raises(labelbox.exceptions.AuthenticationError) as excinfo:
+ with pytest.raises(lbox.exceptions.AuthenticationError) as excinfo:
labelbox.client.Client()
assert excinfo.value.message == "Labelbox API key not provided"
if key is not None:
- os.environ[labelbox.client._LABELBOX_API_KEY] = key
+ os.environ[lbox.request_client._LABELBOX_API_KEY] = key
def test_bad_key(rand_gen):
bad_key = "BAD_KEY_" + rand_gen(str)
client = labelbox.client.Client(api_key=bad_key)
- with pytest.raises(labelbox.exceptions.AuthenticationError) as excinfo:
- client.create_project(name=rand_gen(str))
+ with pytest.raises(lbox.exceptions.AuthenticationError) as excinfo:
+ client.create_project(name=rand_gen(str), media_type=MediaType.Image)
def test_syntax_error(client):
- with pytest.raises(labelbox.exceptions.InvalidQueryError) as excinfo:
+ with pytest.raises(lbox.exceptions.InvalidQueryError) as excinfo:
client.execute("asda", check_naming=False)
assert excinfo.value.message.startswith("Syntax Error:")
def test_semantic_error(client):
- with pytest.raises(labelbox.exceptions.InvalidQueryError) as excinfo:
+ with pytest.raises(lbox.exceptions.InvalidQueryError) as excinfo:
client.execute("query {bbb {id}}", check_naming=False)
assert excinfo.value.message.startswith('Cannot query field "bbb"')
@@ -58,7 +60,7 @@ def test_timeout_error(client, project):
def test_query_complexity_error(client):
- with pytest.raises(labelbox.exceptions.ValidationFailedError) as excinfo:
+ with pytest.raises(lbox.exceptions.ValidationFailedError) as excinfo:
client.execute(
"{projects {datasets {dataRows {labels {id}}}}}", check_naming=False
)
@@ -66,41 +68,17 @@ def test_query_complexity_error(client):
def test_resource_not_found_error(client):
- with pytest.raises(labelbox.exceptions.ResourceNotFoundError):
+ with pytest.raises(lbox.exceptions.ResourceNotFoundError):
client.get_project("invalid project ID")
def test_network_error(client):
client = labelbox.client.Client(
- api_key=client.api_key, endpoint="not_a_valid_URL"
+ api_key=client._request_client.api_key, endpoint="not_a_valid_URL"
)
- with pytest.raises(labelbox.exceptions.NetworkError) as excinfo:
- client.create_project(name="Project name")
-
-
-def test_invalid_attribute_error(
- client,
- rand_gen,
-):
- # Creation
- with pytest.raises(labelbox.exceptions.InvalidAttributeError) as excinfo:
- client.create_project(name="Name", invalid_field="Whatever")
- assert excinfo.value.db_object_type == Project
- assert excinfo.value.field == "invalid_field"
-
- # Update
- project = client.create_project(name=rand_gen(str))
- with pytest.raises(labelbox.exceptions.InvalidAttributeError) as excinfo:
- project.update(invalid_field="Whatever")
- assert excinfo.value.db_object_type == Project
- assert excinfo.value.field == "invalid_field"
-
- # Top-level-fetch
- with pytest.raises(labelbox.exceptions.InvalidAttributeError) as excinfo:
- client.get_projects(where=User.email == "email")
- assert excinfo.value.db_object_type == Project
- assert excinfo.value.field == {User.email}
+ with pytest.raises(lbox.exceptions.NetworkError) as excinfo:
+ client.create_project(name="Project name", media_type=MediaType.Image)
@pytest.mark.skip("timeouts cause failure before rate limit")
@@ -108,7 +86,7 @@ def test_api_limit_error(client):
def get(arg):
try:
return client.get_user()
- except labelbox.exceptions.ApiLimitError as e:
+ except lbox.exceptions.ApiLimitError as e:
return e
# Rate limited at 1500 + buffer
@@ -120,7 +98,7 @@ def get(arg):
elapsed = time.time() - start
assert elapsed < 60, "Didn't finish fast enough"
- assert labelbox.exceptions.ApiLimitError in {type(r) for r in results}
+ assert lbox.exceptions.ApiLimitError in {type(r) for r in results}
# Sleep at the end of this test to allow other tests to execute.
time.sleep(60)
diff --git a/libs/labelbox/tests/integration/test_data_row_delete_metadata.py b/libs/labelbox/tests/integration/test_data_row_delete_metadata.py
index 2df860181..a2ffd31ba 100644
--- a/libs/labelbox/tests/integration/test_data_row_delete_metadata.py
+++ b/libs/labelbox/tests/integration/test_data_row_delete_metadata.py
@@ -1,13 +1,13 @@
-from datetime import datetime, timezone
import uuid
+from datetime import datetime, timezone
import pytest
+from lbox.exceptions import MalformedQueryException
-from labelbox import DataRow, Dataset, Client, DataRowMetadataOntology
-from labelbox.exceptions import MalformedQueryException
+from labelbox import Client, DataRow, DataRowMetadataOntology, Dataset
from labelbox.schema.data_row_metadata import (
- DataRowMetadataField,
DataRowMetadata,
+ DataRowMetadataField,
DataRowMetadataKind,
DeleteDataRowMetadata,
)
diff --git a/libs/labelbox/tests/integration/test_data_rows.py b/libs/labelbox/tests/integration/test_data_rows.py
index baa65db69..485719575 100644
--- a/libs/labelbox/tests/integration/test_data_rows.py
+++ b/libs/labelbox/tests/integration/test_data_rows.py
@@ -7,13 +7,13 @@
import pytest
import requests
-
-from labelbox import AssetAttachment, DataRow
-from labelbox.exceptions import (
+from lbox.exceptions import (
InvalidQueryError,
MalformedQueryException,
ResourceCreationError,
)
+
+from labelbox import AssetAttachment, DataRow
from labelbox.schema.data_row_metadata import (
DataRowMetadataField,
DataRowMetadataKind,
@@ -105,7 +105,7 @@ def make_metadata_fields_dict(constants):
def test_get_data_row_by_global_key(data_row_and_global_key, client, rand_gen):
_, global_key = data_row_and_global_key
data_row = client.get_data_row_by_global_key(global_key)
- assert type(data_row) == DataRow
+ assert type(data_row) is DataRow
assert data_row.global_key == global_key
@@ -505,8 +505,6 @@ def test_create_data_rows_with_metadata(
[
("create_data_rows", "class"),
("create_data_rows", "dict"),
- ("create_data_rows_sync", "class"),
- ("create_data_rows_sync", "dict"),
("create_data_row", "class"),
("create_data_row", "dict"),
],
@@ -546,7 +544,6 @@ def create_data_row(data_rows):
CREATION_FUNCTION = {
"create_data_rows": dataset.create_data_rows,
- "create_data_rows_sync": dataset.create_data_rows_sync,
"create_data_row": create_data_row,
}
data_rows = [METADATA_FIELDS[metadata_obj_type]]
@@ -695,9 +692,10 @@ def test_data_row_update(
pdf_url = "https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483.pdf"
tileLayerUrl = "https://storage.googleapis.com/labelbox-datasets/arxiv-pdf/data/99-word-token-pdfs/0801.3483-lb-textlayer.json"
data_row.update(row_data={"pdfUrl": pdf_url, "tileLayerUrl": tileLayerUrl})
- custom_check = (
- lambda data_row: data_row.row_data and "pdfUrl" not in data_row.row_data
- )
+
+ def custom_check(data_row):
+ return data_row.row_data and "pdfUrl" not in data_row.row_data
+
data_row = wait_for_data_row_processing(
client, data_row, custom_check=custom_check
)
@@ -821,49 +819,6 @@ def test_data_row_attachments(dataset, image_url):
)
-def test_create_data_rows_sync_attachments(dataset, image_url):
- attachments = [
- ("IMAGE", image_url, "image URL"),
- ("RAW_TEXT", "test-text", None),
- ("IMAGE_OVERLAY", image_url, "Overlay"),
- ("HTML", image_url, None),
- ]
- attachments_per_data_row = 3
- dataset.create_data_rows_sync(
- [
- {
- "row_data": image_url,
- "external_id": "test-id",
- "attachments": [
- {
- "type": attachment_type,
- "value": attachment_value,
- "name": attachment_name,
- }
- for _ in range(attachments_per_data_row)
- ],
- }
- for attachment_type, attachment_value, attachment_name in attachments
- ]
- )
- data_rows = list(dataset.data_rows())
- assert len(data_rows) == len(attachments)
- for data_row in data_rows:
- assert len(list(data_row.attachments())) == attachments_per_data_row
-
-
-def test_create_data_rows_sync_mixed_upload(dataset, image_url):
- n_local = 100
- n_urls = 100
- with NamedTemporaryFile() as fp:
- fp.write("Test data".encode())
- fp.flush()
- dataset.create_data_rows_sync(
- [{DataRow.row_data: image_url}] * n_urls + [fp.name] * n_local
- )
- assert len(list(dataset.data_rows())) == n_local + n_urls
-
-
def test_create_data_row_attachment(data_row):
att = data_row.create_attachment(
"IMAGE", "https://example.com/image.jpg", "name"
@@ -1047,9 +1002,9 @@ def test_data_row_bulk_creation_with_same_global_keys(
task.wait_till_done()
assert task.status == "COMPLETE"
- assert type(task.failed_data_rows) is list
+ assert isinstance(task.failed_data_rows, list)
assert len(task.failed_data_rows) == 1
- assert type(task.created_data_rows) is list
+ assert isinstance(task.created_data_rows, list)
assert len(task.created_data_rows) == 1
assert (
task.failed_data_rows[0]["message"]
@@ -1109,53 +1064,6 @@ def test_data_row_delete_and_create_with_same_global_key(
assert task.result[0]["global_key"] == global_key_1
-def test_data_row_bulk_creation_sync_with_unique_global_keys(
- dataset, sample_image
-):
- global_key_1 = str(uuid.uuid4())
- global_key_2 = str(uuid.uuid4())
- global_key_3 = str(uuid.uuid4())
-
- dataset.create_data_rows_sync(
- [
- {DataRow.row_data: sample_image, DataRow.global_key: global_key_1},
- {DataRow.row_data: sample_image, DataRow.global_key: global_key_2},
- {DataRow.row_data: sample_image, DataRow.global_key: global_key_3},
- ]
- )
-
- assert {row.global_key for row in dataset.data_rows()} == {
- global_key_1,
- global_key_2,
- global_key_3,
- }
-
-
-def test_data_row_bulk_creation_sync_with_same_global_keys(
- dataset, sample_image
-):
- global_key_1 = str(uuid.uuid4())
-
- with pytest.raises(ResourceCreationError) as exc_info:
- dataset.create_data_rows_sync(
- [
- {
- DataRow.row_data: sample_image,
- DataRow.global_key: global_key_1,
- },
- {
- DataRow.row_data: sample_image,
- DataRow.global_key: global_key_1,
- },
- ]
- )
-
- assert len(list(dataset.data_rows())) == 1
- assert list(dataset.data_rows())[0].global_key == global_key_1
- assert "Duplicate global key" in str(exc_info.value)
- assert exc_info.value.args[1] # task id
-
-
@pytest.fixture
def conversational_data_rows(dataset, conversational_content):
examples = [
@@ -1197,7 +1105,7 @@ def test_invalid_media_type(dataset, conversational_content):
# TODO: What error kind should this be? It looks like for global key we are
# using malformed query. But for invalid contents in FileUploads we use InvalidQueryError
with pytest.raises(ResourceCreationError):
- dataset.create_data_rows_sync(
+ dataset._create_data_rows_sync(
[{**conversational_content, "media_type": "IMAGE"}]
)
@@ -1207,7 +1115,8 @@ def test_create_tiled_layer(dataset, tile_content):
{**tile_content, "media_type": "TMS_GEO"},
tile_content,
]
- dataset.create_data_rows_sync(examples)
+ task = dataset.create_data_rows(examples)
+ task.wait_until_done()
data_rows = list(dataset.data_rows())
assert len(data_rows) == len(examples)
for data_row in data_rows:
diff --git a/libs/labelbox/tests/integration/test_dataset.py b/libs/labelbox/tests/integration/test_dataset.py
index 89210d6c9..a32c5541d 100644
--- a/libs/labelbox/tests/integration/test_dataset.py
+++ b/libs/labelbox/tests/integration/test_dataset.py
@@ -1,9 +1,10 @@
+from unittest.mock import MagicMock
+
import pytest
import requests
-from unittest.mock import MagicMock
-from labelbox import Dataset
-from labelbox.exceptions import ResourceNotFoundError, ResourceCreationError
+from lbox.exceptions import ResourceCreationError, ResourceNotFoundError
+from labelbox import Dataset
from labelbox.schema.internal.descriptor_file_creator import (
DescriptorFileCreator,
)
diff --git a/libs/labelbox/tests/integration/test_dates.py b/libs/labelbox/tests/integration/test_dates.py
index 7bde0d666..f4e2364c7 100644
--- a/libs/labelbox/tests/integration/test_dates.py
+++ b/libs/labelbox/tests/integration/test_dates.py
@@ -22,6 +22,8 @@ def test_utc_conversion(project):
# Update with a datetime with TZ info
tz = timezone(timedelta(hours=6)) # +6 timezone
- project.update(setup_complete=datetime.utcnow().replace(tzinfo=tz))
- diff = datetime.utcnow() - project.setup_complete.replace(tzinfo=None)
+ project.update(setup_complete=datetime.now(timezone.utc).replace(tzinfo=tz))
+ diff = datetime.now(timezone.utc) - project.setup_complete.replace(
+ tzinfo=timezone.utc
+ )
assert diff > timedelta(hours=5, minutes=58)
diff --git a/libs/labelbox/tests/integration/test_embedding.py b/libs/labelbox/tests/integration/test_embedding.py
index 1b54ab81c..41f7ed3de 100644
--- a/libs/labelbox/tests/integration/test_embedding.py
+++ b/libs/labelbox/tests/integration/test_embedding.py
@@ -2,12 +2,12 @@
import random
import threading
from tempfile import NamedTemporaryFile
-from typing import List, Dict, Any
+from typing import Any, Dict, List
+import lbox.exceptions
import pytest
-import labelbox.exceptions
-from labelbox import Client, Dataset, DataRow
+from labelbox import Client, DataRow, Dataset
from labelbox.schema.embedding import Embedding
@@ -23,7 +23,7 @@ def test_get_embedding_by_id(client: Client, embedding: Embedding):
def test_get_embedding_by_name_not_found(client: Client):
- with pytest.raises(labelbox.exceptions.ResourceNotFoundError):
+ with pytest.raises(lbox.exceptions.ResourceNotFoundError):
client.get_embedding_by_name("does-not-exist")
diff --git a/libs/labelbox/tests/integration/test_ephemeral.py b/libs/labelbox/tests/integration/test_ephemeral.py
index a23572fdf..3c4fc62e4 100644
--- a/libs/labelbox/tests/integration/test_ephemeral.py
+++ b/libs/labelbox/tests/integration/test_ephemeral.py
@@ -7,7 +7,7 @@
reason="This test only runs in EPHEMERAL environment",
)
def test_org_and_user_setup(client, ephmeral_client):
- assert type(client) == ephmeral_client
+ assert type(client) is ephmeral_client
assert client.admin_client
assert client.api_key != client.admin_client.api_key
@@ -22,4 +22,4 @@ def test_org_and_user_setup(client, ephmeral_client):
reason="This test does not run in EPHEMERAL environment",
)
def test_integration_client(client, integration_client):
- assert type(client) == integration_client
+ assert type(client) is integration_client
diff --git a/libs/labelbox/tests/integration/test_filtering.py b/libs/labelbox/tests/integration/test_filtering.py
index 1c37227b2..bba483b19 100644
--- a/libs/labelbox/tests/integration/test_filtering.py
+++ b/libs/labelbox/tests/integration/test_filtering.py
@@ -1,8 +1,8 @@
import pytest
+from lbox.exceptions import InvalidQueryError
from labelbox import Project
-from labelbox.exceptions import InvalidQueryError
-from labelbox.schema.queue_mode import QueueMode
+from labelbox.schema.media_type import MediaType
@pytest.fixture
@@ -11,9 +11,9 @@ def project_to_test_where(client, rand_gen):
p_b_name = f"b-{rand_gen(str)}"
p_c_name = f"c-{rand_gen(str)}"
- p_a = client.create_project(name=p_a_name, queue_mode=QueueMode.Batch)
- p_b = client.create_project(name=p_b_name, queue_mode=QueueMode.Batch)
- p_c = client.create_project(name=p_c_name, queue_mode=QueueMode.Batch)
+ p_a = client.create_project(name=p_a_name, media_type=MediaType.Image)
+ p_b = client.create_project(name=p_b_name, media_type=MediaType.Image)
+ p_c = client.create_project(name=p_c_name, media_type=MediaType.Image)
yield p_a, p_b, p_c
diff --git a/libs/labelbox/tests/integration/test_foundry.py b/libs/labelbox/tests/integration/test_foundry.py
index 83c4effc5..b9fd1b6f3 100644
--- a/libs/labelbox/tests/integration/test_foundry.py
+++ b/libs/labelbox/tests/integration/test_foundry.py
@@ -1,7 +1,8 @@
-import labelbox as lb
import pytest
-from labelbox.schema.foundry.app import App
+from lbox.exceptions import LabelboxError, ResourceNotFoundError
+import labelbox as lb
+from labelbox.schema.foundry.app import App
from labelbox.schema.foundry.foundry_client import FoundryClient
# Yolo object detection model id
@@ -97,7 +98,7 @@ def test_get_app(foundry_client, app):
def test_get_app_with_invalid_id(foundry_client):
- with pytest.raises(lb.exceptions.ResourceNotFoundError):
+ with pytest.raises(ResourceNotFoundError):
foundry_client._get_app("invalid-id")
@@ -144,7 +145,7 @@ def test_run_foundry_app_returns_model_run_id(
def test_run_foundry_with_invalid_data_row_id(foundry_client, app, random_str):
invalid_datarow_id = "invalid-global-key"
data_rows = lb.GlobalKeys([invalid_datarow_id])
- with pytest.raises(lb.exceptions.LabelboxError) as exception:
+ with pytest.raises(LabelboxError) as exception:
foundry_client.run_app(
model_run_name=f"test-app-with-invalid-datarow-id-{random_str}",
data_rows=data_rows,
@@ -156,7 +157,7 @@ def test_run_foundry_with_invalid_data_row_id(foundry_client, app, random_str):
def test_run_foundry_with_invalid_global_key(foundry_client, app, random_str):
invalid_global_key = "invalid-global-key"
data_rows = lb.GlobalKeys([invalid_global_key])
- with pytest.raises(lb.exceptions.LabelboxError) as exception:
+ with pytest.raises(LabelboxError) as exception:
foundry_client.run_app(
model_run_name=f"test-app-with-invalid-global-key-{random_str}",
data_rows=data_rows,
diff --git a/libs/labelbox/tests/integration/test_labeling_frontend.py b/libs/labelbox/tests/integration/test_labeling_frontend.py
index d6ea1aac9..9a72fed47 100644
--- a/libs/labelbox/tests/integration/test_labeling_frontend.py
+++ b/libs/labelbox/tests/integration/test_labeling_frontend.py
@@ -1,7 +1,7 @@
import pytest
+from lbox.exceptions import OperationNotSupportedException
from labelbox import LabelingFrontend
-from labelbox.exceptions import OperationNotSupportedException
def test_get_labeling_frontends(client):
diff --git a/libs/labelbox/tests/integration/test_labeling_parameter_overrides.py b/libs/labelbox/tests/integration/test_labeling_parameter_overrides.py
index bd14040de..afa038482 100644
--- a/libs/labelbox/tests/integration/test_labeling_parameter_overrides.py
+++ b/libs/labelbox/tests/integration/test_labeling_parameter_overrides.py
@@ -23,7 +23,7 @@ def test_labeling_parameter_overrides(consensus_project_with_batch):
data_rows[2].uid,
}
- data = [(data_rows[0], 4, 2), (data_rows[1], 3)]
+ data = [(UniqueId(data_rows[0].uid), 4, 2), (UniqueId(data_rows[1].uid), 3)]
success = project.set_labeling_parameter_overrides(data)
assert success
@@ -60,7 +60,7 @@ def test_labeling_parameter_overrides(consensus_project_with_batch):
assert {o.priority for o in updated_overrides} == {2, 3, 4}
with pytest.raises(TypeError) as exc_info:
- data = [(data_rows[2], "a_string", 3)]
+ data = [(UniqueId(data_rows[2].uid), "a_string", 3)]
project.set_labeling_parameter_overrides(data)
assert (
str(exc_info.value)
@@ -72,7 +72,7 @@ def test_labeling_parameter_overrides(consensus_project_with_batch):
project.set_labeling_parameter_overrides(data)
assert (
str(exc_info.value)
- == f"Data row identifier should be be of type DataRow, UniqueId or GlobalKey. Found for data_row_identifier {data_rows[2].uid}"
+ == "Data row identifier should be of type DataRowIdentifier. Found ."
)
@@ -85,13 +85,6 @@ def test_set_labeling_priority(consensus_project_with_batch):
assert len(init_labeling_parameter_overrides) == 3
assert {o.priority for o in init_labeling_parameter_overrides} == {5, 5, 5}
- data = [data_row.uid for data_row in data_rows]
- success = project.update_data_row_labeling_priority(data, 1)
- lo = list(project.labeling_parameter_overrides())
- assert success
- assert len(lo) == 3
- assert {o.priority for o in lo} == {1, 1, 1}
-
data = [data_row.uid for data_row in data_rows]
success = project.update_data_row_labeling_priority(UniqueIds(data), 2)
lo = list(project.labeling_parameter_overrides())
diff --git a/libs/labelbox/tests/integration/test_labeling_service.py b/libs/labelbox/tests/integration/test_labeling_service.py
index 03a5694a7..e8b3d4cdc 100644
--- a/libs/labelbox/tests/integration/test_labeling_service.py
+++ b/libs/labelbox/tests/integration/test_labeling_service.py
@@ -1,9 +1,9 @@
import pytest
-
-from labelbox.exceptions import (
- MalformedQueryException,
+from lbox.exceptions import (
+ LabelboxError,
ResourceNotFoundError,
)
+
from labelbox.schema.labeling_service import LabelingServiceStatus
@@ -53,7 +53,7 @@ def test_request_labeling_service_moe_project(
labeling_service = project.get_labeling_service()
with pytest.raises(
- MalformedQueryException,
+ LabelboxError,
match='[{"errorType":"PROJECT_MODEL_CONFIG","errorMessage":"Project model config is not completed"}]',
):
labeling_service.request()
@@ -75,5 +75,5 @@ def test_request_labeling_service_incomplete_requirements(ontology, project):
): # No labeling service by default
labeling_service.request()
project.connect_ontology(ontology)
- with pytest.raises(MalformedQueryException):
+ with pytest.raises(LabelboxError):
labeling_service.request()
diff --git a/libs/labelbox/tests/integration/test_legacy_project.py b/libs/labelbox/tests/integration/test_legacy_project.py
index 320a2191d..3e652f333 100644
--- a/libs/labelbox/tests/integration/test_legacy_project.py
+++ b/libs/labelbox/tests/integration/test_legacy_project.py
@@ -1,40 +1,13 @@
-import pytest
-
-from labelbox.schema.queue_mode import QueueMode
-
-
-def test_project_dataset(client, rand_gen):
- with pytest.raises(
- ValueError,
- match="Dataset queue mode is deprecated. Please prefer Batch queue mode.",
- ):
- client.create_project(
- name=rand_gen(str),
- queue_mode=QueueMode.Dataset,
- )
+from os import name
+import pytest
+from pydantic import ValidationError
-def test_project_auto_audit_parameters(client, rand_gen):
- with pytest.raises(
- ValueError,
- match="quality_modes must be set instead of auto_audit_percentage or auto_audit_number_of_labels.",
- ):
- client.create_project(name=rand_gen(str), auto_audit_percentage=0.5)
-
- with pytest.raises(
- ValueError,
- match="quality_modes must be set instead of auto_audit_percentage or auto_audit_number_of_labels.",
- ):
- client.create_project(name=rand_gen(str), auto_audit_number_of_labels=2)
+from labelbox.schema.media_type import MediaType
def test_project_name_parameter(client, rand_gen):
with pytest.raises(
- ValueError, match="project name must be a valid string."
- ):
- client.create_project()
-
- with pytest.raises(
- ValueError, match="project name must be a valid string."
+ ValidationError, match="project name must be a valid string"
):
- client.create_project(name=" ")
+ client.create_project(name=" ", media_type=MediaType.Image)
diff --git a/libs/labelbox/tests/integration/test_model_config.py b/libs/labelbox/tests/integration/test_model_config.py
index 7a060b917..66912e8d9 100644
--- a/libs/labelbox/tests/integration/test_model_config.py
+++ b/libs/labelbox/tests/integration/test_model_config.py
@@ -1,5 +1,5 @@
import pytest
-from labelbox.exceptions import ResourceNotFoundError
+from lbox.exceptions import ResourceNotFoundError
def test_create_model_config(client, valid_model_id):
diff --git a/libs/labelbox/tests/integration/test_ontology.py b/libs/labelbox/tests/integration/test_ontology.py
index 91ef74a39..c7c7c270c 100644
--- a/libs/labelbox/tests/integration/test_ontology.py
+++ b/libs/labelbox/tests/integration/test_ontology.py
@@ -1,11 +1,10 @@
-import pytest
-
-from labelbox import OntologyBuilder, MediaType, Tool
-from labelbox.orm.model import Entity
import json
import time
-from labelbox.schema.queue_mode import QueueMode
+import pytest
+
+from labelbox import MediaType, OntologyBuilder, Tool
+from labelbox.orm.model import Entity
def test_feature_schema_is_not_archived(client, ontology):
@@ -13,7 +12,7 @@ def test_feature_schema_is_not_archived(client, ontology):
result = client.is_feature_schema_archived(
ontology.uid, feature_schema_to_check["featureSchemaId"]
)
- assert result == False
+ assert result is False
def test_feature_schema_is_archived(client, configured_project_with_label):
@@ -23,10 +22,10 @@ def test_feature_schema_is_archived(client, configured_project_with_label):
result = client.delete_feature_schema_from_ontology(
ontology.uid, feature_schema_id
)
- assert result.archived == True and result.deleted == False
+ assert result.archived is True and result.deleted is False
assert (
client.is_feature_schema_archived(ontology.uid, feature_schema_id)
- == True
+ is True
)
@@ -58,8 +57,8 @@ def test_delete_tool_feature_from_ontology(client, ontology):
result = client.delete_feature_schema_from_ontology(
ontology.uid, feature_schema_to_delete["featureSchemaId"]
)
- assert result.deleted == True
- assert result.archived == False
+ assert result.deleted is True
+ assert result.archived is False
updatedOntology = client.get_ontology(ontology.uid)
assert len(updatedOntology.normalized["tools"]) == 1
@@ -99,7 +98,6 @@ def test_deletes_an_ontology(client):
def test_cant_delete_an_ontology_with_project(client):
project = client.create_project(
name="test project",
- queue_mode=QueueMode.Batch,
media_type=MediaType.Image,
)
tool = client.upsert_feature_schema(point.asdict())
@@ -187,7 +185,6 @@ def test_does_not_include_used_ontologies(client):
)
project = client.create_project(
name="test project",
- queue_mode=QueueMode.Batch,
media_type=MediaType.Image,
)
project.connect_ontology(ontology_with_project)
@@ -300,7 +297,7 @@ def test_unarchive_feature_schema_node(client, ontology):
result = client.unarchive_feature_schema_node(
ontology.uid, feature_schema_to_unarchive["featureSchemaId"]
)
- assert result == None
+ assert result is None
def test_unarchive_feature_schema_node_for_non_existing_feature_schema(
diff --git a/libs/labelbox/tests/integration/test_project.py b/libs/labelbox/tests/integration/test_project.py
index a38fa2b5d..ea995c6f6 100644
--- a/libs/labelbox/tests/integration/test_project.py
+++ b/libs/labelbox/tests/integration/test_project.py
@@ -1,21 +1,22 @@
-import time
import os
+import time
import uuid
+
+from labelbox.schema.ontology import OntologyBuilder, Tool
import pytest
import requests
+from lbox.exceptions import InvalidQueryError
-from labelbox import Project, LabelingFrontend, Dataset
-from labelbox.exceptions import InvalidQueryError
+from labelbox import Dataset, LabelingFrontend, Project
+from labelbox.schema import media_type
from labelbox.schema.media_type import MediaType
from labelbox.schema.quality_mode import QualityMode
-from labelbox.schema.queue_mode import QueueMode
def test_project(client, rand_gen):
data = {
"name": rand_gen(str),
"description": rand_gen(str),
- "queue_mode": QueueMode.Batch.Batch,
"media_type": MediaType.Image,
}
project = client.create_project(**data)
@@ -50,7 +51,7 @@ def data_for_project_test(client, rand_gen):
def _create_project(name: str = None):
if name is None:
name = rand_gen(str)
- project = client.create_project(name=name)
+ project = client.create_project(name=name, media_type=MediaType.Image)
projects.append(project)
return project
@@ -139,10 +140,6 @@ def test_extend_reservations(project):
project.extend_reservations("InvalidQueueType")
-@pytest.mark.skipif(
- condition=os.environ["LABELBOX_TEST_ENVIRON"] == "onprem",
- reason="new mutation does not work for onprem",
-)
def test_attach_instructions(client, project):
with pytest.raises(ValueError) as execinfo:
project.upsert_instructions("tests/integration/media/sample_pdf.pdf")
@@ -150,11 +147,17 @@ def test_attach_instructions(client, project):
str(execinfo.value)
== "Cannot attach instructions to a project that has not been set up."
)
- editor = list(
- client.get_labeling_frontends(where=LabelingFrontend.name == "editor")
- )[0]
- empty_ontology = {"tools": [], "classifications": []}
- project.setup(editor, empty_ontology)
+ ontology_builder = OntologyBuilder(
+ tools=[
+ Tool(tool=Tool.Type.BBOX, name="test-bbox-class"),
+ ]
+ )
+ ontology = client.create_ontology(
+ name="ontology with features",
+ media_type=MediaType.Image,
+ normalized=ontology_builder.asdict(),
+ )
+ project.connect_ontology(ontology)
project.upsert_instructions("tests/integration/media/sample_pdf.pdf")
time.sleep(3)
@@ -171,24 +174,20 @@ def test_attach_instructions(client, project):
condition=os.environ["LABELBOX_TEST_ENVIRON"] == "onprem",
reason="new mutation does not work for onprem",
)
-def test_html_instructions(project_with_empty_ontology):
+def test_html_instructions(project_with_one_feature_ontology):
html_file_path = "/tmp/instructions.html"
sample_html_str = ""
with open(html_file_path, "w") as file:
file.write(sample_html_str)
- project_with_empty_ontology.upsert_instructions(html_file_path)
- updated_ontology = project_with_empty_ontology.ontology().normalized
+ project_with_one_feature_ontology.upsert_instructions(html_file_path)
+ updated_ontology = project_with_one_feature_ontology.ontology().normalized
instructions = updated_ontology.pop("projectInstructions")
assert requests.get(instructions).text == sample_html_str
-@pytest.mark.skipif(
- condition=os.environ["LABELBOX_TEST_ENVIRON"] == "onprem",
- reason="new mutation does not work for onprem",
-)
def test_same_ontology_after_instructions(
configured_project_with_complex_ontology,
):
@@ -247,9 +246,11 @@ def test_media_type(client, project: Project, rand_gen):
assert isinstance(project.media_type, MediaType)
# Update test
- project = client.create_project(name=rand_gen(str))
- project.update(media_type=MediaType.Image)
- assert project.media_type == MediaType.Image
+ project = client.create_project(
+ name=rand_gen(str), media_type=MediaType.Image
+ )
+ project.update(media_type=MediaType.Text)
+ assert project.media_type == MediaType.Text
project.delete()
for media_type in MediaType.get_supported_members():
@@ -270,13 +271,16 @@ def test_media_type(client, project: Project, rand_gen):
def test_queue_mode(client, rand_gen):
project = client.create_project(
- name=rand_gen(str)
+ name=rand_gen(str),
+ media_type=MediaType.Image,
) # defaults to benchmark and consensus
assert project.auto_audit_number_of_labels == 3
assert project.auto_audit_percentage == 0
project = client.create_project(
- name=rand_gen(str), quality_modes=[QualityMode.Benchmark]
+ name=rand_gen(str),
+ quality_modes=[QualityMode.Benchmark],
+ media_type=MediaType.Image,
)
assert project.auto_audit_number_of_labels == 1
assert project.auto_audit_percentage == 1
@@ -284,13 +288,16 @@ def test_queue_mode(client, rand_gen):
project = client.create_project(
name=rand_gen(str),
quality_modes=[QualityMode.Benchmark, QualityMode.Consensus],
+ media_type=MediaType.Image,
)
assert project.auto_audit_number_of_labels == 3
assert project.auto_audit_percentage == 0
def test_label_count(client, configured_batch_project_with_label):
- project = client.create_project(name="test label count")
+ project = client.create_project(
+ name="test label count", media_type=MediaType.Image
+ )
assert project.get_label_count() == 0
project.delete()
@@ -308,7 +315,6 @@ def test_clone(client, project, rand_gen):
assert cloned_project.description == project.description
assert cloned_project.media_type == project.media_type
- assert cloned_project.queue_mode == project.queue_mode
assert (
cloned_project.auto_audit_number_of_labels
== project.auto_audit_number_of_labels
diff --git a/libs/labelbox/tests/integration/test_project_model_config.py b/libs/labelbox/tests/integration/test_project_model_config.py
index 975a39afe..f1646dfc0 100644
--- a/libs/labelbox/tests/integration/test_project_model_config.py
+++ b/libs/labelbox/tests/integration/test_project_model_config.py
@@ -1,6 +1,5 @@
import pytest
-
-from labelbox.exceptions import ResourceNotFoundError
+from lbox.exceptions import ResourceNotFoundError
def test_add_single_model_config(live_chat_evaluation_project, model_config):
diff --git a/libs/labelbox/tests/integration/test_project_set_model_setup_complete.py b/libs/labelbox/tests/integration/test_project_set_model_setup_complete.py
index 16a124945..30e179028 100644
--- a/libs/labelbox/tests/integration/test_project_set_model_setup_complete.py
+++ b/libs/labelbox/tests/integration/test_project_set_model_setup_complete.py
@@ -1,6 +1,5 @@
import pytest
-
-from labelbox.exceptions import LabelboxError, OperationNotAllowedException
+from lbox.exceptions import LabelboxError, OperationNotAllowedException
def test_live_chat_evaluation_project(
diff --git a/libs/labelbox/tests/integration/test_project_setup.py b/libs/labelbox/tests/integration/test_project_setup.py
index faadea228..a6fba03e0 100644
--- a/libs/labelbox/tests/integration/test_project_setup.py
+++ b/libs/labelbox/tests/integration/test_project_setup.py
@@ -1,11 +1,11 @@
-from datetime import datetime, timedelta, timezone
import json
import time
+from datetime import datetime, timedelta, timezone
import pytest
+from lbox.exceptions import InvalidQueryError
from labelbox import LabelingFrontend
-from labelbox.exceptions import InvalidQueryError, ResourceConflict
def simple_ontology():
@@ -24,35 +24,6 @@ def simple_ontology():
return {"tools": [], "classifications": classifications}
-def test_project_setup(project) -> None:
- client = project.client
- labeling_frontends = list(
- client.get_labeling_frontends(where=LabelingFrontend.name == "Editor")
- )
- assert len(labeling_frontends)
- labeling_frontend = labeling_frontends[0]
-
- time.sleep(3)
- now = datetime.now().astimezone(timezone.utc)
-
- project.setup(labeling_frontend, simple_ontology())
- assert now - project.setup_complete <= timedelta(seconds=3)
- assert now - project.last_activity_time <= timedelta(seconds=3)
-
- assert project.labeling_frontend() == labeling_frontend
- options = list(project.labeling_frontend_options())
- assert len(options) == 1
- options = options[0]
- # TODO ensure that LabelingFrontendOptions can be obtaind by ID
- with pytest.raises(InvalidQueryError):
- assert options.labeling_frontend() == labeling_frontend
- assert options.project() == project
- assert options.organization() == client.get_organization()
- assert options.customization_options == json.dumps(simple_ontology())
- assert project.organization() == client.get_organization()
- assert project.created_by() == client.get_user()
-
-
def test_project_editor_setup(client, project, rand_gen):
ontology_name = f"test_project_editor_setup_ontology_name-{rand_gen(str)}"
ontology = client.create_ontology(ontology_name, simple_ontology())
diff --git a/libs/labelbox/tests/integration/test_prompt_response_generation_project.py b/libs/labelbox/tests/integration/test_prompt_response_generation_project.py
index 1373ee470..f5003f061 100644
--- a/libs/labelbox/tests/integration/test_prompt_response_generation_project.py
+++ b/libs/labelbox/tests/integration/test_prompt_response_generation_project.py
@@ -1,9 +1,8 @@
-import pytest
from unittest.mock import patch
+import pytest
+
from labelbox import MediaType
-from labelbox.schema.ontology_kind import OntologyKind
-from labelbox.exceptions import MalformedQueryException
@pytest.mark.parametrize(
diff --git a/libs/labelbox/tests/integration/test_task_queue.py b/libs/labelbox/tests/integration/test_task_queue.py
index 835f67219..0cd66cb62 100644
--- a/libs/labelbox/tests/integration/test_task_queue.py
+++ b/libs/labelbox/tests/integration/test_task_queue.py
@@ -68,7 +68,9 @@ def test_move_to_task(configured_batch_project_with_label):
review_queue = next(
tq for tq in task_queues if tq.queue_type == "MANUAL_REVIEW_QUEUE"
)
- project.move_data_rows_to_task_queue([data_row.uid], review_queue.uid)
+ project.move_data_rows_to_task_queue(
+ UniqueIds([data_row.uid]), review_queue.uid
+ )
_validate_moved(project, "MANUAL_REVIEW_QUEUE", 1)
review_queue = next(
diff --git a/libs/labelbox/tests/integration/test_toggle_mal.py b/libs/labelbox/tests/integration/test_toggle_mal.py
index 41dfbe395..566c4210c 100644
--- a/libs/labelbox/tests/integration/test_toggle_mal.py
+++ b/libs/labelbox/tests/integration/test_toggle_mal.py
@@ -1,9 +1,9 @@
def test_enable_model_assisted_labeling(project):
response = project.enable_model_assisted_labeling()
- assert response == True
+ assert response is True
response = project.enable_model_assisted_labeling(True)
- assert response == True
+ assert response is True
response = project.enable_model_assisted_labeling(False)
- assert response == False
+ assert response is False
diff --git a/libs/labelbox/tests/unit/export_task/test_unit_file_converter.py b/libs/labelbox/tests/unit/export_task/test_unit_file_converter.py
deleted file mode 100644
index 81e9eb60f..000000000
--- a/libs/labelbox/tests/unit/export_task/test_unit_file_converter.py
+++ /dev/null
@@ -1,77 +0,0 @@
-from unittest.mock import MagicMock
-
-from labelbox.schema.export_task import (
- Converter,
- FileConverter,
- Range,
- StreamType,
- _MetadataFileInfo,
- _MetadataHeader,
- _TaskContext,
-)
-
-
-class TestFileConverter:
- def test_with_correct_ndjson(self, tmp_path, generate_random_ndjson):
- directory = tmp_path / "file-converter"
- directory.mkdir()
- line_count = 10
- ndjson = generate_random_ndjson(line_count)
- file_content = "\n".join(ndjson) + "\n"
- input_args = Converter.ConverterInputArgs(
- ctx=_TaskContext(
- client=MagicMock(),
- task_id="task-id",
- stream_type=StreamType.RESULT,
- metadata_header=_MetadataHeader(
- total_size=len(file_content), total_lines=line_count
- ),
- ),
- file_info=_MetadataFileInfo(
- offsets=Range(start=0, end=len(file_content) - 1),
- lines=Range(start=0, end=line_count - 1),
- file="file.ndjson",
- ),
- raw_data=file_content,
- )
- path = directory / "output.ndjson"
- with FileConverter(file_path=path) as converter:
- for output in converter.convert(input_args):
- assert output.current_line == 0
- assert output.current_offset == 0
- assert output.file_path == path
- assert output.total_lines == line_count
- assert output.total_size == len(file_content)
- assert output.bytes_written == len(file_content)
-
- def test_with_no_newline_at_end(self, tmp_path, generate_random_ndjson):
- directory = tmp_path / "file-converter"
- directory.mkdir()
- line_count = 10
- ndjson = generate_random_ndjson(line_count)
- file_content = "\n".join(ndjson)
- input_args = Converter.ConverterInputArgs(
- ctx=_TaskContext(
- client=MagicMock(),
- task_id="task-id",
- stream_type=StreamType.RESULT,
- metadata_header=_MetadataHeader(
- total_size=len(file_content), total_lines=line_count
- ),
- ),
- file_info=_MetadataFileInfo(
- offsets=Range(start=0, end=len(file_content) - 1),
- lines=Range(start=0, end=line_count - 1),
- file="file.ndjson",
- ),
- raw_data=file_content,
- )
- path = directory / "output.ndjson"
- with FileConverter(file_path=path) as converter:
- for output in converter.convert(input_args):
- assert output.current_line == 0
- assert output.current_offset == 0
- assert output.file_path == path
- assert output.total_lines == line_count
- assert output.total_size == len(file_content)
- assert output.bytes_written == len(file_content)
diff --git a/libs/labelbox/tests/unit/export_task/test_unit_file_retriever_by_line.py b/libs/labelbox/tests/unit/export_task/test_unit_file_retriever_by_line.py
deleted file mode 100644
index 37c93647e..000000000
--- a/libs/labelbox/tests/unit/export_task/test_unit_file_retriever_by_line.py
+++ /dev/null
@@ -1,126 +0,0 @@
-from unittest.mock import MagicMock, patch
-from labelbox.schema.export_task import (
- FileRetrieverByLine,
- _TaskContext,
- _MetadataHeader,
- StreamType,
-)
-
-
-class TestFileRetrieverByLine:
- def test_by_line_from_start(self, generate_random_ndjson, mock_response):
- line_count = 10
- ndjson = generate_random_ndjson(line_count)
- file_content = "\n".join(ndjson) + "\n"
-
- mock_client = MagicMock()
- mock_client.execute = MagicMock(
- return_value={
- "task": {
- "exportFileFromLine": {
- "offsets": {"start": "0", "end": len(file_content) - 1},
- "lines": {"start": "0", "end": str(line_count - 1)},
- "file": "http://some-url.com/file.ndjson",
- }
- }
- }
- )
-
- mock_ctx = _TaskContext(
- client=mock_client,
- task_id="task-id",
- stream_type=StreamType.RESULT,
- metadata_header=_MetadataHeader(
- total_size=len(file_content), total_lines=line_count
- ),
- )
-
- with patch("requests.get", return_value=mock_response(file_content)):
- retriever = FileRetrieverByLine(mock_ctx, 0)
- info, content = retriever.get_next_chunk()
- assert info.offsets.start == 0
- assert info.offsets.end == len(file_content) - 1
- assert info.lines.start == 0
- assert info.lines.end == line_count - 1
- assert info.file == "http://some-url.com/file.ndjson"
- assert content == file_content
-
- def test_by_line_from_middle(self, generate_random_ndjson, mock_response):
- line_count = 10
- ndjson = generate_random_ndjson(line_count)
- file_content = "\n".join(ndjson) + "\n"
-
- mock_client = MagicMock()
- mock_client.execute = MagicMock(
- return_value={
- "task": {
- "exportFileFromLine": {
- "offsets": {"start": "0", "end": len(file_content) - 1},
- "lines": {"start": "0", "end": str(line_count - 1)},
- "file": "http://some-url.com/file.ndjson",
- }
- }
- }
- )
-
- mock_ctx = _TaskContext(
- client=mock_client,
- task_id="task-id",
- stream_type=StreamType.RESULT,
- metadata_header=_MetadataHeader(
- total_size=len(file_content), total_lines=line_count
- ),
- )
-
- line_start = 5
- current_offset = file_content.find(ndjson[line_start])
-
- with patch("requests.get", return_value=mock_response(file_content)):
- retriever = FileRetrieverByLine(mock_ctx, line_start)
- info, content = retriever.get_next_chunk()
- assert info.offsets.start == current_offset
- assert info.offsets.end == len(file_content) - 1
- assert info.lines.start == line_start
- assert info.lines.end == line_count - 1
- assert info.file == "http://some-url.com/file.ndjson"
- assert content == file_content[current_offset:]
-
- def test_by_line_from_last(self, generate_random_ndjson, mock_response):
- line_count = 10
- ndjson = generate_random_ndjson(line_count)
- file_content = "\n".join(ndjson) + "\n"
-
- mock_client = MagicMock()
- mock_client.execute = MagicMock(
- return_value={
- "task": {
- "exportFileFromLine": {
- "offsets": {"start": "0", "end": len(file_content) - 1},
- "lines": {"start": "0", "end": str(line_count - 1)},
- "file": "http://some-url.com/file.ndjson",
- }
- }
- }
- )
-
- mock_ctx = _TaskContext(
- client=mock_client,
- task_id="task-id",
- stream_type=StreamType.RESULT,
- metadata_header=_MetadataHeader(
- total_size=len(file_content), total_lines=line_count
- ),
- )
-
- line_start = 9
- current_offset = file_content.find(ndjson[line_start])
-
- with patch("requests.get", return_value=mock_response(file_content)):
- retriever = FileRetrieverByLine(mock_ctx, line_start)
- info, content = retriever.get_next_chunk()
- assert info.offsets.start == current_offset
- assert info.offsets.end == len(file_content) - 1
- assert info.lines.start == line_start
- assert info.lines.end == line_count - 1
- assert info.file == "http://some-url.com/file.ndjson"
- assert content == file_content[current_offset:]
diff --git a/libs/labelbox/tests/unit/export_task/test_unit_file_retriever_by_offset.py b/libs/labelbox/tests/unit/export_task/test_unit_file_retriever_by_offset.py
deleted file mode 100644
index 870e03307..000000000
--- a/libs/labelbox/tests/unit/export_task/test_unit_file_retriever_by_offset.py
+++ /dev/null
@@ -1,87 +0,0 @@
-from unittest.mock import MagicMock, patch
-from labelbox.schema.export_task import (
- FileRetrieverByOffset,
- _TaskContext,
- _MetadataHeader,
- StreamType,
-)
-
-
-class TestFileRetrieverByOffset:
- def test_by_offset_from_start(self, generate_random_ndjson, mock_response):
- line_count = 10
- ndjson = generate_random_ndjson(line_count)
- file_content = "\n".join(ndjson) + "\n"
-
- mock_client = MagicMock()
- mock_client.execute = MagicMock(
- return_value={
- "task": {
- "exportFileFromOffset": {
- "offsets": {"start": "0", "end": len(file_content) - 1},
- "lines": {"start": "0", "end": str(line_count - 1)},
- "file": "http://some-url.com/file.ndjson",
- }
- }
- }
- )
-
- mock_ctx = _TaskContext(
- client=mock_client,
- task_id="task-id",
- stream_type=StreamType.RESULT,
- metadata_header=_MetadataHeader(
- total_size=len(file_content), total_lines=line_count
- ),
- )
-
- with patch("requests.get", return_value=mock_response(file_content)):
- retriever = FileRetrieverByOffset(mock_ctx, 0)
- info, content = retriever.get_next_chunk()
- assert info.offsets.start == 0
- assert info.offsets.end == len(file_content) - 1
- assert info.lines.start == 0
- assert info.lines.end == line_count - 1
- assert info.file == "http://some-url.com/file.ndjson"
- assert content == file_content
-
- def test_by_offset_from_middle(self, generate_random_ndjson, mock_response):
- line_count = 10
- ndjson = generate_random_ndjson(line_count)
- file_content = "\n".join(ndjson) + "\n"
-
- mock_client = MagicMock()
- mock_client.execute = MagicMock(
- return_value={
- "task": {
- "exportFileFromOffset": {
- "offsets": {"start": "0", "end": len(file_content) - 1},
- "lines": {"start": "0", "end": str(line_count - 1)},
- "file": "http://some-url.com/file.ndjson",
- }
- }
- }
- )
-
- mock_ctx = _TaskContext(
- client=mock_client,
- task_id="task-id",
- stream_type=StreamType.RESULT,
- metadata_header=_MetadataHeader(
- total_size=len(file_content), total_lines=line_count
- ),
- )
-
- line_start = 5
- skipped_bytes = 15
- current_offset = file_content.find(ndjson[line_start]) + skipped_bytes
-
- with patch("requests.get", return_value=mock_response(file_content)):
- retriever = FileRetrieverByOffset(mock_ctx, current_offset)
- info, content = retriever.get_next_chunk()
- assert info.offsets.start == current_offset
- assert info.offsets.end == len(file_content) - 1
- assert info.lines.start == 5
- assert info.lines.end == line_count - 1
- assert info.file == "http://some-url.com/file.ndjson"
- assert content == file_content[current_offset:]
diff --git a/libs/labelbox/tests/unit/export_task/test_unit_json_converter.py b/libs/labelbox/tests/unit/export_task/test_unit_json_converter.py
deleted file mode 100644
index f5ccf26fb..000000000
--- a/libs/labelbox/tests/unit/export_task/test_unit_json_converter.py
+++ /dev/null
@@ -1,112 +0,0 @@
-from unittest.mock import MagicMock
-
-from labelbox.schema.export_task import (
- Converter,
- JsonConverter,
- Range,
- _MetadataFileInfo,
-)
-
-
-class TestJsonConverter:
- def test_with_correct_ndjson(self, generate_random_ndjson):
- line_count = 10
- ndjson = generate_random_ndjson(line_count)
- file_content = "\n".join(ndjson) + "\n"
- input_args = Converter.ConverterInputArgs(
- ctx=MagicMock(),
- file_info=_MetadataFileInfo(
- offsets=Range(start=0, end=len(file_content) - 1),
- lines=Range(start=0, end=line_count - 1),
- file="file.ndjson",
- ),
- raw_data=file_content,
- )
- with JsonConverter() as converter:
- current_offset = 0
- for idx, output in enumerate(converter.convert(input_args)):
- assert output.current_line == idx
- assert output.current_offset == current_offset
- assert output.json_str == ndjson[idx]
- current_offset += len(output.json_str) + 1
-
- def test_with_no_newline_at_end(self, generate_random_ndjson):
- line_count = 10
- ndjson = generate_random_ndjson(line_count)
- file_content = "\n".join(ndjson)
- input_args = Converter.ConverterInputArgs(
- ctx=MagicMock(),
- file_info=_MetadataFileInfo(
- offsets=Range(start=0, end=len(file_content) - 1),
- lines=Range(start=0, end=line_count - 1),
- file="file.ndjson",
- ),
- raw_data=file_content,
- )
- with JsonConverter() as converter:
- current_offset = 0
- for idx, output in enumerate(converter.convert(input_args)):
- assert output.current_line == idx
- assert output.current_offset == current_offset
- assert output.json_str == ndjson[idx]
- current_offset += len(output.json_str) + 1
-
- def test_from_offset(self, generate_random_ndjson):
- # testing middle of a JSON string, but not the last line
- line_count = 10
- line_start = 5
- ndjson = generate_random_ndjson(line_count)
- file_content = "\n".join(ndjson) + "\n"
- offset_end = len(file_content)
- skipped_bytes = 15
- current_offset = file_content.find(ndjson[line_start]) + skipped_bytes
- file_content = file_content[current_offset:]
-
- input_args = Converter.ConverterInputArgs(
- ctx=MagicMock(),
- file_info=_MetadataFileInfo(
- offsets=Range(start=current_offset, end=offset_end),
- lines=Range(start=line_start, end=line_count - 1),
- file="file.ndjson",
- ),
- raw_data=file_content,
- )
- with JsonConverter() as converter:
- for idx, output in enumerate(converter.convert(input_args)):
- assert output.current_line == line_start + idx
- assert output.current_offset == current_offset
- assert (
- output.json_str == ndjson[line_start + idx][skipped_bytes:]
- )
- current_offset += len(output.json_str) + 1
- skipped_bytes = 0
-
- def test_from_offset_last_line(self, generate_random_ndjson):
- # testing middle of a JSON string, but not the last line
- line_count = 10
- line_start = 9
- ndjson = generate_random_ndjson(line_count)
- file_content = "\n".join(ndjson) + "\n"
- offset_end = len(file_content)
- skipped_bytes = 15
- current_offset = file_content.find(ndjson[line_start]) + skipped_bytes
- file_content = file_content[current_offset:]
-
- input_args = Converter.ConverterInputArgs(
- ctx=MagicMock(),
- file_info=_MetadataFileInfo(
- offsets=Range(start=current_offset, end=offset_end),
- lines=Range(start=line_start, end=line_count - 1),
- file="file.ndjson",
- ),
- raw_data=file_content,
- )
- with JsonConverter() as converter:
- for idx, output in enumerate(converter.convert(input_args)):
- assert output.current_line == line_start + idx
- assert output.current_offset == current_offset
- assert (
- output.json_str == ndjson[line_start + idx][skipped_bytes:]
- )
- current_offset += len(output.json_str) + 1
- skipped_bytes = 0
diff --git a/libs/labelbox/tests/unit/schema/test_user_group.py b/libs/labelbox/tests/unit/schema/test_user_group.py
index 65584f8ef..6bc29048d 100644
--- a/libs/labelbox/tests/unit/schema/test_user_group.py
+++ b/libs/labelbox/tests/unit/schema/test_user_group.py
@@ -1,20 +1,21 @@
-import pytest
from collections import defaultdict
from unittest.mock import MagicMock
-from labelbox import Client
-from labelbox.exceptions import (
+
+import pytest
+from lbox.exceptions import (
+ MalformedQueryException,
ResourceConflict,
ResourceCreationError,
ResourceNotFoundError,
- MalformedQueryException,
UnprocessableEntityError,
)
+
+from labelbox import Client
+from labelbox.schema.media_type import MediaType
+from labelbox.schema.ontology_kind import EditorTaskType
from labelbox.schema.project import Project
from labelbox.schema.user import User
from labelbox.schema.user_group import UserGroup, UserGroupColor
-from labelbox.schema.queue_mode import QueueMode
-from labelbox.schema.ontology_kind import EditorTaskType
-from labelbox.schema.media_type import MediaType
@pytest.fixture
@@ -30,7 +31,6 @@ def group_project():
project_values = defaultdict(lambda: None)
project_values["id"] = "project_id"
project_values["name"] = "Test Project"
- project_values["queueMode"] = QueueMode.Batch.value
project_values["editorTaskType"] = EditorTaskType.Missing.value
project_values["mediaType"] = MediaType.Image.value
return Project(MagicMock(Client), project_values)
@@ -55,12 +55,6 @@ def setup_method(self):
self.client.enable_experimental = True
self.group = UserGroup(client=self.client)
- def test_constructor_experimental_needed(self):
- client = MagicMock(Client)
- client.enable_experimental = False
- with pytest.raises(RuntimeError):
- group = UserGroup(client)
-
def test_constructor(self):
group = UserGroup(self.client)
diff --git a/libs/labelbox/tests/unit/test_exceptions.py b/libs/labelbox/tests/unit/test_exceptions.py
index 4602fb984..074a735f2 100644
--- a/libs/labelbox/tests/unit/test_exceptions.py
+++ b/libs/labelbox/tests/unit/test_exceptions.py
@@ -1,6 +1,5 @@
import pytest
-
-from labelbox.exceptions import error_message_for_unparsed_graphql_error
+from lbox.exceptions import error_message_for_unparsed_graphql_error
@pytest.mark.parametrize(
diff --git a/libs/labelbox/tests/unit/test_label_data_type.py b/libs/labelbox/tests/unit/test_label_data_type.py
index 7bc32e37c..611324f78 100644
--- a/libs/labelbox/tests/unit/test_label_data_type.py
+++ b/libs/labelbox/tests/unit/test_label_data_type.py
@@ -1,11 +1,7 @@
-from email import message
import pytest
-from pydantic import ValidationError
-
from labelbox.data.annotation_types.data.generic_data_row_data import (
GenericDataRowData,
)
-from labelbox.data.annotation_types.data.video import VideoData
from labelbox.data.annotation_types.label import Label
@@ -37,20 +33,6 @@ def test_generic_data_type_validations():
Label(data=data)
-def test_video_data_type():
- data = {
- "global_key": "https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-BEidMVWRmyXjVCnr",
- }
- with pytest.warns(UserWarning, match="Use a dict"):
- label = Label(data=VideoData(**data))
- data = label.data
- assert isinstance(data, VideoData)
- assert (
- data.global_key
- == "https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-BEidMVWRmyXjVCnr"
- )
-
-
def test_generic_data_row():
data = {
"global_key": "https://lb-test-data.s3.us-west-1.amazonaws.com/image-samples/sample-image-1.jpg-BEidMVWRmyXjVCnr",
diff --git a/libs/labelbox/tests/unit/test_project.py b/libs/labelbox/tests/unit/test_project.py
index 5e5f99c57..1bc6fa840 100644
--- a/libs/labelbox/tests/unit/test_project.py
+++ b/libs/labelbox/tests/unit/test_project.py
@@ -21,7 +21,6 @@ def project_entity():
"editorTaskType": "MODEL_CHAT_EVALUATION",
"lastActivityTime": "2021-06-01T00:00:00.000Z",
"allowedMediaType": "IMAGE",
- "queueMode": "BATCH",
"setupComplete": "2021-06-01T00:00:00.000Z",
"modelSetupComplete": None,
"uploadType": "Auto",
@@ -62,7 +61,6 @@ def test_project_editor_task_type(
"editorTaskType": api_editor_task_type,
"lastActivityTime": "2021-06-01T00:00:00.000Z",
"allowedMediaType": "IMAGE",
- "queueMode": "BATCH",
"setupComplete": "2021-06-01T00:00:00.000Z",
"modelSetupComplete": None,
"uploadType": "Auto",
@@ -72,13 +70,3 @@ def test_project_editor_task_type(
)
assert project.editor_task_type == expected_editor_task_type
-
-
-def test_setup_editor_using_connect_ontology(project_entity):
- project = project_entity
- ontology = MagicMock()
- project.connect_ontology = MagicMock()
- with patch("warnings.warn") as warn:
- project.setup_editor(ontology)
- warn.assert_called_once()
- project.connect_ontology.assert_called_once_with(ontology)
diff --git a/libs/labelbox/tests/unit/test_queue_mode.py b/libs/labelbox/tests/unit/test_queue_mode.py
deleted file mode 100644
index a07b14a54..000000000
--- a/libs/labelbox/tests/unit/test_queue_mode.py
+++ /dev/null
@@ -1,20 +0,0 @@
-import pytest
-
-from labelbox.schema.queue_mode import QueueMode
-
-
-def test_parse_deprecated_catalog():
- assert QueueMode("CATALOG") == QueueMode.Batch
-
-
-def test_parse_batch():
- assert QueueMode("BATCH") == QueueMode.Batch
-
-
-def test_parse_data_set():
- assert QueueMode("DATA_SET") == QueueMode.Dataset
-
-
-def test_fails_for_unknown():
- with pytest.raises(ValueError):
- QueueMode("foo")
diff --git a/libs/labelbox/tests/unit/test_unit_ontology.py b/libs/labelbox/tests/unit/test_unit_ontology.py
index 0566ad623..61c9f523a 100644
--- a/libs/labelbox/tests/unit/test_unit_ontology.py
+++ b/libs/labelbox/tests/unit/test_unit_ontology.py
@@ -1,8 +1,9 @@
+from itertools import product
+
import pytest
+from lbox.exceptions import InconsistentOntologyException
-from labelbox.exceptions import InconsistentOntologyException
-from labelbox import Tool, Classification, Option, OntologyBuilder
-from itertools import product
+from labelbox import Classification, OntologyBuilder, Option, Tool
_SAMPLE_ONTOLOGY = {
"tools": [
@@ -197,7 +198,7 @@ def test_add_ontology_tool() -> None:
assert len(o.tools) == 2
for tool in o.tools:
- assert type(tool) == Tool
+ assert type(tool) is Tool
with pytest.raises(InconsistentOntologyException) as exc:
o.add_tool(Tool(tool=Tool.Type.BBOX, name="bounding box"))
@@ -217,7 +218,7 @@ def test_add_ontology_classification() -> None:
assert len(o.classifications) == 2
for classification in o.classifications:
- assert type(classification) == Classification
+ assert type(classification) is Classification
with pytest.raises(InconsistentOntologyException) as exc:
o.add_classification(
diff --git a/libs/labelbox/tests/unit/test_unit_project_validate_labeling_parameter_overrides.py b/libs/labelbox/tests/unit/test_unit_project_validate_labeling_parameter_overrides.py
index 7f6d29d5a..8e6c38559 100644
--- a/libs/labelbox/tests/unit/test_unit_project_validate_labeling_parameter_overrides.py
+++ b/libs/labelbox/tests/unit/test_unit_project_validate_labeling_parameter_overrides.py
@@ -6,19 +6,6 @@
from labelbox.schema.project import validate_labeling_parameter_overrides
-def test_validate_labeling_parameter_overrides_valid_data():
- mock_data_row = MagicMock(spec=DataRow)
- mock_data_row.uid = "abc"
- data = [(mock_data_row, 1), (UniqueId("efg"), 2), (GlobalKey("hij"), 3)]
- validate_labeling_parameter_overrides(data)
-
-
-def test_validate_labeling_parameter_overrides_invalid_data():
- data = [("abc", 1), (UniqueId("efg"), 2), (GlobalKey("hij"), 3)]
- with pytest.raises(TypeError):
- validate_labeling_parameter_overrides(data)
-
-
def test_validate_labeling_parameter_overrides_invalid_priority():
mock_data_row = MagicMock(spec=DataRow)
mock_data_row.uid = "abc"
diff --git a/libs/lbox-clients/.gitignore b/libs/lbox-clients/.gitignore
new file mode 100644
index 000000000..ae8554dec
--- /dev/null
+++ b/libs/lbox-clients/.gitignore
@@ -0,0 +1,10 @@
+# python generated files
+__pycache__/
+*.py[oc]
+build/
+dist/
+wheels/
+*.egg-info
+
+# venv
+.venv
diff --git a/libs/lbox-clients/.python-version b/libs/lbox-clients/.python-version
new file mode 100644
index 000000000..43077b246
--- /dev/null
+++ b/libs/lbox-clients/.python-version
@@ -0,0 +1 @@
+3.9.18
diff --git a/libs/lbox-clients/Dockerfile b/libs/lbox-clients/Dockerfile
new file mode 100644
index 000000000..dd3b5147d
--- /dev/null
+++ b/libs/lbox-clients/Dockerfile
@@ -0,0 +1,44 @@
+# https://github.com/ucyo/python-package-template/blob/master/Dockerfile
+FROM python:3.9-slim as rye
+
+ENV LANG="C.UTF-8" \
+ LC_ALL="C.UTF-8" \
+ PATH="/home/python/.local/bin:/home/python/.rye/shims:$PATH" \
+ PIP_NO_CACHE_DIR="false" \
+ RYE_VERSION="0.34.0" \
+ RYE_INSTALL_OPTION="--yes" \
+ LABELBOX_TEST_ENVIRON="prod"
+
+RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
+ ca-certificates \
+ curl \
+ inotify-tools \
+ make \
+ # cv2
+ libsm6 \
+ libxext6 \
+ ffmpeg \
+ libfontconfig1 \
+ libxrender1 \
+ libgl1-mesa-glx \
+ libgeos-dev \
+ gcc \
+ && rm -rf /var/lib/apt/lists/*
+
+RUN groupadd --gid 1000 python && \
+ useradd --uid 1000 --gid python --shell /bin/bash --create-home python
+
+USER 1000
+WORKDIR /home/python/
+
+RUN curl -sSf https://rye.astral.sh/get | bash -
+
+COPY --chown=python:python . /home/python/labelbox-python/
+WORKDIR /home/python/labelbox-python
+
+RUN rye config --set-bool behavior.global-python=true && \
+ rye config --set-bool behavior.use-uv=true && \
+ rye pin 3.9 && \
+ rye sync
+
+CMD rye run unit && rye integration
\ No newline at end of file
diff --git a/libs/lbox-clients/README.md b/libs/lbox-clients/README.md
new file mode 100644
index 000000000..a7bf3e94f
--- /dev/null
+++ b/libs/lbox-clients/README.md
@@ -0,0 +1,9 @@
+# lbox-example
+
+This is an example module which can be cloned and reused to develop modules under the `lbox` namespace.
+
+## Module Status: Experimental
+
+**TLDR: This module may be removed or altered at any given time and there is no offical support.**
+
+Please see [here](https://docs.labelbox.com/docs/product-release-phases) for the formal definition of `Experimental`.
\ No newline at end of file
diff --git a/libs/lbox-clients/pyproject.toml b/libs/lbox-clients/pyproject.toml
new file mode 100644
index 000000000..744cf7ee5
--- /dev/null
+++ b/libs/lbox-clients/pyproject.toml
@@ -0,0 +1,61 @@
+[project]
+name = "lbox-clients"
+version = "1.1.0"
+description = "This module contains client sdk uses to conntect to the Labelbox API and backends"
+authors = [
+ { name = "Labelbox", email = "engineering@labelbox.com" }
+]
+dependencies = [
+ "requests>=2.22.0",
+ "google-api-core>=1.22.1",
+]
+readme = "README.md"
+requires-python = ">= 3.9"
+
+classifiers=[
+ # How mature is this project?
+ "Development Status :: 5 - Production/Stable",
+ # Indicate who your project is intended for
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
+ "Topic :: Software Development :: Libraries",
+ "Intended Audience :: Developers",
+ "Intended Audience :: Science/Research",
+ "Intended Audience :: Education",
+ # Pick your license as you wish
+ "License :: OSI Approved :: Apache Software License",
+ # Specify the Python versions you support here.
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3.12",
+]
+keywords = ["ml", "ai", "labelbox", "labeling", "llm", "machinelearning", "edu"]
+
+[project.urls]
+Homepage = "https://labelbox.com/"
+Documentation = "https://labelbox-python.readthedocs.io/en/latest/"
+Repository = "https://github.com/Labelbox/labelbox-python"
+Issues = "https://github.com/Labelbox/labelbox-python/issues"
+Changelog = "https://github.com/Labelbox/labelbox-python/blob/develop/libs/labelbox/CHANGELOG.md"
+
+[build-system]
+requires = ["hatchling"]
+build-backend = "hatchling.build"
+
+[tool.rye]
+managed = true
+dev-dependencies = []
+
+[tool.rye.scripts]
+unit = "pytest tests/unit"
+integration = "python -c \"import sys; sys.exit(0)\""
+
+[tool.hatch.metadata]
+allow-direct-references = true
+
+[tool.hatch.build.targets.wheel]
+packages = ["src/lbox"]
+
+[tool.pytest.ini_options]
+addopts = "-rP -vvv --durations=20 --cov=lbox.example --import-mode=importlib"
\ No newline at end of file
diff --git a/libs/labelbox/src/labelbox/exceptions.py b/libs/lbox-clients/src/lbox/exceptions.py
similarity index 96%
rename from libs/labelbox/src/labelbox/exceptions.py
rename to libs/lbox-clients/src/lbox/exceptions.py
index 34cfeaf4d..493e9cb09 100644
--- a/libs/labelbox/src/labelbox/exceptions.py
+++ b/libs/lbox-clients/src/lbox/exceptions.py
@@ -16,7 +16,10 @@ def __init__(self, message, cause=None):
self.cause = cause
def __str__(self):
- return self.message + str(self.args)
+ exception_message = self.message
+ if self.cause is not None:
+ exception_message += " (caused by: %s)" % self.cause
+ return exception_message
class AuthenticationError(LabelboxError):
diff --git a/libs/lbox-clients/src/lbox/request_client.py b/libs/lbox-clients/src/lbox/request_client.py
new file mode 100644
index 000000000..c855413bd
--- /dev/null
+++ b/libs/lbox-clients/src/lbox/request_client.py
@@ -0,0 +1,426 @@
+import inspect
+import json
+import logging
+import os
+import re
+import sys
+from datetime import datetime, timezone
+from types import MappingProxyType
+from typing import Callable, Dict, Optional, TypedDict
+
+import requests
+import requests.exceptions
+from google.api_core import retry
+from lbox import exceptions # type: ignore
+
+logger = logging.getLogger(__name__)
+
+_LABELBOX_API_KEY = "LABELBOX_API_KEY"
+
+
+def python_version_info():
+ version_info = sys.version_info
+
+ return f"{version_info.major}.{version_info.minor}.{version_info.micro}-{version_info.releaselevel}"
+
+
+LABELBOX_CALL_PATTERN = re.compile(r"/labelbox/")
+TEST_FILE_PATTERN = re.compile(r".*test.*\.py$")
+
+
+class _RequestInfo(TypedDict):
+ prefix: str
+ class_name: str
+ method_name: str
+
+
+def call_info():
+ method_name = "Unknown"
+ prefix = ""
+ class_name = ""
+ skip_methods = ["wrapper", "__init__"]
+ skip_classes = ["PaginatedCollection", "_CursorPagination", "_OffsetPagination"]
+
+ try:
+ call_info = None
+ for stack in reversed(inspect.stack()):
+ if LABELBOX_CALL_PATTERN.search(stack.filename):
+ call_info = stack
+ method_name = call_info.function
+ class_name = call_info.frame.f_locals.get(
+ "self", None
+ ).__class__.__name__
+
+ if method_name not in skip_methods and class_name not in skip_classes:
+ if TEST_FILE_PATTERN.search(call_info.filename):
+ prefix = "test:"
+ else:
+ if class_name == "NoneType":
+ class_name = ""
+ break
+
+ except Exception:
+ pass
+ return _RequestInfo(prefix=prefix, class_name=class_name, method_name=method_name)
+
+
+def call_info_as_str():
+ info = call_info()
+ return f"{info['prefix']}{info['class_name']}:{info['method_name']}"
+
+
+class RequestClient:
+ """A Labelbox request client.
+
+ Contains info necessary for connecting to a Labelbox server (URL,
+ authentication key).
+ """
+
+ def __init__(
+ self,
+ sdk_version,
+ api_key=None,
+ endpoint="https://api.labelbox.com/graphql",
+ enable_experimental=False,
+ app_url="https://app.labelbox.com",
+ rest_endpoint="https://api.labelbox.com/api/v1",
+ ):
+ """Creates and initializes a RequestClient.
+ This class executes graphql and rest requests to the Labelbox server.
+
+ Args:
+ api_key (str): API key. If None, the key is obtained from the "LABELBOX_API_KEY" environment variable.
+ endpoint (str): URL of the Labelbox server to connect to.
+ enable_experimental (bool): Indicates whether or not to use experimental features
+ app_url (str) : host url for all links to the web app
+ Raises:
+ exceptions.AuthenticationError: If no `api_key`
+ is provided as an argument or via the environment
+ variable.
+ """
+ if api_key is None:
+ if _LABELBOX_API_KEY not in os.environ:
+ raise exceptions.AuthenticationError("Labelbox API key not provided")
+ api_key = os.environ[_LABELBOX_API_KEY]
+ self.api_key = api_key
+
+ self.enable_experimental = enable_experimental
+ if enable_experimental:
+ logger.info("Experimental features have been enabled")
+
+ logger.info("Initializing Labelbox client at '%s'", endpoint)
+ self.app_url = app_url
+ self.endpoint = endpoint
+ self.rest_endpoint = rest_endpoint
+ self.sdk_version = sdk_version
+ self._connection: requests.Session = self._init_connection()
+
+ def _init_connection(self) -> requests.Session:
+ connection = requests.Session() # using default connection pool size of 10
+ connection.headers.update(self._default_headers())
+
+ return connection
+
+ @property
+ def headers(self) -> MappingProxyType:
+ return self._connection.headers
+
+ def _default_headers(self):
+ return {
+ "Authorization": "Bearer %s" % self.api_key,
+ "Accept": "application/json",
+ "Content-Type": "application/json",
+ "X-User-Agent": f"python-sdk {self.sdk_version}",
+ "X-Python-Version": f"{python_version_info()}",
+ }
+
+ @retry.Retry(
+ predicate=retry.if_exception_type(
+ exceptions.InternalServerError,
+ exceptions.TimeoutError,
+ )
+ )
+ def execute(
+ self,
+ query=None,
+ params=None,
+ data=None,
+ files=None,
+ timeout=60.0,
+ experimental=False,
+ error_log_key="message",
+ raise_return_resource_not_found=False,
+ error_handlers: Optional[
+ Dict[str, Callable[[requests.models.Response], None]]
+ ] = None,
+ ):
+ """Sends a request to the server for the execution of the
+ given query.
+
+ Checks the response for errors and wraps errors
+ in appropriate `exceptions.LabelboxError` subtypes.
+
+ Args:
+ query (str): The query to execute.
+ params (dict): Query parameters referenced within the query.
+ data (str): json string containing the query to execute
+ files (dict): file arguments for request
+ timeout (float): Max allowed time for query execution,
+ in seconds.
+ raise_return_resource_not_found: By default the client relies on the caller to raise the correct exception when a resource is not found.
+ If this is set to True, the client will raise a ResourceNotFoundError exception automatically.
+ This simplifies processing.
+ We recommend to use it only of api returns a clear and well-formed error when a resource not found for a given query.
+ error_handlers (dict): A dictionary mapping graphql error code to handler functions.
+ Allows a caller to handle specific errors reporting in a custom way or produce more user-friendly readable messages.
+
+ Example - custom error handler:
+ >>> def _raise_readable_errors(self, response):
+ >>> errors = response.json().get('errors', [])
+ >>> if errors:
+ >>> message = errors[0].get(
+ >>> 'message', json.dumps([{
+ >>> "errorMessage": "Unknown error"
+ >>> }]))
+ >>> errors = json.loads(message)
+ >>> error_messages = [error['errorMessage'] for error in errors]
+ >>> else:
+ >>> error_messages = ["Uknown error"]
+ >>> raise LabelboxError(". ".join(error_messages))
+
+ Returns:
+ dict, parsed JSON response.
+ Raises:
+ exceptions.AuthenticationError: If authentication
+ failed.
+ exceptions.InvalidQueryError: If `query` is not
+ syntactically or semantically valid (checked server-side).
+ exceptions.ApiLimitError: If the server API limit was
+ exceeded. See "How to import data" in the online documentation
+ to see API limits.
+ exceptions.TimeoutError: If response was not received
+ in `timeout` seconds.
+ exceptions.NetworkError: If an unknown error occurred
+ most likely due to connection issues.
+ exceptions.LabelboxError: If an unknown error of any
+ kind occurred.
+ ValueError: If query and data are both None.
+ """
+ logger.debug("Query: %s, params: %r, data %r", query, params, data)
+
+ # Convert datetimes to UTC strings.
+ def convert_value(value):
+ if isinstance(value, datetime):
+ value = value.astimezone(timezone.utc)
+ value = value.strftime("%Y-%m-%dT%H:%M:%SZ")
+ return value
+
+ if query is not None:
+ if params is not None:
+ params = {key: convert_value(value) for key, value in params.items()}
+ data = json.dumps({"query": query, "variables": params}).encode("utf-8")
+ elif data is None:
+ raise ValueError("query and data cannot both be none")
+
+ endpoint = (
+ self.endpoint
+ if not experimental
+ else self.endpoint.replace("/graphql", "/_gql")
+ )
+
+ try:
+ headers = self._connection.headers.copy()
+ if files:
+ del headers["Content-Type"]
+ del headers["Accept"]
+ headers["X-SDK-Method"] = call_info_as_str()
+
+ request = requests.Request(
+ "POST",
+ endpoint,
+ headers=headers,
+ data=data,
+ files=files if files else None,
+ )
+
+ prepped: requests.PreparedRequest = request.prepare()
+
+ settings = self._connection.merge_environment_settings(
+ prepped.url, {}, None, None, None
+ )
+ response = self._connection.send(prepped, timeout=timeout, **settings)
+ logger.debug("Response: %s", response.text)
+ except requests.exceptions.Timeout as e:
+ raise exceptions.TimeoutError(str(e))
+ except requests.exceptions.RequestException as e:
+ logger.error("Unknown error: %s", str(e))
+ raise exceptions.NetworkError(e)
+ except Exception as e:
+ raise exceptions.LabelboxError(
+ "Unknown error during Client.query(): " + str(e), e
+ )
+
+ if (
+ 200 <= response.status_code < 300
+ or response.status_code < 500
+ or response.status_code >= 600
+ ):
+ try:
+ r_json = response.json()
+ except Exception:
+ raise exceptions.LabelboxError(
+ "Failed to parse response as JSON: %s" % response.text
+ )
+ else:
+ if (
+ "upstream connect error or disconnect/reset before headers"
+ in response.text
+ ):
+ raise exceptions.InternalServerError("Connection reset")
+ elif response.status_code == 502:
+ error_502 = "502 Bad Gateway"
+ raise exceptions.InternalServerError(error_502)
+ elif 500 <= response.status_code < 600:
+ error_500 = f"Internal server http error {response.status_code}"
+ raise exceptions.InternalServerError(error_500)
+
+ errors = r_json.get("errors", [])
+
+ def check_errors(keywords, *path):
+ """Helper that looks for any of the given `keywords` in any of
+ current errors on paths (like error[path][component][to][keyword]).
+ """
+ for error in errors:
+ obj = error
+ for path_elem in path:
+ obj = obj.get(path_elem, {})
+ if obj in keywords:
+ return error
+ return None
+
+ def get_error_status_code(error: dict) -> int:
+ try:
+ return int(error["extensions"].get("exception").get("status"))
+ except Exception:
+ return 500
+
+ if check_errors(["AUTHENTICATION_ERROR"], "extensions", "code") is not None:
+ raise exceptions.AuthenticationError("Invalid API key")
+
+ authorization_error = check_errors(
+ ["AUTHORIZATION_ERROR"], "extensions", "code"
+ )
+ if authorization_error is not None:
+ raise exceptions.AuthorizationError(authorization_error["message"])
+
+ validation_error = check_errors(
+ ["GRAPHQL_VALIDATION_FAILED"], "extensions", "code"
+ )
+
+ if validation_error is not None:
+ message = validation_error["message"]
+ if message == "Query complexity limit exceeded":
+ raise exceptions.ValidationFailedError(message)
+ else:
+ raise exceptions.InvalidQueryError(message)
+
+ graphql_error = check_errors(["GRAPHQL_PARSE_FAILED"], "extensions", "code")
+ if graphql_error is not None:
+ raise exceptions.InvalidQueryError(graphql_error["message"])
+
+ # Check if API limit was exceeded
+ response_msg = r_json.get("message", "")
+
+ if response_msg.startswith("You have exceeded"):
+ raise exceptions.ApiLimitError(response_msg)
+
+ resource_not_found_error = check_errors(
+ ["RESOURCE_NOT_FOUND"], "extensions", "code"
+ )
+ if resource_not_found_error is not None:
+ if raise_return_resource_not_found:
+ raise exceptions.ResourceNotFoundError(
+ message=resource_not_found_error["message"]
+ )
+ else:
+ # Return None and let the caller methods raise an exception
+ # as they already know which resource type and ID was requested
+ return None
+
+ resource_conflict_error = check_errors(
+ ["RESOURCE_CONFLICT"], "extensions", "code"
+ )
+ if resource_conflict_error is not None:
+ raise exceptions.ResourceConflict(resource_conflict_error["message"])
+
+ malformed_request_error = check_errors(
+ ["MALFORMED_REQUEST"], "extensions", "code"
+ )
+
+ error_code = "MALFORMED_REQUEST"
+ if malformed_request_error is not None:
+ if error_handlers and error_code in error_handlers:
+ handler = error_handlers[error_code]
+ handler(response)
+ return None
+ raise exceptions.MalformedQueryException(
+ malformed_request_error[error_log_key]
+ )
+
+ # A lot of different error situations are now labeled serverside
+ # as INTERNAL_SERVER_ERROR, when they are actually client errors.
+ # TODO: fix this in the server API
+ internal_server_error = check_errors(
+ ["INTERNAL_SERVER_ERROR"], "extensions", "code"
+ )
+ error_code = "INTERNAL_SERVER_ERROR"
+
+ if internal_server_error is not None:
+ if error_handlers and error_code in error_handlers:
+ handler = error_handlers[error_code]
+ handler(response)
+ return None
+ message = internal_server_error.get("message")
+ error_status_code = get_error_status_code(internal_server_error)
+ if error_status_code == 400:
+ raise exceptions.InvalidQueryError(message)
+ elif error_status_code == 422:
+ raise exceptions.UnprocessableEntityError(message)
+ elif error_status_code == 426:
+ raise exceptions.OperationNotAllowedException(message)
+ elif error_status_code == 500:
+ raise exceptions.LabelboxError(message)
+ else:
+ raise exceptions.InternalServerError(message)
+
+ not_allowed_error = check_errors(
+ ["OPERATION_NOT_ALLOWED"], "extensions", "code"
+ )
+ if not_allowed_error is not None:
+ message = not_allowed_error.get("message")
+ raise exceptions.OperationNotAllowedException(message)
+
+ if len(errors) > 0:
+ logger.warning("Unparsed errors on query execution: %r", errors)
+ messages = list(
+ map(
+ lambda x: {
+ "message": x["message"],
+ "code": x["extensions"]["code"],
+ },
+ errors,
+ )
+ )
+ raise exceptions.LabelboxError("Unknown error: %s" % str(messages))
+
+ # if we do return a proper error code, and didn't catch this above
+ # reraise
+ # this mainly catches a 401 for API access disabled for free tier
+ # TODO: need to unify API errors to handle things more uniformly
+ # in the SDK
+ if response.status_code != requests.codes.ok:
+ message = f"{response.status_code} {response.reason}"
+ cause = r_json.get("message")
+ raise exceptions.LabelboxError(message, cause)
+
+ return r_json["data"]
diff --git a/libs/lbox-clients/tests/unit/lbox/test_client.py b/libs/lbox-clients/tests/unit/lbox/test_client.py
new file mode 100644
index 000000000..42b141d33
--- /dev/null
+++ b/libs/lbox-clients/tests/unit/lbox/test_client.py
@@ -0,0 +1,46 @@
+from unittest.mock import MagicMock
+
+from lbox.request_client import RequestClient
+
+
+# @patch.dict(os.environ, {'LABELBOX_API_KEY': 'bar'})
+def test_headers():
+ client = RequestClient(
+ sdk_version="foo", api_key="api_key", endpoint="http://localhost:8080/_gql"
+ )
+ assert client.headers
+ assert client.headers["Authorization"] == "Bearer api_key"
+ assert client.headers["Content-Type"] == "application/json"
+ assert client.headers["User-Agent"]
+ assert client.headers["X-Python-Version"]
+
+
+def test_custom_error_handling():
+ mock_raise_error = MagicMock()
+
+ response_dict = {
+ "errors": [
+ {
+ "message": "Internal server error",
+ "extensions": {"code": "INTERNAL_SERVER_ERROR"},
+ }
+ ],
+ }
+ response = MagicMock()
+ response.json.return_value = response_dict
+ response.status_code = 200
+
+ client = RequestClient(
+ sdk_version="foo", api_key="api_key", endpoint="http://localhost:8080/_gql"
+ )
+ connection_mock = MagicMock()
+ connection_mock.send.return_value = response
+ client._connection = connection_mock
+
+ client.execute(
+ "query_str",
+ {"projectId": "project_id"},
+ raise_return_resource_not_found=True,
+ error_handlers={"INTERNAL_SERVER_ERROR": mock_raise_error},
+ )
+ mock_raise_error.assert_called_once_with(response)
diff --git a/libs/lbox-example/Dockerfile b/libs/lbox-example/Dockerfile
index 2ee61ab7e..dd3b5147d 100644
--- a/libs/lbox-example/Dockerfile
+++ b/libs/lbox-example/Dockerfile
@@ -1,5 +1,5 @@
# https://github.com/ucyo/python-package-template/blob/master/Dockerfile
-FROM python:3.8-slim as rye
+FROM python:3.9-slim as rye
ENV LANG="C.UTF-8" \
LC_ALL="C.UTF-8" \
@@ -38,7 +38,7 @@ WORKDIR /home/python/labelbox-python
RUN rye config --set-bool behavior.global-python=true && \
rye config --set-bool behavior.use-uv=true && \
- rye pin 3.8 && \
+ rye pin 3.9 && \
rye sync
CMD rye run unit && rye integration
\ No newline at end of file
diff --git a/libs/lbox-example/pyproject.toml b/libs/lbox-example/pyproject.toml
index a14f5a34f..a3fe0a13e 100644
--- a/libs/lbox-example/pyproject.toml
+++ b/libs/lbox-example/pyproject.toml
@@ -9,7 +9,7 @@ dependencies = [
"art>=6.2",
]
readme = "README.md"
-requires-python = ">= 3.8"
+requires-python = ">= 3.9"
classifiers=[
# How mature is this project?
@@ -24,7 +24,6 @@ classifiers=[
"License :: OSI Approved :: Apache Software License",
# Specify the Python versions you support here.
"Programming Language :: Python :: 3",
- "Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
diff --git a/pyproject.toml b/pyproject.toml
index ebce059f5..7ab0a0b79 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -9,7 +9,7 @@ dependencies = [
"sphinx-rtd-theme>=2.0.0",
]
readme = "README.md"
-requires-python = ">= 3.8"
+requires-python = ">= 3.9"
[tool.rye]
managed = true
@@ -28,6 +28,7 @@ dev-dependencies = [
"pytest-timestamper>=0.0.10",
"pytest-timeout>=2.3.1",
"pytest-order>=1.2.1",
+ "pyjwt>=2.9.0",
]
[tool.rye.workspace]
@@ -35,7 +36,7 @@ members = ["libs/*", "examples"]
[tool.pytest.ini_options]
# https://github.com/pytest-dev/pytest-rerunfailures/issues/99
-addopts = "-rP -vvv --reruns 1 --reruns-delay 5 --durations=20 -n auto --cov=labelbox --import-mode=importlib --order-group-scope=module"
+addopts = "-rP -vvv"
markers = """
slow: marks tests as slow (deselect with '-m "not slow"')
"""
diff --git a/requirements-dev.lock b/requirements-dev.lock
index 05ca1683a..7616ff075 100644
--- a/requirements-dev.lock
+++ b/requirements-dev.lock
@@ -6,10 +6,10 @@
# features: []
# all-features: true
# with-sources: false
-# generate-hashes: false
-# universal: false
-e file:libs/labelbox
+-e file:libs/lbox-clients
+ # via labelbox
-e file:libs/lbox-example
alabaster==0.7.13
# via sphinx
@@ -74,6 +74,7 @@ gitpython==3.1.43
# via databooks
google-api-core==2.19.1
# via labelbox
+ # via lbox-clients
google-auth==2.31.0
# via google-api-core
googleapis-common-protos==1.63.2
@@ -89,9 +90,6 @@ importlib-metadata==8.0.0
# via sphinx
# via typeguard
# via yapf
-importlib-resources==6.4.0
- # via jsonschema
- # via jsonschema-specifications
iniconfig==2.0.0
# via pytest
ipython==8.12.3
@@ -125,6 +123,7 @@ matplotlib-inline==0.1.7
mistune==3.0.2
# via nbconvert
mypy==1.10.1
+ # via labelbox
mypy-extensions==1.0.0
# via black
# via mypy
@@ -161,8 +160,6 @@ pickleshare==0.7.5
# via ipython
pillow==10.4.0
# via labelbox
-pkgutil-resolve-name==1.3.10
- # via jsonschema
platformdirs==4.2.2
# via black
# via jupyter-core
@@ -198,6 +195,7 @@ pygments==2.18.0
# via nbconvert
# via rich
# via sphinx
+pyjwt==2.9.0
pyproj==3.5.0
# via labelbox
pytest==8.2.2
@@ -221,7 +219,6 @@ python-dateutil==2.9.0.post0
# via labelbox
# via pandas
pytz==2024.1
- # via babel
# via pandas
pyzmq==26.0.3
# via jupyter-client
@@ -233,6 +230,7 @@ regex==2024.5.15
requests==2.32.3
# via google-api-core
# via labelbox
+ # via lbox-clients
# via sphinx
rich==12.6.0
# via databooks
@@ -315,7 +313,6 @@ types-python-dateutil==2.9.0.20240316
types-requests==2.32.0.20240622
types-tqdm==4.66.0.20240417
typing-extensions==4.12.2
- # via annotated-types
# via black
# via databooks
# via ipython
@@ -323,7 +320,6 @@ typing-extensions==4.12.2
# via mypy
# via pydantic
# via pydantic-core
- # via rich
# via typeguard
# via typer
tzdata==2024.1
@@ -339,4 +335,3 @@ webencodings==0.5.1
yapf==0.40.2
zipp==3.19.2
# via importlib-metadata
- # via importlib-resources
diff --git a/requirements.lock b/requirements.lock
index 07e026d59..36871f22b 100644
--- a/requirements.lock
+++ b/requirements.lock
@@ -6,10 +6,10 @@
# features: []
# all-features: true
# with-sources: false
-# generate-hashes: false
-# universal: false
-e file:libs/labelbox
+-e file:libs/lbox-clients
+ # via labelbox
-e file:libs/lbox-example
alabaster==0.7.13
# via sphinx
@@ -33,6 +33,7 @@ geojson==3.1.0
# via labelbox
google-api-core==2.19.1
# via labelbox
+ # via lbox-clients
google-auth==2.31.0
# via google-api-core
googleapis-common-protos==1.63.2
@@ -49,6 +50,10 @@ jinja2==3.1.4
# via sphinx
markupsafe==2.1.5
# via jinja2
+mypy==1.10.1
+ # via labelbox
+mypy-extensions==1.0.0
+ # via mypy
numpy==1.24.4
# via labelbox
# via opencv-python-headless
@@ -82,11 +87,10 @@ pyproj==3.5.0
# via labelbox
python-dateutil==2.9.0.post0
# via labelbox
-pytz==2024.1
- # via babel
requests==2.32.3
# via google-api-core
# via labelbox
+ # via lbox-clients
# via sphinx
rsa==4.9
# via google-auth
@@ -117,13 +121,15 @@ sphinxcontrib-serializinghtml==1.1.5
# via sphinx
strenum==0.4.15
# via labelbox
+tomli==2.0.1
+ # via mypy
tqdm==4.66.4
# via labelbox
typeguard==4.3.0
# via labelbox
typing-extensions==4.12.2
- # via annotated-types
# via labelbox
+ # via mypy
# via pydantic
# via pydantic-core
# via typeguard